<a href="https://colab.research.google.com/github/jawaharganesh24189/DLA/blob/main/Hybrid_Chess_Engine_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🧠 Hybrid Neural Chess Engine v2.0

## 📚 Overview

This is an **improved version** of the Hybrid Neural Chess Engine that combines:
- **Imitation Learning**: Learning from master games (e.g., Hikaru Nakamura)
- **Self-Play Reinforcement**: Improving through self-play

### ✨ Key Improvements in v2.0:
- ✅ **Early stopping** with validation split and patience mechanism
- ✅ **Google Drive integration** for persistent model storage
- ✅ **Improved model architecture** with batch normalization and dropout
- ✅ **Comprehensive metrics tracking** (Top-1/Top-5 accuracy, loss curves)
- ✅ **Checkpoint management** (best model + periodic saves)
- ✅ **Self-play game saving** as PGN files
- ✅ **Training visualization** (loss plots, accuracy curves)
- ✅ **Progress bars** using tqdm
- ✅ **Resume from checkpoint** capability
- ✅ **Organized directory structure** on Google Drive

### 🏗️ Architecture:
- **Policy Network**: CNN + Transformer → Move Probabilities (4544 moves)
- **Value Network**: CNN → Position Evaluation (-1 to +1)

### 📖 Notebook Sections:
1. **Setup & Configuration**
2. **Google Drive Setup**
3. **Data Preparation**
4. **Model Architecture** (with improvements)
5. **Training Utilities** (early stopping, checkpointing, metrics)
6. **Imitation Learning**
7. **Self-Play Reinforcement**
8. **Evaluation & Metrics**
9. **Model Saving & Loading**
10. **Play Against Engine**
11. **Visualization**

## 📋 SECTION 1 — Configuration

Configure all hyperparameters and paths here for easy adjustment.

In [None]:
# ===== CONFIGURATION =====
# All hyperparameters and settings in one place for easy tuning

CONFIG = {
    # === Paths ===
    'pgn_path': '/content/drive/MyDrive/ChessEngine/data/master_games.pgn',
    'save_dir': '/content/drive/MyDrive/ChessEngine/',
    
    # === Data Parameters ===
    'max_games': 200,  # Number of games to load from PGN
    'train_val_split': 0.8,  # 80% train, 20% validation
    
    # === Training Parameters ===
    'batch_size': 32,
    'epochs': 10,
    'learning_rate': 1e-3,
    'weight_decay': 1e-5,
    
    # === Early Stopping ===
    'patience': 5,  # Stop if no improvement for N epochs
    'min_delta': 0.001,  # Minimum change to qualify as improvement
    
    # === Checkpointing ===
    'save_every': 2,  # Save checkpoint every N epochs
    
    # === Self-Play Parameters ===
    'selfplay_games': 20,
    'max_plies': 200,
    
    # === Mode ===
    # Options: 'imitation', 'selfplay', or 'hybrid'
    'mode': 'hybrid',
    
    # === Random Seed ===
    'seed': 42
}

print("✅ Configuration loaded successfully!")
print("\nKey settings:")
for key, value in CONFIG.items():
    print(f"  • {key}: {value}")

## 🔧 SECTION 2 — Setup & Installation

Install required packages and import libraries.

In [None]:
# Install required libraries
!pip -q install python-chess torch torchvision tqdm matplotlib

# Standard library imports
import os
import json
import random
import copy
from datetime import datetime
from pathlib import Path
from collections import defaultdict

# Third-party imports
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Chess library
import chess
import chess.pgn

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split

print("✅ All libraries imported successfully!")

In [None]:
# Set seeds for reproducibility
def set_seed(seed):
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(CONFIG['seed'])

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n🖥️  Using device: {device}")

if device.type == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

CONFIG['device'] = device

## 💾 SECTION 3 — Google Drive Setup

Mount Google Drive and create organized directory structure.

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
print("✅ Google Drive mounted successfully!")

In [None]:
def setup_directories(base_dir):
    """
    Create organized directory structure for the chess engine.
    
    Directory structure:
    ChessEngine/
    ├── models/       (saved model checkpoints)
    ├── games/        (self-play games in PGN format)
    ├── logs/         (training logs and metrics)
    ├── data/         (PGN data files)
    └── plots/        (training visualizations)
    """
    dirs = {
        'base': base_dir,
        'models': os.path.join(base_dir, 'models'),
        'games': os.path.join(base_dir, 'games'),
        'logs': os.path.join(base_dir, 'logs'),
        'data': os.path.join(base_dir, 'data'),
        'plots': os.path.join(base_dir, 'plots')
    }
    
    print("Creating directory structure...")
    for name, path in dirs.items():
        os.makedirs(path, exist_ok=True)
        print(f"  ✓ {name:10s}: {path}")
    
    return dirs

