# Tako HRM - Training

Train the Hierarchical Reasoning Model (HRM) on different games using self-play RL.

## Games

- **TicTacToe** - Simple 3x3 game (1.1M params, ~30min to convergence)
- **Othello** - 8x8 board (8.4M params, ~2-3 hours)
- **Hex** - 11x11 board (Coming soon)
- **Chess** - Full chess (27M params, requires pretraining)

---

## Setup (Run Once)

In [None]:
# Install uv package manager
!curl -LsSf https://astral.sh/uv/install.sh | sh

# Add to PATH
import os
os.environ['PATH'] = f"{os.path.expanduser('~/.cargo/bin')}:{os.environ['PATH']}"

print("‚úÖ uv installed")

In [None]:
# Clone repository (organization repo with authentication)
import os
import subprocess
from google.colab import userdata

ORG_NAME = "YOUR_ORG_NAME"  # ‚Üê UPDATE THIS!
REPO_NAME = "tako-v2"

def clone_repo():
    """Clone private organization repository."""
    try:
        github_token = userdata.get('GITHUB_TOKEN')
        print(f"‚úÖ Retrieved GITHUB_TOKEN from Colab Secrets")
    except Exception as e:
        print("\n‚ùå ERROR: Could not access GITHUB_TOKEN from Colab Secrets")
        print("\nSetup Instructions:")
        print("1. Create token: https://github.com/settings/tokens")
        print("   ‚Üí Scopes: ‚úÖ repo, ‚úÖ read:org")
        print("2. Authorize for org: Click 'Configure SSO' ‚Üí 'Authorize'")
        print("3. Add to Colab: üîë ‚Üí GITHUB_TOKEN = ghp_...")
        print("4. Update ORG_NAME in cell above")
        print("\nüìö Full guide: notebooks/ORG_REPO_SETUP.md")
        raise e
    
    if os.path.exists(REPO_NAME):
        print(f"‚úÖ Repository already exists")
        return True
    
    repo_url = f"https://{github_token}@github.com/{ORG_NAME}/{REPO_NAME}.git"
    print(f"üîÑ Cloning {ORG_NAME}/{REPO_NAME}...")
    
    result = subprocess.run(
        ['git', 'clone', repo_url, REPO_NAME],
        capture_output=True,
        text=True,
        timeout=60
    )
    
    if result.returncode == 0:
        print(f"‚úÖ Repository cloned successfully")
        # Remove token from git config
        subprocess.run(
            ['git', '-C', REPO_NAME, 'remote', 'set-url', 'origin',
             f'https://github.com/{ORG_NAME}/{REPO_NAME}.git'],
            capture_output=True
        )
        print(f"‚úÖ Token removed from git config")
        return True
    else:
        print("\n‚ùå Clone failed!")
        if 'not found' in result.stderr.lower():
            print("Fix: Authorize token at https://github.com/settings/tokens")
        print(f"\nError: {result.stderr}")
        return False

if clone_repo():
    os.chdir(REPO_NAME)
    print(f"\nüìÇ Changed to: {os.getcwd()}")
    
    # Install dependencies
    print("\nüì¶ Installing dependencies...")
    result = subprocess.run(
        [os.path.expanduser('~/.cargo/bin/uv'), 'sync'],
        capture_output=True,
        text=True
    )
    
    if result.returncode == 0:
        print("‚úÖ Dependencies installed")
    else:
        print("‚ö†Ô∏è  Warning: Dependency installation had issues")
        print(result.stderr)
    
    print("\n‚úÖ Setup complete!")
else:
    print("\n‚ùå Setup failed")

In [None]:
# Mount Google Drive for checkpoint persistence
from google.colab import drive
drive.mount('/content/drive')

# Link checkpoint directories
!mkdir -p /content/drive/MyDrive/tako_checkpoints
!rm -rf checkpoints
!ln -s /content/drive/MyDrive/tako_checkpoints checkpoints

print("‚úÖ Checkpoints will be saved to Google Drive")

In [None]:
# Check GPU availability
import torch

if torch.cuda.is_available():
    device = 'cuda'
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print(f"\n   Note: Ray workers will share GPU using fractional allocation")
else:
    device = 'cpu'
    print("‚ö†Ô∏è  No GPU detected - training will be slower")
    print("   Enable: Runtime ‚Üí Change runtime type ‚Üí GPU")

---

## TicTacToe Training

**Model:** 1.1M parameters  
**Time:** ~30 minutes to convergence  
**Target:** 90%+ win rate vs random

In [None]:
# TicTacToe training configuration
GAME = "tictactoe"
CONFIG = f"config/{GAME}.yaml"
EPOCHS = 5

print(f"Training: {GAME}")
print(f"Config: {CONFIG}")
print(f"Epochs: {EPOCHS}")
print(f"Device: {device}")
print("\n" + "="*80)

# Start training
!~/.cargo/bin/uv run python scripts/train.py --config {CONFIG} --epochs {EPOCHS}

### Monitor TicTacToe Progress

In [None]:
# Plot training curves
import re
import matplotlib.pyplot as plt
from pathlib import Path

