# AlphaZero Chess Training on Google ColabTrain a chess AI from scratch using the AlphaZero algorithm with GPU acceleration.## Features- **Iterative training** for faster learning (refreshes replay buffer each iteration)- **A100 GPU optimizations** (torch.compile, large batches, FP16)- **Google Drive checkpoint persistence** (survives session timeouts)- **Real-time progress monitoring** (tqdm progress bars)- **Self-contained** (no external files needed)## Quick Start1. Select **Runtime â†’ Change runtime type â†’ GPU** (T4 or A100)2. Run all cells in order3. Training auto-saves to Google Drive every 50-100 steps## Configuration Presets- **A100 Long**: 5 iterations Ã— 4000 steps = 20k total (8-12 hours)- **A100 Short**: 3 iterations Ã— 1500 steps = 4.5k total (2-3 hours)- **T4 Free**: 3 iterations Ã— 1500 steps = 4.5k total (optimized for 12-hour limit)## Iterative Training ExplainedEach iteration:1. Generates fresh self-play games with current model2. Clears old replay buffer3. Trains on new dataThis prevents overfitting to weak early games and accelerates learning!

In [None]:
# Cell 2: Setup & GPU Checkimport sysimport subprocessprint("Checking GPU availability...")!nvidia-smi# Check if dependencies are installedtry:    import torch    import chess    print(f"\nPyTorch {torch.__version__}")    print(f"CUDA available: {torch.cuda.is_available()}")    if torch.cuda.is_available():        print(f"GPU: {torch.cuda.get_device_name(0)}")        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")except ImportError:    print("\nInstalling dependencies...")    !pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118    !pip install -q python-chess numpy tqdm    print("Dependencies installed")    # Re-import after installation    import torch    import chess

In [None]:
# Cell 3: Mount Google Drivefrom google.colab import driveimport os# Mount Google Drivetry:    drive.mount('/content/drive', force_remount=False)    print("Google Drive mounted successfully")except Exception as e:    print(f"Warning: Could not mount Google Drive: {e}")    print("Checkpoints will be saved locally only")# Create checkpoint directoryDRIVE_CHECKPOINT_DIR = "/content/drive/MyDrive/alphazero_checkpoints"LOCAL_CHECKPOINT_DIR = "/content/checkpoints"try:    os.makedirs(DRIVE_CHECKPOINT_DIR, exist_ok=True)    CHECKPOINT_DIR = DRIVE_CHECKPOINT_DIR    print(f"Checkpoints will be saved to Google Drive: {CHECKPOINT_DIR}")except:    os.makedirs(LOCAL_CHECKPOINT_DIR, exist_ok=True)    CHECKPOINT_DIR = LOCAL_CHECKPOINT_DIR    print(f"Checkpoints will be saved locally: {CHECKPOINT_DIR}")

In [None]:
# Cell 4: Clone Repository and Install Dependenciesimport osimport sys# Clone repository if not already presentREPO_URL = "https://github.com/lirockyzhang/alpha-zero-chess.git"REPO_DIR = "/content/alpha-zero-chess"if not os.path.exists(REPO_DIR):    print(f"Cloning repository from {REPO_URL}...")    !git clone {REPO_URL} {REPO_DIR}    print("Repository cloned successfully")else:    print("Repository already exists")# Change to repository directoryos.chdir(REPO_DIR)print(f"Working directory: {os.getcwd()}")# Install dependencies directly (avoid editable install issues in Colab)print("\nInstalling dependencies...")!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118!pip install -q python-chess numpy tqdm matplotlib scipy# Add package to Python pathif REPO_DIR not in sys.path:    sys.path.insert(0, REPO_DIR)    print(f"Added {REPO_DIR} to Python path")# Verify installationprint("\nVerifying installation...")try:    import torch    import chess    from alphazero import AlphaZeroConfig    from alphazero.neural.network import AlphaZeroNetwork    from alphazero.chess_env import GameState    print("All imports successful!")    print(f"PyTorch version: {torch.__version__}")    print(f"CUDA available: {torch.cuda.is_available()}")    if torch.cuda.is_available():        print(f"GPU: {torch.cuda.get_device_name(0)}")except ImportError as e:    print(f"Import error: {e}")    print("\nNote: Some imports may fail due to Python version compatibility.")    print("The notebook will attempt to continue with available functionality.")

