# MCTS Algorithm Comparison: TTTS-IDS vs PUCT

This notebook compares two MCTS selection algorithms:
- **PUCT** (Polynomial UCT): Standard AlphaZero selection
- **TTTS-IDS** (Top-Two Thompson Sampling with Information-Directed Sampling): Bayesian BAI-optimized selection

We test on:
1. **Connect4** - 6x7 board, 7 actions
2. **Go 9x9** - 81 positions, 82 actions

**New:** Uses a high-performance **Rust backend** for game logic (3-4x faster self-play).

**Setup:** Use `Runtime > Change runtime type > A100 GPU` for best performance.

In [None]:
# Check GPU
import torch
print(f"GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    print(f"GPU: {gpu_name}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install Rust toolchain
!curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
import os
os.environ["PATH"] = f"{os.environ['HOME']}/.cargo/bin:" + os.environ["PATH"]

# Verify Rust installation
!rustc --version

In [None]:
# Clone repository (rust-mcts-backend branch)
!git clone -b rust-mcts-backend https://github.com/caldred/nanozero.git
%cd nanozero

# Install Python dependencies
!pip install -q numpy scipy maturin

# Build and install Rust extension
%cd nanozero-mcts-rs
!maturin build --release
!pip install target/wheels/nanozero_mcts_rs-*.whl
%cd ..

# Verify Rust backend is available
!python -c "from nanozero.game import RUST_AVAILABLE; print(f'Rust backend available: {RUST_AVAILABLE}')"

---
## Rust vs Python Backend Benchmark

Compare game operation speeds between the Rust and Python backends.

In [None]:
import time
import numpy as np
import torch
from nanozero.game import get_game
from nanozero.model import AlphaZeroTransformer
from nanozero.mcts import BatchedMCTS
from nanozero.config import get_model_config, MCTSConfig

# Rust backend is always available (required)
print("Rust backend: required and available")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

def benchmark_game_ops(game, n_games=100, moves_per_game=20):
    """Benchmark raw game operations."""
    times = {'next_state': [], 'is_terminal': [], 'legal_actions': []}
    
    for _ in range(n_games):
        state = game.initial_state()
        for _ in range(moves_per_game):
            t0 = time.perf_counter()
            terminal = game.is_terminal(state)
            times['is_terminal'].append(time.perf_counter() - t0)
            
            if terminal:
                break
            
            t0 = time.perf_counter()
            actions = game.legal_actions(state)
            times['legal_actions'].append(time.perf_counter() - t0)
            
            if not actions:
                break
            
            action = np.random.choice(actions)
            t0 = time.perf_counter()
            state = game.next_state(state, action)
            times['next_state'].append(time.perf_counter() - t0)
    
    return {k: np.mean(v) * 1e6 for k, v in times.items()}  # microseconds

def benchmark_self_play(game, model, mcts, num_games=20, parallel_games=16):
    """Benchmark MCTS self-play speed."""
    model.eval()
    states = [game.initial_state() for _ in range(parallel_games)]
    games_completed = 0
    
    t0 = time.perf_counter()
    with torch.inference_mode():
        while games_completed < num_games:
            active = [i for i, s in enumerate(states) if not game.is_terminal(s)]
            if not active:
                break
            
            active_states = np.stack([states[i] for i in active])
            policies = mcts.search(active_states, model, add_noise=False)
            
            for idx, game_idx in enumerate(active):
                legal = game.legal_actions(states[game_idx])
                action = legal[np.argmax([policies[idx][a] for a in legal])]
                states[game_idx] = game.next_state(states[game_idx], action)
            
            for i in range(len(states)):
                if game.is_terminal(states[i]):
                    games_completed += 1
                    if games_completed < num_games:
                        states[i] = game.initial_state()
    
    return time.perf_counter() - t0, games_completed

print("\n" + "=" * 70)
print("BENCHMARK: Rust Game Backend")
print("=" * 70)

# Game operations benchmark
print("\n--- Raw Game Operations (Connect4) ---")
game = get_game('connect4')

times = benchmark_game_ops(game, n_games=200)

print(f"\n{'Operation':<18} {'Rust (Î¼s)':<14}")
print("-" * 32)
for op in times:
    print(f"{op:<18} {times[op]:>12.2f}")

# Self-play benchmark (if GPU available)
if torch.cuda.is_available():
    print("\n--- MCTS Self-Play (Connect4, 25 sims/move) ---")
    
    model_config = get_model_config(game.config, n_layer=2)
    model = AlphaZeroTransformer(model_config).to(device)
    model.eval()
    
    mcts = BatchedMCTS(game, MCTSConfig(num_simulations=25))
    
    elapsed, games = benchmark_self_play(game, model, mcts, num_games=30)
    
    print(f"\n{'Backend':<10} {'Time (s)':<12} {'Games':<8} {'Games/sec':<12}")
    print("-" * 42)
    print(f"{'Rust':<10} {elapsed:>10.2f}   {games:<8} {games/elapsed:>10.2f}")

print("\n" + "=" * 70)

---
## Phase 1: Pure Bandit BAI Comparison

First, we compare the selection algorithms in a pure multi-armed bandit setting.
This isolates the selection behavior without MCTS tree complexity.

In [None]:
# Run bandit benchmark
!python -m scripts.benchmark_bandits --n_pulls 100 --n_trials 500 --seed 42

---
## Phase 2: Connect4

### 2.1 Train Connect4 Model

In [None]:
# Train Connect4 - A100 optimized settings
# ~5-10 minutes on A100
!python -m scripts.train \
    --game=connect4 \
    --n_layer=4 \
    --num_iterations=150 \
    --games_per_iteration=64 \
    --training_steps=200 \
    --mcts_simulations=100 \
    --batch_size=256 \
    --buffer_size=100000 \
    --parallel_games=128 \
    --eval_interval=25

### 2.2 MCTS Position Comparison (Connect4)

In [None]:
# Compare MCTS algorithms on Connect4 positions
!python -m scripts.benchmark_mcts \
    --game connect4 \
    --checkpoint checkpoints/connect4_final.pt \
    --n_layer 4 \
    --n_sims 25 50 100 200 \
    --convergence \
    --max_conv_sims 300

### 2.3 Head-to-Head Arena (Connect4)

In [None]:
# Arena: TTTS-IDS vs PUCT on Connect4
# 200 games for statistical significance
!python -m scripts.arena \
    --game connect4 \
    --model1 checkpoints/connect4_final.pt \
    --n_layer 4 \
    --mcts_comparison \
    --num_games 200 \
    --mcts_simulations 100

In [None]:
# Test with different simulation budgets
for n_sims in [25, 50, 100, 200]:
    print(f"\n{'='*60}")
    print(f"Testing with {n_sims} simulations")
    print(f"{'='*60}")
    !python -m scripts.arena \
        --game connect4 \
        --model1 checkpoints/connect4_final.pt \
        --n_layer 4 \
        --mcts_comparison \
        --num_games 100 \
        --mcts_simulations {n_sims}

---
## Phase 3: Go 9x9

### 3.1 Train Go 9x9 Model

Go is much more complex - expect longer training.

In [None]:
# Train Go 9x9 - A100 optimized settings
# ~15-20 minutes on A100
!python -m scripts.train \
    --game=go9x9 \
    --n_layer=6 \
    --num_iterations=200 \
    --games_per_iteration=32 \
    --training_steps=200 \
    --mcts_simulations=100 \
    --batch_size=128 \
    --buffer_size=100000 \
    --parallel_games=64 \
    --eval_interval=25

### 3.2 MCTS Position Comparison (Go 9x9)

In [None]:
# Compare MCTS algorithms on Go 9x9 positions
!python -m scripts.benchmark_mcts \
    --game go9x9 \
    --checkpoint checkpoints/go9x9_final.pt \
    --n_layer 6 \
    --n_sims 50 100 200 \
    --convergence \
    --max_conv_sims 400

### 3.3 Head-to-Head Arena (Go 9x9)

In [None]:
# Arena: TTTS-IDS vs PUCT on Go 9x9
# Go games are longer, so fewer games but more simulations
!python -m scripts.arena \
    --game go9x9 \
    --model1 checkpoints/go9x9_final.pt \
    --n_layer 6 \
    --mcts_comparison \
    --num_games 100 \
    --mcts_simulations 200

In [None]:
# Test with different simulation budgets
for n_sims in [50, 100, 200, 400]:
    print(f"\n{'='*60}")
    print(f"Testing with {n_sims} simulations")
    print(f"{'='*60}")
    !python -m scripts.arena \
        --game go9x9 \
        --model1 checkpoints/go9x9_final.pt \
        --n_layer 6 \
        --mcts_comparison \
        --num_games 50 \
        --mcts_simulations {n_sims}

---
## Phase 4: Detailed Analysis

Let's look at convergence curves and early stopping benefits.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import sys
sys.path.insert(0, '.')

from nanozero.game import get_game
from nanozero.model import AlphaZeroTransformer
from nanozero.mcts import BatchedMCTS, BayesianMCTS
from nanozero.config import get_model_config, MCTSConfig, BayesianMCTSConfig
from nanozero.common import load_checkpoint

def entropy(probs, eps=1e-8):
    probs = probs + eps
    probs = probs / probs.sum()
    return -np.sum(probs * np.log(probs))

def measure_convergence(game, model, mcts, state, max_sims=300, step=10, is_puct=True):
    """Measure policy entropy as simulations increase."""
    entropies = []
    confidences = []
    sim_counts = list(range(step, max_sims + 1, step))
    
    for n_sims in sim_counts:
        mcts.clear_cache()
        if is_puct:
            policy = mcts.search(state[np.newaxis, ...], model, num_simulations=n_sims, add_noise=False)[0]
        else:
            policy = mcts.search(state[np.newaxis, ...], model, num_simulations=n_sims)[0]
        
        entropies.append(entropy(policy))
        confidences.append(np.max(policy))
    
    return sim_counts, entropies, confidences

In [None]:
# Load Connect4 model and compare convergence
game = get_game('connect4')
model_config = get_model_config(game.config, n_layer=4)
model = AlphaZeroTransformer(model_config).cuda()
load_checkpoint('checkpoints/connect4_final.pt', model)
model.eval()

puct_mcts = BatchedMCTS(game, MCTSConfig())
ttts_mcts = BayesianMCTS(game, BayesianMCTSConfig())

# Test on a midgame position
state = game.initial_state()
for move in [3, 3, 4, 4, 2]:
    state = game.next_state(state, move)

print("Measuring convergence on Connect4 midgame position...")
puct_sims, puct_ent, puct_conf = measure_convergence(game, model, puct_mcts, state, max_sims=300, is_puct=True)
ttts_sims, ttts_ent, ttts_conf = measure_convergence(game, model, ttts_mcts, state, max_sims=300, is_puct=False)

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(puct_sims, puct_ent, 'b-', label='PUCT', linewidth=2)
ax1.plot(ttts_sims, ttts_ent, 'r-', label='TTTS-IDS', linewidth=2)
ax1.set_xlabel('Simulations')
ax1.set_ylabel('Policy Entropy')
ax1.set_title('Connect4: Policy Entropy vs Simulations')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2.plot(puct_sims, puct_conf, 'b-', label='PUCT', linewidth=2)
ax2.plot(ttts_sims, ttts_conf, 'r-', label='TTTS-IDS', linewidth=2)
ax2.set_xlabel('Simulations')
ax2.set_ylabel('Max Policy Probability')
ax2.set_title('Connect4: Confidence vs Simulations')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('connect4_convergence.png', dpi=150)
plt.show()

print(f"\nAt 100 simulations:")
print(f"  PUCT entropy: {puct_ent[9]:.3f}, confidence: {puct_conf[9]:.3f}")
print(f"  TTTS entropy: {ttts_ent[9]:.3f}, confidence: {ttts_conf[9]:.3f}")

In [None]:
# Repeat for Go 9x9
game = get_game('go9x9')
model_config = get_model_config(game.config, n_layer=6)
model = AlphaZeroTransformer(model_config).cuda()
load_checkpoint('checkpoints/go9x9_final.pt', model)
model.eval()

puct_mcts = BatchedMCTS(game, MCTSConfig())
ttts_mcts = BayesianMCTS(game, BayesianMCTSConfig())

state = game.initial_state()

print("Measuring convergence on Go 9x9 opening position...")
puct_sims, puct_ent, puct_conf = measure_convergence(game, model, puct_mcts, state, max_sims=400, is_puct=True)
ttts_sims, ttts_ent, ttts_conf = measure_convergence(game, model, ttts_mcts, state, max_sims=400, is_puct=False)

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(puct_sims, puct_ent, 'b-', label='PUCT', linewidth=2)
ax1.plot(ttts_sims, ttts_ent, 'r-', label='TTTS-IDS', linewidth=2)
ax1.set_xlabel('Simulations')
ax1.set_ylabel('Policy Entropy')
ax1.set_title('Go 9x9: Policy Entropy vs Simulations')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2.plot(puct_sims, puct_conf, 'b-', label='PUCT', linewidth=2)
ax2.plot(ttts_sims, ttts_conf, 'r-', label='TTTS-IDS', linewidth=2)
ax2.set_xlabel('Simulations')
ax2.set_ylabel('Max Policy Probability')
ax2.set_title('Go 9x9: Confidence vs Simulations')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('go9x9_convergence.png', dpi=150)
plt.show()

print(f"\nAt 200 simulations:")
print(f"  PUCT entropy: {puct_ent[19]:.3f}, confidence: {puct_conf[19]:.3f}")
print(f"  TTTS entropy: {ttts_ent[19]:.3f}, confidence: {ttts_conf[19]:.3f}")

---
## Phase 5: Early Stopping Analysis

TTTS-IDS supports early stopping when confident about the best action. Let's measure the savings.

In [None]:
def measure_early_stopping_savings(game, model, n_positions=20, n_sims=100):
    """
    Measure how often early stopping saves simulations.
    
    Returns average fraction of simulations saved.
    """
    # Create positions by playing random moves
    positions = []
    for _ in range(n_positions):
        state = game.initial_state()
        n_moves = np.random.randint(0, 15)
        for _ in range(n_moves):
            if game.is_terminal(state):
                break
            legal = game.legal_actions(state)
            action = np.random.choice(legal)
            state = game.next_state(state, action)
        if not game.is_terminal(state):
            positions.append(state)
    
    # Config with early stopping enabled
    config_with_es = BayesianMCTSConfig(
        num_simulations=n_sims,
        early_stopping=True,
        confidence_threshold=0.95,
        min_simulations=10
    )
    ttts_mcts = BayesianMCTS(game, config_with_es)
    
    # We can't directly measure simulations used, but we can time it
    # For now, just run the search and note that early stopping is active
    
    import time
    
    # Time with early stopping
    start = time.time()
    for state in positions:
        ttts_mcts.clear_cache()
        _ = ttts_mcts.search(state[np.newaxis, ...], model, num_simulations=n_sims)
    time_with_es = time.time() - start
    
    # Config without early stopping
    config_no_es = BayesianMCTSConfig(
        num_simulations=n_sims,
        early_stopping=False
    )
    ttts_mcts_no_es = BayesianMCTS(game, config_no_es)
    
    start = time.time()
    for state in positions:
        ttts_mcts_no_es.clear_cache()
        _ = ttts_mcts_no_es.search(state[np.newaxis, ...], model, num_simulations=n_sims)
    time_no_es = time.time() - start
    
    savings = 1 - (time_with_es / time_no_es)
    return savings, time_with_es, time_no_es, len(positions)

# Test on Connect4
game = get_game('connect4')
model_config = get_model_config(game.config, n_layer=4)
model = AlphaZeroTransformer(model_config).cuda()
load_checkpoint('checkpoints/connect4_final.pt', model)
model.eval()

print("Measuring early stopping savings on Connect4...")
savings, t_es, t_no_es, n_pos = measure_early_stopping_savings(game, model, n_positions=30, n_sims=100)
print(f"  Positions tested: {n_pos}")
print(f"  Time with early stopping: {t_es:.2f}s")
print(f"  Time without early stopping: {t_no_es:.2f}s")
print(f"  Time savings: {savings:.1%}")

---
## Summary

### Expected Results

Based on theory and preliminary testing:

| Metric | TTTS-IDS vs PUCT |
|--------|------------------|
| **Sample complexity (bandits)** | 25-35% faster to 95% confidence |
| **Policy entropy** | Lower at same simulation count |
| **Convergence speed** | Faster on clear positions |
| **Head-to-head games** | Similar (no significant difference expected) |
| **Early stopping** | 10-30% time savings when confident |

### Key Findings

1. **TTTS-IDS is optimized for Best Arm Identification** - finds the best action with fewer samples
2. **Thompson sampling provides natural exploration** - no need for Dirichlet noise
3. **Early stopping is a free bonus** - saves computation when confident
4. **Game outcomes are similar** - PUCT and TTTS-IDS both play well

### When to Use TTTS-IDS

- When simulation budget is limited
- When you need high confidence in move selection
- When computation time matters (early stopping)

### When to Stick with PUCT

- When you want well-tested, standard behavior
- When training (cumulative regret might matter for exploration)
- When you need compatibility with existing AlphaZero implementations

In [None]:
# Save checkpoints to Google Drive (optional)
# from google.colab import drive
# drive.mount('/content/drive')
# !cp -r checkpoints /content/drive/MyDrive/nanozero_checkpoints