# Setup directories
DIRS = setup_directories(CONFIG['save_dir'])
print("\n✅ Directory structure created successfully!")

## 📊 SECTION 4 — Data Preparation

Board encoding and move vocabulary generation.

In [None]:
def board_to_tensor(board):
    """
    Convert chess board to 12×8×8 tensor.
    
    Channels represent: [P,N,B,R,Q,K, p,n,b,r,q,k]
    Where uppercase = white pieces, lowercase = black pieces
    
    Args:
        board: chess.Board object
    
    Returns:
        torch.Tensor of shape (12, 8, 8)
    """
    tensor = torch.zeros(12, 8, 8, dtype=torch.float32)
    
    piece_to_idx = {
        'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
        'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
    }
    
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            rank, file = divmod(square, 8)
            channel = piece_to_idx[piece.symbol()]
            tensor[channel, rank, file] = 1.0
    
    return tensor

# Test board encoding
test_board = chess.Board()
test_tensor = board_to_tensor(test_board)
print(f"Board tensor shape: {test_tensor.shape}")
print(f"✅ Board encoding working correctly!")

In [None]:
def generate_move_vocab():
    """
    Generate move vocabulary covering all possible chess moves.
    
    Includes:
    - Standard moves (source square → destination square)
    - Promotion moves (pawn reaching last rank with piece selection)
    
    Returns:
        move_to_idx: dict mapping (from_sq, to_sq, promotion) -> index
        idx_to_move: dict mapping index -> (from_sq, to_sq, promotion)
        vocab_size: total number of unique moves
    """
    moves = []
    
    # Standard moves: 64 × 64 combinations (excluding same square)
    for from_sq in range(64):
        for to_sq in range(64):
            if from_sq != to_sq:
                moves.append((from_sq, to_sq, None))
    
    # Promotions for white pawns (rank 6 → rank 7)
    for file in range(8):
        from_sq = chess.square(file, 6)
        for to_file in [file-1, file, file+1]:  # Can promote to 3 squares
            if 0 <= to_file < 8:
                to_sq = chess.square(to_file, 7)
                for promo in [chess.QUEEN, chess.ROOK, chess.BISHOP, chess.KNIGHT]:
                    moves.append((from_sq, to_sq, promo))
    
    # Promotions for black pawns (rank 1 → rank 0)
    for file in range(8):
        from_sq = chess.square(file, 1)
        for to_file in [file-1, file, file+1]:
            if 0 <= to_file < 8:
                to_sq = chess.square(to_file, 0)
                for promo in [chess.QUEEN, chess.ROOK, chess.BISHOP, chess.KNIGHT]:
                    moves.append((from_sq, to_sq, promo))
    
    move_to_idx = {move: idx for idx, move in enumerate(moves)}
    idx_to_move = {idx: move for idx, move in enumerate(moves)}
    
    return move_to_idx, idx_to_move, len(moves)

# Generate move vocabulary
MOVE_TO_IDX, IDX_TO_MOVE, VOCAB_SIZE = generate_move_vocab()
print(f"Move vocabulary size: {VOCAB_SIZE}")
print(f"✅ Move vocabulary generated successfully!")

In [None]:
def move_to_index(move, move_to_idx):
    """Convert chess.Move to vocabulary index."""
    key = (move.from_square, move.to_square, move.promotion)
    return move_to_idx.get(key, None)

def index_to_move(idx, idx_to_move):
    """Convert vocabulary index to chess.Move."""
    if idx not in idx_to_move:
        return None
    from_sq, to_sq, promo = idx_to_move[idx]
    return chess.Move(from_sq, to_sq, promotion=promo)

# Test conversions
test_move = chess.Move.from_uci("e2e4")
test_idx = move_to_index(test_move, MOVE_TO_IDX)
recovered_move = index_to_move(test_idx, IDX_TO_MOVE)
print(f"Test: {test_move} → {test_idx} → {recovered_move}")
print(f"✅ Move conversion working correctly!")

## 📚 SECTION 5 — Dataset Class

Load and process PGN games with train/validation split.

