# Rust MCTS Benchmark

This notebook benchmarks the Rust MCTS backend vs Python MCTS on a GPU.

With a powerful GPU, neural network inference is fast, so we can see the true
speedup from Rust tree operations.

## 1. Setup Environment

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Clone the repository
!git clone https://github.com/caldred/nanozero.git
%cd nanozero

In [None]:
# Install Rust toolchain
!curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
import os
os.environ['PATH'] = f"/root/.cargo/bin:{os.environ['PATH']}"
!rustc --version

In [None]:
# Install Python dependencies
!pip install -q torch numpy maturin

In [None]:
# Build Rust MCTS backend
%cd nanozero-mcts-rs
!maturin develop --release
%cd ..

In [None]:
# Verify installation
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

from nanozero_mcts_rs import RustBatchedMCTS
print("Rust MCTS backend loaded successfully!")

## 2. Setup Game and Model

In [None]:
import numpy as np
import torch
import time
from nanozero.game import get_game
from nanozero.model import AlphaZeroTransformer
from nanozero.mcts import RustBatchedMCTS, BatchedMCTS
from nanozero.config import get_model_config, MCTSConfig

# Use GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Create TicTacToe game and model
game = get_game('tictactoe')
model_config = get_model_config(game.config, n_layer=4)  # Larger model
model = AlphaZeroTransformer(model_config).to(device)
model.eval()

print(f"Game: TicTacToe")
print(f"Model parameters: {model.count_parameters():,}")
print(f"Device: {device}")

## 3. Benchmark: Varying Batch Size

In [None]:
def benchmark_mcts(game, model, batch_sizes, num_simulations=100, n_runs=5):
    """Benchmark Rust vs Python MCTS across different batch sizes."""
    results = []
    
    for batch_size in batch_sizes:
        config = MCTSConfig(num_simulations=num_simulations)
        
        rust_mcts = RustBatchedMCTS(game, config)
        python_mcts = BatchedMCTS(game, config, use_transposition_table=False)
        
        # Create batch of initial states
        state = game.initial_state()
        states = np.stack([state] * batch_size)
        
        # Warmup
        with torch.inference_mode():
            _ = rust_mcts.search(states, model, add_noise=False)
            _ = python_mcts.search(states, model, add_noise=False)
        
        # Synchronize GPU
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        
        # Benchmark Rust
        rust_times = []
        for _ in range(n_runs):
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            start = time.perf_counter()
            with torch.inference_mode():
                _ = rust_mcts.search(states, model, add_noise=False)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            rust_times.append(time.perf_counter() - start)
        
        # Benchmark Python
        python_times = []
        for _ in range(n_runs):
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            start = time.perf_counter()
            with torch.inference_mode():
                _ = python_mcts.search(states, model, add_noise=False)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            python_times.append(time.perf_counter() - start)
        
        rust_avg = np.mean(rust_times) * 1000
        python_avg = np.mean(python_times) * 1000
        speedup = python_avg / rust_avg
        
        results.append({
            'batch_size': batch_size,
            'rust_ms': rust_avg,
            'python_ms': python_avg,
            'speedup': speedup
        })
        
        print(f"Batch {batch_size:3d}: Rust={rust_avg:7.1f}ms, Python={python_avg:7.1f}ms, Speedup={speedup:.2f}x")
    
    return results

In [None]:
print("=" * 60)
print("Benchmark: Varying Batch Size (100 simulations)")
print("=" * 60)

batch_results = benchmark_mcts(
    game, model,
    batch_sizes=[8, 16, 32, 64, 128, 256],
    num_simulations=100
)

## 4. Benchmark: Varying Simulation Count