In [None]:
# Cell 5: Configurationimport torch# Detect GPU typegpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"is_a100 = "A100" in gpu_nameis_t4 = "T4" in gpu_nameprint(f"Detected GPU: {gpu_name}")print(f"Configuring for: {'A100' if is_a100 else 'T4' if is_t4 else 'CPU'}")# ============================================================================# CONFIGURATION - Modify these settings# ============================================================================# Choose configuration presetif is_a100:    # A100 (40GB) - Long Training (8-12 hours)    ITERATIONS = 5              # Number of training iterations    STEPS_PER_ITERATION = 4000  # Steps per iteration (20k total)    NUM_ACTORS = 2              # Sequential actors    NUM_FILTERS = 192           # Network width    NUM_BLOCKS = 15             # Network depth    BATCH_SIZE = 8192           # Training batch size    SIMULATIONS = 400           # MCTS simulations per move    MIN_BUFFER_SIZE = 8192      # Min positions before training    CHECKPOINT_INTERVAL = 100   # Save every N steps    USE_TORCH_COMPILE = True    # Enable torch.compile for 20-30% speedup    # A100 Short Training (uncomment to use)    # ITERATIONS = 3    # STEPS_PER_ITERATION = 1500    # NUM_FILTERS = 128    # NUM_BLOCKS = 10    # BATCH_SIZE = 4096    # SIMULATIONS = 200    # MIN_BUFFER_SIZE = 4096    # CHECKPOINT_INTERVAL = 50else:    # T4 (16GB) - Free Tier (optimized for 12-hour limit)    ITERATIONS = 3    STEPS_PER_ITERATION = 1500    NUM_ACTORS = 1              # Single actor for memory    NUM_FILTERS = 64    NUM_BLOCKS = 5    BATCH_SIZE = 2048    SIMULATIONS = 200    MIN_BUFFER_SIZE = 2048    CHECKPOINT_INTERVAL = 50    USE_TORCH_COMPILE = False   # Disabled for T4# MCTS ConfigurationC_PUCT = 1.25DIRICHLET_ALPHA = 0.3DIRICHLET_EPSILON = 0.25TEMPERATURE = 1.0TEMPERATURE_THRESHOLD = 30# Training ConfigurationLEARNING_RATE = 0.2MOMENTUM = 0.9WEIGHT_DECAY = 1e-4LR_SCHEDULE_STEPS = [100000, 300000, 500000]LR_SCHEDULE_GAMMA = 0.1MAX_GRAD_NORM = 1.0USE_AMP = True  # Mixed precision training# Replay Buffer ConfigurationBUFFER_CAPACITY = 500000# Self-Play ConfigurationMAX_MOVES = 512RESIGN_THRESHOLD = -0.95RESIGN_CHECK_MOVES = 5# DeviceDEVICE = "cuda" if torch.cuda.is_available() else "cpu"print(f"\n{'='*60}")print(f"TRAINING CONFIGURATION")print(f"{'='*60}")print(f"Iterations: {ITERATIONS}")print(f"Steps per iteration: {STEPS_PER_ITERATION}")print(f"Total steps: {ITERATIONS * STEPS_PER_ITERATION}")print(f"Network: {NUM_FILTERS} filters, {NUM_BLOCKS} blocks")print(f"Batch size: {BATCH_SIZE}")print(f"MCTS simulations: {SIMULATIONS}")print(f"Torch compile: {USE_TORCH_COMPILE}")print(f"Device: {DEVICE}")print(f"{'='*60}\n")

In [None]:
# Cell 6: Helper Classes - ColabStorageimport osimport globimport torchfrom pathlib import Pathclass ColabStorage:    """Manages checkpoint storage for Google Colab."""    def __init__(self, checkpoint_dir):        self.checkpoint_dir = checkpoint_dir        os.makedirs(checkpoint_dir, exist_ok=True)    def save_checkpoint(self, network, optimizer, scheduler, step, iteration,                       num_filters, num_blocks, extra_state=None):        """Save training checkpoint."""        checkpoint_path = os.path.join(            self.checkpoint_dir,            f"checkpoint_iter{iteration}_step{step}_f{num_filters}_b{num_blocks}.pt"        )        state = {            'step': step,            'iteration': iteration,            'network_state_dict': network.state_dict(),            'optimizer_state_dict': optimizer.state_dict(),            'scheduler_state_dict': scheduler.state_dict(),            'num_filters': num_filters,            'num_blocks': num_blocks,        }        if extra_state:            state.update(extra_state)        torch.save(state, checkpoint_path)        print(f"Saved checkpoint: {checkpoint_path}")        return checkpoint_path    def load_latest_checkpoint(self, device='cuda'):        """Load the most recent checkpoint."""        checkpoints = glob.glob(os.path.join(self.checkpoint_dir, "checkpoint_*.pt"))        if not checkpoints:            return None        # Sort by modification time        latest_checkpoint = max(checkpoints, key=os.path.getmtime)        print(f"Loading checkpoint: {latest_checkpoint}")        checkpoint = torch.load(latest_checkpoint, map_location=device)        return checkpoint, latest_checkpoint    def list_checkpoints(self):        """List all available checkpoints."""        checkpoints = glob.glob(os.path.join(self.checkpoint_dir, "checkpoint_*.pt"))        return sorted(checkpoints, key=os.path.getmtime, reverse=True)print("ColabStorage class defined")

