# IaC-GPT TPU Training (Colab Prototype)

Test nanochat training on Google Colab TPUs before deploying to Kaggle TPU v5e-8.

**Setup:**
1. Runtime → Change runtime type → TPU (v2-8 or v3-8)
2. Run all cells

**TPU Specs:**
- Colab offers TPU v2-8 (8 cores, 64GB HBM) or v3-8 (8 cores, 128GB HBM)
- Native bfloat16 support
- ~10x faster than dual T4 GPUs

In [None]:
# Install dependencies in order to avoid conflicts
# Step 1: Upgrade Google Cloud libraries to versions compatible with protobuf 4.x+
!pip install -q --upgrade "google-api-core>=2.27.0" "google-cloud-storage>=3.9.0"

# Step 2: Install torch + torch-xla for TPU
!pip install -q torch==2.9.0
!pip install -q torch-xla==2.9.0

# Step 3: Install cloud-tpu-client (needs protobuf 4.x+)
!pip install -q cloud-tpu-client

# Step 4: Install remaining dependencies
!pip install -q tiktoken pyarrow filelock rustbpe wandb tabulate regex zstandard pyyaml

In [None]:
# Clone nanochat repo
!git clone https://github.com/holynakamoto/iacgpt.git nanochat 2>/dev/null || \
    (cd nanochat && git pull origin master)
%cd nanochat

In [None]:
# Verify TPU detection
import torch
import torch_xla
import torch_xla.core.xla_model as xm

print("=" * 70)
print("TPU DETECTION TEST")
print("=" * 70)

device = xm.xla_device()
print(f"TPU device: {device}")
print(f"Number of TPU cores: {xm.xrt_world_size()}")

# Test tensor operation
x = torch.randn(3, 3).to(device)
y = x @ x.t()
print(f"\nTest matmul: {y.shape}")
print(f"Device type: {y.device}")
print("=" * 70)

## Add TPU Support to common.py

We need to patch the device detection to recognize TPUs.

In [None]:
# Patch common.py to add TPU support
import os

tpu_patch = '''
def autodetect_device_type():
    # Check for TPU first (Colab, Kaggle)
    try:
        import torch_xla.core.xla_model as xm
        device = xm.xla_device()
        device_type = "xla"
        print0(f"Autodetected device type: {device_type} (TPU with {xm.xrt_world_size()} cores)")
        return device_type
    except ImportError:
        pass
    except Exception as e:
        print0(f"TPU detection failed: {e}")
    
    # Fallback to CUDA/MPS/CPU
    if torch.cuda.is_available():
        device_type = "cuda"
    elif torch.backends.mps.is_available():
        device_type = "mps"
    else:
        device_type = "cpu"
    print0(f"Autodetected device type: {device_type}")
    return device_type
'''

# Read current common.py
with open('common.py', 'r') as f:
    content = f.read()

# Replace autodetect_device_type function
import re
pattern = r'def autodetect_device_type\(\):.*?return device_type'
content = re.sub(pattern, tpu_patch.strip(), content, flags=re.DOTALL)

with open('common.py', 'w') as f:
    f.write(content)

print("✅ Patched common.py with TPU support")

## Prepare IaC Training Data

Same data pipeline as GPU training.

In [None]:
import os, glob, subprocess

CACHE_DIR = os.path.expanduser("~/.cache/nanochat")
DATA_DIR = os.path.join(CACHE_DIR, "iac_data")
BASE_DATA = os.path.join(CACHE_DIR, "base_data")

# Quick data prep (minimal dataset for testing)
print("Preparing minimal IaC dataset for TPU testing...")
subprocess.run(["bash", "dev/fast_scrape_iac.sh"], input=b"n", check=True)

# Convert to training shards
subprocess.run([
    "python3", "dev/repackage_iac_data.py",
    "--input-dir", "data/iac_raw_cloned",
    "--output-dir", DATA_DIR,
    "--include-synthetic", "--include-docs"
], check=True)

# Link base_data
if os.path.islink(BASE_DATA):
    os.unlink(BASE_DATA)
os.symlink(DATA_DIR, BASE_DATA)

print(f"✅ Data ready: {len(glob.glob(f'{BASE_DATA}/*.parquet'))} shards")

## Train Tokenizer

In [None]:
# Train BPE tokenizer
!python3 -m scripts.tok_train

## Train on TPU (XLA)

Use torch_xla's distributed launcher instead of torchrun.

In [None]:
# TPU training command
# Note: This requires modifications to base_train.py to use XLA

MODEL_DEPTH = 12
BATCH_SIZE = 4  # Can be larger on TPU due to 64-128GB HBM
NUM_CORES = 8   # TPU v2-8 or v3-8

cmd = f"""python3 -m torch_xla.distributed.xla_dist \
    --tpu-vm --num-cores={NUM_CORES} \
    scripts/base_train.py -- \
    --depth={MODEL_DEPTH} \
    --device-batch-size={BATCH_SIZE} \
    --window-pattern=L \
    --target-param-data-ratio=8 \
    --run=dummy \
    --model-tag=iac-gpt-tpu-d{MODEL_DEPTH} \
    --eval-every=100 \
    --sample-every=100 \
    --save-every=100"""

print("=" * 80)
print("TPU TRAINING COMMAND:")
print(cmd)
print("=" * 80)
print("\n⚠️  Note: base_train.py needs XLA modifications first!")
print("Next step: Patch base_train.py for XLA compatibility\n")

## Next Steps

To complete TPU support:

1. **Modify base_train.py:**
   - Replace `torch.distributed` with `torch_xla.distributed`
   - Use `xm.optimizer_step(optimizer)` instead of `optimizer.step()`
   - Use `xm.all_reduce()` for gradient synchronization

2. **Modify engine.py:**
   - Add XLA-specific compilation flags
   - Use `xm.mark_step()` after backward pass

3. **Test on Colab TPU v2-8**

4. **Port to Kaggle for TPU v5e-8**