# AlphaZero Chess Training on Google Colab

Train an AlphaZero chess engine using a C++ MCTS backend with GPU-accelerated neural networks.

**Requirements:**
- Colab GPU runtime (Runtime → Change runtime type → T4 GPU)
- Google Drive for persistent checkpoint storage

**What this notebook does:**
1. Builds the C++ backend (MCTS + self-play engine) from source
2. Runs the AlphaZero training loop with parallel self-play
3. Saves checkpoints to Google Drive so they persist across sessions

## 1. GPU Check & Google Drive Mount

In [None]:
# Verify GPU is available and mount Google Drive for checkpoint persistence
import torch

if not torch.cuda.is_available():
    raise RuntimeError(
        "No GPU detected! Go to Runtime > Change runtime type > T4 GPU"
    )

print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"CUDA: {torch.version.cuda}")
!nvidia-smi --query-gpu=name,memory.total --format=csv,noheader

# Mount Google Drive
from google.colab import drive
drive.mount("/content/drive")

import os
DRIVE_DIR = "/content/drive/MyDrive/alphazero-chess"
os.makedirs(DRIVE_DIR, exist_ok=True)
print(f"\nDrive output directory: {DRIVE_DIR}")

## 2. Clone Repository

In [None]:
import os

REPO_DIR = "/content/alpha-zero-chess"

if not os.path.exists(REPO_DIR):
    !git clone https://github.com/lirockyzhang/alpha-zero-chess.git {REPO_DIR}
else:
    print(f"Repository already cloned at {REPO_DIR}")
    !cd {REPO_DIR} && git pull

# Initialize git submodules (chess-library dependency)
!cd {REPO_DIR} && git submodule update --init --recursive

os.chdir(REPO_DIR)
print(f"Working directory: {os.getcwd()}")

# Verify the chess library is present
chess_hpp = os.path.join(REPO_DIR, "alphazero-cpp", "third_party", "chess-library", "include", "chess.hpp")
if os.path.exists(chess_hpp):
    print("chess-library submodule: OK")
else:
    raise RuntimeError(
        "chess-library submodule is missing! "
        "Run: git submodule update --init --recursive"
    )

## 3. Install Dependencies

In [None]:
# torch and numpy are pre-installed on Colab
# We only need pybind11 (for building C++) and python-chess (for the training script)
!pip install pybind11 python-chess -q

# Verify imports
import pybind11, chess, torch, numpy
print(f"pybind11:     {pybind11.__version__}")
print(f"python-chess: {chess.__version__}")
print(f"torch:        {torch.__version__}")
print(f"numpy:        {numpy.__version__}")

## 4. Build C++ Module

Compiles the C++ MCTS engine and Python bindings. This takes ~2 minutes on Colab.

In [None]:
import os, subprocess, sys

BUILD_DIR = os.path.join(REPO_DIR, "alphazero-cpp", "build")

# Clean previous build artifacts to avoid stale cache issues
if os.path.exists(BUILD_DIR):
    import shutil
    shutil.rmtree(BUILD_DIR)
    print("Cleaned previous build directory")
os.makedirs(BUILD_DIR)

# Get pybind11 cmake directory
pybind11_dir = subprocess.check_output(
    ["python3", "-c", "import pybind11; print(pybind11.get_cmake_dir())"]
).decode().strip()
print(f"pybind11 cmake dir: {pybind11_dir}")

# Configure - use subprocess to catch errors
print("\n--- CMake Configure ---")
configure_result = subprocess.run(
    [
        "cmake", "..",
        f"-DCMAKE_BUILD_TYPE=Release",
        f"-Dpybind11_DIR={pybind11_dir}",
    ],
    cwd=BUILD_DIR,
    capture_output=True,
    text=True,
)
print(configure_result.stdout)
if configure_result.returncode != 0:
    print("STDERR:", configure_result.stderr)
    raise RuntimeError(
        f"CMake configure failed (exit code {configure_result.returncode}).\n"
        "Check the output above for details."
    )

# Verify pybind11 was found and Python bindings will be built
if "Python bindings will be built" not in configure_result.stdout:
    print("\nWARNING: CMake did not find pybind11 or Python!")
    print("The alphazero_cpp module will NOT be built.")
    print("CMake output above should show why.\n")