In [None]:
# Cell 6A: OPTION 1 - Run Full Training Script (Recommended for A100)
# This cell runs the actual scripts/train.py with all features enabled

import os
import torch

# Detect GPU type
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
is_a100 = "A100" in gpu_name
is_t4 = "T4" in gpu_name

print(f"Detected GPU: {gpu_name}")
print(f"Configuring for: {'A100' if is_a100 else 'T4' if is_t4 else 'CPU'}")

# Configuration based on GPU type
if is_a100:
    # A100 (40GB) - High Performance Configuration
    config = {
        'iterations': 5,
        'steps_per_iteration': 4000,
        'num_actors': 4,           # Multiple actors for parallel self-play
        'num_filters': 192,
        'num_blocks': 15,
        'batch_size': 8192,
        'simulations': 400,
        'checkpoint_interval': 100,
        'batched_inference': True,  # Enable batched inference for 2-3x speedup
    }
    print("\nðŸš€ A100 Configuration: High Performance")
else:
    # T4 (16GB) - Free Tier Configuration
    config = {
        'iterations': 3,
        'steps_per_iteration': 1500,
        'num_actors': 2,           # Fewer actors for memory constraints
        'num_filters': 64,
        'num_blocks': 5,
        'batch_size': 2048,
        'simulations': 200,
        'checkpoint_interval': 50,
        'batched_inference': False,  # Disabled for T4
    }
    print("\nðŸ’¡ T4 Configuration: Optimized for Free Tier")

# Build command
cmd_parts = [
    "cd /content/alpha-zero-chess &&",
    "python -u scripts/train.py",  # -u for unbuffered output
    f"--iterations {config['iterations']}",
    f"--steps-per-iteration {config['steps_per_iteration']}",
    f"--num-actors {config['num_actors']}",
    f"--filters {config['num_filters']}",
    f"--blocks {config['num_blocks']}",
    f"--batch-size {config['batch_size']}",
    f"--simulations {config['simulations']}",
    f"--checkpoint-interval {config['checkpoint_interval']}",
    f"--checkpoint-dir {CHECKPOINT_DIR}",
    "--device cuda",
]

if config['batched_inference']:
    cmd_parts.append("--batched-inference")

command = " ".join(cmd_parts)

print(f"\n{'='*60}")
print("TRAINING COMMAND")
print(f"{'='*60}")
print(command)
print(f"{'='*60}\n")

print("ðŸ“‹ Configuration Summary:")
print(f"  Total steps: {config['iterations'] * config['steps_per_iteration']}")
print(f"  Network: {config['num_filters']} filters, {config['num_blocks']} blocks")
print(f"  Actors: {config['num_actors']}")
print(f"  Batched inference: {config['batched_inference']}")
print(f"  Checkpoints: {CHECKPOINT_DIR}")
print()

# Option 1: Run in foreground (blocks the cell until complete)
print("ðŸ”¹ OPTION A: Run in foreground (recommended for short training)")
print("   Uncomment the line below to run:")
print(f"   !{command}")
print()

# Option 2: Run in background (survives disconnections)
print("ðŸ”¹ OPTION B: Run in background with nohup (recommended for long training)")
print("   This allows training to continue even if you disconnect!")
print("   Uncomment the lines below to run:")
nohup_cmd = f"nohup {command} > /content/training.log 2>&1 &"
print(f"   !{nohup_cmd}")
print("   !sleep 2 && tail -f /content/training.log")
print()

print("ðŸ’¡ To monitor background training:")
print("   !tail -f /content/training.log")
print()
print("ðŸ’¡ To check if training is still running:")
print("   !ps aux | grep train.py")
print()
print("ðŸ’¡ To stop background training:")
print("   !pkill -f train.py")
print()

# Uncomment ONE of these to start training:

