# HackMatrix PureJaxRL Training on TPU (Colab)

This notebook trains the HackMatrix game using PureJaxRL on Google Colab's free TPUs.

**Before running:**
1. Runtime → Change runtime type → TPU
2. Run cells in order

## 1. Setup Environment

In [None]:
# Check TPU is available
import jax
print("JAX version:", jax.__version__)
print("Devices:", jax.devices())
print("Device count:", len(jax.devices()))
print("Backend:", jax.devices()[0].platform)

if jax.devices()[0].platform != 'tpu':
    print("\n⚠️ WARNING: TPU not detected!")
    print("Go to: Runtime → Change runtime type → Hardware accelerator → TPU")
else:
    print("\n✅ TPU detected! Ready to train.")

## 2. Clone Repository and Install Dependencies

In [None]:
# Clone the repository
!git clone https://github.com/charleseff/hack-matrix.git
%cd hack-matrix/python

In [None]:
# Install dependencies
!pip install -q -r requirements.txt
print("✅ Dependencies installed")

## 3. Quick Test (1K timesteps)

In [None]:
# Quick test to verify everything works
!python scripts/train_purejaxrl.py \
  --num-envs 256 \
  --num-steps 128 \
  --total-timesteps 1000 \
  --seed 42

## 4. Medium Training (100K timesteps)

This should take 1-2 minutes on TPU.

In [None]:
!python scripts/train_purejaxrl.py \
  --num-envs 1024 \
  --num-steps 256 \
  --total-timesteps 100000 \
  --save-interval 10 \
  --log-interval 5 \
  --checkpoint-dir checkpoints/colab_medium \
  --seed 123

## 5. Full Training (10M timesteps)

This should take 5-10 minutes on TPU. Adjust parameters as needed.

In [None]:
!python scripts/train_purejaxrl.py \
  --num-envs 2048 \
  --num-steps 512 \
  --total-timesteps 10000000 \
  --learning-rate 0.0003 \
  --num-minibatches 8 \
  --update-epochs 4 \
  --hidden-dim 512 \
  --num-layers 3 \
  --save-interval 100 \
  --log-interval 10 \
  --checkpoint-dir checkpoints/colab_full \
  --seed 42

## 6. Download Checkpoints

Download the trained model to your local machine.

In [None]:
# List available checkpoints
!ls -lh checkpoints/

# Download final checkpoint
from google.colab import files
files.download('checkpoints/final_params.npz')

## 7. Evaluate Trained Agent (Optional)

Test the trained agent's performance.

In [None]:
# TODO: Add evaluation code here
# This would load the checkpoint and run test episodes
print("Evaluation script coming soon!")

## Tips

### TPU Performance
- Colab TPUs are v2-8 (8 cores) or v3-8 depending on availability
- Expected throughput: 50K-100K steps/second
- 10M timesteps should take 5-10 minutes

### Hyperparameter Tuning
- `--num-envs`: More envs = better GPU/TPU utilization (try 1024-4096)
- `--num-steps`: Longer rollouts = more stable gradients (try 256-1024)
- `--learning-rate`: Start with 0.0003, reduce if training unstable
- `--hidden-dim`: Larger network = more capacity (try 256-1024)

### Memory Issues
If you run out of memory:
- Reduce `--num-envs` (try 1024 instead of 2048)
- Reduce `--hidden-dim` (try 256 instead of 512)
- Reduce `--num-steps` (try 256 instead of 512)

### Session Limits
- Colab has a 12-hour session limit (free tier)
- Download checkpoints regularly to avoid losing progress
- For longer training, consider TRC program or Colab Pro