# Build using all available cores
import multiprocessing
n_cores = multiprocessing.cpu_count()
print(f"\n--- CMake Build ({n_cores} cores) ---")
build_result = subprocess.run(
    ["cmake", "--build", ".", "--config", "Release", f"-j{n_cores}"],
    cwd=BUILD_DIR,
    capture_output=True,
    text=True,
)
print(build_result.stdout[-3000:] if len(build_result.stdout) > 3000 else build_result.stdout)
if build_result.returncode != 0:
    print("STDERR:", build_result.stderr[-3000:] if len(build_result.stderr) > 3000 else build_result.stderr)
    raise RuntimeError(
        f"C++ build failed (exit code {build_result.returncode}).\n"
        "Check the output above for the actual compiler error."
    )

# On Linux, CMake puts the .so directly in build/ (not build/Release/)
# The training script expects it in build/Release/, so create that structure
RELEASE_DIR = os.path.join(BUILD_DIR, "Release")
os.makedirs(RELEASE_DIR, exist_ok=True)

import glob
so_files = glob.glob(os.path.join(BUILD_DIR, "alphazero_cpp*.so"))
if so_files:
    for so in so_files:
        target = os.path.join(RELEASE_DIR, os.path.basename(so))
        if not os.path.exists(target):
            os.symlink(so, target)
            print(f"Symlinked: {os.path.basename(so)} -> Release/")
else:
    print("\nERROR: No alphazero_cpp*.so found in build directory!")
    print("Files in build dir:", os.listdir(BUILD_DIR))
    raise RuntimeError(
        "Build succeeded but no .so file was produced. "
        "Check if pybind11 was found during CMake configure."
    )

# Verify the module loads
sys.path.insert(0, RELEASE_DIR)
sys.path.insert(0, BUILD_DIR)
import alphazero_cpp
print(f"\nalphazero_cpp loaded successfully!")

## 5. Configure Training Parameters

Adjust these parameters before starting training. The defaults are a good starting point for full training.

In [None]:
# @title Training Configuration { run: "auto" }

# --- Network Architecture ---
FILTERS = 192          # @param {type: "integer"}
BLOCKS = 15            # @param {type: "integer"}

# --- Training Loop ---
ITERATIONS = 100       # @param {type: "integer"}
GAMES_PER_ITER = 50    # @param {type: "integer"}
SIMULATIONS = 800      # @param {type: "integer"}
EPOCHS = 5             # @param {type: "integer"}
LR = 0.001             # @param {type: "number"}
TRAIN_BATCH = 256      # @param {type: "integer"}
BUFFER_SIZE = 100000   # @param {type: "integer"}

# --- Parallel Self-Play ---
WORKERS = 64           # @param {type: "integer"}
EVAL_BATCH = 1024      # @param {type: "integer"}
SEARCH_BATCH = 16      # @param {type: "integer"}
GPU_BATCH_TIMEOUT_MS = 10  # @param {type: "integer"}
C_PUCT = 2.0           # @param {type: "number"}

# --- Draw Score ---
# Value assigned to draws from White's perspective.
# 0.0 = standard symmetric draws.
# -0.2 = slightly penalize White for drawing (encourages decisive play).
DRAW_SCORE = 0.0       # @param {type: "number"}

# --- Checkpointing ---
SAVE_INTERVAL = 1      # @param {type: "integer"}

# ==============================================================================
# RESUME OPTIONS (for continuing from a previous checkpoint)
# ==============================================================================
# RESUME_RUN_DIR: The run directory to resume from (on Google Drive or local)
#   - New checkpoints will be saved here
#   - Leave empty for a fresh start
RESUME_RUN_DIR = ""    # @param {type: "string"}

# LOAD_CHECKPOINT_PATH: (Optional) Load model weights from a different location
#   - Useful for fast loading: upload checkpoint to /content/ instead of reading from Drive
#   - If empty, loads from RESUME_RUN_DIR (default behavior)
#   - Example: "/content/model_iter_010.pt" (uploaded file)
LOAD_CHECKPOINT_PATH = ""  # @param {type: "string"}

# LOAD_REPLAY_BUFFER: Whether to load the replay buffer when resuming
#   - Buffer files can be 100+ MB and slow to load from Drive
#   - Set to False to skip loading (training will start with empty buffer)
LOAD_REPLAY_BUFFER = False  # @param {type: "boolean"}

# --- Checkpoint directory setup ---
import os

SAVE_DIR = os.path.join(DRIVE_DIR, "checkpoints")
os.makedirs(SAVE_DIR, exist_ok=True)

# Create symlink so the training script's default path also works
LOCAL_CKPT = os.path.join(REPO_DIR, "checkpoints")
if os.path.islink(LOCAL_CKPT):
    os.remove(LOCAL_CKPT)
if not os.path.exists(LOCAL_CKPT):
    os.symlink(SAVE_DIR, LOCAL_CKPT)
    print(f"Symlinked checkpoints/ -> {SAVE_DIR}")