# OPTION A: Foreground (blocks until complete)
# !{command}

# OPTION B: Background (continues even if disconnected)
# !{nohup_cmd}
# !sleep 2 && tail -f /content/training.log

# Training Options

You have **two ways** to train on Google Colab:

## Option 1: Run Full Training Script (Recommended for A100)
**Advantages:**
- âœ… Full feature access (batched inference, multi-actor, all CLI options)
- âœ… Better GPU utilization (2-3x faster with batched inference)
- âœ… Can run in background with `nohup` to survive disconnections
- âœ… Uses the actual codebase without code duplication

**Best for:** A100 instances where you want maximum performance

**See Cell 6A below** for how to run the full script.

## Option 2: Simplified Notebook Coordinator (Current Approach)
**Advantages:**
- âœ… Self-contained (all code visible in notebook)
- âœ… Easier to understand and modify
- âœ… No multiprocessing complexity
- âœ… Works reliably on all GPU types

**Best for:** T4 free tier, learning, or when you want to see all the code

**See Cells 6-11 below** for the simplified approach.

---

Choose one option and skip the other!

In [None]:
# Cell 7: Helper Classes - SingleProcessCoordinator with Iterative Trainingimport torchimport torch.nn as nnimport torch.optim as optimfrom torch.amp import autocast, GradScalerimport numpy as npfrom tqdm import tqdmimport timefrom collections import dequefrom alphazero.neural.network import AlphaZeroNetworkfrom alphazero.neural.loss import AlphaZeroLoss, compute_policy_accuracy, compute_value_accuracyfrom alphazero.chess_env import GameStatefrom alphazero.mcts import create_mctsfrom alphazero.mcts.evaluator import NetworkEvaluatorfrom alphazero.selfplay.game import SelfPlayGamefrom alphazero.training.trajectory import Trajectory, TrajectoryStatefrom alphazero.config import AlphaZeroConfig, MCTSConfig, SelfPlayConfigclass SingleProcessCoordinator:    \"\"\"Single-process training coordinator for Google Colab.    Features:    - Sequential actor execution (no multiprocessing)    - Iterative training (refreshes replay buffer each iteration)    - A100 optimizations (torch.compile, large batches)    - Progress monitoring with tqdm    \"\"\"    def __init__(self, network, config, device='cuda', use_torch_compile=False):        self.network = network.to(device)        self.config = config        self.device = device        self.use_torch_compile = use_torch_compile        # Compile network for A100 speedup (20-30% faster)        if use_torch_compile and hasattr(torch, 'compile'):            print("Compiling network with torch.compile...")            self.network = torch.compile(self.network)            print("Network compiled successfully")        # Optimizer        self.optimizer = optim.SGD(            self.network.parameters(),            lr=config.training.learning_rate,            momentum=config.training.momentum,            weight_decay=config.training.weight_decay        )        # Learning rate scheduler        self.scheduler = optim.lr_scheduler.MultiStepLR(            self.optimizer,            milestones=config.training.lr_schedule_steps,            gamma=config.training.lr_schedule_gamma        )        # Loss function        self.loss_fn = AlphaZeroLoss()        # Mixed precision training        self.scaler = GradScaler() if config.training.use_amp else None        # Training state        self.global_step = 0        self.current_iteration = 0        # Replay buffer (simple deque-based)        self.replay_buffer = deque(maxlen=config.replay_buffer.capacity)        # Statistics        self.total_games = 0        self.metrics_history = []    def generate_selfplay_games(self, num_games, num_actors=1):        \"\"\"Generate self-play games sequentially.\"\"\"        print(f"\nGenerating {num_games} self-play games...")        # Create MCTS and evaluator        mcts = create_mcts(config=self.config.mcts)        evaluator = NetworkEvaluator(self.network, self.device, use_amp=self.config.training.use_amp)        trajectories = []        games_per_actor = num_games // num_actors        with tqdm(total=num_games, desc="Self-play", unit="game") as pbar:            for actor_id in range(num_actors):                for game_idx in range(games_per_actor):                    try:                        # Play game                        game = SelfPlayGame(mcts, evaluator, self.config.selfplay)                        trajectory, result_str = game.play()                        trajectories.append(trajectory)                        self.total_games += 1                        pbar.update(1)                        pbar.set_postfix({                            'actor': actor_id,                            'result': result_str,                            'moves': len(trajectory)                        })                    except Exception as e:                        print(f"\nError in game {game_idx}: {e}")                        continue        print(f"Generated {len(trajectories)} games successfully")        return trajectories    def fill_replay_buffer(self, trajectories):        \"\"\"Fill replay buffer with trajectories.\"\"\"        print(f"\nFilling replay buffer with {len(trajectories)} games...")        total_positions = 0        for trajectory in trajectories:            for state in trajectory.states:                self.replay_buffer.append(state)                total_positions += 1        print(f"Added {total_positions} positions to replay buffer")        print(f"Buffer size: {len(self.replay_buffer)}/{self.replay_buffer.maxlen}")    def sample_batch(self, batch_size):        \"\"\"Sample a random batch from replay buffer.\"\"\"        if len(self.replay_buffer) < batch_size:            raise ValueError(f"Not enough samples: {len(self.replay_buffer)} < {batch_size}")        # Random sampling with replacement        indices = np.random.randint(0, len(self.replay_buffer), size=batch_size)        states = [self.replay_buffer[i] for i in indices]        # Convert to numpy arrays        observations = np.stack([s.observation for s in states])        legal_masks = np.stack([s.legal_mask for s in states])        policies = np.stack([s.policy for s in states])        values = np.array([s.value for s in states], dtype=np.float32)        return observations, legal_masks, policies, values    def train_step(self):        \"\"\"Execute a single training step.\"\"\"        self.network.train()        # Sample batch        observations, legal_masks, policies, values = self.sample_batch(            self.config.training.batch_size        )        # Convert to tensors        obs_tensor = torch.from_numpy(observations).float().to(self.device)        mask_tensor = torch.from_numpy(legal_masks).float().to(self.device)        policy_tensor = torch.from_numpy(policies).float().to(self.device)        value_tensor = torch.from_numpy(values).float().to(self.device)        # Forward pass with optional mixed precision        self.optimizer.zero_grad()        if self.config.training.use_amp and self.scaler is not None:            with autocast('cuda'):                policy_logits, value_pred = self.network(obs_tensor, mask_tensor)                loss, metrics = self.loss_fn(                    policy_logits, policy_tensor,                    value_pred, value_tensor,                    mask_tensor                )            # Backward pass with gradient scaling            self.scaler.scale(loss).backward()            self.scaler.unscale_(self.optimizer)            torch.nn.utils.clip_grad_norm_(                self.network.parameters(),                self.config.training.max_grad_norm            )            self.scaler.step(self.optimizer)            self.scaler.update()        else:            policy_logits, value_pred = self.network(obs_tensor, mask_tensor)            loss, metrics = self.loss_fn(                policy_logits, policy_tensor,                value_pred, value_tensor,                mask_tensor            )            loss.backward()            torch.nn.utils.clip_grad_norm_(                self.network.parameters(),                self.config.training.max_grad_norm            )            self.optimizer.step()        # Compute additional metrics        with torch.no_grad():            metrics['policy_accuracy'] = compute_policy_accuracy(                policy_logits, policy_tensor, mask_tensor            )            metrics['value_accuracy'] = compute_value_accuracy(                value_pred, value_tensor            )            metrics['learning_rate'] = self.optimizer.param_groups[0]['lr']        self.global_step += 1        self.metrics_history.append(metrics)        return metrics    def train_iteration(self, steps, checkpoint_interval, storage):        \"\"\"Train for one iteration.\"\"\"        print(f"\n{'='*60}")        print(f"Training Iteration {self.current_iteration + 1}")        print(f"{'='*60}")        with tqdm(total=steps, desc=f"Iteration {self.current_iteration + 1}", unit="step") as pbar:            for step in range(steps):                try:                    metrics = self.train_step()                    pbar.update(1)                    pbar.set_postfix({                        'loss': f"{metrics['loss']:.4f}",                        'p_loss': f"{metrics['policy_loss']:.4f}",                        'v_loss': f"{metrics['value_loss']:.4f}",                        'buffer': len(self.replay_buffer)                    })                    # Checkpoint                    if (step + 1) % checkpoint_interval == 0:                        storage.save_checkpoint(                            self.network,                            self.optimizer,                            self.scheduler,                            self.global_step,                            self.current_iteration,                            self.config.network.num_filters,                            self.config.network.num_blocks                        )                except Exception as e:                    print(f"\nError at step {step + 1}: {e}")                    continue        self.current_iteration += 1    def run_iterative_training(self, iterations, steps_per_iteration,                              num_actors, checkpoint_interval, storage):        \"\"\"Run iterative training with buffer refresh.\"\"\"        print(f"\n{'='*60}")        print(f"ITERATIVE TRAINING")        print(f"{'='*60}")        print(f"Iterations: {iterations}")        print(f"Steps per iteration: {steps_per_iteration}")        print(f"Total steps: {iterations * steps_per_iteration}")        print(f"{'='*60}\n")        for iteration in range(iterations):            print(f"\n{'='*60}")            print(f"ITERATION {iteration + 1}/{iterations}")            print(f"{'='*60}")            # Step 1: Generate fresh self-play games            games_to_generate = max(                self.config.replay_buffer.min_size_to_train // 50,  # ~50 moves per game                num_actors * 10  # At least 10 games per actor            )            trajectories = self.generate_selfplay_games(games_to_generate, num_actors)            # Step 2: Clear old buffer and fill with new games            print(f"\nClearing old replay buffer...")            self.replay_buffer.clear()            self.fill_replay_buffer(trajectories)            # Step 3: Train on new data            self.train_iteration(steps_per_iteration, checkpoint_interval, storage)            # Save checkpoint after each iteration            storage.save_checkpoint(                self.network,                self.optimizer,                self.scheduler,                self.global_step,                self.current_iteration,                self.config.network.num_filters,                self.config.network.num_blocks            )        print(f"\n{'='*60}")        print(f"TRAINING COMPLETE")        print(f"{'='*60}")        print(f"Total steps: {self.global_step}")        print(f"Total games: {self.total_games}")        print(f"{'='*60}\n")print("SingleProcessCoordinator class defined")