In [None]:
def benchmark_simulations(game, model, batch_size, sim_counts, n_runs=3):
    """Benchmark Rust vs Python MCTS across different simulation counts."""
    results = []
    
    for num_sims in sim_counts:
        config = MCTSConfig(num_simulations=num_sims)
        
        rust_mcts = RustBatchedMCTS(game, config)
        python_mcts = BatchedMCTS(game, config, use_transposition_table=False)
        
        state = game.initial_state()
        states = np.stack([state] * batch_size)
        
        # Warmup
        with torch.inference_mode():
            _ = rust_mcts.search(states, model, add_noise=False)
            _ = python_mcts.search(states, model, add_noise=False)
        
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        
        # Benchmark
        rust_times = []
        for _ in range(n_runs):
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            start = time.perf_counter()
            with torch.inference_mode():
                _ = rust_mcts.search(states, model, add_noise=False)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            rust_times.append(time.perf_counter() - start)
        
        python_times = []
        for _ in range(n_runs):
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            start = time.perf_counter()
            with torch.inference_mode():
                _ = python_mcts.search(states, model, add_noise=False)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            python_times.append(time.perf_counter() - start)
        
        rust_avg = np.mean(rust_times) * 1000
        python_avg = np.mean(python_times) * 1000
        speedup = python_avg / rust_avg
        
        results.append({
            'num_sims': num_sims,
            'rust_ms': rust_avg,
            'python_ms': python_avg,
            'speedup': speedup
        })
        
        print(f"{num_sims:4d} sims: Rust={rust_avg:7.1f}ms, Python={python_avg:7.1f}ms, Speedup={speedup:.2f}x")
    
    return results

In [None]:
print("=" * 60)
print("Benchmark: Varying Simulations (batch_size=64)")
print("=" * 60)

sim_results = benchmark_simulations(
    game, model,
    batch_size=64,
    sim_counts=[25, 50, 100, 200, 400, 800]
)

## 5. Benchmark: Connect4 (Larger Game)

In [None]:
# Create Connect4 game and model
game_c4 = get_game('connect4')
model_config_c4 = get_model_config(game_c4.config, n_layer=4)
model_c4 = AlphaZeroTransformer(model_config_c4).to(device)
model_c4.eval()

print(f"Game: Connect4")
print(f"Board size: {game_c4.config.board_size}")
print(f"Action size: {game_c4.config.action_size}")
print(f"Model parameters: {model_c4.count_parameters():,}")

In [None]:
print("=" * 60)
print("Benchmark: Connect4 - Varying Batch Size (100 sims)")
print("=" * 60)

c4_results = benchmark_mcts(
    game_c4, model_c4,
    batch_sizes=[8, 16, 32, 64, 128],
    num_simulations=100
)

## 6. Throughput Analysis

In [None]:
def measure_throughput(game, model, batch_size, num_simulations, duration_seconds=10):
    """Measure searches per second for both backends."""
    config = MCTSConfig(num_simulations=num_simulations)
    
    rust_mcts = RustBatchedMCTS(game, config)
    python_mcts = BatchedMCTS(game, config, use_transposition_table=False)
    
    state = game.initial_state()
    states = np.stack([state] * batch_size)
    
    # Warmup
    with torch.inference_mode():
        _ = rust_mcts.search(states, model, add_noise=False)
        _ = python_mcts.search(states, model, add_noise=False)
    
    # Measure Rust throughput
    rust_count = 0
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    start = time.perf_counter()
    while time.perf_counter() - start < duration_seconds:
        with torch.inference_mode():
            _ = rust_mcts.search(states, model, add_noise=False)
        rust_count += batch_size
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    rust_elapsed = time.perf_counter() - start
    rust_throughput = rust_count / rust_elapsed
    
    # Measure Python throughput
    python_count = 0
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    start = time.perf_counter()
    while time.perf_counter() - start < duration_seconds:
        with torch.inference_mode():
            _ = python_mcts.search(states, model, add_noise=False)
        python_count += batch_size
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    python_elapsed = time.perf_counter() - start
    python_throughput = python_count / python_elapsed
    
    return rust_throughput, python_throughput

In [None]:
print("=" * 60)
print("Throughput Test (searches/second)")
print("=" * 60)

for game_name, g, m in [('TicTacToe', game, model), ('Connect4', game_c4, model_c4)]:
    print(f"\n{game_name}:")
    for batch_size in [32, 64, 128]:
        rust_tp, python_tp = measure_throughput(g, m, batch_size, num_simulations=100, duration_seconds=5)
        print(f"  Batch {batch_size:3d}: Rust={rust_tp:6.1f}/s, Python={python_tp:6.1f}/s, Speedup={rust_tp/python_tp:.2f}x")

## 7. Visualization

In [None]:
import matplotlib.pyplot as plt

# Plot batch size results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Time comparison
ax1 = axes[0]
batch_sizes = [r['batch_size'] for r in batch_results]
rust_times = [r['rust_ms'] for r in batch_results]
python_times = [r['python_ms'] for r in batch_results]

x = np.arange(len(batch_sizes))
width = 0.35

