# Tako HRM - MCTS Benchmark

Benchmark MCTS search performance across different configurations and games.

## What This Measures

- **MCTS searches/second** - How fast can we search game trees?
- **Forward pass time** - Neural network inference speed
- **Batching efficiency** - Speedup from batched evaluation
- **GPU vs CPU** - Device comparison
- **Game complexity** - How game size affects performance

---

## Verify Setup

**Run `setup.ipynb` first if you haven't already!**

In [2]:
# Verify setup and import libraries
import os
import sys
import torch
import yaml
import time
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm

os.chdir('/content/tako-v2')

if not os.path.exists('scripts/train.py'):
    print("‚ùå ERROR: Not in tako-v2 directory")
    print("   Run setup.ipynb first!")
    raise FileNotFoundError("Run setup.ipynb first")

sys.path.insert(0, os.getcwd())

from model.hrm import HRM
from training.mcts import MCTS
from games.tictactoe import TicTacToeGame
from games.othello import OthelloGame

# Detect available devices
devices = []
if torch.cuda.is_available():
    devices.append('cuda')
    print(f"‚úÖ CUDA GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
if torch.backends.mps.is_available():
    devices.append('mps')
    print(f"‚úÖ Apple MPS")
devices.append('cpu')
print(f"‚úÖ CPU")

print(f"\n‚úÖ Setup verified - ready to benchmark!")
print(f"Devices to test: {devices}")

‚úÖ CUDA GPU: Tesla T4
   Memory: 15.6 GB
‚úÖ CPU

‚úÖ Setup verified - ready to benchmark!
Devices to test: ['cuda', 'cpu']


In [3]:
# Check available devices
import torch

devices = []

if torch.cuda.is_available():
    devices.append('cuda')
    print(f"‚úÖ CUDA GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

if torch.backends.mps.is_available():
    devices.append('mps')
    print(f"‚úÖ Apple MPS (Metal Performance Shaders)")

devices.append('cpu')
print(f"‚úÖ CPU")

print(f"\nDevices to benchmark: {devices}")

‚úÖ CUDA GPU: Tesla T4
   Memory: 15.6 GB
‚úÖ CPU

Devices to benchmark: ['cuda', 'cpu']


In [4]:
def benchmark_forward_pass(game_name, device, num_trials=100, use_optimizations=True):
    """Benchmark forward pass speed with optional optimizations.
    
    Args:
        game_name: Game to benchmark ('tictactoe', 'othello')
        device: Device to use ('cuda', 'cpu', 'mps')
        num_trials: Number of trials to run
        use_optimizations: Enable torch.compile + bfloat16 + inference_mode (default: True)
    """
    # Load config
    with open(f'config/{game_name}.yaml') as f:
        config = yaml.safe_load(f)
    
    # Create model
    model = HRM(**config['model'])
    model.to(device)
    model.eval()
    
    # Apply optimizations if requested
    if use_optimizations:
        if device == 'cuda':
            dtype = torch.bfloat16
            use_compile = True
        else:
            dtype = None  # Keep float32 on CPU/MPS
            use_compile = False  # torch.compile less beneficial on CPU
        
        if dtype is not None or use_compile:
            model.optimize_for_inference(use_compile=use_compile, dtype=dtype)
    
    # Create game for tokens
    if game_name == 'tictactoe':
        game = TicTacToeGame()
    elif game_name == 'othello':
        game = OthelloGame()
    else:
        raise ValueError(f"Unknown game: {game_name}")
    
    tokens = game.to_tokens().unsqueeze(0).to(device)
    max_segments = config['mcts'].get('max_segments_inference', 1)
    
    # Warmup (extra warmup for torch.compile first-run compilation)
    warmup_iterations = 20 if use_optimizations else 10
    with torch.no_grad():
        for _ in range(warmup_iterations):
            policy, value, _ = model.predict(tokens, use_act=True, max_segments=max_segments)
    
    # Benchmark
    times = []
    with torch.no_grad():
        for _ in tqdm(range(num_trials), desc=f"{game_name} on {device}"):
            start = time.time()
            policy, value, _ = model.predict(tokens, use_act=True, max_segments=max_segments)
            if device == 'cuda':
                torch.cuda.synchronize()
            times.append(time.time() - start)
    
    avg_time = np.mean(times) * 1000  # Convert to ms
    std_time = np.std(times) * 1000
    
    return {
        'game': game_name,
        'device': device,
        'avg_ms': avg_time,
        'std_ms': std_time,
        'params': sum(p.numel() for p in model.parameters()) / 1e6,
        'max_segments': max_segments,
        'optimized': use_optimizations
    }

print("‚úÖ Forward pass benchmark function ready")

‚úÖ Forward pass benchmark function ready


In [None]:
# Run forward pass benchmarks
forward_results = []

games_to_test = ['tictactoe', 'othello']

for game in games_to_test:
    for device in devices:
        try:
            result = benchmark_forward_pass(game, device, num_trials=100)
            forward_results.append(result)
            print(f"\n{game} on {device}: {result['avg_ms']:.2f} ¬± {result['std_ms']:.2f} ms")
        except Exception as e:
            print(f"\n‚ö†Ô∏è  {game} on {device} failed: {e}")

print("\n" + "="*80)
print("Forward Pass Benchmark Results:")
print("="*80)
for r in forward_results:
    print(f"{r['game']:<12} {r['device']:<6} {r['avg_ms']:>7.2f} ms  ({r['params']:.1f}M params, {r['max_segments']} seg)")


‚ö†Ô∏è  tictactoe on cuda failed: 'HRM' object has no attribute 'optimize_for_inference'


tictactoe on cpu: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 100/100 [00:00<00:00, 141.60it/s]



tictactoe on cpu: 7.01 ¬± 1.05 ms

‚ö†Ô∏è  othello on cuda failed: 'HRM' object has no attribute 'optimize_for_inference'


othello on cpu:  47%|‚ñà‚ñà‚ñà‚ñà‚ñã     | 47/100 [00:47<01:00,  1.14s/it]

---

## Benchmark 2: MCTS Search Speed

Measure complete MCTS search including tree traversal + neural evaluations.

In [11]:
def benchmark_mcts_search(game_name, device, num_searches=50, simulations=25):
    """Benchmark MCTS search speed."""
    # Load config
    with open(f'config/{game_name}.yaml') as f:
        config = yaml.safe_load(f)
    
    # Override simulations
    config['mcts']['simulations'] = simulations
    
    # Create model
    model = HRM(**config['model'])
    model.to(device)
    model.eval()
    
    # Create game class
    if game_name == 'tictactoe':
        game_class = TicTacToeGame
    elif game_name == 'othello':
        game_class = OthelloGame
    else:
        raise ValueError(f"Unknown game: {game_name}")
    
    # Create MCTS
    mcts = MCTS(model, game_class, config['mcts'], device=device)
    
    # Warmup
    game = game_class()
    for _ in range(5):
        _ = mcts.search(game, move_num=0)
    
    # Benchmark
    times = []
    for _ in tqdm(range(num_searches), desc=f"{game_name} MCTS on {device}"):
        game = game_class()
        start = time.time()
        policy = mcts.search(game, move_num=0)  # Returns policy distribution
        times.append(time.time() - start)
    
    avg_time = np.mean(times)
    std_time = np.std(times)
    
    return {
        'game': game_name,
        'device': device,
        'simulations': simulations,
        'avg_sec': avg_time,
        'std_sec': std_time,
        'searches_per_sec': 1.0 / avg_time,
        'batch_size': config['mcts'].get('batch_size', 1)
    }

print("‚úÖ MCTS benchmark function ready")

In [12]:
# Run MCTS benchmarks with default config
mcts_results = []

games_to_test = ['tictactoe', 'othello']

for game in games_to_test:
    for device in devices:
        try:
            result = benchmark_mcts_search(game, device, num_searches=50, simulations=25)
            mcts_results.append(result)
            print(f"\n{game} on {device}:")
            print(f"  {result['avg_sec']:.3f} ¬± {result['std_sec']:.3f} sec/search")
            print(f"  {result['searches_per_sec']:.1f} searches/sec")
        except Exception as e:
            print(f"\n‚ö†Ô∏è  {game} on {device} failed: {e}")

print("\n" + "="*80)
print("MCTS Search Benchmark Results (25 simulations):")
print("="*80)
for r in mcts_results:
    print(f"{r['game']:<12} {r['device']:<6} {r['avg_sec']:>6.3f} s  ({r['searches_per_sec']:>6.1f} searches/s, batch={r['batch_size']})")

---

## Benchmark 3: Scaling with Simulations

How does MCTS performance scale with different numbers of simulations?

In [None]:
# Benchmark MCTS with different simulation counts
GAME = "tictactoe"  # Change to 'othello' for larger game
DEVICE = devices[0]  # Use best available device

simulation_counts = [10, 25, 50, 100, 200, 400]
scaling_results = []

print(f"Benchmarking {GAME} on {DEVICE} with varying simulations...\n")

for sims in simulation_counts:
    result = benchmark_mcts_search(GAME, DEVICE, num_searches=30, simulations=sims)
    scaling_results.append(result)
    print(f"{sims:>4} sims: {result['avg_sec']:.3f} s ({result['searches_per_sec']:.1f} searches/s)")

print("\n‚úÖ Scaling benchmark complete")

In [None]:
# Plot scaling results
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

sims = [r['simulations'] for r in scaling_results]
times = [r['avg_sec'] for r in scaling_results]
throughput = [r['searches_per_sec'] for r in scaling_results]

# Time vs simulations
axes[0].plot(sims, times, 'o-', linewidth=2, markersize=8)
axes[0].set_xlabel('MCTS Simulations', fontsize=12)
axes[0].set_ylabel('Time per Search (seconds)', fontsize=12)
axes[0].set_title(f'{GAME.capitalize()} - Search Time vs Simulations', fontsize=14)
axes[0].grid(True, alpha=0.3)
axes[0].set_xscale('log')

# Throughput vs simulations
axes[1].plot(sims, throughput, 'o-', linewidth=2, markersize=8, color='green')
axes[1].set_xlabel('MCTS Simulations', fontsize=12)
axes[1].set_ylabel('Searches per Second', fontsize=12)
axes[1].set_title(f'{GAME.capitalize()} - Throughput vs Simulations', fontsize=14)
axes[1].grid(True, alpha=0.3)
axes[1].set_xscale('log')

plt.tight_layout()
plt.show()

print(f"\nüìä Observations:")
print(f"  - Time scales approximately linearly with simulations")
print(f"  - Throughput (searches/sec) decreases with more simulations")
print(f"  - For training, balance: more sims = stronger play, fewer sims = faster iteration")

---

## Benchmark 4: Batching Efficiency

Compare batched vs non-batched MCTS evaluation.

In [None]:
def benchmark_batching(game_name, device, batch_sizes=[1, 4, 8, 16, 32]):
    """Benchmark different batch sizes."""
    results = []
    
    for batch_size in batch_sizes:
        # Load config
        with open(f'config/{game_name}.yaml') as f:
            config = yaml.safe_load(f)
        
        # Override batch size
        config['mcts']['batch_size'] = batch_size
        config['mcts']['simulations'] = 100  # Fixed for comparison
        
        # Create model
        model = HRM(**config['model'])
        model.to(device)
        model.eval()
        
        # Create game class
        if game_name == 'tictactoe':
            game_class = TicTacToeGame
        elif game_name == 'othello':
            game_class = OthelloGame
        
        # Create MCTS
        mcts = MCTS(model, game_class, config['mcts'], device=device)
        
        # Warmup
        game = game_class()
        for _ in range(3):
            _ = mcts.search(game, move_num=0)
        
        # Benchmark
        times = []
        for _ in range(20):
            game = game_class()
            start = time.time()
            policy = mcts.search(game, move_num=0)  # Returns policy distribution
            times.append(time.time() - start)
        
        avg_time = np.mean(times)
        results.append({
            'batch_size': batch_size,
            'avg_sec': avg_time,
            'searches_per_sec': 1.0 / avg_time
        })
        
        print(f"Batch size {batch_size:>2}: {avg_time:.3f} s ({1.0/avg_time:.1f} searches/s)")
    
    return results

print("‚úÖ Batching benchmark function ready")

In [None]:
# Run batching benchmark
GAME = "tictactoe"
DEVICE = devices[0]

print(f"Benchmarking batching for {GAME} on {DEVICE}...\n")
batching_results = benchmark_batching(GAME, DEVICE, batch_sizes=[1, 4, 8, 16, 32])

print("\n‚úÖ Batching benchmark complete")

In [None]:
# Plot batching efficiency
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

batch_sizes = [r['batch_size'] for r in batching_results]
times = [r['avg_sec'] for r in batching_results]
throughput = [r['searches_per_sec'] for r in batching_results]

# Speedup vs batch size
baseline_time = batching_results[0]['avg_sec']  # batch_size=1
speedups = [baseline_time / t for t in times]

axes[0].plot(batch_sizes, speedups, 'o-', linewidth=2, markersize=8)
axes[0].axhline(y=1, color='r', linestyle='--', alpha=0.5, label='No speedup')
axes[0].set_xlabel('Batch Size', fontsize=12)
axes[0].set_ylabel('Speedup vs Batch=1', fontsize=12)
axes[0].set_title(f'{GAME.capitalize()} - Batching Speedup', fontsize=14)
axes[0].grid(True, alpha=0.3)
axes[0].legend()
axes[0].set_xscale('log', base=2)

# Efficiency (speedup / batch_size)
efficiency = [speedup / bs for speedup, bs in zip(speedups, batch_sizes)]
axes[1].plot(batch_sizes, efficiency, 'o-', linewidth=2, markersize=8, color='orange')
axes[1].axhline(y=1, color='r', linestyle='--', alpha=0.5, label='Perfect scaling')
axes[1].set_xlabel('Batch Size', fontsize=12)
axes[1].set_ylabel('Batching Efficiency', fontsize=12)
axes[1].set_title(f'{GAME.capitalize()} - Batching Efficiency', fontsize=14)
axes[1].grid(True, alpha=0.3)
axes[1].legend()
axes[1].set_xscale('log', base=2)

plt.tight_layout()
plt.show()

print(f"\nüìä Observations:")
print(f"  - Best batch size: {batch_sizes[speedups.index(max(speedups))]} ({max(speedups):.1f}x speedup)")
print(f"  - GPU batching is most effective with larger batch sizes")
print(f"  - CPU benefits less from batching (overhead dominates)")

---

## Benchmark 5: End-to-End Game Generation

Measure complete game generation time (full playthrough).

In [None]:
def benchmark_game_generation(game_name, device, num_games=20):
    """Benchmark full game generation."""
    # Load config
    with open(f'config/{game_name}.yaml') as f:
        config = yaml.safe_load(f)
    
    # Create model
    model = HRM(**config['model'])
    model.to(device)
    model.eval()
    
    # Create game class
    if game_name == 'tictactoe':
        game_class = TicTacToeGame
    elif game_name == 'othello':
        game_class = OthelloGame
    
    # Create MCTS
    mcts = MCTS(model, game_class, config['mcts'], device=device)
    
    # Generate games
    times = []
    move_counts = []
    
    for _ in tqdm(range(num_games), desc=f"Generating {game_name} games"):
        game = game_class()
        move_num = 0
        
        start = time.time()
        while not game.is_terminal():
            policy = mcts.search(game, move_num=move_num)  # Returns policy distribution
            # Select move from policy
            legal_moves = game.legal_moves()
            legal_policy = policy[legal_moves]
            move = legal_moves[np.argmax(legal_policy)]
            game.make_move(move)
            move_num += 1
        elapsed = time.time() - start
        
        times.append(elapsed)
        move_counts.append(move_num)
    
    avg_time = np.mean(times)
    avg_moves = np.mean(move_counts)
    games_per_hour = 3600 / avg_time
    
    return {
        'game': game_name,
        'device': device,
        'avg_time': avg_time,
        'avg_moves': avg_moves,
        'games_per_hour': games_per_hour,
        'simulations': config['mcts']['simulations']
    }

print("‚úÖ Game generation benchmark function ready")

In [None]:
# Run game generation benchmarks
game_gen_results = []

games_to_test = ['tictactoe', 'othello']

for game in games_to_test:
    for device in devices:
        try:
            result = benchmark_game_generation(game, device, num_games=20)
            game_gen_results.append(result)
            print(f"\n{game} on {device}:")
            print(f"  {result['avg_time']:.2f} sec/game ({result['avg_moves']:.1f} moves avg)")
            print(f"  {result['games_per_hour']:.0f} games/hour")
        except Exception as e:
            print(f"\n‚ö†Ô∏è  {game} on {device} failed: {e}")

print("\n" + "="*80)
print("Game Generation Benchmark Results:")
print("="*80)
for r in game_gen_results:
    print(f"{r['game']:<12} {r['device']:<6} {r['avg_time']:>6.2f} s/game  ({r['games_per_hour']:>7.0f} games/hr, {r['simulations']} sims)")

---

## Summary Report

Comprehensive comparison across all benchmarks.

In [None]:
# Generate summary report
print("="*80)
print("TAKO HRM - MCTS BENCHMARK SUMMARY")
print("="*80)

print("\n1. FORWARD PASS PERFORMANCE")
print("-" * 80)
print(f"{'Game':<12} {'Device':<8} {'Latency (ms)':<15} {'Params':<12} {'Segments'}")
print("-" * 80)
for r in forward_results:
    print(f"{r['game']:<12} {r['device']:<8} {r['avg_ms']:>7.2f} ¬± {r['std_ms']:<5.2f} {r['params']:>7.1f}M     {r['max_segments']:>2}")

print("\n2. MCTS SEARCH PERFORMANCE")
print("-" * 80)
print(f"{'Game':<12} {'Device':<8} {'Time/Search (s)':<18} {'Searches/sec':<15} {'Batch'}")
print("-" * 80)
for r in mcts_results:
    print(f"{r['game']:<12} {r['device']:<8} {r['avg_sec']:>8.3f} ¬± {r['std_sec']:<6.3f} {r['searches_per_sec']:>10.1f}      {r['batch_size']:>2}")

print("\n3. GAME GENERATION THROUGHPUT")
print("-" * 80)
print(f"{'Game':<12} {'Device':<8} {'Time/Game (s)':<16} {'Games/Hour':<15} {'Avg Moves'}")
print("-" * 80)
for r in game_gen_results:
    print(f"{r['game']:<12} {r['device']:<8} {r['avg_time']:>10.2f}       {r['games_per_hour']:>10.0f}      {r['avg_moves']:>6.1f}")

print("\n" + "="*80)
print("KEY INSIGHTS:")
print("="*80)

# Find best performers
if forward_results:
    best_forward = min(forward_results, key=lambda x: x['avg_ms'])
    print(f"\n‚úÖ Fastest forward pass: {best_forward['game']} on {best_forward['device']} ({best_forward['avg_ms']:.2f} ms)")

if mcts_results:
    best_mcts = min(mcts_results, key=lambda x: x['avg_sec'])
    print(f"‚úÖ Fastest MCTS search: {best_mcts['game']} on {best_mcts['device']} ({best_mcts['searches_per_sec']:.1f} searches/sec)")

if game_gen_results:
    best_gen = max(game_gen_results, key=lambda x: x['games_per_hour'])
    print(f"‚úÖ Highest throughput: {best_gen['game']} on {best_gen['device']} ({best_gen['games_per_hour']:.0f} games/hour)")

print("\nüìä Recommendations:")
if 'cuda' in devices:
    print("  ‚Ä¢ Use CUDA for training (best performance)")
    print("  ‚Ä¢ Use batch_size=16 for optimal GPU utilization")
elif 'mps' in devices:
    print("  ‚Ä¢ Use MPS on Apple Silicon (good performance)")
    print("  ‚Ä¢ Batching provides moderate speedup on MPS")
else:
    print("  ‚Ä¢ CPU only - consider reducing num_workers and batch_size")
    print("  ‚Ä¢ Training will be slower; use simpler games (TicTacToe)")

print("\n" + "="*80)

---

## Export Results

Save benchmark results for future reference.

In [None]:
# Save results to file
import json
from datetime import datetime

results = {
    'timestamp': datetime.now().isoformat(),
    'devices': devices,
    'forward_pass': forward_results,
    'mcts_search': mcts_results,
    'game_generation': game_gen_results,
    'scaling': scaling_results if 'scaling_results' in dir() else [],
    'batching': batching_results if 'batching_results' in dir() else []
}

output_file = f"benchmark_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(output_file, 'w') as f:
    json.dump(results, f, indent=2)

print(f"‚úÖ Results saved to: {output_file}")
print(f"\nFile size: {Path(output_file).stat().st_size / 1024:.1f} KB")