In [None]:
# Cell 8: Training Setupimport torchfrom alphazero.neural.network import AlphaZeroNetworkfrom alphazero.config import AlphaZeroConfig, MCTSConfig, NetworkConfig, TrainingConfig, ReplayBufferConfig, SelfPlayConfig# Create configurationconfig = AlphaZeroConfig(    network=NetworkConfig(        num_filters=NUM_FILTERS,        num_blocks=NUM_BLOCKS    ),    mcts=MCTSConfig(        num_simulations=SIMULATIONS,        c_puct=C_PUCT,        dirichlet_alpha=DIRICHLET_ALPHA,        dirichlet_epsilon=DIRICHLET_EPSILON,        temperature=TEMPERATURE,        temperature_threshold=TEMPERATURE_THRESHOLD    ),    training=TrainingConfig(        batch_size=BATCH_SIZE,        learning_rate=LEARNING_RATE,        momentum=MOMENTUM,        weight_decay=WEIGHT_DECAY,        lr_schedule_steps=LR_SCHEDULE_STEPS,        lr_schedule_gamma=LR_SCHEDULE_GAMMA,        max_grad_norm=MAX_GRAD_NORM,        use_amp=USE_AMP    ),    replay_buffer=ReplayBufferConfig(        capacity=BUFFER_CAPACITY,        min_size_to_train=MIN_BUFFER_SIZE    ),    selfplay=SelfPlayConfig(        max_moves=MAX_MOVES,        resign_threshold=RESIGN_THRESHOLD,        resign_check_moves=RESIGN_CHECK_MOVES    ),    device=DEVICE)# Create networkprint("Creating neural network...")network = AlphaZeroNetwork(    num_filters=NUM_FILTERS,    num_blocks=NUM_BLOCKS)network = network.to(DEVICE)# Count parameterstotal_params = sum(p.numel() for p in network.parameters())trainable_params = sum(p.numel() for p in network.parameters() if p.requires_grad)print(f"Total parameters: {total_params:,}")print(f"Trainable parameters: {trainable_params:,}")# Create storagestorage = ColabStorage(CHECKPOINT_DIR)# Try to load latest checkpointcheckpoint_data = storage.load_latest_checkpoint(DEVICE)if checkpoint_data:    checkpoint, checkpoint_path = checkpoint_data    print(f"\nLoaded checkpoint from: {checkpoint_path}")    print(f"Checkpoint step: {checkpoint['step']}")    print(f"Checkpoint iteration: {checkpoint['iteration']}")else:    print("\nNo checkpoint found, starting from scratch")# Create coordinatorprint("\nCreating training coordinator...")coordinator = SingleProcessCoordinator(    network=network,    config=config,    device=DEVICE,    use_torch_compile=USE_TORCH_COMPILE)# Load checkpoint state if availableif checkpoint_data:    checkpoint, _ = checkpoint_data    coordinator.network.load_state_dict(checkpoint['network_state_dict'])    coordinator.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])    coordinator.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])    coordinator.global_step = checkpoint['step']    coordinator.current_iteration = checkpoint['iteration']    print(f"Resumed from step {coordinator.global_step}, iteration {coordinator.current_iteration}")print("\nSetup complete! Ready to train.")

