In [21]:
# %% [code]
"""
Minimal test to verify Kaggle environment setup with uv.
"""
import sys
import subprocess
import os

print("=" * 80)
print("Kaggle Environment Test with uv")
print("=" * 80)

# Step 0: Install uv and sync dependencies
print("\n[0/5] Installing uv and syncing dependencies...")
try:
    # Install uv
    print("  Installing uv...")
    result = subprocess.run(
        ["pip", "install", "-q", "uv"],
        capture_output=True,
        text=True,
        timeout=60
    )
    if result.returncode == 0:
        print("  ✓ uv installed")
    else:
        print(f"  ✗ Failed to install uv: {result.stderr}")
        sys.exit(1)

    # Sync dependencies using uv
    print("  Running uv sync...")
    result = subprocess.run(
        ["uv", "sync"],
        capture_output=True,
        text=True,
        timeout=300
    )
    if result.returncode == 0:
        print("  ✓ Dependencies synced successfully")
        print(f"  Output: {result.stdout[:200]}...")
    else:
        print(f"  ✗ uv sync failed: {result.stderr[:500]}")
        # Continue anyway to see what's available

except Exception as e:
    print(f"  ✗ Error during uv setup: {e}")
    # Continue anyway

# Step 1: Check Python version
print(f"\n[1/5] Python version: {sys.version}")

# Step 2: Check GPU availability
print("\n[2/5] Checking GPU...")
try:
    import jax
    print(f"  JAX version: {jax.__version__}")
    devices = jax.devices()
    print(f"  Devices: {devices}")
    print(f"  Device type: {devices[0].platform}")
    print(f"  Number of devices: {len(devices)}")
except Exception as e:
    print(f"  Error: {e}")

# Step 3: Check if we can import key libraries
print("\n[3/5] Checking imports...")
try:
    import optax
    print(f"  ✓ optax {optax.__version__}")
except Exception as e:
    print(f"  ✗ optax: {e}")

try:
    from flax import nnx
    print(f"  ✓ flax.nnx")
except Exception as e:
    print(f"  ✗ flax.nnx: {e}")

try:
    from huggingface_hub import snapshot_download
    print(f"  ✓ huggingface_hub")
except Exception as e:
    print(f"  ✗ huggingface_hub: {e}")

# Step 4: Check if tunix is available
print("\n[4/5] Checking tunix library...")
try:
    from tunix.rl import rl_cluster as rl_cluster_lib
    print(f"  ✓ tunix.rl")
except Exception as e:
    print(f"  ✗ tunix.rl: {e}")
    print(f"  Installing tunix...")
    import subprocess
    result = subprocess.run(
        ["pip", "install", "-q", "git+https://github.com/google-deepmind/tunix.git"],
        capture_output=True,
        text=True
    )
    if result.returncode == 0:
        print(f"  ✓ tunix installed successfully")
    else:
        print(f"  ✗ Failed to install tunix: {result.stderr}")

# Step 5: Test complete
print("\n[5/5] Environment check complete!")
print("=" * 80)
print("✅ Ready to run GRPO training pipeline")
print("=" * 80)


Kaggle Environment Test with uv

[0/5] Installing uv and syncing dependencies...
  Installing uv...
  ✓ uv installed
  Running uv sync...
  ✓ Dependencies synced successfully
  Output: ...

[1/5] Python version: 3.12.12 (main, Dec  9 2025, 02:04:51) [GCC 14.2.0]

[2/5] Checking GPU...
  JAX version: 0.8.1
  Devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,2,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,2,0), core_on_chip=0), TpuDevice(id=6, process_index=0, coords=(0,3,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,3,0), core_on_chip=0)]
  Device type: tpu
  Number of devices: 8

[3/5] Checking imports...
  ✓ optax 0.2.6
  ✓ flax.nnx
  ✓ huggingface_hub

[4/5] Checking tunix library...
  