# Optimized AlphaZero Training for Chopsticks

This notebook demonstrates the optimized training pipeline with:
- **Vectorized game execution** for parallel game processing
- **Batched model predictions** for efficient GPU utilization
- **Parallel MCTS** with batched neural network calls
- **Multi-game self-play** running multiple games simultaneously

These optimizations significantly speed up the training process compared to the sequential implementation.

## Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time

# Import game
from games.chopsticks import ChopsticksGame
from games.chopsticks_vectorized import ChopsticksVectorized

# Import models
from models.chopsticks import ChopsticksMLP
from models.batched_model import BatchedModelWrapper, CachedBatchedModel

# Import MCTS
from sims.tree import MCTS
from sims.tree_parallel import ParallelMCTS, VirtualLossMCTS

# Import trainers
from utils.trainer import self_play, create_dataloader, train_model
from utils.trainer_parallel import (
    self_play_parallel, 
    self_play_vectorized,
    train_iteration,
    full_training_loop
)

print("All imports successful!")

## Configuration

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Training hyperparameters
config = {
    # MCTS parameters
    'c_puct': 1.0,              # Exploration constant
    'num_simulations': 800,     # MCTS simulations per move
    'batch_size': 16,           # Batch size for MCTS predictions
    'virtual_loss': 3,          # Virtual loss value
    
    # Training parameters
    'num_games': 100,           # Self-play games per iteration
    'num_workers': 8,           # Parallel games
    'num_epochs': 50,           # Training epochs per iteration
    'train_batch_size': 64,    # Batch size for training
    'learning_rate': 0.001,
    
    # Overall training
    'num_iterations': 10,       # Number of training iterations
}

print("Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

## Initialize Model and Game

In [None]:
# Initialize game
game = ChopsticksGame()

# Initialize model
model = ChopsticksMLP(
    input_size=game.state_dim(),
    output_size=game.num_actions(),
    hidden_size=128  # Increased hidden size for better capacity
)

model.to(device)
print(f"Model architecture:")
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters())}")

## Initialize Batched Model Wrapper

In [None]:
# Use cached batched model for better performance
batched_model = CachedBatchedModel(model, device=device, cache_size=10000)

print("Batched model wrapper initialized with caching")

## Initialize Parallel MCTS

In [None]:
# Initialize parallel MCTS with virtual loss
mcts = VirtualLossMCTS(game, batched_model, config)

print(f"Parallel MCTS initialized:")
print(f"  Simulations: {config['num_simulations']}")
print(f"  Batch size: {config['batch_size']}")
print(f"  Virtual loss: {config['virtual_loss']}")

## Test Optimized Components

In [None]:
# Test vectorized game
print("Testing vectorized game...")
vec_game = ChopsticksVectorized(batch_size=4, device=device)
states = vec_game.reset()
print(f"  Batch shape: {states.shape}")
print(f"  Initial state[0]: {states[0]}")

# Test batched model prediction
print("\nTesting batched model prediction...")
test_states = torch.stack([game.reset() for _ in range(8)]).to(device)
policies, values = batched_model.predict_batch(test_states)
print(f"  Input shape: {test_states.shape}")
print(f"  Policies shape: {policies.shape}")
print(f"  Values shape: {values.shape}")

# Test parallel MCTS
print("\nTesting parallel MCTS...")
start_time = time.time()
root = mcts.run(game.reset())
elapsed = time.time() - start_time
print(f"  MCTS completed in {elapsed:.2f}s")
print(f"  Root visits: {root.visit_count}")
print(f"  Root value: {root.value():.3f}")
print(f"  Number of children: {len(root.children)}")

# Check cache stats
cache_stats = batched_model.get_cache_stats()
print(f"\nCache statistics:")
print(f"  Cache size: {cache_stats['size']}")
print(f"  Hit rate: {cache_stats['hit_rate']:.2%}")

print("\n✓ All components working correctly!")

## Benchmark: Compare Sequential vs Parallel MCTS

In [None]:
print("Benchmarking MCTS implementations...\n")

# Sequential MCTS (original)
print("Testing sequential MCTS...")
simple_model = model  # Use model directly without batching
sequential_mcts = MCTS(game, simple_model, config)

start_time = time.time()
for _ in range(5):
    root = sequential_mcts.run(game.reset())
sequential_time = time.time() - start_time
print(f"  5 runs: {sequential_time:.2f}s ({sequential_time/5:.2f}s per run)")

# Parallel MCTS (new)
print("\nTesting parallel MCTS...")
batched_model.clear_cache()

start_time = time.time()
for _ in range(5):
    root = mcts.run(game.reset())
parallel_time = time.time() - start_time
print(f"  5 runs: {parallel_time:.2f}s ({parallel_time/5:.2f}s per run)")

# Compare
speedup = sequential_time / parallel_time
print(f"\nSpeedup: {speedup:.2f}x faster")
print(f"Time saved: {sequential_time - parallel_time:.2f}s ({(1 - parallel_time/sequential_time)*100:.1f}% faster)")

## Training Loop

In [None]:
# Initialize optimizer
optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])

# Training arguments
train_args = {
    'num_games': config['num_games'],
    'num_workers': config['num_workers'],
    'num_epochs': config['num_epochs'],
    'batch_size': config['train_batch_size']
}