In [None]:
# Cell 9: Training Loop (Iterative Training)import timeprint("Starting iterative training...")print(f"Configuration: {ITERATIONS} iterations Ã— {STEPS_PER_ITERATION} steps = {ITERATIONS * STEPS_PER_ITERATION} total steps")start_time = time.time()try:    coordinator.run_iterative_training(        iterations=ITERATIONS,        steps_per_iteration=STEPS_PER_ITERATION,        num_actors=NUM_ACTORS,        checkpoint_interval=CHECKPOINT_INTERVAL,        storage=storage    )except KeyboardInterrupt:    print("\nTraining interrupted by user")except Exception as e:    print(f"\nTraining error: {e}")    import traceback    traceback.print_exc()end_time = time.time()elapsed_time = end_time - start_timeprint(f"\n{'='*60}")print(f"TRAINING SUMMARY")print(f"{'='*60}")print(f"Total time: {elapsed_time / 3600:.2f} hours")print(f"Total steps: {coordinator.global_step}")print(f"Total games: {coordinator.total_games}")print(f"Steps/second: {coordinator.global_step / elapsed_time:.2f}")print(f"Games/hour: {coordinator.total_games / (elapsed_time / 3600):.1f}")print(f"{'='*60}")# Save final checkpointprint("\nSaving final checkpoint...")storage.save_checkpoint(    coordinator.network,    coordinator.optimizer,    coordinator.scheduler,    coordinator.global_step,    coordinator.current_iteration,    NUM_FILTERS,    NUM_BLOCKS)print("\nTraining complete!")