else:
    print(f"Note: {LOCAL_CKPT} already exists as a directory, using --save-dir instead")

print(f"\nCheckpoints will be saved to: {SAVE_DIR}")
print(f"Network: {FILTERS} filters, {BLOCKS} blocks")
print(f"Training: {ITERATIONS} iterations, {GAMES_PER_ITER} games/iter, {WORKERS} workers")
print(f"MCTS: {SIMULATIONS} sims, search_batch={SEARCH_BATCH}, c_puct={C_PUCT}")
print(f"GPU: eval_batch={EVAL_BATCH}, timeout={GPU_BATCH_TIMEOUT_MS}ms")
print(f"Draw score: {DRAW_SCORE}")
if RESUME_RUN_DIR:
    print(f"\nResuming run directory: {RESUME_RUN_DIR}")
    if LOAD_CHECKPOINT_PATH:
        print(f"Loading checkpoint from: {LOAD_CHECKPOINT_PATH} (fast local path)")
    else:
        print(f"Loading checkpoint from: {RESUME_RUN_DIR} (run directory)")
    print(f"Load replay buffer: {LOAD_REPLAY_BUFFER}")

## 6. (Optional) Upload Checkpoint for Fast Loading

If you have a checkpoint file on your computer (e.g., `model_iter_010.pt`), upload it here for **much faster loading** than reading from Google Drive.

**Why this is faster:**
- Google Drive: ~2-5 MB/s read speed (130MB = 30-60 seconds)
- Colab local storage: ~500 MB/s (130MB = <1 second)

Run this cell, then set `LOAD_CHECKPOINT_PATH = "/content/your_model.pt"` in cell 5.

In [None]:
# Upload checkpoint file from your computer to Colab's fast local storage
# After uploading, set LOAD_CHECKPOINT_PATH in cell 5 to the uploaded file path

from google.colab import files
import os

print("Select a checkpoint file (.pt) to upload...")
print("(This will be stored in /content/ for fast access)\n")

uploaded = files.upload()

if uploaded:
    for filename in uploaded.keys():
        filepath = f"/content/{filename}"
        size_mb = os.path.getsize(filepath) / (1024 * 1024)
        print(f"\n✓ Uploaded: {filepath} ({size_mb:.1f} MB)")
        print(f"\nTo use this checkpoint, set in cell 5:")
        print(f'  LOAD_CHECKPOINT_PATH = "{filepath}"')
else:
    print("No file uploaded.")

## 7. Run Training

This starts the AlphaZero training loop. Progress is printed every 30 seconds.

Checkpoints are saved to Google Drive every `SAVE_INTERVAL` iterations, so they persist even if Colab disconnects.

**First run will be slower** due to:
- `torch.compile()` JIT compilation (~2-5 minutes, one-time)
- cuDNN algorithm auto-tuning (first few batches)

In [None]:
import os
import shutil

# Handle checkpoint loading from a different location
if RESUME_RUN_DIR and LOAD_CHECKPOINT_PATH:
    # User wants to load from a fast local path but save to the run directory
    # Copy the checkpoint to the run directory so --resume works correctly
    if os.path.exists(LOAD_CHECKPOINT_PATH):
        # Determine the target filename
        ckpt_filename = os.path.basename(LOAD_CHECKPOINT_PATH)
        target_path = os.path.join(RESUME_RUN_DIR, ckpt_filename)
        
        # Copy if not already there (or if source is newer)
        if not os.path.exists(target_path) or \
           os.path.getmtime(LOAD_CHECKPOINT_PATH) > os.path.getmtime(target_path):
            print(f"Copying checkpoint to run directory...")
            print(f"  From: {LOAD_CHECKPOINT_PATH}")
            print(f"  To:   {target_path}")
            shutil.copy2(LOAD_CHECKPOINT_PATH, target_path)
            print(f"  Done! ({os.path.getsize(target_path) / 1024 / 1024:.1f} MB)")
        else:
            print(f"Checkpoint already in run directory: {target_path}")
    else:
        raise FileNotFoundError(f"Checkpoint not found: {LOAD_CHECKPOINT_PATH}")