log_dir = Path('logs')
if log_dir.exists():
    log_files = sorted(log_dir.glob('*tictactoe*.log'), key=lambda p: p.stat().st_mtime)
    if log_files:
        latest_log = log_files[-1]
        print(f"Reading: {latest_log.name}")
        
        steps, losses = [], []
        with open(latest_log) as f:
            for line in f:
                if 'loss=' in line:
                    step_match = re.search(r'step=(\d+)', line)
                    loss_match = re.search(r'loss=([\d.]+)', line)
                    if step_match and loss_match:
                        steps.append(int(step_match.group(1)))
                        losses.append(float(loss_match.group(1)))
        
        if steps:
            plt.figure(figsize=(10, 4))
            plt.plot(steps, losses, alpha=0.6)
            plt.xlabel('Training Step')
            plt.ylabel('Loss')
            plt.title('TicTacToe Training Loss')
            plt.grid(True, alpha=0.3)
            plt.show()
            print(f"\n‚úÖ {len(steps)} steps, latest loss: {losses[-1]:.4f}")
        else:
            print("‚ö†Ô∏è  No training metrics found")
    else:
        print("‚ö†Ô∏è  No log files found")
else:
    print("‚ö†Ô∏è  Logs directory not found")

---

## Othello Training

**Model:** 8.4M parameters  
**Time:** ~2-3 hours to competent play  
**Target:** Beat Edax level 3

In [None]:
# Othello training configuration
GAME = "othello"
CONFIG = f"config/{GAME}.yaml"
EPOCHS = 10

print(f"Training: {GAME}")
print(f"Config: {CONFIG}")
print(f"Epochs: {EPOCHS}")
print(f"Device: {device}")
print("\n" + "="*80)

# Start training
!~/.cargo/bin/uv run python scripts/train.py --config {CONFIG} --epochs {EPOCHS}

### Monitor Othello Progress

In [None]:
# Plot training curves
import re
import matplotlib.pyplot as plt
from pathlib import Path

log_dir = Path('logs')
if log_dir.exists():
    log_files = sorted(log_dir.glob('*othello*.log'), key=lambda p: p.stat().st_mtime)
    if log_files:
        latest_log = log_files[-1]
        print(f"Reading: {latest_log.name}")
        
        steps, losses = [], []
        with open(latest_log) as f:
            for line in f:
                if 'loss=' in line:
                    step_match = re.search(r'step=(\d+)', line)
                    loss_match = re.search(r'loss=([\d.]+)', line)
                    if step_match and loss_match:
                        steps.append(int(step_match.group(1)))
                        losses.append(float(loss_match.group(1)))
        
        if steps:
            plt.figure(figsize=(10, 4))
            plt.plot(steps, losses, alpha=0.6)
            plt.xlabel('Training Step')
            plt.ylabel('Loss')
            plt.title('Othello Training Loss')
            plt.grid(True, alpha=0.3)
            plt.show()
            print(f"\n‚úÖ {len(steps)} steps, latest loss: {losses[-1]:.4f}")
        else:
            print("‚ö†Ô∏è  No training metrics found")
    else:
        print("‚ö†Ô∏è  No log files found")
else:
    print("‚ö†Ô∏è  Logs directory not found")

---

## Hex Training

**Model:** ~8M parameters  
**Time:** ~3-4 hours  
**Target:** Strong tactical play on 11x11 board

In [None]:
# Hex training configuration
GAME = "hex"
CONFIG = f"config/{GAME}.yaml"
EPOCHS = 10

print(f"Training: {GAME}")
print(f"Config: {CONFIG}")
print(f"Epochs: {EPOCHS}")
print(f"Device: {device}")
print("\n" + "="*80)

# Start training
!~/.cargo/bin/uv run python scripts/train.py --config {CONFIG} --epochs {EPOCHS}

---

## Chess Training

**Model:** 27M parameters  
**Time:** Days (requires pretraining)  
**Target:** 2500+ Elo (GM level)

**Note:** Chess requires supervised pretraining on PGN data before self-play.

In [None]:
# Chess pretraining (run first)
print("Chess pretraining...")
print("This requires PGN data in data/chess/")
print("\n" + "="*80)

!~/.cargo/bin/uv run python scripts/pretrain.py --config config/chess.yaml --data data/chess/games.pgn

In [None]:
# Chess self-play training (run after pretraining)
GAME = "chess"
CONFIG = f"config/{GAME}.yaml"
EPOCHS = 20
RESUME = "checkpoints/chess/pretrain_final.pt"  # Load pretrained checkpoint

print(f"Training: {GAME}")
print(f"Config: {CONFIG}")
print(f"Resume from: {RESUME}")
print(f"Epochs: {EPOCHS}")
print(f"Device: {device}")
print("\n" + "="*80)

# Start training
!~/.cargo/bin/uv run python scripts/train.py --config {CONFIG} --epochs {EPOCHS} --resume {RESUME}

---

## List Checkpoints

In [None]:
# List all checkpoints
from pathlib import Path
import datetime

checkpoint_dir = Path('checkpoints')

if checkpoint_dir.exists():
    for game_dir in sorted(checkpoint_dir.iterdir()):
        if game_dir.is_dir():
            checkpoints = sorted(game_dir.glob('*.pt'), key=lambda p: p.stat().st_mtime)
            if checkpoints:
                print(f"\n{game_dir.name.upper()}:")
                print("="*80)
                for ckpt in checkpoints[-5:]:  # Show last 5
                    size_mb = ckpt.stat().st_size / 1e6
                    mtime = datetime.datetime.fromtimestamp(ckpt.stat().st_mtime)
                    print(f"  {ckpt.name:<40} {size_mb:>6.1f} MB   {mtime.strftime('%Y-%m-%d %H:%M:%S')}")
                print(f"  Latest: {checkpoints[-1].name}")
else:
    print("‚ö†Ô∏è  No checkpoints found")