print("Starting training...\n")
print(f"Total iterations: {config['num_iterations']}")
print(f"Games per iteration: {config['num_games']}")
print(f"Epochs per iteration: {config['num_epochs']}")
print(f"Parallel games: {config['num_workers']}")

In [None]:
# Run full training loop
training_start = time.time()

history = full_training_loop(
    model=model,
    mcts=mcts,
    game=game,
    optimizer=optimizer,
    args=train_args,
    num_iterations=config['num_iterations'],
    device=device
)

training_time = time.time() - training_start
print(f"\nTotal training time: {training_time:.2f}s ({training_time/60:.2f} minutes)")

## Visualize Training Progress

In [None]:
# Extract metrics
iterations = [h['iteration'] for h in history]
final_losses = [h['final_loss'] for h in history]
num_samples = [h['num_samples'] for h in history]

# Plot training loss
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Loss curve
ax1.plot(iterations, final_losses, marker='o', linewidth=2, markersize=8)
ax1.set_xlabel('Iteration', fontsize=12)
ax1.set_ylabel('Final Loss', fontsize=12)
ax1.set_title('Training Loss Over Iterations', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)

# Samples per iteration
ax2.bar(iterations, num_samples, color='skyblue', edgecolor='navy', alpha=0.7)
ax2.set_xlabel('Iteration', fontsize=12)
ax2.set_ylabel('Number of Samples', fontsize=12)
ax2.set_title('Training Samples Per Iteration', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('training_progress.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Training completed successfully!")
print(f"Initial loss: {final_losses[0]:.4f}")
print(f"Final loss: {final_losses[-1]:.4f}")
print(f"Improvement: {(1 - final_losses[-1]/final_losses[0])*100:.1f}%")

## Detailed Loss Curves

In [None]:
# Plot loss curves for each iteration
fig, axes = plt.subplots(2, 5, figsize=(18, 8))
axes = axes.flatten()

for idx, h in enumerate(history):
    if idx < len(axes):
        epoch_losses = h['epoch_losses']
        axes[idx].plot(range(1, len(epoch_losses)+1), epoch_losses, linewidth=2)
        axes[idx].set_title(f"Iteration {h['iteration']}", fontweight='bold')
        axes[idx].set_xlabel('Epoch')
        axes[idx].set_ylabel('Loss')
        axes[idx].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('detailed_loss_curves.png', dpi=150, bbox_inches='tight')
plt.show()

## Evaluate Trained Model

In [None]:
print("Evaluating trained model...\n")

# Play a sample game
game.reset()
move_count = 0
max_moves = 50

print("Playing a sample game with trained model:\n")
print(f"Initial state:")
game.print_state(game.state)

while move_count < max_moves:
    root = mcts.run(game.state)
    action = root.best_action()  # Use best action (greedy)
    
    print(f"\nMove {move_count + 1}:")
    print(f"  Action: {game.describe_action(action)}")
    print(f"  Root value: {root.value():.3f}")
    
    game_continues = game.play(action)
    game.print_state(game.state)
    
    if not game_continues:
        reward = game.reward(game.state)
        game.print_winner_result(reward, game.state)
        break
    
    move_count += 1

print(f"\nGame ended in {move_count} moves")

## Cache Performance Analysis

In [None]:
# Get final cache statistics
cache_stats = batched_model.get_cache_stats()

print("Model Prediction Cache Statistics:")
print(f"  Cache size: {cache_stats['size']}")
print(f"  Cache hits: {cache_stats['hits']}")
print(f"  Cache misses: {cache_stats['misses']}")
print(f"  Hit rate: {cache_stats['hit_rate']:.2%}")
print(f"\nCaching saved {cache_stats['hits']} model predictions!")

## Save Trained Model

In [None]:
# Save model
model_path = 'chopsticks_model_optimized.pth'
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'config': config,
    'history': history,
    'training_time': training_time
}, model_path)

print(f"Model saved to {model_path}")
print(f"\nTraining Summary:")
print(f"  Total iterations: {config['num_iterations']}")
print(f"  Total games played: {sum(num_samples)}")
print(f"  Total training time: {training_time/60:.2f} minutes")
print(f"  Final loss: {final_losses[-1]:.4f}")

## Performance Summary

In [None]:
print("="*60)
print("OPTIMIZED TRAINING PERFORMANCE SUMMARY")
print("="*60)
print(f"\nOptimizations Applied:")
print(f"  ✓ Vectorized game execution")
print(f"  ✓ Batched model predictions (batch size: {config['batch_size']})")
print(f"  ✓ Parallel MCTS with virtual loss")
print(f"  ✓ Cached predictions (hit rate: {cache_stats['hit_rate']:.1%})")
print(f"  ✓ Multi-game self-play ({config['num_workers']} parallel games)")
print(f"\nTraining Results:")
print(f"  Total training time: {training_time/60:.2f} minutes")
print(f"  Games played: {sum(num_samples)}")
print(f"  Samples collected: {sum(num_samples)}")
print(f"  Initial loss: {final_losses[0]:.4f}")
print(f"  Final loss: {final_losses[-1]:.4f}")
print(f"  Loss reduction: {(1 - final_losses[-1]/final_losses[0])*100:.1f}%")
print(f"\nEfficiency Metrics:")
print(f"  Time per iteration: {training_time/config['num_iterations']:.2f}s")
print(f"  Games per second: {sum(num_samples)/training_time:.2f}")
print("\n" + "="*60)