# Rust MCTS Virtual Loss Batching Test

Tests that PUCT and TTTS Rust backends maintain playing quality with large batch sizes.

**Key Questions:**
1. Does increasing `leaves_per_batch` hurt game-playing performance?
2. How does training speed differ between PUCT and TTTS at various batch sizes?
3. Do models trained with high batch multipliers learn as well?

**Setup:** Use `Runtime > Change runtime type > GPU` for best performance.

## 1. Environment Setup

In [None]:
# Check GPU
import torch
print(f"GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    print(f"GPU: {gpu_name}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install Rust toolchain
!curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
import os
os.environ["PATH"] = f"{os.environ['HOME']}/.cargo/bin:" + os.environ["PATH"]

# Verify Rust installation
!rustc --version

In [None]:
# Clone repository
!git clone https://github.com/caldred/nanozero.git
%cd nanozero

# Install Python dependencies
# Pin sympy<1.13 to avoid PyTorch compatibility issues
!pip install -q numpy scipy maturin "sympy<1.13"

# Build and install Rust extension
%cd nanozero-mcts-rs
!maturin build --release
!pip install target/wheels/nanozero_mcts_rs-*.whl
%cd ..

# Verify Rust backend is available
!python -c "from nanozero_mcts_rs import RustBatchedMCTS, RustBayesianMCTS; print('Rust backends loaded!')"

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import time
from scipy import stats
from nanozero.game import get_game
from nanozero.model import AlphaZeroTransformer
from nanozero.mcts import BatchedMCTS, BayesianMCTS
from nanozero.common import sample_action
from nanozero.config import get_model_config, MCTSConfig, BayesianMCTSConfig
from nanozero.replay import ReplayBuffer

# Aliases for backwards compatibility with notebook
RustBatchedMCTS = BatchedMCTS
RustBayesianMCTS = BayesianMCTS

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## 2. Training Infrastructure

In [None]:
@torch.inference_mode()
def self_play_games(game, model, mcts, num_games, temperature_threshold=15, 
                    parallel_games=32, is_bayesian=False):
    """
    Play multiple games of self-play in parallel.
    Works with both PUCT and Bayesian MCTS.
    """
    model.eval()
    all_examples = []
    games_completed = 0

    n_parallel = min(parallel_games, num_games)
    states = [game.initial_state() for _ in range(n_parallel)]
    move_counts = [0] * n_parallel
    game_examples = [[] for _ in range(n_parallel)]

    while games_completed < num_games:
        active_indices = [i for i, s in enumerate(states) if not game.is_terminal(s)]
        if not active_indices:
            break

        # Batch all active states for MCTS
        active_states = np.stack([states[i] for i in active_indices])
        
        if is_bayesian:
            # Bayesian MCTS doesn't use add_noise parameter
            policies = mcts.search(active_states, model)
        else:
            # PUCT: add noise only at move 0
            noise_mask = [move_counts[i] == 0 for i in active_indices]
            if all(noise_mask):
                policies = mcts.search(active_states, model, add_noise=True)
            elif not any(noise_mask):
                policies = mcts.search(active_states, model, add_noise=False)
            else:
                # Mixed: do two searches
                policies = np.zeros((len(active_indices), game.config.action_size), dtype=np.float32)
                noise_idx = [i for i, m in enumerate(noise_mask) if m]
                no_noise_idx = [i for i, m in enumerate(noise_mask) if not m]
                
                if noise_idx:
                    noise_states = np.stack([active_states[i] for i in noise_idx])
                    noise_policies = mcts.search(noise_states, model, add_noise=True)
                    for j, idx in enumerate(noise_idx):
                        policies[idx] = noise_policies[j]
                
                if no_noise_idx:
                    no_noise_states = np.stack([active_states[i] for i in no_noise_idx])
                    no_noise_policies = mcts.search(no_noise_states, model, add_noise=False)
                    for j, idx in enumerate(no_noise_idx):
                        policies[idx] = no_noise_policies[j]

        # Process each active game
        for idx, game_idx in enumerate(active_indices):
            state = states[game_idx]
            policy = policies[idx]
            player = game.current_player(state)
            move_count = move_counts[game_idx]

            # Store example
            canonical = game.canonical_state(state)
            game_examples[game_idx].append((canonical.copy(), policy.copy(), player))

            # Sample action
            temperature = 1.0 if move_count < temperature_threshold else 0.0
            action = sample_action(policy, temperature=temperature)

            states[game_idx] = game.next_state(state, action)
            move_counts[game_idx] += 1

        # Check for finished games
        for i in range(n_parallel):
            if game.is_terminal(states[i]) and game_examples[i]:
                reward = game.terminal_reward(states[i])
                final_player = game.current_player(states[i])

                for canonical, policy, player in game_examples[i]:
                    value = reward if player == final_player else -reward
                    for sym_state, sym_policy in game.symmetries(canonical, policy):
                        all_examples.append((sym_state, sym_policy, value))

                games_completed += 1

                if games_completed < num_games:
                    states[i] = game.initial_state()
                    move_counts[i] = 0
                    game_examples[i] = []

    return all_examples


def train_step(model, optimizer, states, policies, values, action_masks, device):
    """Single training step."""
    model.train()
    
    states = states.to(device)
    policies = policies.to(device)
    values = values.to(device)
    action_masks = action_masks.to(device)
    
    optimizer.zero_grad()
    
    pred_log_policies, pred_values = model(states, action_masks)
    policy_loss = -torch.mean(torch.sum(policies * pred_log_policies, dim=1))
    value_loss = F.mse_loss(pred_values.squeeze(-1), values)
    loss = policy_loss + value_loss
    
    loss.backward()
    optimizer.step()
    
    return loss.item(), policy_loss.item(), value_loss.item()


@torch.inference_mode()
def evaluate_vs_random(game, model, mcts, num_games=50, is_bayesian=False):
    """
    Evaluate model against random player.
    Runs all games in parallel with batched MCTS.
    """
    model.eval()
    
    # Initialize all games
    states = [game.initial_state() for _ in range(num_games)]
    model_players = [1 if i % 2 == 0 else -1 for i in range(num_games)]
    results = [None] * num_games  # None = ongoing, 1 = win, 0 = draw, -1 = loss
    
    while any(r is None for r in results):
        # Find games where it's the model's turn
        model_turn_indices = []
        random_turn_indices = []
        
        for i, (state, result) in enumerate(zip(states, results)):
            if result is not None:
                continue
            if game.is_terminal(state):
                # Game just ended
                reward = game.terminal_reward(state)
                final_player = game.current_player(state)
                model_result = reward if final_player == model_players[i] else -reward
                results[i] = model_result
                continue
            
            current = game.current_player(state)
            if current == model_players[i]:
                model_turn_indices.append(i)
            else:
                random_turn_indices.append(i)
        
        # Batch MCTS for all model moves
        if model_turn_indices:
            model_states = np.stack([states[i] for i in model_turn_indices])
            if is_bayesian:
                policies = mcts.search(model_states, model)
            else:
                policies = mcts.search(model_states, model, add_noise=False)
            
            for idx, game_idx in enumerate(model_turn_indices):
                action = sample_action(policies[idx], temperature=0)
                states[game_idx] = game.next_state(states[game_idx], action)
        
        # Random moves (no batching needed)
        for game_idx in random_turn_indices:
            legal = game.legal_actions(states[game_idx])
            action = np.random.choice(legal)
            states[game_idx] = game.next_state(states[game_idx], action)
    
    wins = sum(1 for r in results if r > 0)
    return wins / num_games


print("Training infrastructure ready!")

In [None]:
import copy

def train_model(game, model, mcts, num_iterations=10, games_per_iter=50, 
                training_steps=50, batch_size=32, parallel_games=32,
                eval_interval=5, is_bayesian=False, verbose=True, save_checkpoints=True):
    """
    Train a model using self-play.
    
    Returns:
        dict with training metrics and checkpoints
    """
    buffer = ReplayBuffer(100000)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    
    metrics = {
        'self_play_times': [],
        'train_times': [],
        'losses': [],
        'win_rates': [],
        'checkpoints': {},  # iteration -> model state dict
    }
    
    for iteration in range(num_iterations):
        if verbose:
            print(f"  Iteration {iteration + 1}/{num_iterations}", end=" ")
        
        # Self-play
        sp_start = time.perf_counter()
        examples = self_play_games(
            game, model, mcts,
            num_games=games_per_iter,
            parallel_games=parallel_games,
            is_bayesian=is_bayesian
        )
        sp_time = time.perf_counter() - sp_start
        metrics['self_play_times'].append(sp_time)
        
        # Add to buffer
        for state, policy, value in examples:
            buffer.push(state, policy, value)
        
        # Training
        if len(buffer) >= batch_size:
            train_start = time.perf_counter()
            total_loss = 0
            
            for _ in range(training_steps):
                states, policies, values = buffer.sample(batch_size)
                
                state_tensors = torch.stack([game.to_tensor(s) for s in states])
                policy_tensors = torch.from_numpy(policies).float()
                value_tensors = torch.from_numpy(values).float()
                action_masks = torch.stack([
                    torch.from_numpy(game.legal_actions_mask(s)) for s in states
                ]).float()
                
                loss, _, _ = train_step(
                    model, optimizer,
                    state_tensors, policy_tensors, value_tensors,
                    action_masks, device
                )
                total_loss += loss
            
            train_time = time.perf_counter() - train_start
            avg_loss = total_loss / training_steps
            metrics['train_times'].append(train_time)
            metrics['losses'].append(avg_loss)
            
            # Clear MCTS cache
            if hasattr(mcts, 'clear_cache'):
                mcts.clear_cache()
        
        # Evaluate at specified interval
        if (iteration + 1) % eval_interval == 0 or iteration == num_iterations - 1:
            win_rate = evaluate_vs_random(game, model, mcts, num_games=50, is_bayesian=is_bayesian)
            metrics['win_rates'].append((iteration + 1, win_rate))
            
            # Save checkpoint
            if save_checkpoints:
                metrics['checkpoints'][iteration + 1] = copy.deepcopy(model.state_dict())
            
            if verbose:
                print(f"| WR: {win_rate:.0%}", end="")
        
        if verbose:
            print(f" | SP: {sp_time:.1f}s")
    
    return metrics


def load_checkpoint(model, state_dict):
    """Load a checkpoint into a model (returns a new model copy)."""
    new_model = copy.deepcopy(model)
    new_model.load_state_dict(state_dict)
    new_model.eval()
    return new_model


print("Training function ready!")

## 3. Setup Game

In [None]:
# Use Connect4 for more meaningful comparison
game = get_game('connect4')
print(f"Game: Connect4")
print(f"Board size: {game.config.board_height}x{game.config.board_width}")
print(f"Action size: {game.config.action_size}")

## 4. Training Comparison: PUCT vs TTTS at Different Batch Sizes

Train models with different `leaves_per_batch` settings and compare:
1. Training speed (games/second)
2. Final model quality (win rate vs random)

In [None]:
# Training configuration
NUM_ITERATIONS = 100
GAMES_PER_ITER = 100
TRAINING_STEPS = 100
MCTS_SIMS = 500
PARALLEL_GAMES = 64
BATCH_SIZE = 128
EVAL_INTERVAL = 10

print(f"Training config:")
print(f"  Iterations: {NUM_ITERATIONS}")
print(f"  Games/iter: {GAMES_PER_ITER}")
print(f"  Training steps: {TRAINING_STEPS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  MCTS sims: {MCTS_SIMS}")
print(f"  Parallel games: {PARALLEL_GAMES}")
print(f"  Eval interval: {EVAL_INTERVAL}")

In [None]:
def create_fresh_model(game):
    """Create a fresh model with random weights."""
    model_config = get_model_config(game.config, n_layer=4)
    model = AlphaZeroTransformer(model_config).to(device)
    return model

# Store results
all_results = {}

### 4.1 Train PUCT Models

In [None]:
print("=" * 60)
print("PUCT Training at Different Batch Sizes")
print("=" * 60)

puct_models = {}
puct_metrics = {}

for mult in [1, 4, 8]:
    print(f"\nTraining PUCT with {mult}x leaves_per_batch...")
    
    # Fresh model
    model = create_fresh_model(game)
    
    # Create MCTS with specified batch size
    # Use game-appropriate Dirichlet alpha: ~10/action_size
    leaves_per_batch = PARALLEL_GAMES * mult
    config = MCTSConfig(
        num_simulations=MCTS_SIMS,
        dirichlet_alpha=10.0 / game.config.action_size,  # ~1.4 for Connect4
        c_puct=1.5,
    )
    mcts = RustBatchedMCTS(game, config, leaves_per_batch=leaves_per_batch)
    
    # Train
    np.random.seed(42)
    torch.manual_seed(42)
    
    metrics = train_model(
        game, model, mcts,
        num_iterations=NUM_ITERATIONS,
        games_per_iter=GAMES_PER_ITER,
        training_steps=TRAINING_STEPS,
        batch_size=BATCH_SIZE,
        parallel_games=PARALLEL_GAMES,
        eval_interval=EVAL_INTERVAL,
        is_bayesian=False
    )
    
    puct_models[mult] = model
    puct_metrics[mult] = metrics
    
    avg_sp_time = np.mean(metrics['self_play_times'])
    final_wr = metrics['win_rates'][-1][1] if metrics['win_rates'] else 0
    print(f"  Avg self-play time: {avg_sp_time:.2f}s | Final win rate: {final_wr:.0%}")

all_results['puct'] = {'models': puct_models, 'metrics': puct_metrics}

### 4.2 Train TTTS Models

In [None]:
print("=" * 60)
print("TTTS Training at Different Batch Sizes")
print("=" * 60)

ttts_models = {}
ttts_metrics = {}

for mult in [1, 4, 8]:
    print(f"\nTraining TTTS with {mult}x leaves_per_batch...")
    
    # Fresh model
    model = create_fresh_model(game)
    
    # Create MCTS with specified batch size
    leaves_per_batch = PARALLEL_GAMES * mult
    config = BayesianMCTSConfig(num_simulations=MCTS_SIMS)
    mcts = RustBayesianMCTS(game, config, leaves_per_batch=leaves_per_batch)
    
    # Train
    np.random.seed(42)
    torch.manual_seed(42)
    
    metrics = train_model(
        game, model, mcts,
        num_iterations=NUM_ITERATIONS,
        games_per_iter=GAMES_PER_ITER,
        training_steps=TRAINING_STEPS,
        batch_size=BATCH_SIZE,
        parallel_games=PARALLEL_GAMES,
        eval_interval=EVAL_INTERVAL,
        is_bayesian=True
    )
    
    ttts_models[mult] = model
    ttts_metrics[mult] = metrics
    
    avg_sp_time = np.mean(metrics['self_play_times'])
    final_wr = metrics['win_rates'][-1][1] if metrics['win_rates'] else 0
    print(f"  Avg self-play time: {avg_sp_time:.2f}s | Final win rate: {final_wr:.0%}")

all_results['ttts'] = {'models': ttts_models, 'metrics': ttts_metrics}

## 5. Compare Trained Models in Arena

In [None]:
@torch.inference_mode()
def run_arena(game, model1, mcts1, model2, mcts2, num_games, 
              is_bayesian1=False, is_bayesian2=False,
              player1_name="P1", player2_name="P2", verbose=True):
    """
    Run arena matches between two models.
    All games run in parallel with batched MCTS.
    """
    # Initialize all games
    states = [game.initial_state() for _ in range(num_games)]
    p1_colors = [1 if i % 2 == 0 else -1 for i in range(num_games)]
    results = [None] * num_games  # None = ongoing
    
    while any(r is None for r in results):
        # Separate games by whose turn it is
        p1_turn_indices = []
        p2_turn_indices = []
        
        for i, (state, result) in enumerate(zip(states, results)):
            if result is not None:
                continue
            if game.is_terminal(state):
                # Game just ended - record result from P1's perspective
                reward = game.terminal_reward(state)
                final_player = game.current_player(state)
                p1_result = reward if final_player == p1_colors[i] else -reward
                results[i] = p1_result
                continue
            
            current = game.current_player(state)
            if current == p1_colors[i]:
                p1_turn_indices.append(i)
            else:
                p2_turn_indices.append(i)
        
        # Batch MCTS for player 1's moves
        if p1_turn_indices:
            p1_states = np.stack([states[i] for i in p1_turn_indices])
            if is_bayesian1:
                policies = mcts1.search(p1_states, model1)
            else:
                policies = mcts1.search(p1_states, model1, add_noise=False)
            
            for idx, game_idx in enumerate(p1_turn_indices):
                action = sample_action(policies[idx], temperature=0)
                states[game_idx] = game.next_state(states[game_idx], action)
        
        # Batch MCTS for player 2's moves
        if p2_turn_indices:
            p2_states = np.stack([states[i] for i in p2_turn_indices])
            if is_bayesian2:
                policies = mcts2.search(p2_states, model2)
            else:
                policies = mcts2.search(p2_states, model2, add_noise=False)
            
            for idx, game_idx in enumerate(p2_turn_indices):
                action = sample_action(policies[idx], temperature=0)
                states[game_idx] = game.next_state(states[game_idx], action)
    
    # Count results
    wins = sum(1 for r in results if r > 0)
    draws = sum(1 for r in results if r == 0)
    losses = sum(1 for r in results if r < 0)
    
    if verbose:
        print(f"  {num_games} games: {player1_name} {wins}W/{draws}D/{losses}L")
    
    return wins, draws, losses

### 5.1 PUCT: Compare 1x vs 8x trained models

In [None]:
print("=" * 60)
print("PUCT Arena: 1x-trained vs 8x-trained models")
print("=" * 60)
print("\nBoth use 1x leaves_per_batch for fair evaluation.\n")

# Use same MCTS config for fair comparison (with tuned hyperparams)
eval_config = MCTSConfig(
    num_simulations=MCTS_SIMS,
    dirichlet_alpha=10.0 / game.config.action_size,
    c_puct=1.5,
)
eval_mcts = RustBatchedMCTS(game, eval_config, leaves_per_batch=PARALLEL_GAMES)

np.random.seed(42)
wins, draws, losses = run_arena(
    game,
    puct_models[8], eval_mcts,
    puct_models[1], eval_mcts,
    num_games=100,
    is_bayesian1=False, is_bayesian2=False,
    player1_name="8x-trained", player2_name="1x-trained"
)

decisive = wins + losses
wr_8x = wins / decisive if decisive > 0 else 0.5

print(f"\nResults (8x-trained perspective): {wins}W / {draws}D / {losses}L")
print(f"8x-trained decisive win rate: {wr_8x:.1%}")

if decisive > 0:
    p_value = stats.binomtest(wins, decisive, 0.5).pvalue
    print(f"p-value: {p_value:.4f}")

### 5.2 TTTS: Compare 1x vs 8x trained models

In [None]:
print("=" * 60)
print("TTTS Arena: 1x-trained vs 8x-trained models")
print("=" * 60)
print("\nBoth use 1x leaves_per_batch for fair evaluation.\n")

# Use same MCTS config for fair comparison
eval_config = BayesianMCTSConfig(num_simulations=MCTS_SIMS)
eval_mcts = RustBayesianMCTS(game, eval_config, leaves_per_batch=PARALLEL_GAMES)

np.random.seed(42)
wins, draws, losses = run_arena(
    game,
    ttts_models[8], eval_mcts,
    ttts_models[1], eval_mcts,
    num_games=100,
    is_bayesian1=True, is_bayesian2=True,
    player1_name="8x-trained", player2_name="1x-trained"
)

decisive = wins + losses
wr_8x = wins / decisive if decisive > 0 else 0.5

print(f"\nResults (8x-trained perspective): {wins}W / {draws}D / {losses}L")
print(f"8x-trained decisive win rate: {wr_8x:.1%}")

if decisive > 0:
    p_value = stats.binomtest(wins, decisive, 0.5).pvalue
    print(f"p-value: {p_value:.4f}")

### 5.3 Cross-algorithm: Best PUCT vs Best TTTS

In [None]:
print("=" * 60)
print("PUCT vs TTTS: Best models from each algorithm")
print("=" * 60)

# Find best performing model from each
best_puct_mult = max(puct_metrics.keys(), 
                     key=lambda m: puct_metrics[m]['win_rates'][-1][1] if puct_metrics[m]['win_rates'] else 0)
best_ttts_mult = max(ttts_metrics.keys(),
                     key=lambda m: ttts_metrics[m]['win_rates'][-1][1] if ttts_metrics[m]['win_rates'] else 0)

print(f"\nBest PUCT: {best_puct_mult}x (WR: {puct_metrics[best_puct_mult]['win_rates'][-1][1]:.0%})")
print(f"Best TTTS: {best_ttts_mult}x (WR: {ttts_metrics[best_ttts_mult]['win_rates'][-1][1]:.0%})")
print("\nPlaying 100 games...\n")

puct_eval = RustBatchedMCTS(game, MCTSConfig(
    num_simulations=MCTS_SIMS,
    dirichlet_alpha=10.0 / game.config.action_size,
    c_puct=1.5,
), leaves_per_batch=PARALLEL_GAMES)
ttts_eval = RustBayesianMCTS(game, BayesianMCTSConfig(num_simulations=MCTS_SIMS), leaves_per_batch=PARALLEL_GAMES)

np.random.seed(42)
wins, draws, losses = run_arena(
    game,
    ttts_models[best_ttts_mult], ttts_eval,
    puct_models[best_puct_mult], puct_eval,
    num_games=100,
    is_bayesian1=True, is_bayesian2=False,
    player1_name="TTTS", player2_name="PUCT"
)

decisive = wins + losses
ttts_wr = wins / decisive if decisive > 0 else 0.5

print(f"\nResults (TTTS perspective): {wins}W / {draws}D / {losses}L")
print(f"TTTS decisive win rate: {ttts_wr:.1%}")

if decisive > 0:
    p_value = stats.binomtest(wins, decisive, 0.5).pvalue
    print(f"p-value: {p_value:.4f}")

### 5.4 Checkpoint Progression: Early vs Late Training

In [None]:
print("=" * 60)
print("Checkpoint Progression Arena")
print("=" * 60)
print("\nCompare early checkpoints vs final models to verify learning.\n")

# Use the best batch multiplier for each algorithm
best_puct_mult = max(puct_metrics.keys(), 
                     key=lambda m: puct_metrics[m]['win_rates'][-1][1] if puct_metrics[m]['win_rates'] else 0)
best_ttts_mult = max(ttts_metrics.keys(),
                     key=lambda m: ttts_metrics[m]['win_rates'][-1][1] if ttts_metrics[m]['win_rates'] else 0)

# Get checkpoints
puct_checkpoints = puct_metrics[best_puct_mult]['checkpoints']
ttts_checkpoints = ttts_metrics[best_ttts_mult]['checkpoints']

checkpoint_iters = sorted(puct_checkpoints.keys())
print(f"Available checkpoints: {checkpoint_iters}")

if len(checkpoint_iters) >= 2:
    early_iter = checkpoint_iters[0]
    final_iter = checkpoint_iters[-1]
    
    # Create models from checkpoints
    template_model = create_fresh_model(game)
    
    # PUCT: Early vs Final
    print(f"\n--- PUCT: Iteration {early_iter} vs Iteration {final_iter} ---")
    puct_early = load_checkpoint(template_model, puct_checkpoints[early_iter])
    puct_final = load_checkpoint(template_model, puct_checkpoints[final_iter])
    
    puct_eval = RustBatchedMCTS(game, MCTSConfig(
        num_simulations=MCTS_SIMS,
        dirichlet_alpha=10.0 / game.config.action_size,
        c_puct=1.5,
    ), leaves_per_batch=PARALLEL_GAMES)
    
    np.random.seed(42)
    wins, draws, losses = run_arena(
        game,
        puct_final, puct_eval,
        puct_early, puct_eval,
        num_games=100,
        is_bayesian1=False, is_bayesian2=False,
        player1_name=f"iter-{final_iter}", player2_name=f"iter-{early_iter}"
    )
    
    decisive = wins + losses
    final_wr = wins / decisive if decisive > 0 else 0.5
    print(f"  Final model decisive win rate: {final_wr:.1%}")
    
    # TTTS: Early vs Final
    print(f"\n--- TTTS: Iteration {early_iter} vs Iteration {final_iter} ---")
    ttts_early = load_checkpoint(template_model, ttts_checkpoints[early_iter])
    ttts_final = load_checkpoint(template_model, ttts_checkpoints[final_iter])
    
    ttts_eval = RustBayesianMCTS(game, BayesianMCTSConfig(num_simulations=MCTS_SIMS), leaves_per_batch=PARALLEL_GAMES)
    
    np.random.seed(42)
    wins, draws, losses = run_arena(
        game,
        ttts_final, ttts_eval,
        ttts_early, ttts_eval,
        num_games=100,
        is_bayesian1=True, is_bayesian2=True,
        player1_name=f"iter-{final_iter}", player2_name=f"iter-{early_iter}"
    )
    
    decisive = wins + losses
    final_wr = wins / decisive if decisive > 0 else 0.5
    print(f"  Final model decisive win rate: {final_wr:.1%}")
    
    # Full checkpoint ladder for best algorithm
    print(f"\n--- TTTS Checkpoint Ladder (best performer) ---")
    print("Each checkpoint plays against all earlier checkpoints:\n")
    
    ladder_results = {}
    for i, later_iter in enumerate(checkpoint_iters[1:], 1):
        for earlier_iter in checkpoint_iters[:i]:
            later_model = load_checkpoint(template_model, ttts_checkpoints[later_iter])
            earlier_model = load_checkpoint(template_model, ttts_checkpoints[earlier_iter])
            
            wins, draws, losses = run_arena(
                game,
                later_model, ttts_eval,
                earlier_model, ttts_eval,
                num_games=50,
                is_bayesian1=True, is_bayesian2=True,
                player1_name=f"iter-{later_iter}", player2_name=f"iter-{earlier_iter}",
                verbose=False
            )
            
            decisive = wins + losses
            wr = wins / decisive if decisive > 0 else 0.5
            ladder_results[(later_iter, earlier_iter)] = (wins, draws, losses, wr)
            print(f"  iter-{later_iter} vs iter-{earlier_iter}: {wins}W/{draws}D/{losses}L (WR: {wr:.0%})")

else:
    print("Not enough checkpoints to compare.")

## 6. Summary

In [None]:
print("=" * 70)
print("SUMMARY: Training with Different Batch Sizes")
print("=" * 70)

print("\n Training Speed (avg self-play time per iteration):")
print("-" * 50)
print(f"{'Multiplier':<12} {'PUCT (s)':<15} {'TTTS (s)':<15}")
for mult in [1, 4, 8]:
    puct_time = np.mean(puct_metrics[mult]['self_play_times'])
    ttts_time = np.mean(ttts_metrics[mult]['self_play_times'])
    print(f"{mult}x{'':<10} {puct_time:<15.2f} {ttts_time:<15.2f}")

print("\n Final Win Rate vs Random:")
print("-" * 50)
print(f"{'Multiplier':<12} {'PUCT':<15} {'TTTS':<15}")
for mult in [1, 4, 8]:
    puct_wr = puct_metrics[mult]['win_rates'][-1][1] if puct_metrics[mult]['win_rates'] else 0
    ttts_wr = ttts_metrics[mult]['win_rates'][-1][1] if ttts_metrics[mult]['win_rates'] else 0
    print(f"{mult}x{'':<10} {puct_wr:<15.0%} {ttts_wr:<15.0%}")

print("\n Conclusions:")
print("-" * 50)
print("- Higher batch multipliers provide faster training")
print("- Model quality should be similar across batch sizes")
print("- PUCT and TTTS show different sensitivity to batching")

## 7. Visualization

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Training speed comparison
ax1 = axes[0, 0]
mults = [1, 4, 8]
puct_times = [np.mean(puct_metrics[m]['self_play_times']) for m in mults]
ttts_times = [np.mean(ttts_metrics[m]['self_play_times']) for m in mults]

x = np.arange(len(mults))
width = 0.35
ax1.bar(x - width/2, puct_times, width, label='PUCT', color='#3498DB')
ax1.bar(x + width/2, ttts_times, width, label='TTTS', color='#E67E22')
ax1.set_xlabel('leaves_per_batch Multiplier')
ax1.set_ylabel('Self-play Time (s)')
ax1.set_title('Training Speed by Batch Size')
ax1.set_xticks(x)
ax1.set_xticklabels([f'{m}x' for m in mults])
ax1.legend()
ax1.grid(axis='y', alpha=0.3)

# Final win rates
ax2 = axes[0, 1]
puct_wrs = [puct_metrics[m]['win_rates'][-1][1] if puct_metrics[m]['win_rates'] else 0 for m in mults]
ttts_wrs = [ttts_metrics[m]['win_rates'][-1][1] if ttts_metrics[m]['win_rates'] else 0 for m in mults]

ax2.bar(x - width/2, puct_wrs, width, label='PUCT', color='#3498DB')
ax2.bar(x + width/2, ttts_wrs, width, label='TTTS', color='#E67E22')
ax2.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
ax2.set_xlabel('leaves_per_batch Multiplier')
ax2.set_ylabel('Win Rate vs Random')
ax2.set_title('Final Model Quality')
ax2.set_xticks(x)
ax2.set_xticklabels([f'{m}x' for m in mults])
ax2.set_ylim(0, 1)
ax2.legend()
ax2.grid(axis='y', alpha=0.3)

# Learning curves - PUCT
ax3 = axes[1, 0]
for mult in mults:
    iters = [wr[0] for wr in puct_metrics[mult]['win_rates']]
    wrs = [wr[1] for wr in puct_metrics[mult]['win_rates']]
    ax3.plot(iters, wrs, 'o-', label=f'{mult}x', linewidth=2)
ax3.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
ax3.set_xlabel('Iteration')
ax3.set_ylabel('Win Rate vs Random')
ax3.set_title('PUCT Learning Curves')
ax3.legend()
ax3.grid(alpha=0.3)
ax3.set_ylim(0, 1)

# Learning curves - TTTS
ax4 = axes[1, 1]
for mult in mults:
    iters = [wr[0] for wr in ttts_metrics[mult]['win_rates']]
    wrs = [wr[1] for wr in ttts_metrics[mult]['win_rates']]
    ax4.plot(iters, wrs, 'o-', label=f'{mult}x', linewidth=2)
ax4.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
ax4.set_xlabel('Iteration')
ax4.set_ylabel('Win Rate vs Random')
ax4.set_title('TTTS Learning Curves')
ax4.legend()
ax4.grid(alpha=0.3)
ax4.set_ylim(0, 1)

plt.tight_layout()
plt.show()