In [None]:
class ChessDataset(Dataset):
    """
    Chess dataset from PGN file.
    Extracts board positions and corresponding moves from master games.
    """
    
    def __init__(self, pgn_path, max_games=None):
        self.positions = []
        self.moves = []
        
        print(f"Loading games from: {pgn_path}")
        
        if not os.path.exists(pgn_path):
            print(f"⚠️  Warning: File not found: {pgn_path}")
            print("   Using empty dataset. Make sure to upload your PGN file to Google Drive!")
            return
        
        with open(pgn_path) as pgn_file:
            game_count = 0
            pbar = tqdm(desc="Loading games", total=max_games)
            
            while True:
                game = chess.pgn.read_game(pgn_file)
                if game is None or (max_games and game_count >= max_games):
                    break
                
                board = game.board()
                for move in game.mainline_moves():
                    # Save position before move
                    self.positions.append(board_to_tensor(board))
                    
                    # Save move as index
                    move_idx = move_to_index(move, MOVE_TO_IDX)
                    if move_idx is not None:
                        self.moves.append(move_idx)
                    else:
                        # Remove position if move couldn't be encoded
                        self.positions.pop()
                    
                    board.push(move)
                
                game_count += 1
                pbar.update(1)
            
            pbar.close()
        
        print(f"✅ Loaded {game_count} games with {len(self.positions)} positions")
    
    def __len__(self):
        return len(self.positions)
    
    def __getitem__(self, idx):
        return self.positions[idx], self.moves[idx]

In [None]:
def create_dataloaders(dataset, batch_size, train_split=0.8):
    """
    Split dataset into train/validation sets and create DataLoaders.
    
    Args:
        dataset: ChessDataset instance
        batch_size: Batch size for training
        train_split: Fraction of data to use for training (rest for validation)
    
    Returns:
        train_loader, val_loader: PyTorch DataLoader objects
    """
    if len(dataset) == 0:
        print("⚠️  Empty dataset! Cannot create dataloaders.")
        return None, None
    
    # Calculate split sizes
    train_size = int(train_split * len(dataset))
    val_size = len(dataset) - train_size
    
    # Split dataset
    train_dataset, val_dataset = random_split(
        dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(CONFIG['seed'])
    )
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True,
        num_workers=0,  # Use 0 for Colab compatibility
        pin_memory=True if device.type == 'cuda' else False
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False,
        num_workers=0,
        pin_memory=True if device.type == 'cuda' else False
    )
    
    print(f"📊 Dataset split:")
    print(f"   Train: {len(train_dataset):,} positions")
    print(f"   Val:   {len(val_dataset):,} positions")
    
    return train_loader, val_loader

## 🏗️ SECTION 6 — Model Architecture

Improved neural network architectures with batch normalization and dropout.

