# MCTS Algorithm Comparison: PUCT vs BayesianMCTS

This notebook compares BatchedMCTS (PUCT) vs BayesianMCTS (Thompson Sampling + IDS) on Connect4.

**Metrics compared:**
- Training throughput (games/sec, examples/sec)
- Inference speed (searches/sec)
- Policy quality (win rate vs random)
- Policy agreement between algorithms

**Recommended:** Run on A100 GPU for best performance.

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Clone nanozero repo
!git clone https://github.com/calaldred/nanozero.git
%cd nanozero

In [None]:
# Install dependencies
!pip install -q torch numpy

In [None]:
import time
import numpy as np
import torch
import torch.nn.functional as F
from dataclasses import dataclass
from typing import List, Tuple, Dict

from nanozero.game import get_game
from nanozero.model import AlphaZeroTransformer
from nanozero.mcts import BatchedMCTS, MCTSConfig, sample_action
from nanozero.bayesian_mcts import BayesianMCTS, BayesianMCTSConfig
from nanozero.replay import ReplayBuffer
from nanozero.config import get_model_config
from nanozero.common import set_seed

# Set seed for reproducibility
set_seed(42)

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Setup: Game and Model

In [None]:
# Create Connect4 game
game = get_game('connect4')
print(f"Game: Connect4")
print(f"Board: {game.config.board_height}x{game.config.board_width}")
print(f"Actions: {game.config.action_size}")

# Create model (4 layers for Connect4)
model_config = get_model_config(game.config, n_layer=4)
model = AlphaZeroTransformer(model_config).to(device)
print(f"Model parameters: {model.count_parameters():,}")

In [None]:
# Create both MCTS variants
NUM_SIMS = 100  # Simulations per search

puct_config = MCTSConfig(num_simulations=NUM_SIMS)
puct_mcts = BatchedMCTS(game, puct_config)

bayesian_config = BayesianMCTSConfig(
    num_simulations=NUM_SIMS,
    early_stopping=True,
    confidence_threshold=0.95,
)
bayesian_mcts = BayesianMCTS(game, bayesian_config)

print(f"PUCT MCTS: {NUM_SIMS} simulations")
print(f"Bayesian MCTS: {NUM_SIMS} simulations (with early stopping)")

## Benchmark 1: Inference Speed

Compare raw search speed for both algorithms.

In [None]:
@torch.inference_mode()
def benchmark_inference(mcts, model, game, batch_sizes=[1, 8, 32, 64], num_searches=50):
    """Benchmark inference speed at various batch sizes."""
    model.eval()
    results = {}
    
    for batch_size in batch_sizes:
        # Create batch of random states (a few moves in)
        states = []
        for _ in range(batch_size):
            state = game.initial_state()
            # Play 3-6 random moves
            for _ in range(np.random.randint(3, 7)):
                if game.is_terminal(state):
                    break
                legal = game.legal_actions(state)
                state = game.next_state(state, np.random.choice(legal))
            if game.is_terminal(state):
                state = game.initial_state()
            states.append(state)
        
        states = np.stack(states)
        mcts.clear_cache()
        
        # Warmup
        _ = mcts.search(states, model)
        mcts.clear_cache()
        
        # Timed runs
        times = []
        for _ in range(num_searches):
            mcts.clear_cache()
            start = time.perf_counter()
            _ = mcts.search(states, model)
            torch.cuda.synchronize() if device.type == 'cuda' else None
            times.append(time.perf_counter() - start)
        
        avg_time = np.mean(times)
        std_time = np.std(times)
        searches_per_sec = batch_size / avg_time
        
        results[batch_size] = {
            'avg_time': avg_time,
            'std_time': std_time,
            'searches_per_sec': searches_per_sec,
        }
        
    return results

In [None]:
print("Benchmarking PUCT inference...")
puct_results = benchmark_inference(puct_mcts, model, game)

print("Benchmarking Bayesian inference...")
bayesian_results = benchmark_inference(bayesian_mcts, model, game)

In [None]:
print("\n" + "="*70)
print("                    INFERENCE SPEED COMPARISON")
print("="*70)
print(f"{'Batch':<8} {'PUCT (s/s)':<15} {'Bayesian (s/s)':<15} {'Ratio':<10}")
print("-"*70)