ax1.bar(x - width/2, rust_times, width, label='Rust MCTS', color='#E67E22')
ax1.bar(x + width/2, python_times, width, label='Python MCTS', color='#3498DB')
ax1.set_xlabel('Batch Size')
ax1.set_ylabel('Time (ms)')
ax1.set_title('MCTS Search Time by Batch Size (100 sims)')
ax1.set_xticks(x)
ax1.set_xticklabels(batch_sizes)
ax1.legend()
ax1.grid(axis='y', alpha=0.3)

# Speedup
ax2 = axes[1]
speedups = [r['speedup'] for r in batch_results]
colors = ['#2ECC71' if s >= 1 else '#E74C3C' for s in speedups]
ax2.bar(x, speedups, color=colors)
ax2.axhline(y=1, color='black', linestyle='--', alpha=0.5)
ax2.set_xlabel('Batch Size')
ax2.set_ylabel('Speedup (Rust/Python)')
ax2.set_title('Rust MCTS Speedup vs Python')
ax2.set_xticks(x)
ax2.set_xticklabels(batch_sizes)
ax2.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Plot simulation count results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Time comparison
ax1 = axes[0]
sim_counts = [r['num_sims'] for r in sim_results]
rust_times = [r['rust_ms'] for r in sim_results]
python_times = [r['python_ms'] for r in sim_results]

ax1.plot(sim_counts, rust_times, 'o-', label='Rust MCTS', color='#E67E22', linewidth=2)
ax1.plot(sim_counts, python_times, 's-', label='Python MCTS', color='#3498DB', linewidth=2)
ax1.set_xlabel('Number of Simulations')
ax1.set_ylabel('Time (ms)')
ax1.set_title('MCTS Search Time by Simulation Count (batch=64)')
ax1.legend()
ax1.grid(alpha=0.3)

# Speedup
ax2 = axes[1]
speedups = [r['speedup'] for r in sim_results]
ax2.plot(sim_counts, speedups, 'o-', color='#9B59B6', linewidth=2)
ax2.axhline(y=1, color='black', linestyle='--', alpha=0.5)
ax2.set_xlabel('Number of Simulations')
ax2.set_ylabel('Speedup (Rust/Python)')
ax2.set_title('Rust MCTS Speedup vs Simulation Count')
ax2.grid(alpha=0.3)
ax2.fill_between(sim_counts, speedups, 1, where=[s >= 1 for s in speedups], 
                  alpha=0.3, color='#2ECC71', label='Rust faster')
ax2.fill_between(sim_counts, speedups, 1, where=[s < 1 for s in speedups], 
                  alpha=0.3, color='#E74C3C', label='Python faster')

plt.tight_layout()
plt.show()

## 8. Profiling: Where is Time Spent?

In [None]:
# Profile a single search to understand time breakdown
import cProfile
import pstats
from io import StringIO

config = MCTSConfig(num_simulations=100)
python_mcts = BatchedMCTS(game, config, use_transposition_table=False)

state = game.initial_state()
states = np.stack([state] * 64)

# Profile Python MCTS
pr = cProfile.Profile()
pr.enable()

with torch.inference_mode():
    for _ in range(5):
        _ = python_mcts.search(states, model, add_noise=False)

pr.disable()

s = StringIO()
ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
ps.print_stats(20)
print("Python MCTS Profile (top 20 by cumulative time):")
print(s.getvalue())

## 9. Summary

In [None]:
print("=" * 60)
print("SUMMARY")
print("=" * 60)
print()
print("Rust MCTS Backend Performance:")
print()

# Find best speedup
best_batch = max(batch_results, key=lambda x: x['speedup'])
print(f"  Best speedup by batch size: {best_batch['speedup']:.2f}x at batch_size={best_batch['batch_size']}")

best_sim = max(sim_results, key=lambda x: x['speedup'])
print(f"  Best speedup by sim count:  {best_sim['speedup']:.2f}x at {best_sim['num_sims']} simulations")

print()
print("Key Observations:")
print("  - Rust MCTS runs entire search loop in Rust")
print("  - Only calls back to Python for NN inference")
print("  - Speedup improves with larger batch sizes")
print()
print("Future Optimizations:")
print("  - Rayon parallelization for tree selection")
print("  - Virtual loss for multi-simulation batching")
print("  - Bayesian MCTS (TTTS) variant")