# Tako HRM - Training

Train the Hierarchical Reasoning Model (HRM) on different games using self-play RL.

## Games

- **TicTacToe** - Simple 3x3 game (1.1M params, ~30min to convergence)
- **Othello** - 8x8 board (8.4M params, ~2-3 hours)
- **Hex** - 11x11 board (Coming soon)
- **Chess** - Full chess (27M params, requires pretraining)

---

## Verify Setup

**Run `setup.ipynb` first if you haven't already!**

In [None]:
# Check GPU availability
import torch

if torch.cuda.is_available():
    device = 'cuda'
    print(f"✅ GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print(f"\n   Note: Ray workers will share GPU using fractional allocation")
else:
    device = 'cpu'
    print("⚠️  No GPU detected - training will be slower")
    print("   Enable: Runtime → Change runtime type → GPU")

# Verify setup has been run
import os
import sys

# Check if we're in the repo directory
if not os.path.exists('scripts/train.py'):
    print("❌ ERROR: Not in tako-v2 directory")
    print("   Run setup.ipynb first!")
    raise FileNotFoundError("Run setup.ipynb first")

# Add to path
sys.path.insert(0, os.getcwd())

# Check GPU
import torch
if torch.cuda.is_available():
    device = 'cuda'
    print(f"✅ GPU: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    device = 'mps'
    print(f"✅ MPS: Apple Silicon GPU")
else:
    device = 'cpu'
    print(f"⚠️  CPU only (slower)")

print(f"✅ Setup verified - ready to train!")

In [None]:
# Plot training curves
import re
import matplotlib.pyplot as plt
from pathlib import Path

log_dir = Path('logs')
if log_dir.exists():
    log_files = sorted(log_dir.glob('*tictactoe*.log'), key=lambda p: p.stat().st_mtime)
    if log_files:
        latest_log = log_files[-1]
        print(f"Reading: {latest_log.name}")
        
        steps, losses = [], []
        with open(latest_log) as f:
            for line in f:
                if 'loss=' in line:
                    step_match = re.search(r'step=(\d+)', line)
                    loss_match = re.search(r'loss=([\d.]+)', line)
                    if step_match and loss_match:
                        steps.append(int(step_match.group(1)))
                        losses.append(float(loss_match.group(1)))
        
        if steps:
            plt.figure(figsize=(10, 4))
            plt.plot(steps, losses, alpha=0.6)
            plt.xlabel('Training Step')
            plt.ylabel('Loss')
            plt.title('TicTacToe Training Loss')
            plt.grid(True, alpha=0.3)
            plt.show()
            print(f"\n✅ {len(steps)} steps, latest loss: {losses[-1]:.4f}")
        else:
            print("⚠️  No training metrics found")
    else:
        print("⚠️  No log files found")
else:
    print("⚠️  Logs directory not found")

---

## Othello Training

**Model:** 8.4M parameters  
**Time:** ~2-3 hours to competent play  
**Target:** Beat Edax level 3

In [None]:
# Othello training configuration
GAME = "othello"
CONFIG = f"config/{GAME}.yaml"
EPOCHS = 10

print(f"Training: {GAME}")
print(f"Config: {CONFIG}")
print(f"Epochs: {EPOCHS}")
print(f"Device: {device}")
print("\n" + "="*80)

# Start training
!uv run python scripts/train.py --config {CONFIG} --epochs {EPOCHS}

### Monitor Othello Progress

In [None]:
# Plot training curves
import re
import matplotlib.pyplot as plt
from pathlib import Path

log_dir = Path('logs')
if log_dir.exists():
    log_files = sorted(log_dir.glob('*othello*.log'), key=lambda p: p.stat().st_mtime)
    if log_files:
        latest_log = log_files[-1]
        print(f"Reading: {latest_log.name}")
        
        steps, losses = [], []
        with open(latest_log) as f:
            for line in f:
                if 'loss=' in line:
                    step_match = re.search(r'step=(\d+)', line)
                    loss_match = re.search(r'loss=([\d.]+)', line)
                    if step_match and loss_match:
                        steps.append(int(step_match.group(1)))
                        losses.append(float(loss_match.group(1)))
        
        if steps:
            plt.figure(figsize=(10, 4))
            plt.plot(steps, losses, alpha=0.6)
            plt.xlabel('Training Step')
            plt.ylabel('Loss')
            plt.title('Othello Training Loss')
            plt.grid(True, alpha=0.3)
            plt.show()
            print(f"\n✅ {len(steps)} steps, latest loss: {losses[-1]:.4f}")
        else:
            print("⚠️  No training metrics found")
    else:
        print("⚠️  No log files found")
else:
    print("⚠️  Logs directory not found")

---

## Hex Training

**Model:** ~8M parameters  
**Time:** ~3-4 hours  
**Target:** Strong tactical play on 11x11 board

In [None]:
# Hex training configuration
GAME = "hex"
CONFIG = f"config/{GAME}.yaml"
EPOCHS = 10

print(f"Training: {GAME}")
print(f"Config: {CONFIG}")
print(f"Epochs: {EPOCHS}")
print(f"Device: {device}")
print("\n" + "="*80)

# Start training
!uv run python scripts/train.py --config {CONFIG} --epochs {EPOCHS}

---

## Chess Training

**Model:** 27M parameters  
**Time:** Days (requires pretraining)  
**Target:** 2500+ Elo (GM level)

**Note:** Chess requires supervised pretraining on PGN data before self-play.

In [None]:
# Chess pretraining (run first)
print("Chess pretraining...")
print("This requires PGN data in data/chess/")
print("\n" + "="*80)

!uv run python scripts/pretrain.py --config config/chess.yaml --data data/chess/games.pgn

In [None]:
# Chess self-play training (run after pretraining)
GAME = "chess"
CONFIG = f"config/{GAME}.yaml"
EPOCHS = 20
RESUME = "checkpoints/chess/pretrain_final.pt"  # Load pretrained checkpoint

print(f"Training: {GAME}")
print(f"Config: {CONFIG}")
print(f"Resume from: {RESUME}")
print(f"Epochs: {EPOCHS}")
print(f"Device: {device}")
print("\n" + "="*80)

# Start training
!uv run python scripts/train.py --config {CONFIG} --epochs {EPOCHS} --resume {RESUME}

---

## List Checkpoints

In [None]:
# List all checkpoints
from pathlib import Path
import datetime

checkpoint_dir = Path('checkpoints')

if checkpoint_dir.exists():
    for game_dir in sorted(checkpoint_dir.iterdir()):
        if game_dir.is_dir():
            checkpoints = sorted(game_dir.glob('*.pt'), key=lambda p: p.stat().st_mtime)
            if checkpoints:
                print(f"\n{game_dir.name.upper()}:")
                print("="*80)
                for ckpt in checkpoints[-5:]:  # Show last 5
                    size_mb = ckpt.stat().st_size / 1e6
                    mtime = datetime.datetime.fromtimestamp(ckpt.stat().st_mtime)
                    print(f"  {ckpt.name:<40} {size_mb:>6.1f} MB   {mtime.strftime('%Y-%m-%d %H:%M:%S')}")
                print(f"  Latest: {checkpoints[-1].name}")
else:
    print("⚠️  No checkpoints found")