for batch_size in puct_results.keys():
    puct_sps = puct_results[batch_size]['searches_per_sec']
    bay_sps = bayesian_results[batch_size]['searches_per_sec']
    ratio = bay_sps / puct_sps
    print(f"{batch_size:<8} {puct_sps:<15.1f} {bay_sps:<15.1f} {ratio:<10.2f}x")

print("\n(s/s = searches per second, higher is better)")

## Benchmark 2: Training Throughput

Compare self-play speed (games per second, examples per second).

In [None]:
@torch.inference_mode()
def self_play_games(game, model, mcts, num_games, parallel_games=32, 
                    temperature_threshold=15, mcts_type='puct'):
    """
    Play self-play games and return examples + timing info.
    """
    model.eval()
    all_examples = []
    games_completed = 0
    total_searches = 0
    
    n_parallel = min(parallel_games, num_games)
    states = [game.initial_state() for _ in range(n_parallel)]
    move_counts = [0] * n_parallel
    game_examples = [[] for _ in range(n_parallel)]
    
    start_time = time.perf_counter()
    
    while games_completed < num_games:
        active_indices = [i for i, s in enumerate(states) if not game.is_terminal(s)]
        
        if not active_indices:
            break
        
        # Batch MCTS search
        active_states = np.stack([states[i] for i in active_indices])
        
        if mcts_type == 'puct':
            # Split by noise requirement
            policies = np.zeros((len(active_indices), game.config.action_size), dtype=np.float32)
            noise_indices = [i for i in range(len(active_indices)) if move_counts[active_indices[i]] == 0]
            no_noise_indices = [i for i in range(len(active_indices)) if move_counts[active_indices[i]] != 0]
            
            if noise_indices:
                noise_states = active_states[noise_indices]
                noise_policies = mcts.search(noise_states, model, add_noise=True)
                for local_idx, idx in enumerate(noise_indices):
                    policies[idx] = noise_policies[local_idx]
            
            if no_noise_indices:
                no_noise_states = active_states[no_noise_indices]
                no_noise_policies = mcts.search(no_noise_states, model, add_noise=False)
                for local_idx, idx in enumerate(no_noise_indices):
                    policies[idx] = no_noise_policies[local_idx]
        else:
            policies = mcts.search(active_states, model)
        
        total_searches += len(active_indices)
        
        # Process moves
        for idx, game_idx in enumerate(active_indices):
            state = states[game_idx]
            policy = policies[idx]
            player = game.current_player(state)
            move_count = move_counts[game_idx]
            
            canonical = game.canonical_state(state)
            game_examples[game_idx].append((canonical.copy(), policy.copy(), player))
            
            temperature = 1.0 if move_count < temperature_threshold else 0.0
            action = sample_action(policy, temperature=temperature)
            
            states[game_idx] = game.next_state(state, action)
            move_counts[game_idx] += 1
        
        # Check finished games
        for i in range(n_parallel):
            if game.is_terminal(states[i]) and game_examples[i]:
                reward = game.terminal_reward(states[i])
                final_player = game.current_player(states[i])
                
                for canonical, policy, player in game_examples[i]:
                    value = reward if player == final_player else -reward
                    for sym_state, sym_policy in game.symmetries(canonical, policy):
                        all_examples.append((sym_state, sym_policy, value))
                
                games_completed += 1
                
                if games_completed < num_games:
                    states[i] = game.initial_state()
                    move_counts[i] = 0
                    game_examples[i] = []
    
    elapsed = time.perf_counter() - start_time
    
    return {
        'examples': all_examples,
        'games': games_completed,
        'searches': total_searches,
        'elapsed': elapsed,
        'games_per_sec': games_completed / elapsed,
        'examples_per_sec': len(all_examples) / elapsed,
        'searches_per_sec': total_searches / elapsed,
    }

In [None]:
NUM_GAMES = 50
PARALLEL_GAMES = 32

print(f"Running self-play benchmark: {NUM_GAMES} games, {PARALLEL_GAMES} parallel\n")

# Clear caches
puct_mcts.clear_cache()
bayesian_mcts.clear_cache()

print("PUCT self-play...")
puct_selfplay = self_play_games(game, model, puct_mcts, NUM_GAMES, PARALLEL_GAMES, mcts_type='puct')