In [None]:
class PolicyNetwork(nn.Module):
    """
    Policy Network: Predicts move probabilities.
    
    Architecture: CNN + Transformer + Policy Head
    
    Improvements in v2:
    - Batch normalization for training stability
    - Dropout for regularization
    - Xavier weight initialization
    - Deeper CNN layers
    """
    
    def __init__(self, vocab_size=4544, dropout=0.2):
        super(PolicyNetwork, self).__init__()
        
        # CNN layers for spatial feature extraction
        self.conv1 = nn.Conv2d(12, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        
        # Flatten: 256 * 8 * 8 = 16384
        self.fc1 = nn.Linear(16384, 512)
        self.bn_fc1 = nn.BatchNorm1d(512)
        self.dropout1 = nn.Dropout(dropout)
        
        # Transformer encoder layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=512,
            nhead=8,
            dim_feedforward=2048,
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
        
        # Policy head
        self.fc_policy = nn.Linear(512, vocab_size)
        self.dropout2 = nn.Dropout(dropout)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights using Xavier initialization."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def forward(self, x):
        # CNN feature extraction
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Dense layer
        x = F.relu(self.bn_fc1(self.fc1(x)))
        x = self.dropout1(x)
        
        # Transformer (expects batch, seq_len, features)
        x = x.unsqueeze(1)  # Add sequence dimension
        x = self.transformer(x)
        x = x.squeeze(1)  # Remove sequence dimension
        
        # Policy head
        x = self.dropout2(x)
        logits = self.fc_policy(x)
        
        return logits

print("✅ PolicyNetwork defined!")

In [None]:
class ValueNetwork(nn.Module):
    """
    Value Network: Evaluates board positions.
    
    Architecture: CNN + Dense Layers
    Output: Single value in range [-1, 1]
    
    Improvements in v2:
    - Batch normalization
    - Dropout
    - Deeper architecture
    - Better initialization
    """
    
    def __init__(self, dropout=0.2):
        super(ValueNetwork, self).__init__()
        
        # CNN layers
        self.conv1 = nn.Conv2d(12, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        
        # Flatten: 256 * 8 * 8 = 16384
        self.fc1 = nn.Linear(16384, 512)
        self.bn_fc1 = nn.BatchNorm1d(512)
        self.dropout1 = nn.Dropout(dropout)
        
        self.fc2 = nn.Linear(512, 256)
        self.bn_fc2 = nn.BatchNorm1d(256)
        self.dropout2 = nn.Dropout(dropout)
        
        self.fc_value = nn.Linear(256, 1)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights using Xavier initialization."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def forward(self, x):
        # CNN feature extraction
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Dense layers
        x = F.relu(self.bn_fc1(self.fc1(x)))
        x = self.dropout1(x)
        
        x = F.relu(self.bn_fc2(self.fc2(x)))
        x = self.dropout2(x)
        
        # Value head (Tanh to bound output in [-1, 1])
        value = torch.tanh(self.fc_value(x))
        
        return value

print("✅ ValueNetwork defined!")

In [None]:
# Initialize models
policy_net = PolicyNetwork(vocab_size=VOCAB_SIZE).to(device)
value_net = ValueNetwork().to(device)

print(f"\n📊 Model Statistics:")
print(f"   Policy Network parameters: {sum(p.numel() for p in policy_net.parameters()):,}")
print(f"   Value Network parameters:  {sum(p.numel() for p in value_net.parameters()):,}")
print(f"\n✅ Models initialized and moved to {device}!")

## 🎯 SECTION 7 — Training Utilities

Early stopping, checkpointing, and metrics tracking.

In [None]:
class EarlyStopping:
    """
    Early stopping to stop training when validation loss stops improving.
    """
    
    def __init__(self, patience=5, min_delta=0.001, verbose=True):
        """
        Args:
            patience: How many epochs to wait after last improvement
            min_delta: Minimum change to qualify as an improvement
            verbose: Whether to print messages
        """
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_epoch = 0
    
    def __call__(self, val_loss, epoch):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_epoch = epoch
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.verbose:
                print(f"   Early stopping counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.best_epoch = epoch
            self.counter = 0
    
    def should_stop(self):
        return self.early_stop

print("✅ EarlyStopping class defined!")

In [None]:
def compute_accuracy(logits, targets, k_values=[1, 5]):
    """
    Compute Top-K accuracy.
    
    Args:
        logits: Model output logits (batch_size, vocab_size)
        targets: Ground truth indices (batch_size,)
        k_values: List of K values to compute (e.g., [1, 5])
    
    Returns:
        Dictionary with Top-K accuracies
    """
    batch_size = targets.size(0)
    _, pred = logits.topk(max(k_values), dim=1, largest=True, sorted=True)
    pred = pred.t()
    correct = pred.eq(targets.view(1, -1).expand_as(pred))
    
    accuracies = {}
    for k in k_values:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        accuracies[f'top{k}'] = (correct_k / batch_size * 100).item()
    
    return accuracies

print("✅ Accuracy computation function defined!")

In [None]:
def save_checkpoint(model, optimizer, epoch, metrics, filepath, is_best=False):
    """
    Save model checkpoint to Google Drive.
    
    Args:
        model: PyTorch model
        optimizer: Optimizer state
        epoch: Current epoch number
        metrics: Dictionary of training metrics
        filepath: Path to save checkpoint
        is_best: Whether this is the best model so far
    """
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'metrics': metrics,
        'timestamp': datetime.now().isoformat()
    }
    
    torch.save(checkpoint, filepath)
    
    if is_best:
        best_path = filepath.replace('.pth', '_best.pth')
        torch.save(checkpoint, best_path)
        print(f"   💾 Saved best model to: {best_path}")
    else:
        print(f"   💾 Saved checkpoint to: {filepath}")

def load_checkpoint(filepath, model, optimizer=None):
    """
    Load model checkpoint from file.
    
    Args:
        filepath: Path to checkpoint file
        model: PyTorch model to load weights into
        optimizer: Optional optimizer to restore state
    
    Returns:
        epoch, metrics: Epoch number and training metrics from checkpoint
    """
    checkpoint = torch.load(filepath, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    print(f"✅ Loaded checkpoint from epoch {checkpoint['epoch']}")
    return checkpoint['epoch'], checkpoint['metrics']

print("✅ Checkpoint functions defined!")

## 📖 SECTION 8 — Imitation Learning

Train policy network on master games with early stopping and validation.

In [None]:
def train_policy_network(model, train_loader, val_loader, epochs, learning_rate, 
                          save_dir, patience=5, save_every=2):
    """
    Train policy network with early stopping and checkpointing.
    
    Args:
        model: PolicyNetwork instance
        train_loader: Training data loader
        val_loader: Validation data loader
        epochs: Maximum number of epochs
        learning_rate: Learning rate for optimizer
        save_dir: Directory to save checkpoints and logs
        patience: Early stopping patience
        save_every: Save checkpoint every N epochs
    
    Returns:
        history: Dictionary with training metrics
    """
    # Setup
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, 
                           weight_decay=CONFIG['weight_decay'])
    criterion = nn.CrossEntropyLoss()
    early_stopping = EarlyStopping(patience=patience, min_delta=CONFIG['min_delta'])
    
    # Tracking
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_top1': [],
        'train_top5': [],
        'val_top1': [],
        'val_top5': [],
        'epochs': []
    }
    
    best_val_loss = float('inf')
    
    print("\n" + "="*60)
    print("🚀 Starting Policy Network Training")
    print("="*60)
    
    for epoch in range(epochs):
        # ===== TRAINING PHASE =====
        model.train()
        train_loss = 0.0
        train_top1 = 0.0
        train_top5 = 0.0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
        for batch_idx, (boards, moves) in enumerate(pbar):
            boards = boards.to(device)
            moves = moves.to(device)
            
            # Forward pass
            optimizer.zero_grad()
            logits = model(boards)
            loss = criterion(logits, moves)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Compute metrics
            train_loss += loss.item()
            acc = compute_accuracy(logits, moves, k_values=[1, 5])
            train_top1 += acc['top1']
            train_top5 += acc['top5']
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'top1': f"{acc['top1']:.2f}%"
            })
        
        train_loss /= len(train_loader)
        train_top1 /= len(train_loader)
        train_top5 /= len(train_loader)
        
        # ===== VALIDATION PHASE =====
        model.eval()
        val_loss = 0.0
        val_top1 = 0.0
        val_top5 = 0.0
        
        with torch.no_grad():
            pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]")
            for boards, moves in pbar:
                boards = boards.to(device)
                moves = moves.to(device)
                
                logits = model(boards)
                loss = criterion(logits, moves)
                
                val_loss += loss.item()
                acc = compute_accuracy(logits, moves, k_values=[1, 5])
                val_top1 += acc['top1']
                val_top5 += acc['top5']
                
                pbar.set_postfix({
                    'loss': f"{loss.item():.4f}",
                    'top1': f"{acc['top1']:.2f}%"
                })
        
        val_loss /= len(val_loader)
        val_top1 /= len(val_loader)
        val_top5 /= len(val_loader)
        
        # Update history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_top1'].append(train_top1)
        history['train_top5'].append(train_top5)
        history['val_top1'].append(val_top1)
        history['val_top5'].append(val_top5)
        history['epochs'].append(epoch + 1)
        
        # Print epoch summary
        print(f"\n📊 Epoch {epoch+1} Summary:")
        print(f"   Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        print(f"   Train Top-1: {train_top1:.2f}% | Val Top-1: {val_top1:.2f}%")
        print(f"   Train Top-5: {train_top5:.2f}% | Val Top-5: {val_top5:.2f}%")
        
        # Save checkpoint
        is_best = val_loss < best_val_loss
        if is_best:
            best_val_loss = val_loss
        
        if (epoch + 1) % save_every == 0 or is_best:
            checkpoint_path = os.path.join(save_dir, 'models', 
                                          f'policy_net_epoch_{epoch+1}.pth')
            save_checkpoint(model, optimizer, epoch + 1, history, 
                          checkpoint_path, is_best=is_best)
        
        # Early stopping check
        early_stopping(val_loss, epoch + 1)
        if early_stopping.should_stop():
            print(f"\n⏹️  Early stopping triggered!")
            print(f"   Best epoch: {early_stopping.best_epoch}")
            print(f"   Best val loss: {early_stopping.best_loss:.4f}")
            break
        
        print()
    
    # Save final history
    history_path = os.path.join(save_dir, 'logs', 'policy_training_history.json')
    with open(history_path, 'w') as f:
        json.dump(history, f, indent=2)
    print(f"\n💾 Training history saved to: {history_path}")
    
    print("\n" + "="*60)
    print("✅ Training Complete!")
    print("="*60)
    
    return history

print("✅ Training function defined!")

### Training Example

Load dataset and train the policy network.

In [None]:
# Load dataset
dataset = ChessDataset(CONFIG['pgn_path'], max_games=CONFIG['max_games'])

# Create dataloaders
if len(dataset) > 0:
    train_loader, val_loader = create_dataloaders(
        dataset, 
        batch_size=CONFIG['batch_size'],
        train_split=CONFIG['train_val_split']
    )
    
    # Train policy network
    if CONFIG['mode'] in ['imitation', 'hybrid']:
        print("\n🎯 Training Policy Network...")
        history = train_policy_network(
            policy_net,
            train_loader,
            val_loader,
            epochs=CONFIG['epochs'],
            learning_rate=CONFIG['learning_rate'],
            save_dir=CONFIG['save_dir'],
            patience=CONFIG['patience'],
            save_every=CONFIG['save_every']
        )
else:
    print("⚠️  No dataset loaded. Skipping training.")
    print("   Upload a PGN file to Google Drive and update CONFIG['pgn_path']")

## 🎮 SECTION 9 — Self-Play Reinforcement

Generate games through self-play and save them as PGN files.

In [None]:
def sample_legal_move_from_policy(board, policy_net):
    """
    Sample a legal move from the policy network.
    
    Args:
        board: chess.Board object
        policy_net: PolicyNetwork model
    
    Returns:
        chess.Move: Selected move
    """
    # Get policy logits
    board_tensor = board_to_tensor(board).unsqueeze(0).to(device)
    with torch.no_grad():
        logits = policy_net(board_tensor).squeeze(0)
    
    # Mask illegal moves
    legal_moves = list(board.legal_moves)
    legal_indices = []
    for move in legal_moves:
        idx = move_to_index(move, MOVE_TO_IDX)
        if idx is not None:
            legal_indices.append(idx)
    
    if not legal_indices:
        # Fallback to random legal move
        return random.choice(legal_moves)
    
    # Get probabilities for legal moves only
    legal_logits = logits[legal_indices]
    probs = F.softmax(legal_logits, dim=0)
    
    # Sample move
    sampled_idx = torch.multinomial(probs, 1).item()
    move_idx = legal_indices[sampled_idx]
    
    return index_to_move(move_idx, IDX_TO_MOVE)

print("✅ Move sampling function defined!")

In [None]:
PIECE_VALUES = {
    chess.PAWN: 1,
    chess.KNIGHT: 3,
    chess.BISHOP: 3,
    chess.ROOK: 5,
    chess.QUEEN: 9,
    chess.KING: 0
}

def calculate_material_score(board):
    """Calculate material balance on the board."""
    white_score = 0
    black_score = 0
    
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            value = PIECE_VALUES.get(piece.piece_type, 0)
            if piece.color == chess.WHITE:
                white_score += value
            else:
                black_score += value
    
    return white_score, black_score

print("✅ Material scoring function defined!")

In [None]:
def self_play_game(policy_net, max_plies=200):
    """
    Play a complete game using the policy network for both sides.
    
    Args:
        policy_net: PolicyNetwork model
        max_plies: Maximum number of moves (half-moves)
    
    Returns:
        game: chess.pgn.Game object
        result: Game result string ('1-0', '0-1', '1/2-1/2')
    """
    board = chess.Board()
    game = chess.pgn.Game()
    game.headers["Event"] = "Self-Play"
    game.headers["Date"] = datetime.now().strftime("%Y.%m.%d")
    game.headers["White"] = "PolicyNet"
    game.headers["Black"] = "PolicyNet"
    
    node = game
    ply_count = 0
    
    while not board.is_game_over() and ply_count < max_plies:
        move = sample_legal_move_from_policy(board, policy_net)
        board.push(move)
        node = node.add_variation(move)
        ply_count += 1
    
    # Determine result
    if board.is_checkmate():
        result = "1-0" if board.turn == chess.BLACK else "0-1"
    elif board.is_stalemate() or board.is_insufficient_material():
        result = "1/2-1/2"
    elif ply_count >= max_plies:
        # Game reached max plies - decide by material
        white_mat, black_mat = calculate_material_score(board)
        if white_mat > black_mat:
            result = "1-0"
        elif black_mat > white_mat:
            result = "0-1"
        else:
            result = "1/2-1/2"
    else:
        result = "1/2-1/2"
    
    game.headers["Result"] = result
    return game, result

print("✅ Self-play function defined!")

In [None]:
def generate_selfplay_games(policy_net, num_games, save_dir):
    """
    Generate multiple self-play games and save them to Google Drive.
    
    Args:
        policy_net: PolicyNetwork model
        num_games: Number of games to generate
        save_dir: Directory to save PGN files
    
    Returns:
        statistics: Dictionary with game statistics
    """
    stats = {
        'white_wins': 0,
        'black_wins': 0,
        'draws': 0,
        'total_plies': 0,
        'games': []
    }
    
    games_dir = os.path.join(save_dir, 'games')
    os.makedirs(games_dir, exist_ok=True)
    
    print(f"\n🎮 Generating {num_games} self-play games...")
    print("="*60)
    
    for i in tqdm(range(num_games), desc="Self-play"):
        game, result = self_play_game(policy_net, max_plies=CONFIG['max_plies'])
        
        # Update statistics
        if result == "1-0":
            stats['white_wins'] += 1
        elif result == "0-1":
            stats['black_wins'] += 1
        else:
            stats['draws'] += 1
        
        ply_count = len(list(game.mainline_moves()))
        stats['total_plies'] += ply_count
        stats['games'].append({
            'game_number': i + 1,
            'result': result,
            'plies': ply_count
        })
        
        # Save game to PGN file
        pgn_path = os.path.join(games_dir, f'selfplay_game_{i+1}.pgn')
        with open(pgn_path, 'w') as f:
            print(game, file=f)
    
    # Calculate averages
    stats['avg_plies'] = stats['total_plies'] / num_games if num_games > 0 else 0
    
    # Print summary
    print("\n📊 Self-Play Statistics:")
    print(f"   White wins: {stats['white_wins']} ({stats['white_wins']/num_games*100:.1f}%)")
    print(f"   Black wins: {stats['black_wins']} ({stats['black_wins']/num_games*100:.1f}%)")
    print(f"   Draws:      {stats['draws']} ({stats['draws']/num_games*100:.1f}%)")
    print(f"   Avg plies:  {stats['avg_plies']:.1f}")
    print(f"\n💾 Games saved to: {games_dir}")
    
    # Save statistics
    stats_path = os.path.join(save_dir, 'logs', 'selfplay_stats.json')
    with open(stats_path, 'w') as f:
        json.dump(stats, f, indent=2)
    
    return stats

print("✅ Self-play generation function defined!")

In [None]:
# Generate self-play games (if in selfplay or hybrid mode)
if CONFIG['mode'] in ['selfplay', 'hybrid']:
    print("\n🎮 Starting self-play...")
    selfplay_stats = generate_selfplay_games(
        policy_net,
        num_games=CONFIG['selfplay_games'],
        save_dir=CONFIG['save_dir']
    )
else:
    print("⏭️  Skipping self-play (mode='imitation')")

## 📈 SECTION 10 — Training Visualization

Visualize training metrics and save plots to Google Drive.

In [None]:
def plot_training_history(history, save_dir):
    """
    Plot training history and save figures to Google Drive.
    
    Args:
        history: Dictionary with training metrics
        save_dir: Directory to save plots
    """
    if not history or 'epochs' not in history:
        print("⚠️  No training history available")
        return
    
    plots_dir = os.path.join(save_dir, 'plots')
    os.makedirs(plots_dir, exist_ok=True)
    
    epochs = history['epochs']
    
    # Create figure with subplots
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot 1: Loss curves
    ax = axes[0]
    ax.plot(epochs, history['train_loss'], 'b-o', label='Train Loss', linewidth=2)
    ax.plot(epochs, history['val_loss'], 'r-s', label='Val Loss', linewidth=2)
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('Loss', fontsize=12)
    ax.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    
    # Plot 2: Accuracy curves
    ax = axes[1]
    ax.plot(epochs, history['train_top1'], 'b-o', label='Train Top-1', linewidth=2)
    ax.plot(epochs, history['val_top1'], 'r-s', label='Val Top-1', linewidth=2)
    ax.plot(epochs, history['train_top5'], 'b--^', label='Train Top-5', linewidth=2, alpha=0.7)
    ax.plot(epochs, history['val_top5'], 'r--v', label='Val Top-5', linewidth=2, alpha=0.7)
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('Accuracy (%)', fontsize=12)
    ax.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save figure
    plot_path = os.path.join(plots_dir, 'training_history.png')
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    print(f"💾 Training plots saved to: {plot_path}")
    
    plt.show()

print("✅ Visualization function defined!")

In [None]:
# Visualize training history (if available)
if 'history' in locals() and history:
    plot_training_history(history, CONFIG['save_dir'])
else:
    print("⚠️  No training history to visualize")
    print("   Run training first to generate metrics")

## ♟️ SECTION 11 — Play Against the Engine

Interactive gameplay and move prediction.

In [None]:
def predict_best_move(board, policy_net):
    """
    Predict the best move for the current position (greedy selection).
    
    Args:
        board: chess.Board object
        policy_net: PolicyNetwork model
    
    Returns:
        chess.Move: Best predicted move
    """
    # Get policy logits
    board_tensor = board_to_tensor(board).unsqueeze(0).to(device)
    with torch.no_grad():
        logits = policy_net(board_tensor).squeeze(0)
    
    # Get legal moves
    legal_moves = list(board.legal_moves)
    legal_indices = []
    legal_moves_list = []
    
    for move in legal_moves:
        idx = move_to_index(move, MOVE_TO_IDX)
        if idx is not None:
            legal_indices.append(idx)
            legal_moves_list.append(move)
    
    if not legal_indices:
        return random.choice(legal_moves)
    
    # Get scores for legal moves
    legal_logits = logits[legal_indices]
    best_idx = torch.argmax(legal_logits).item()
    
    return legal_moves_list[best_idx]

print("✅ Move prediction function defined!")

In [None]:
def play_against_engine(policy_net, player_color=chess.WHITE):
    """
    Play a game against the trained engine.
    
    Args:
        policy_net: PolicyNetwork model
        player_color: chess.WHITE or chess.BLACK
    """
    board = chess.Board()
    print("\n" + "="*60)
    print("♟️  Playing Against Chess Engine")
    print("="*60)
    print(f"You are playing as: {'White' if player_color == chess.WHITE else 'Black'}")
    print("Enter moves in UCI format (e.g., 'e2e4')")
    print("Type 'quit' to exit\n")
    
    while not board.is_game_over():
        print(board)
        print()
        
        if board.turn == player_color:
            # Player's turn
            while True:
                move_str = input("Your move: ").strip().lower()
                if move_str == 'quit':
                    print("Game ended by player.")
                    return
                
                try:
                    move = chess.Move.from_uci(move_str)
                    if move in board.legal_moves:
                        board.push(move)
                        break
                    else:
                        print("❌ Illegal move! Try again.")
                except:
                    print("❌ Invalid format! Use UCI notation (e.g., 'e2e4')")
        else:
            # Engine's turn
            print("🤖 Engine is thinking...")
            move = predict_best_move(board, policy_net)
            print(f"🤖 Engine plays: {move.uci()}")
            board.push(move)
        
        print()
    
    # Game over
    print(board)
    print("\n" + "="*60)
    print("🏁 Game Over!")
    print(f"Result: {board.result()}")
    print("="*60)

print("✅ Interactive play function defined!")

In [None]:
# Example: Display current board state
example_board = chess.Board()
print("Starting position:")
print(example_board)
print(f"\n✅ Ready to play! Use play_against_engine(policy_net) to start a game.")

## 💾 SECTION 12 — Model Management

Load and save models, resume training.

In [None]:
def load_best_model(save_dir, model, model_name='policy_net'):
    """
    Load the best saved model checkpoint.
    
    Args:
        save_dir: Base directory for saved models
        model: Model instance to load weights into
        model_name: Name of the model ('policy_net' or 'value_net')
    
    Returns:
        epoch, metrics: Information from the checkpoint
    """
    best_path = os.path.join(save_dir, 'models', f'{model_name}_best.pth')
    
    if not os.path.exists(best_path):
        print(f"⚠️  No best model found at: {best_path}")
        return None, None
    
    checkpoint = torch.load(best_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    print(f"✅ Loaded best {model_name} from epoch {checkpoint['epoch']}")
    print(f"   Validation loss: {checkpoint['metrics']['val_loss'][-1]:.4f}")
    
    return checkpoint['epoch'], checkpoint['metrics']

def list_saved_models(save_dir):
    """List all saved model checkpoints."""
    models_dir = os.path.join(save_dir, 'models')
    
    if not os.path.exists(models_dir):
        print(f"⚠️  Models directory not found: {models_dir}")
        return
    
    model_files = sorted([f for f in os.listdir(models_dir) if f.endswith('.pth')])
    
    if not model_files:
        print("No saved models found.")
        return
    
    print("\n📁 Saved Models:")
    print("="*60)
    for f in model_files:
        path = os.path.join(models_dir, f)
        size = os.path.getsize(path) / (1024 * 1024)  # MB
        print(f"  • {f:40s} ({size:.2f} MB)")
    print("="*60)

print("✅ Model management functions defined!")

In [None]:
# List available models
list_saved_models(CONFIG['save_dir'])

## 🎉 Summary

### ✅ What's Included in v2.0:

1. **Configuration Management**: All hyperparameters in one place
2. **Google Drive Integration**: Persistent storage for models, games, and logs
3. **Improved Architecture**: Batch normalization and dropout for better training
4. **Training Utilities**: Early stopping, checkpointing, metrics tracking
5. **Validation Split**: Proper train/val split with monitoring
6. **Top-K Accuracy**: Top-1 and Top-5 accuracy metrics
7. **Self-Play**: Generate and save games in PGN format
8. **Visualization**: Training curves and accuracy plots
9. **Model Management**: Save/load checkpoints, resume training
10. **Interactive Play**: Play against the trained engine

### 🚀 Next Steps:

1. **Upload your PGN data** to `/content/drive/MyDrive/ChessEngine/data/`
2. **Run training** by executing the cells in order
3. **Generate self-play games** to create additional training data
4. **Visualize results** with the plotting functions
5. **Play against your engine** to test its strength

### 📚 Tips:

- Start with a small dataset (50-100 games) to test functionality
- Monitor validation loss to avoid overfitting
- Adjust `patience` and `min_delta` for early stopping
- Save checkpoints frequently in case of interruption
- Use self-play to generate unlimited training data

### �� Troubleshooting:

- **Out of memory**: Reduce `batch_size` in CONFIG
- **Training too slow**: Use GPU runtime (Runtime → Change runtime type)
- **No improvement**: Try different learning rates or architectures
- **Overfitting**: Increase dropout or reduce model complexity

Happy training! 🧠♟️