# Tako HRM - TicTacToe Training

Train the Hierarchical Reasoning Model (HRM) on TicTacToe using self-play reinforcement learning.

## üöÄ Quick Start

1. **Enable GPU:** Runtime ‚Üí Change runtime type ‚Üí GPU (T4 or better)
2. **Run setup cells** (sections 1-2)
3. **Start training** (section 3)
4. **Monitor progress** with live visualizations

---

## Training Pipeline

- **Self-play workers:** Generate games using MCTS + current model
- **Replay buffer:** Store positions from recent games (~500K positions)
- **Learner:** Train model on sampled batches (policy + value + ACT loss)
- **Checkpointing:** Save model every 500 steps
- **Evaluation:** Test against random play periodically

## 1. Setup Environment

In [None]:
# Install uv if needed
!curl -LsSf https://astral.sh/uv/install.sh | sh

# Add to PATH
import os
os.environ['PATH'] = f"{os.path.expanduser('~/.cargo/bin')}:{os.environ['PATH']}"

# Clone repo if needed
if not os.path.exists('tako-v2'):
    !git clone https://github.com/zfdupont/tako-v2.git
    %cd tako-v2
    !~/.cargo/bin/uv sync
else:
    %cd tako-v2

print("‚úÖ Environment ready")

In [None]:
# Mount Google Drive for checkpoint persistence
from google.colab import drive
drive.mount('/content/drive')

# Link checkpoint directory
!mkdir -p /content/drive/MyDrive/tako_checkpoints/tictactoe
!rm -rf checkpoints
!ln -s /content/drive/MyDrive/tako_checkpoints checkpoints

print("‚úÖ Checkpoints will be saved to Google Drive")

In [None]:
import torch

if torch.cuda.is_available():
    device = 'cuda'
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = 'cpu'
    print("‚ö†Ô∏è No GPU detected - training will be slower")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

## 2. Training Configuration

View and optionally modify the training config.

In [None]:
import yaml
from pprint import pprint

# Load config
with open('config/tictactoe.yaml') as f:
    config = yaml.safe_load(f)

print("Current TicTacToe Configuration:")
print("="*80)
pprint(config)
print("="*80)

In [None]:
# OPTIONAL: Override config parameters for faster experimentation
# Uncomment and modify as needed

# config['selfplay']['num_workers'] = 4  # Reduce workers if OOM
# config['training']['batch_size'] = 256  # Reduce batch size if OOM
# config['checkpointing']['save_interval'] = 100  # Save more frequently

# Save modified config
# with open('config/tictactoe_colab.yaml', 'w') as f:
#     yaml.dump(config, f)
# print("‚úÖ Config overrides saved to config/tictactoe_colab.yaml")

## 3. Start Training

Launch the training process with Ray distributed workers.

**Note:** Training runs in a separate process. Check the output logs for progress.

In [None]:
# Training parameters
NUM_EPOCHS = 5  # Number of epochs to train
CONFIG_FILE = 'config/tictactoe.yaml'  # Use tictactoe_colab.yaml if you modified config

print(f"Starting training for {NUM_EPOCHS} epochs...")
print(f"Config: {CONFIG_FILE}")
print(f"Device: {device}")
print("\n" + "="*80)
print("Training logs will appear below. Press Ctrl+C to stop.")
print("="*80 + "\n")

# Run training
!~/.cargo/bin/uv run python scripts/train.py --config {CONFIG_FILE} --epochs {NUM_EPOCHS}

## 4. Monitor Training Progress

Load and visualize training metrics from logs.

In [None]:
import re
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