print("Bayesian self-play...")
bayesian_selfplay = self_play_games(game, model, bayesian_mcts, NUM_GAMES, PARALLEL_GAMES, mcts_type='bayesian')

In [None]:
print("\n" + "="*70)
print("                   SELF-PLAY THROUGHPUT COMPARISON")
print("="*70)
print(f"{'Metric':<25} {'PUCT':<20} {'Bayesian':<20}")
print("-"*70)

print(f"{'Games completed':<25} {puct_selfplay['games']:<20} {bayesian_selfplay['games']:<20}")
print(f"{'Total time (s)':<25} {puct_selfplay['elapsed']:<20.2f} {bayesian_selfplay['elapsed']:<20.2f}")
print(f"{'Games/sec':<25} {puct_selfplay['games_per_sec']:<20.2f} {bayesian_selfplay['games_per_sec']:<20.2f}")
print(f"{'Examples generated':<25} {len(puct_selfplay['examples']):<20} {len(bayesian_selfplay['examples']):<20}")
print(f"{'Examples/sec':<25} {puct_selfplay['examples_per_sec']:<20.1f} {bayesian_selfplay['examples_per_sec']:<20.1f}")
print(f"{'Searches/sec':<25} {puct_selfplay['searches_per_sec']:<20.1f} {bayesian_selfplay['searches_per_sec']:<20.1f}")

speedup = puct_selfplay['games_per_sec'] / bayesian_selfplay['games_per_sec']
print(f"\nPUCT is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than Bayesian for self-play")

## Benchmark 3: Training Loop Comparison

Train both algorithms for a few iterations and compare learning curves.

In [None]:
def train_step(model, optimizer, scaler, batch, game, device, use_amp=True):
    """Single training step with mixed precision."""
    states, policies, values = batch
    
    state_tensors = torch.stack([game.to_tensor(s) for s in states]).to(device)
    policy_tensors = torch.from_numpy(policies).float().to(device)
    value_tensors = torch.from_numpy(values).float().to(device)
    action_masks = torch.stack([
        torch.from_numpy(game.legal_actions_mask(s)) for s in states
    ]).float().to(device)
    
    model.train()
    optimizer.zero_grad()
    
    with torch.amp.autocast('cuda', enabled=use_amp):
        pred_log_policies, pred_values = model(state_tensors, action_masks)
        policy_loss = -torch.mean(torch.sum(policy_tensors * pred_log_policies, dim=1))
        value_loss = F.mse_loss(pred_values.squeeze(-1), value_tensors)
        loss = policy_loss + value_loss
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    
    return loss.item(), policy_loss.item(), value_loss.item()


@torch.inference_mode()
def evaluate_vs_random(game, model, mcts, num_games=30, mcts_type='puct'):
    """Evaluate against random player."""
    model.eval()
    wins = 0
    
    for i in range(num_games):
        state = game.initial_state()
        model_player = 1 if i % 2 == 0 else -1
        
        while not game.is_terminal(state):
            current = game.current_player(state)
            
            if current == model_player:
                if mcts_type == 'puct':
                    policy = mcts.search(state[np.newaxis, ...], model, 
                                        num_simulations=50, add_noise=False)[0]
                else:
                    policy = mcts.search(state[np.newaxis, ...], model, 
                                        num_simulations=50)[0]
                action = sample_action(policy, temperature=0)
            else:
                legal = game.legal_actions(state)
                action = np.random.choice(legal)
            
            state = game.next_state(state, action)
        
        reward = game.terminal_reward(state)
        final_player = game.current_player(state)
        model_result = reward if final_player == model_player else -reward
        
        if model_result > 0:
            wins += 1
    
    return wins / num_games

In [None]:
# Training configuration
NUM_ITERATIONS = 10
GAMES_PER_ITER = 30
TRAINING_STEPS = 100
BATCH_SIZE = 64
PARALLEL_GAMES = 32

print(f"Training config:")
print(f"  Iterations: {NUM_ITERATIONS}")
print(f"  Games/iteration: {GAMES_PER_ITER}")
print(f"  Training steps/iteration: {TRAINING_STEPS}")
print(f"  Batch size: {BATCH_SIZE}")

