# IaC-GPT TPU Training (Colab/Kaggle Prototype)

Test nanochat training on TPUs before production deployment.

**Setup:**
1. Runtime ‚Üí Change runtime type ‚Üí TPU
2. Run all cells

**TPU Support:**
- Colab: TPU v2-8 (8 cores, 64GB HBM) or v3-8 (8 cores, 128GB HBM)
- Kaggle: TPU v5e-1 (1 core, 16GB HBM) or v5e-8 (8 cores, 128GB HBM)
- Native bfloat16 support
- ~5-10x faster than T4 GPUs for transformer training

In [None]:
# Install dependencies using uv (better dependency resolution than pip)
# Step 1: Install uv
!curl -LsSf https://astral.sh/uv/install.sh | sh
!source $HOME/.cargo/env

# Step 2: Use uv to install all dependencies (handles conflicts automatically)
!~/.cargo/bin/uv pip install --system \
    torch==2.9.0 \
    torch-xla==2.9.0 \
    cloud-tpu-client \
    "google-api-core>=2.27.0" \
    "google-cloud-storage>=3.9.0" \
    "protobuf>=4.25.2,<6.0" \
    tiktoken pyarrow filelock rustbpe wandb tabulate regex zstandard pyyaml

print("‚úÖ Installation complete via uv")

In [None]:
# Clone nanochat repo
!git clone https://github.com/holynakamoto/iacgpt.git nanochat 2>/dev/null || \
    (cd nanochat && git pull origin master)
%cd nanochat

In [None]:
# Verify TPU detection (updated for torch-xla 2.9.0 API)
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr

print("=" * 70)
print("TPU DETECTION TEST")
print("=" * 70)

# Use new torch_xla 2.9.0 API
device = torch_xla.device()
print(f"TPU device: {device}")
print(f"Number of TPU cores: {xr.world_size()}")
print(f"Local ordinal: {xr.local_ordinal()}")
print(f"Global ordinal: {xr.global_ordinal()}")

# Test tensor operation
x = torch.randn(3, 3, device=device)
y = x @ x.t()
print(f"\nTest matmul: {y.shape}")
print(f"Device type: {y.device}")
print("=" * 70)

## TPU Support Status

‚úÖ **TPU/XLA support is now built into nanochat!**

The following files have native TPU support:
- `common.py`: Auto-detects TPU and handles device initialization
- `scripts/base_train.py`: Uses XLA-specific optimizer and synchronization

No manual patching needed - just run the training command below!

In [None]:
# Verify nanochat has TPU support (should auto-detect)
import sys
sys.path.insert(0, '.')

from common import autodetect_device_type

device_type = autodetect_device_type()
print(f"\n‚úÖ Nanochat detected device type: {device_type}")

if device_type != "xla":
    print("‚ö†Ô∏è  Warning: Expected device_type='xla' but got '{device_type}'")
    print("Make sure you selected TPU runtime and installed torch-xla correctly")

## Prepare IaC Training Data

**Expanded corpus: 110+ repos across Terraform, Kubernetes, Ansible, Crossplane, Helm, Docker, Pulumi**

This will take ~15-30 minutes and produce ~100-200MB of IaC code ‚Üí 8-15 parquet shards.

In [None]:
import os, glob, subprocess

# Pull latest nanochat code (includes expanded 110+ repo list)
print("=" * 80)
print("Updating nanochat to latest version...")
print("=" * 80)
subprocess.run(["git", "pull", "origin", "master"], cwd=".", check=True)
print("\n‚úÖ Updated to latest version with 110+ IaC repos\n")

CACHE_DIR = os.path.expanduser("~/.cache/nanochat")
DATA_DIR = os.path.join(CACHE_DIR, "iac_data")
BASE_DATA = os.path.join(CACHE_DIR, "base_data")

# Scrape 110+ IaC repositories (expanded corpus)
print("=" * 80)
print("Scraping 110+ IaC repositories...")
print("This will take ~15-30 minutes")
print("=" * 80)
subprocess.run(["bash", "dev/fast_scrape_iac.sh"], input=b"n", check=True)

# Convert to training shards
print("\n" + "=" * 80)
print("Converting to parquet shards...")
print("=" * 80)
subprocess.run([
    "python3", "dev/repackage_iac_data.py",
    "--input-dir", "data/iac_raw_cloned",
    "--output-dir", DATA_DIR,
    "--include-synthetic", "--include-docs"
], check=True)

# Link base_data
if os.path.islink(BASE_DATA):
    os.unlink(BASE_DATA)
os.symlink(DATA_DIR, BASE_DATA)

shard_count = len(glob.glob(f'{BASE_DATA}/*.parquet'))
print("\n" + "=" * 80)
print(f"‚úÖ Data ready: {shard_count} shards")
print(f"Location: {BASE_DATA}")
print("=" * 80)

## Train Tokenizer

In [None]:
# Train BPE tokenizer
!python3 -m scripts.tok_train

## Train on TPU (XLA)

Use torch_xla's distributed launcher instead of torchrun.

In [None]:
# TPU training command (v5e-1 single-core)
MODEL_DEPTH = 12
BATCH_SIZE = 8  # v5e-1 has 16GB HBM, can handle larger batches

# Training command (device type will auto-detect as 'xla')
cmd = f"""python3 scripts/base_train.py \
    --depth={MODEL_DEPTH} \
    --device-batch-size={BATCH_SIZE} \
    --window-pattern=L \
    --target-param-data-ratio=8 \
    --run=dummy \
    --model-tag=iac-gpt-tpu-d{MODEL_DEPTH} \
    --eval-every=100 \
    --sample-every=100 \
    --save-every=100"""

print("=" * 80)
print("TPU v5e-1 TRAINING COMMAND:")
print(cmd)
print("=" * 80)
print("\n‚ö†Ô∏è  IMPORTANT: Run cells 7 and 9 first to prepare data and tokenizer!")
print("Then uncomment and run the training command below:\n")
print(f"!{cmd}")

## Full Training Pipeline

**IMPORTANT: Run cells in this order:**

1. **Cells 1-5**: Setup (install dependencies, clone repo, verify TPU)
2. **Cell 7**: üî¥ Scrape 110+ IaC repos + create training shards (~15-30 min, expect 8-15 shards)
3. **Cell 9**: üî¥ Train BPE tokenizer on IaC data (~2-3 min)
4. **Cell 11**: Copy the training command and run it

**What each step does:**
- **Cell 7**: Clones 110+ repos (Terraform, K8s, Ansible, Crossplane, Helm, Docker, Pulumi) ‚Üí parquet shards at `~/.cache/nanochat/base_data/`
- **Cell 9**: Trains 49K vocab BPE tokenizer on IaC corpus ‚Üí saves to `~/.cache/nanochat/tokenizer/`
- **Training**: Pretrains d12 model (124M params) on IaC data with Muon optimizer

**Expected corpus size:**
- 110+ repos ‚Üí ~100-200MB raw IaC code
- ~50-100M tokens (after tokenization with compression ratio 3-4x)
- 8-15 parquet shards for training

**For multi-core TPU (v2-8, v3-8, v5e-8):**
```bash
python3 -m torch_xla.distributed.xla_dist \
    --tpu-vm --num-cores=8 \
    scripts/base_train.py -- [args...]
```