In [None]:
# Cell 10: Evaluation Against Random Playerimport torchimport numpy as npfrom tqdm import tqdmfrom alphazero.chess_env import GameStatefrom alphazero.mcts import create_mctsfrom alphazero.mcts.evaluator import NetworkEvaluator, RandomEvaluatorfrom alphazero.selfplay.game import SelfPlayGamefrom alphazero.config import MCTSConfig, SelfPlayConfigprint("Evaluating trained model against random player...")# Set network to eval modecoordinator.network.eval()# Create evaluatorsnetwork_evaluator = NetworkEvaluator(coordinator.network, DEVICE, use_amp=USE_AMP)random_evaluator = RandomEvaluator()# Create MCTSmcts_config = MCTSConfig(    num_simulations=200,  # Use fewer simulations for faster evaluation    c_puct=C_PUCT,    dirichlet_alpha=0.0,  # No exploration noise during evaluation    dirichlet_epsilon=0.0,    temperature=0.0  # Greedy selection)mcts = create_mcts(config=mcts_config)selfplay_config = SelfPlayConfig(    max_moves=MAX_MOVES,    resign_threshold=RESIGN_THRESHOLD,    resign_check_moves=RESIGN_CHECK_MOVES)# Play evaluation gamesnum_eval_games = 20wins = 0losses = 0draws = 0print(f"\nPlaying {num_eval_games} games against random player...")with tqdm(total=num_eval_games, desc="Evaluation", unit="game") as pbar:    for game_idx in range(num_eval_games):        try:            # Alternate colors            model_plays_white = (game_idx % 2 == 0)            state = GameState()            move_count = 0            while not state.is_terminal() and move_count < MAX_MOVES:                # Choose evaluator based on whose turn it is                is_white_turn = state.turn                if (is_white_turn and model_plays_white) or (not is_white_turn and not model_plays_white):                    # Model's turn                    evaluator = network_evaluator                else:                    # Random player's turn                    evaluator = random_evaluator                # Run MCTS                policy, root, stats = mcts.search(state, evaluator, move_number=move_count, add_noise=False)                # Select action greedily                action = int(np.argmax(policy))                # Apply action                state = state.apply_action(action)                move_count += 1            # Get result            if state.is_terminal():                result = state.get_result()                if result is not None:                    # Determine outcome from model's perspective                    if result.winner is True:  # White won                        if model_plays_white:                            wins += 1                        else:                            losses += 1                    elif result.winner is False:  # Black won                        if model_plays_white:                            losses += 1                        else:                            wins += 1                    else:  # Draw                        draws += 1                else:                    draws += 1            else:                draws += 1            pbar.update(1)            pbar.set_postfix({                'wins': wins,                'losses': losses,                'draws': draws,                'win_rate': f"{wins / (game_idx + 1) * 100:.1f}%"            })        except Exception as e:            print(f"\nError in evaluation game {game_idx}: {e}")            continueprint(f"\n{'='*60}")print(f"EVALUATION RESULTS")print(f"{'='*60}")print(f"Games played: {num_eval_games}")print(f"Wins: {wins} ({wins / num_eval_games * 100:.1f}%)")print(f"Losses: {losses} ({losses / num_eval_games * 100:.1f}%)")print(f"Draws: {draws} ({draws / num_eval_games * 100:.1f}%)")print(f"Win rate: {wins / num_eval_games * 100:.1f}%")print(f"{'='*60}")

