# 🧩 HRM Sudoku-Extreme 1 k Demo
**Google Colab PRO (High-RAM) + T4 GPU – single-GPU reproduction of the paper’s 1 k-shot run.**  
Runtime: ~50 min on A100-high-ram, ~55 min on T4-high-ram.

In [None]:
#@title 0️⃣ Check GPU
!nvidia-smi

In [None]:
#@title 1️⃣ One-liner installs (CUDA 12.6 + PyTorch 2.4 + Flash-Attn 2)
import os, subprocess, sys
def run(cmd): subprocess.run(cmd, shell=True, check=True)

# PyTorch 2.4 + CUDA 12.6 wheels
run("pip install torch==2.4.0+cu126 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126")

# Ninja + setuptools for compilation
run("pip install packaging ninja wheel setuptools setuptools-scm")

# Flash-Attention 2 (works on T4/A100)
run("pip install flash-attn --no-build-isolation")

In [None]:
#@title 2️⃣ Clone HRM repo + submodules
run("git clone --recursive https://github.com/sapientinc/HRM.git")
%cd HRM

In [None]:
#@title 3️⃣ Python deps
run("pip install -r requirements.txt")

## 4️⃣ Build the Sudoku-Extreme 1 k dataset  
This is exactly the same as the paper’s `subsample-size 1000 --num-aug 1000`.

In [None]:
#@title 4️⃣ Build dataset (~30 s)
run("python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000")
!ls data/sudoku-extreme-1k-aug-1000

## 5️⃣ Train (single GPU, small batch)
We halve the batch size (192 instead of 384) to fit T4 16 GB.  
The run will auto-log to Weights & Biases if you’re logged in (`wandb login`).

In [None]:
#@title 5️⃣ Launch training
cmd = """
OMP_NUM_THREADS=8 python pretrain.py \
  data_path=data/sudoku-extreme-1k-aug-1000 \
  epochs=2000 \
  eval_interval=500 \
  global_batch_size=192 \
  lr=7e-5 \
  puzzle_emb_lr=7e-5 \
  weight_decay=1.0 \
  puzzle_emb_weight_decay=1.0 \
  wandb_project="hrm-colab-sudoku1k"
"""
run(cmd)

## 6️⃣ Evaluate
After training finishes (~step 1500) we run the built-in exact-accuracy evaluator.

In [None]:
#@title 6️⃣ Evaluate last checkpoint
ckpt_path = !ls -t checkpoints/*/ckpt.pt | head -1
ckpt_path = ckpt_path[0]
print("Evaluating", ckpt_path)
run(f"python evaluate.py checkpoint={ckpt_path}")

## 7️⃣ Show one solved grid
We decode the first validation sample back to a human-readable Sudoku.

In [None]:
#@title 7️⃣ Pretty print a solved puzzle
from src.utils.sudoku import Sudoku
import torch

ckpt = torch.load(ckpt_path, map_location="cpu")
model = ckpt["model"]
model.eval()

from src.data.sudoku_dataset import SudokuDataset
ds = SudokuDataset("data/sudoku-extreme-1k-aug-1000", split="val")
sample = ds[0]

with torch.no_grad():
    logits = model(sample["input_ids"].unsqueeze(0).cuda())
pred = logits.argmax(-1).cpu()

print("Input puzzle:\n", Sudoku(sample["input_ids"].view(9,9)).grid)
print("Model solution:\n", Sudoku(pred.view(9,9)).grid)
print("Target:\n", Sudoku(sample["target"].view(9,9)).grid)

## 8️⃣ Save checkpoint to Drive (optional)
Mount your Drive and copy the 120 MB checkpoint so others can load it instantly.

In [None]:
#@title 8️⃣ Mount Drive & save
from google.colab import drive
drive.mount('/content/drive')

save_dir = "/content/drive/MyDrive/hrm_sudoku1k_t4"
run(f"mkdir -p {save_dir}")
run(f"cp -r checkpoints {save_dir}")
print("Checkpoint saved to", save_dir)