# HackMatrix PureJaxRL Training on TPU (Colab)

This notebook trains the HackMatrix game using PureJaxRL on Google Colab's free TPUs.

**Before running:**
1. Verify TPU v5e-1 is selected (Runtime > Change runtime type) - should be auto-selected
2. Run cells in order
3. The training script will automatically verify TPU is detected

## 1. Clone Repository

In [None]:
!git clone https://github.com/charleseff/hack-matrix.git
%cd hack-matrix/python

## 2. Install Dependencies

In [None]:
# Install dependencies and enable JAX compilation cache
!pip install -q -r requirements.txt

# Enable compilation caching (speeds up subsequent runs)
import os

os.environ["JAX_COMPILATION_CACHE_DIR"] = "/content/jax_cache"
os.makedirs("/content/jax_cache", exist_ok=True)

print("✅ Dependencies installed")
print("✅ JAX compilation cache enabled at /content/jax_cache")

## 3. Quick Test (~100K timesteps)

This will verify TPU is working and show device information. Takes ~30 seconds.

In [None]:
# Quick test - uses smaller batch to verify everything works
!python scripts/train_purejaxrl.py \
  --num-envs 64 \
  --num-steps 64 \
  --total-timesteps 100000 \
  --seed 42

## 4. Medium Training (500K timesteps)

Larger batch for better TPU utilization. Compiles faster on 2nd run due to cache.

In [None]:
# Medium training - moderate batch size (256 * 128 = 32K batch)
!python scripts/train_purejaxrl.py \
  --num-envs 256 \
  --num-steps 128 \
  --total-timesteps 500000 \
  --save-interval 10 \
  --log-interval 5 \
  --checkpoint-dir checkpoints/colab_medium \
  --seed 123

## 5. Full Training (10M timesteps)

This should take 5-10 minutes on TPU. Adjust parameters as needed.

In [None]:
# Full training - large batch for TPU (1024 * 256 = 262K batch)
!python scripts/train_purejaxrl.py \
  --num-envs 1024 \
  --num-steps 256 \
  --total-timesteps 10000000 \
  --lr 0.0003 \
  --num-minibatches 8 \
  --update-epochs 4 \
  --hidden-dim 512 \
  --num-layers 3 \
  --save-interval 100 \
  --log-interval 10 \
  --checkpoint-dir checkpoints/colab_full \
  --seed 42

## 6. Download Checkpoints

Download the trained model to your local machine.

In [None]:
# List available checkpoints
!ls -lh checkpoints/

# Download final checkpoint
from google.colab import files

files.download("checkpoints/final_params.npz")

## Tips

### TPU Performance
- Colab TPUs are v2-8 (8 cores) or v3-8 depending on availability
- Expected throughput: 50K-100K steps/second
- 10M timesteps should take 5-10 minutes

### Hyperparameter Tuning
- `--num-envs`: More envs = better TPU utilization (try 1024-4096)
- `--num-steps`: Longer rollouts = more stable gradients (try 256-1024)
- `--learning-rate`: Start with 0.0003, reduce if training unstable
- `--hidden-dim`: Larger network = more capacity (try 256-1024)

### Memory Issues
If you run out of memory:
- Reduce `--num-envs` (try 1024 instead of 2048)
- Reduce `--hidden-dim` (try 256 instead of 512)
- Reduce `--num-steps` (try 256 instead of 512)

### Session Limits
- Colab has a 12-hour session limit (free tier)
- Download checkpoints regularly to avoid losing progress
- For longer training, consider TRC program or Colab Pro