In [None]:
def run_training(game, model, mcts, mcts_type, num_iterations, games_per_iter, 
                 training_steps, batch_size, parallel_games, device):
    """Run training loop and collect metrics."""
    
    # Fresh model
    model_config = get_model_config(game.config, n_layer=4)
    model = AlphaZeroTransformer(model_config).to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scaler = torch.amp.GradScaler('cuda', enabled=(device.type == 'cuda'))
    use_amp = device.type == 'cuda'
    
    buffer = ReplayBuffer(50000)
    
    metrics = {
        'iteration': [],
        'selfplay_time': [],
        'train_time': [],
        'loss': [],
        'win_rate': [],
        'examples': [],
    }
    
    total_start = time.perf_counter()
    
    for iteration in range(num_iterations):
        print(f"  Iteration {iteration + 1}/{num_iterations}", end=" ")
        
        # Self-play
        mcts.clear_cache()
        sp_start = time.perf_counter()
        result = self_play_games(game, model, mcts, games_per_iter, parallel_games, mcts_type=mcts_type)
        sp_time = time.perf_counter() - sp_start
        
        for state, policy, value in result['examples']:
            buffer.push(state, policy, value)
        
        # Training
        train_start = time.perf_counter()
        total_loss = 0
        
        if len(buffer) >= batch_size:
            for _ in range(training_steps):
                batch = buffer.sample(batch_size)
                loss, _, _ = train_step(model, optimizer, scaler, batch, game, device, use_amp)
                total_loss += loss
        
        train_time = time.perf_counter() - train_start
        avg_loss = total_loss / training_steps if training_steps > 0 else 0
        
        # Evaluate every 2 iterations
        win_rate = 0
        if (iteration + 1) % 2 == 0:
            mcts.clear_cache()
            win_rate = evaluate_vs_random(game, model, mcts, num_games=20, mcts_type=mcts_type)
        
        metrics['iteration'].append(iteration + 1)
        metrics['selfplay_time'].append(sp_time)
        metrics['train_time'].append(train_time)
        metrics['loss'].append(avg_loss)
        metrics['win_rate'].append(win_rate)
        metrics['examples'].append(len(result['examples']))
        
        print(f"| SP: {sp_time:.1f}s | Train: {train_time:.1f}s | Loss: {avg_loss:.3f} | WR: {win_rate:.0%}")
    
    total_time = time.perf_counter() - total_start
    metrics['total_time'] = total_time
    metrics['model'] = model
    
    return metrics

In [None]:
print("="*70)
print("Training with PUCT MCTS")
print("="*70)
puct_metrics = run_training(
    game, model, puct_mcts, 'puct',
    NUM_ITERATIONS, GAMES_PER_ITER, TRAINING_STEPS, BATCH_SIZE, PARALLEL_GAMES, device
)

In [None]:
print("="*70)
print("Training with Bayesian MCTS")
print("="*70)
bayesian_metrics = run_training(
    game, model, bayesian_mcts, 'bayesian',
    NUM_ITERATIONS, GAMES_PER_ITER, TRAINING_STEPS, BATCH_SIZE, PARALLEL_GAMES, device
)

In [None]:
print("\n" + "="*70)
print("                    TRAINING COMPARISON SUMMARY")
print("="*70)

print(f"\n{'Metric':<30} {'PUCT':<20} {'Bayesian':<20}")
print("-"*70)

print(f"{'Total training time (s)':<30} {puct_metrics['total_time']:<20.1f} {bayesian_metrics['total_time']:<20.1f}")
print(f"{'Avg self-play time/iter (s)':<30} {np.mean(puct_metrics['selfplay_time']):<20.2f} {np.mean(bayesian_metrics['selfplay_time']):<20.2f}")
print(f"{'Avg training time/iter (s)':<30} {np.mean(puct_metrics['train_time']):<20.2f} {np.mean(bayesian_metrics['train_time']):<20.2f}")
print(f"{'Final loss':<30} {puct_metrics['loss'][-1]:<20.3f} {bayesian_metrics['loss'][-1]:<20.3f}")
print(f"{'Final win rate vs random':<30} {puct_metrics['win_rate'][-1]:<20.0%} {bayesian_metrics['win_rate'][-1]:<20.0%}")
print(f"{'Total examples generated':<30} {sum(puct_metrics['examples']):<20} {sum(bayesian_metrics['examples']):<20}")