In [None]:
# Cell 11: Play Against the Model Interactively
import chess
import torch
import numpy as np
from alphazero.chess_env import GameState
from alphazero.mcts import create_mcts
from alphazero.mcts.evaluator import NetworkEvaluator
from alphazero.config import MCTSConfig

print("Interactive Chess Game")
print("="*60)
print("You can play against the trained model!")
print("Enter moves in UCI format (e.g., 'e2e4', 'e7e8q')")
print("Type 'quit' to exit, 'moves' to see legal moves")
print("="*60)

# Set network to eval mode
coordinator.network.eval()

# Create evaluator and MCTS
network_evaluator = NetworkEvaluator(coordinator.network, DEVICE, use_amp=USE_AMP)
mcts_config = MCTSConfig(
    num_simulations=400,  # High quality for interactive play
    c_puct=C_PUCT,
    dirichlet_alpha=0.0,
    dirichlet_epsilon=0.0,
    temperature=0.0
)
mcts = create_mcts(config=mcts_config)

# Choose color
print("\nChoose your color:")
print("1. White (you play first)")
print("2. Black (model plays first)")
color_choice = input("Enter 1 or 2: ").strip()
human_plays_white = (color_choice == "1")
print(f"\nYou are playing as {'White' if human_plays_white else 'Black'}")

# Initialize game
state = GameState()
move_count = 0
print("\nStarting position:")
print(state.board)
print()

# Game loop
while not state.is_terminal() and move_count < MAX_MOVES:
    is_white_turn = state.turn
    is_human_turn = (is_white_turn and human_plays_white) or (not is_white_turn and not human_plays_white)
    
    if is_human_turn:
        # Human's turn
        print(f"\nYour turn ({'White' if is_white_turn else 'Black'})")
        print(f"Legal moves: {len(list(state.board.legal_moves))}")
        
        while True:
            move_input = input("Enter move (or 'moves' to see legal moves, 'quit' to exit): ").strip().lower()
            
            if move_input == 'quit':
                print("Game ended by user")
                break
            
            if move_input == 'moves':
                print("\nLegal moves:")
                for move in state.board.legal_moves:
                    print(f"  {move.uci()}")
                continue
            
            try:
                # Parse move
                move = chess.Move.from_uci(move_input)
                if move not in state.board.legal_moves:
                    print("Illegal move! Try again.")
                    continue
                
                # Apply move (immutable update)
                state = state.apply_move(move)
                move_count += 1
                print(f"\nYou played: {move.uci()}")
                print(state.board)
                break
            except Exception as e:
                print(f"Invalid move format: {e}")
                continue
        
        if move_input == 'quit':
            break
    
    else:
        # Model's turn
        print(f"\nModel's turn ({'White' if is_white_turn else 'Black'})")
        print("Thinking...")
        
        # Run MCTS
        policy, root, stats = mcts.search(state, network_evaluator, move_number=move_count, add_noise=False)
        
        # Select best action
        action = int(np.argmax(policy))
        
        # Convert action to move
        move = state.action_to_move(action)
        
        # Apply move (immutable update)
        state = state.apply_move(move)
        move_count += 1
        
        print(f"Model played: {move.uci()}")
        print(f"Model evaluation: {root.q_value:.3f}")
        print(state.board)

# Game over
print("\n" + "="*60)
print("GAME OVER")
print("="*60)

if state.is_terminal():
    result = state.get_result()
    if result is not None:
        if result.winner is True:
            winner = "White" if human_plays_white else "Model"
            print(f"Winner: {winner} (White)")
        elif result.winner is False:
            winner = "Model" if human_plays_white else "White"
            print(f"Winner: {winner} (Black)")
        else:
            print("Result: Draw")
        print(f"Termination: {result.termination}")
    else:
        print("Result: Draw (no outcome)")
else:
    print("Result: Draw (max moves reached)")
print("="*60)