# IaC-GPT TPU Training (Colab/Kaggle Prototype)

Test nanochat training on TPUs before production deployment.

**Setup:**
1. Runtime → Change runtime type → TPU
2. Run all cells

**TPU Support:**
- Colab: TPU v2-8 (8 cores, 64GB HBM) or v3-8 (8 cores, 128GB HBM)
- Kaggle: TPU v5e-1 (1 core, 16GB HBM) or v5e-8 (8 cores, 128GB HBM)
- Native bfloat16 support
- ~5-10x faster than T4 GPUs for transformer training

In [None]:
# Install dependencies using uv (better dependency resolution than pip)
# Step 1: Install uv
!curl -LsSf https://astral.sh/uv/install.sh | sh
!source $HOME/.cargo/env

# Step 2: Use uv to install all dependencies (handles conflicts automatically)
!~/.cargo/bin/uv pip install --system \
    torch==2.9.0 \
    torch-xla==2.9.0 \
    cloud-tpu-client \
    "google-api-core>=2.27.0" \
    "google-cloud-storage>=3.9.0" \
    "protobuf>=4.25.2,<6.0" \
    tiktoken pyarrow filelock rustbpe wandb tabulate regex zstandard pyyaml

print("✅ Installation complete via uv")

In [None]:
# Clone nanochat repo
!git clone https://github.com/holynakamoto/iacgpt.git nanochat 2>/dev/null || \
    (cd nanochat && git pull origin master)
%cd nanochat

In [None]:
# Verify TPU detection (updated for torch-xla 2.9.0 API)
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr

print("=" * 70)
print("TPU DETECTION TEST")
print("=" * 70)

# Use new torch_xla 2.9.0 API
device = torch_xla.device()
print(f"TPU device: {device}")
print(f"Number of TPU cores: {xr.world_size()}")
print(f"Local ordinal: {xr.local_ordinal()}")
print(f"Global ordinal: {xr.global_ordinal()}")

# Test tensor operation
x = torch.randn(3, 3, device=device)
y = x @ x.t()
print(f"\nTest matmul: {y.shape}")
print(f"Device type: {y.device}")
print("=" * 70)

## Add TPU Support to common.py

We need to patch the device detection to recognize TPUs.

In [None]:
# Patch common.py to add TPU support (torch-xla 2.9.0 API)
import os

tpu_patch = '''
def autodetect_device_type():
    # Check for TPU first (Colab, Kaggle)
    try:
        import torch_xla
        import torch_xla.runtime as xr
        device = torch_xla.device()
        device_type = "xla"
        print0(f"Autodetected device type: {device_type} (TPU with {xr.world_size()} cores)")
        return device_type
    except ImportError:
        pass
    except Exception as e:
        print0(f"TPU detection failed: {e}")
    
    # Fallback to CUDA/MPS/CPU
    if torch.cuda.is_available():
        device_type = "cuda"
    elif torch.backends.mps.is_available():
        device_type = "mps"
    else:
        device_type = "cpu"
    print0(f"Autodetected device type: {device_type}")
    return device_type
'''

# Read current common.py
with open('common.py', 'r') as f:
    content = f.read()

# Replace autodetect_device_type function
import re
pattern = r'def autodetect_device_type\(\):.*?return device_type'
content = re.sub(pattern, tpu_patch.strip(), content, flags=re.DOTALL)

with open('common.py', 'w') as f:
    f.write(content)

print("✅ Patched common.py with TPU support (torch-xla 2.9.0 API)")

## Prepare IaC Training Data

Same data pipeline as GPU training.

In [None]:
import os, glob, subprocess

CACHE_DIR = os.path.expanduser("~/.cache/nanochat")
DATA_DIR = os.path.join(CACHE_DIR, "iac_data")
BASE_DATA = os.path.join(CACHE_DIR, "base_data")

# Quick data prep (minimal dataset for testing)
print("Preparing minimal IaC dataset for TPU testing...")
subprocess.run(["bash", "dev/fast_scrape_iac.sh"], input=b"n", check=True)

# Convert to training shards
subprocess.run([
    "python3", "dev/repackage_iac_data.py",
    "--input-dir", "data/iac_raw_cloned",
    "--output-dir", DATA_DIR,
    "--include-synthetic", "--include-docs"
], check=True)

# Link base_data
if os.path.islink(BASE_DATA):
    os.unlink(BASE_DATA)
os.symlink(DATA_DIR, BASE_DATA)

print(f"✅ Data ready: {len(glob.glob(f'{BASE_DATA}/*.parquet'))} shards")

## Train Tokenizer

In [None]:
# Train BPE tokenizer
!python3 -m scripts.tok_train

## Train on TPU (XLA)

Use torch_xla's distributed launcher instead of torchrun.

In [None]:
# TPU training command (updated for single-core v5e-1)
# Note: This requires modifications to base_train.py to use XLA

MODEL_DEPTH = 12
BATCH_SIZE = 8  # v5e-1 has 16GB HBM, can handle larger batches
NUM_CORES = 1   # v5e-1 is single-core

# For single-core TPU, use regular python (not xla_dist)
cmd = f"""python3 scripts/base_train.py \
    --depth={MODEL_DEPTH} \
    --device-batch-size={BATCH_SIZE} \
    --window-pattern=L \
    --target-param-data-ratio=8 \
    --run=dummy \
    --model-tag=iac-gpt-tpu-d{MODEL_DEPTH} \
    --eval-every=100 \
    --sample-every=100 \
    --save-every=100"""

print("=" * 80)
print("TPU v5e-1 TRAINING COMMAND (Single Core):")
print(cmd)
print("=" * 80)
print("\n⚠️  Note: base_train.py needs XLA modifications first!")
print("Next step: Patch base_train.py for XLA compatibility\n")

## Next Steps

To complete TPU support:

1. **Modify base_train.py:**
   - Replace `torch.distributed` with `torch_xla.distributed`
   - Use `xm.optimizer_step(optimizer)` instead of `optimizer.step()`
   - Use `xm.all_reduce()` for gradient synchronization

2. **Modify engine.py:**
   - Add XLA-specific compilation flags
   - Use `xm.mark_step()` after backward pass

3. **Test on Colab TPU v2-8**

4. **Port to Kaggle for TPU v5e-8**