# Find latest log file
log_dir = Path('logs')
if log_dir.exists():
    log_files = list(log_dir.glob('*.log'))
    if log_files:
        latest_log = max(log_files, key=lambda p: p.stat().st_mtime)
        print(f"Reading log: {latest_log}")
        
        # Parse training metrics
        steps = []
        losses = []
        policy_losses = []
        value_losses = []
        
        with open(latest_log) as f:
            for line in f:
                # Match learner log lines with loss values
                if 'loss=' in line:
                    # Extract step and losses
                    step_match = re.search(r'step=(\d+)', line)
                    loss_match = re.search(r'loss=([\d.]+)', line)
                    policy_match = re.search(r'policy=([\d.]+)', line)
                    value_match = re.search(r'value=([\d.]+)', line)
                    
                    if step_match and loss_match:
                        steps.append(int(step_match.group(1)))
                        losses.append(float(loss_match.group(1)))
                        if policy_match:
                            policy_losses.append(float(policy_match.group(1)))
                        if value_match:
                            value_losses.append(float(value_match.group(1)))
        
        if steps:
            # Plot training curves
            fig, axes = plt.subplots(1, 3, figsize=(15, 4))
            
            # Total loss
            axes[0].plot(steps, losses, alpha=0.6, label='Total Loss')
            axes[0].set_xlabel('Training Step')
            axes[0].set_ylabel('Loss')
            axes[0].set_title('Total Loss')
            axes[0].grid(True, alpha=0.3)
            
            # Policy loss
            if policy_losses:
                axes[1].plot(steps[:len(policy_losses)], policy_losses, alpha=0.6, color='orange', label='Policy Loss')
                axes[1].set_xlabel('Training Step')
                axes[1].set_ylabel('Loss')
                axes[1].set_title('Policy Loss')
                axes[1].grid(True, alpha=0.3)
            
            # Value loss
            if value_losses:
                axes[2].plot(steps[:len(value_losses)], value_losses, alpha=0.6, color='green', label='Value Loss')
                axes[2].set_xlabel('Training Step')
                axes[2].set_ylabel('Loss')
                axes[2].set_title('Value Loss')
                axes[2].grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.show()
            
            print(f"\n‚úÖ Parsed {len(steps)} training steps")
            if losses:
                print(f"   Latest loss: {losses[-1]:.4f}")
        else:
            print("‚ö†Ô∏è No training metrics found in log")
    else:
        print("‚ö†Ô∏è No log files found")
else:
    print("‚ö†Ô∏è Logs directory not found. Run training first.")

## 5. List Checkpoints

In [None]:
from pathlib import Path
import datetime

checkpoint_dir = Path('checkpoints/tictactoe')

if checkpoint_dir.exists():
    checkpoints = sorted(checkpoint_dir.glob('*.pt'), key=lambda p: p.stat().st_mtime)
    
    if checkpoints:
        print(f"Found {len(checkpoints)} checkpoint(s):\n")
        print(f"{'Name':<40} {'Size':<10} {'Modified'}")
        print("="*80)
        
        for ckpt in checkpoints:
            size_mb = ckpt.stat().st_size / 1e6
            mtime = datetime.datetime.fromtimestamp(ckpt.stat().st_mtime)
            print(f"{ckpt.name:<40} {size_mb:>6.1f} MB   {mtime.strftime('%Y-%m-%d %H:%M:%S')}")
        
        print("\n‚úÖ Latest checkpoint:", checkpoints[-1].name)
    else:
        print("‚ö†Ô∏è No checkpoints found. Train first.")
else:
    print("‚ö†Ô∏è Checkpoint directory not found")

## üéØ Next Steps

After training:

1. **Evaluate model:** Open `02_evaluate_model.ipynb` to test against random/perfect play
2. **Play interactively:** Open `03_interactive_play.ipynb` to play against your trained model
3. **Continue training:** Re-run section 3 to train for more epochs

---

### Training Tips

- **GPU memory issues?** Reduce `num_workers` or `batch_size` in config
- **Slow convergence?** TicTacToe is simple - should converge in ~1000 games
- **Check win rate:** Should reach >90% vs random play within 30 minutes
- **Checkpoints persist** in Google Drive even if session disconnects