# Build the training command
cmd = (
    f"python {os.path.join(REPO_DIR, 'alphazero-cpp', 'scripts', 'train.py')}"
    f" --iterations {ITERATIONS}"
    f" --games-per-iter {GAMES_PER_ITER}"
    f" --simulations {SIMULATIONS}"
    f" --workers {WORKERS}"
    f" --eval-batch {EVAL_BATCH}"
    f" --search-batch {SEARCH_BATCH}"
    f" --gpu-batch-timeout-ms {GPU_BATCH_TIMEOUT_MS}"
    f" --c-puct {C_PUCT}"
    f" --filters {FILTERS}"
    f" --blocks {BLOCKS}"
    f" --train-batch {TRAIN_BATCH}"
    f" --lr {LR}"
    f" --epochs {EPOCHS}"
    f" --buffer-size {BUFFER_SIZE}"
    f" --save-dir {SAVE_DIR}"
    f" --save-interval {SAVE_INTERVAL}"
    f" --progress-interval 30"
    f" --no-visualization"
    f" --device cuda"
    f" --draw-score {DRAW_SCORE}"
)

# Add resume options
if RESUME_RUN_DIR:
    cmd += f" --resume {RESUME_RUN_DIR}"
    if LOAD_REPLAY_BUFFER:
        cmd += " --load-buffer"

print("Training command:")
print(cmd)
print("\n" + "=" * 60)
print("Starting training...")
print("=" * 60 + "\n")

!{cmd}

In [None]:
import os, json

print("=" * 60)
print("Saved checkpoints on Google Drive:")
print("=" * 60)

if os.path.exists(SAVE_DIR):
    for run_dir in sorted(os.listdir(SAVE_DIR)):
        run_path = os.path.join(SAVE_DIR, run_dir)
        if os.path.isdir(run_path):
            files = os.listdir(run_path)
            pt_files = sorted([f for f in files if f.endswith(".pt") and not f.endswith("_emergency.pt")])
            print(f"\n  {run_dir}/")
            for f in pt_files[-5:]:  # Show last 5 checkpoints
                size_mb = os.path.getsize(os.path.join(run_path, f)) / (1024 * 1024)
                print(f"    {f}  ({size_mb:.1f} MB)")
            if len(pt_files) > 5:
                print(f"    ... and {len(pt_files) - 5} more checkpoints")

            # Show training metrics if available
            metrics_path = os.path.join(run_path, "metrics_history.json")
            if os.path.exists(metrics_path):
                with open(metrics_path) as mf:
                    metrics = json.load(mf)
                iters = metrics.get("iterations", [])
                if iters:
                    print(f"    Iterations completed: {len(iters)}")
                    last = iters[-1]
                    print(f"    Latest loss: {last.get('loss', 'N/A'):.4f}")
else:
    print("  No checkpoints found.")

print("\n" + "=" * 60)
print("To resume training in a new session:")
print("=" * 60)
print(f"""
  1. Run cells 1-5 (setup)
  
  2. In cell 5 (Configuration), set:
     RESUME_RUN_DIR = "{SAVE_DIR}/<run_directory>"
  
  3. (Optional) For FAST checkpoint loading:
     - Run cell 6 to upload checkpoint from your computer
     - Set: LOAD_CHECKPOINT_PATH = "/content/model_iter_XXX.pt"
     - Set: LOAD_REPLAY_BUFFER = False  (skip slow buffer loading)
  
  4. Run cell 7 to continue training
""")

In [None]:
import os, json

print("=" * 60)
print("Saved checkpoints on Google Drive:")
print("=" * 60)

if os.path.exists(SAVE_DIR):
    for run_dir in sorted(os.listdir(SAVE_DIR)):
        run_path = os.path.join(SAVE_DIR, run_dir)
        if os.path.isdir(run_path):
            files = os.listdir(run_path)
            pt_files = sorted([f for f in files if f.endswith(".pt")])
            print(f"\n  {run_dir}/")
            for f in pt_files:
                size_mb = os.path.getsize(os.path.join(run_path, f)) / (1024 * 1024)
                print(f"    {f}  ({size_mb:.1f} MB)")

            # Show training metrics if available
            metrics_path = os.path.join(run_path, "training_metrics.json")
            if os.path.exists(metrics_path):
                with open(metrics_path) as mf:
                    metrics = json.load(mf)
                n_iters = len(metrics)
                if n_iters > 0:
                    last = metrics[-1]
                    print(f"    Iterations completed: {n_iters}")
                    print(f"    Final loss: {last.get('total_loss', 'N/A')}")
else:
    print("  No checkpoints found.")

print("\n" + "=" * 60)
print("To resume training in a new session:")
print("=" * 60)
print(f"  1. Run cells 1-5 again")
print(f"  2. Set RESUME_PATH to the run directory, e.g.:")
print(f"     RESUME_PATH = \"{SAVE_DIR}/<run_directory>\"")
print(f"  3. Run cell 6 to continue training")