speedup = puct_metrics['total_time'] / bayesian_metrics['total_time']
print(f"\nBayesian is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PUCT overall")

## Benchmark 4: Policy Agreement Analysis

Compare policies produced by both algorithms on the same positions.

In [None]:
def entropy(probs, eps=1e-8):
    """Shannon entropy of probability distribution."""
    probs = probs + eps
    probs = probs / probs.sum()
    return float(-np.sum(probs * np.log(probs)))


@torch.inference_mode()
def compare_policies(game, model, puct_mcts, bayesian_mcts, num_positions=50, num_sims=100):
    """Compare policies from both algorithms on same positions."""
    model.eval()
    
    results = {
        'agree': 0,
        'puct_entropy': [],
        'bayesian_entropy': [],
        'puct_confidence': [],
        'bayesian_confidence': [],
    }
    
    for _ in range(num_positions):
        # Generate random position
        state = game.initial_state()
        for _ in range(np.random.randint(2, 10)):
            if game.is_terminal(state):
                break
            legal = game.legal_actions(state)
            state = game.next_state(state, np.random.choice(legal))
        
        if game.is_terminal(state):
            state = game.initial_state()
        
        # Get policies
        puct_mcts.clear_cache()
        bayesian_mcts.clear_cache()
        
        puct_policy = puct_mcts.search(state[np.newaxis, ...], model, 
                                       num_simulations=num_sims, add_noise=False)[0]
        bayesian_policy = bayesian_mcts.search(state[np.newaxis, ...], model,
                                               num_simulations=num_sims)[0]
        
        # Compare
        puct_best = np.argmax(puct_policy)
        bayesian_best = np.argmax(bayesian_policy)
        
        if puct_best == bayesian_best:
            results['agree'] += 1
        
        results['puct_entropy'].append(entropy(puct_policy))
        results['bayesian_entropy'].append(entropy(bayesian_policy))
        results['puct_confidence'].append(float(np.max(puct_policy)))
        results['bayesian_confidence'].append(float(np.max(bayesian_policy)))
    
    results['agreement_rate'] = results['agree'] / num_positions
    
    return results

In [None]:
# Use the trained PUCT model for comparison
trained_model = puct_metrics['model']

print("Comparing policies on 50 random positions...")
policy_comparison = compare_policies(game, trained_model, puct_mcts, bayesian_mcts, 
                                     num_positions=50, num_sims=100)

In [None]:
print("\n" + "="*70)
print("                      POLICY COMPARISON")
print("="*70)

print(f"\nAgreement rate (same best action): {policy_comparison['agreement_rate']:.1%}")

print(f"\n{'Metric':<25} {'PUCT':<15} {'Bayesian':<15}")
print("-"*55)
print(f"{'Avg entropy':<25} {np.mean(policy_comparison['puct_entropy']):<15.3f} {np.mean(policy_comparison['bayesian_entropy']):<15.3f}")
print(f"{'Avg confidence (max prob)':<25} {np.mean(policy_comparison['puct_confidence']):<15.3f} {np.mean(policy_comparison['bayesian_confidence']):<15.3f}")

print("\n(Lower entropy = more confident, higher confidence = stronger preference)")

## Summary

In [None]:
print("\n" + "="*70)
print("                         FINAL SUMMARY")
print("="*70)

print("\n1. INFERENCE SPEED (searches/sec at batch=32):")
print(f"   PUCT:     {puct_results[32]['searches_per_sec']:.1f}")
print(f"   Bayesian: {bayesian_results[32]['searches_per_sec']:.1f}")

print(f"\n2. TRAINING THROUGHPUT:")
print(f"   PUCT:     {puct_metrics['total_time']:.1f}s total")
print(f"   Bayesian: {bayesian_metrics['total_time']:.1f}s total")

print(f"\n3. LEARNING (win rate vs random after {NUM_ITERATIONS} iterations):")
print(f"   PUCT:     {puct_metrics['win_rate'][-1]:.0%}")
print(f"   Bayesian: {bayesian_metrics['win_rate'][-1]:.0%}")

print(f"\n4. POLICY AGREEMENT: {policy_comparison['agreement_rate']:.1%}")

print("\n" + "="*70)