# SFT Training - Supervised Fine-Tuning for Chess AI

**Self-contained notebook** - All classes defined inline, no external imports needed!

**Features:**
- ‚úÖ Fixed 4672 action space (canonical encoding)
- ‚úÖ Complete implementation in single notebook
- ‚úÖ Google Colab ready
- ‚úÖ Just upload PGN files and run!

## 1. Install Dependencies

In [None]:
!pip install torch numpy python-chess tqdm -q

# Mount Google Drive
try:
    from google.colab import drive
    drive.mount('/content/drive')
except:
    print("Not in Colab")

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import chess
import chess.pgn
import os
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f" Device: {device}")

## 2. Define All Classes

In [None]:
# ============================================================
# CANONICAL MOVE ENCODER - Fixed 4672 action space
# ============================================================
class CanonicalMoveEncoder:
    def __init__(self):
        self.move_to_idx, self.idx_to_move = self._build_canonical_map()
        print(f"üîí Canonical move map: {len(self.move_to_idx)} moves")
    
    def _build_canonical_map(self):
        move_to_idx, idx_to_move = {}, {}
        idx = 0
        
        # Queen moves (3584)
        directions = [(1,0),(1,1),(0,1),(-1,1),(-1,0),(-1,-1),(0,-1),(1,-1)]
        for from_sq in range(64):
            for d_r, d_f in directions:
                for dist in range(1, 8):
                    to_rank = (from_sq // 8) + d_r * dist
                    to_file = (from_sq % 8) + d_f * dist
                    if 0 <= to_rank < 8 and 0 <= to_file < 8:
                        dest = to_rank * 8 + to_file
                        uci = chess.Move(from_sq, dest).uci()
                        move_to_idx[uci] = idx
                        idx_to_move[idx] = uci
                    idx += 1
        
        # Knight moves (512)
        knight_moves = [(2,1),(1,2),(-1,2),(-2,1),(-2,-1),(-1,-2),(1,-2),(2,-1)]
        for from_sq in range(64):
            for d_r, d_f in knight_moves:
                to_rank = (from_sq // 8) + d_r
                to_file = (from_sq % 8) + d_f
                if 0 <= to_rank < 8 and 0 <= to_file < 8:
                    dest = to_rank * 8 + to_file
                    uci = chess.Move(from_sq, dest).uci()
                    move_to_idx[uci] = idx
                    idx_to_move[idx] = uci
                idx += 1
        
        # Underpromotions (576)
        for from_sq in range(64):
            rank = from_sq // 8
            file = from_sq % 8
            rank_step = 1 if rank == 6 else (-1 if rank == 1 else 0)
            
            for f_step in [0, -1, 1]:
                for p in ['r', 'b', 'n']:
                    if rank_step != 0:
                        to_rank = rank + rank_step
                        to_file = file + f_step
                        if 0 <= to_file < 8:
                            dest = to_rank * 8 + to_file
                            uci = chess.Move(from_sq, dest, promotion=chess.Piece.from_symbol(p).piece_type).uci()
                            move_to_idx[uci] = idx
                            idx_to_move[idx] = uci
                    idx += 1
        
        return move_to_idx, idx_to_move
    
    def encode_move(self, move_uci):
        idx = self.move_to_idx.get(move_uci, None)
        if idx is not None:
            return idx
        # Queen promotion fallback
        if isinstance(move_uci, str) and len(move_uci) == 5 and move_uci[-1] == 'q':
            return self.move_to_idx.get(move_uci[:4], None)
        return None


# ============================================================
# BOARD ENCODER - 32-channel tensor representation
# ============================================================
class BoardEncoder:
    def __init__(self):
        self.piece_map = {
            'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
            'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
        }
    
    def encode(self, board, prev_board=None):
        tensor = np.zeros((32, 8, 8), dtype=np.float32)
        
        # Current pieces (0-11)
        for square, piece in board.piece_map().items():
            channel = self.piece_map[piece.symbol()]
            rank, file = chess.square_rank(square), chess.square_file(square)
            tensor[channel][rank][file] = 1.0
        
        # Previous position (12-23)
        if prev_board:
            for square, piece in prev_board.piece_map().items():
                channel = self.piece_map[piece.symbol()] + 12
                rank, file = chess.square_rank(square), chess.square_file(square)
                tensor[channel][rank][file] = 1.0
        
        # Metadata (24-31)
        if board.turn == chess.WHITE: tensor[24, :, :] = 1.0
        if board.has_kingside_castling_rights(chess.WHITE): tensor[25, :, :] = 1.0
        if board.has_queenside_castling_rights(chess.WHITE): tensor[26, :, :] = 1.0
        if board.has_kingside_castling_rights(chess.BLACK): tensor[27, :, :] = 1.0
        if board.has_queenside_castling_rights(chess.BLACK): tensor[28, :, :] = 1.0
        if board.ep_square:
            rank, file = chess.square_rank(board.ep_square), chess.square_file(board.ep_square)
            tensor[29][rank][file] = 1.0
        if board.is_repetition(1): tensor[30, :, :] = 1.0
        if board.is_repetition(2): tensor[31, :, :] = 1.0
        
        return tensor


# ============================================================
# MODEL ARCHITECTURE - SmallResNet
# ============================================================
class ResidualBlock(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(num_channels, num_channels, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.conv2 = nn.Conv2d(num_channels, num_channels, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(num_channels)
    
    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        return F.relu(out)

class SmallResNet(nn.Module):
    def __init__(self, num_res_blocks=6, num_channels=64, action_size=4672):
        super().__init__()
        self.conv_input = nn.Conv2d(32, num_channels, 3, padding=1, bias=False)
        self.bn_input = nn.BatchNorm2d(num_channels)
        self.res_blocks = nn.ModuleList([ResidualBlock(num_channels) for _ in range(num_res_blocks)])
        
        # Policy head
        self.policy_conv = nn.Conv2d(num_channels, 32, 1, bias=False)
        self.policy_bn = nn.BatchNorm2d(32)
        self.policy_fc = nn.Linear(32 * 8 * 8, action_size)
        
        # Value head
        self.value_conv = nn.Conv2d(num_channels, 3, 1, bias=False)
        self.value_bn = nn.BatchNorm2d(3)
        self.value_fc1 = nn.Linear(3 * 8 * 8, 64)
        self.value_fc2 = nn.Linear(64, 1)
    
    def forward(self, x):
        x = F.relu(self.bn_input(self.conv_input(x)))
        for block in self.res_blocks:
            x = block(x)
        
        # Policy
        p = F.relu(self.policy_bn(self.policy_conv(x)))
        p = p.view(-1, 32 * 8 * 8)
        policy_logits = self.policy_fc(p)
        
        # Value
        v = F.relu(self.value_bn(self.value_conv(x)))
        v = v.view(-1, 3 * 8 * 8)
        v = F.relu(self.value_fc1(v))
        value = torch.tanh(self.value_fc2(v))
        
        return policy_logits, value


# ============================================================
# LOSS FUNCTION
# ============================================================
class AlphaZeroLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
    
    def forward(self, policy_pred, value_pred, policy_target, value_target):
        log_policy = torch.log_softmax(policy_pred, dim=1)
        policy_loss = -(policy_target * log_policy).sum(dim=1).mean()
        value_loss = self.mse(value_pred, value_target)
        return policy_loss + value_loss, policy_loss, value_loss


# ============================================================
# DATASET
# ============================================================
class ChessPGNDataset(Dataset):
    def __init__(self, pgn_paths, max_games=None):
        self.move_encoder = CanonicalMoveEncoder()
        self.board_encoder = BoardEncoder()
        self.samples = []
        
        for pgn_path in pgn_paths:
            print(f"\nüìÇ Processing: {pgn_path}")
            self._process_pgn(pgn_path, max_games)
        
        print(f"\n Total samples: {len(self.samples)}")
    
    def _process_pgn(self, pgn_path, max_games):
        games_processed = 0
        
        with open(pgn_path, 'r', encoding='utf-8', errors='ignore') as f:
            pbar = tqdm(desc="Processing games")
            
            while True:
                if max_games and games_processed >= max_games:
                    break
                
                game = chess.pgn.read_game(f)
                if game is None:
                    break
                
                result = game.headers.get("Result", "*")
                if result == "1-0":
                    game_outcome = 1.0
                elif result == "0-1":
                    game_outcome = -1.0
                elif result == "1/2-1/2":
                    game_outcome = 0.0
                else:
                    continue
                
                board = game.board()
                prev_board = None
                
                for move in game.mainline_moves():
                    state = self.board_encoder.encode(board, prev_board)
                    move_idx = self.move_encoder.encode_move(move.uci())
                    
                    if move_idx is not None:
                        policy_target = np.zeros(4672, dtype=np.float32)
                        policy_target[move_idx] = 1.0
                        value_target = game_outcome if board.turn == chess.WHITE else -game_outcome
                        self.samples.append((state, policy_target, value_target))
                    
                    prev_board = board.copy()
                    board.push(move)
                
                games_processed += 1
                pbar.update(1)
                pbar.set_postfix({'samples': len(self.samples)})
            
            pbar.close()
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        state, policy, value = self.samples[idx]
        return {
            'state': torch.FloatTensor(state),
            'policy': torch.FloatTensor(policy),
            'value': torch.FloatTensor([value])
        }

print(" All classes defined!")

## 3. Configuration

In [None]:
# ============================================================
# CONFIGURATION - Update paths for your data!
# ============================================================
PGN_PATHS = [
    '/content/drive/MyDrive/data/games_2000.pgn',  # ‚Üê UPDATE THESE!
    '/content/drive/MyDrive/data/games_2001.pgn',
]

SAVE_PATH = '/content/drive/MyDrive/models'  # ‚Üê UPDATE THIS!

# Hyperparameters
MAX_GAMES = 8000
BATCH_SIZE = 256
EPOCHS = 20
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4

print(f"üìä Configuration:")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Epochs: {EPOCHS}")
print(f"   Learning rate: {LEARNING_RATE}")

## 4. Load Dataset

In [None]:
dataset = ChessPGNDataset(PGN_PATHS, max_games=MAX_GAMES)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

print(f" Dataset ready: {len(dataset)} samples")

## 5. Initialize Model

In [None]:
model = SmallResNet(num_res_blocks=6, num_channels=64, action_size=4672).to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
criterion = AlphaZeroLoss()

total_params = sum(p.numel() for p in model.parameters())
print(f" Model initialized")
print(f"   Parameters: {total_params:,}")
print(f"   Policy output: 4672")

## 6. Training Loop

In [None]:
os.makedirs(SAVE_PATH, exist_ok=True)
best_loss = float('inf')

for epoch in range(1, EPOCHS + 1):
    model.train()
    total_loss = 0.0
    
    pbar = tqdm(loader, desc=f"Epoch {epoch}/{EPOCHS}")
    
    for batch in pbar:
        states = batch['state'].to(device)
        policy_targets = batch['policy'].to(device)
        value_targets = batch['value'].to(device)
        
        optimizer.zero_grad()
        policy_pred, value_pred = model(states)
        
        loss, p_loss, v_loss = criterion(policy_pred, value_pred, policy_targets, value_targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'p': f'{p_loss.item():.4f}', 'v': f'{v_loss.item():.4f}'})
    
    avg_loss = total_loss / len(loader)
    print(f"\nEpoch {epoch} - Avg Loss: {avg_loss:.4f}")
    
    # Save checkpoint
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'architecture': {'action_size': 4672, 'num_res_blocks': 6, 'num_channels': 64},
        'metadata': {'epoch': epoch, 'loss': avg_loss, 'stage': 'sft'}
    }, f"{SAVE_PATH}/sft_epoch_{epoch}.pth")
    
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'architecture': {'action_size': 4672, 'num_res_blocks': 6, 'num_channels': 64},
            'metadata': {'epoch': epoch, 'loss': avg_loss, 'stage': 'sft'}
        }, f"{SAVE_PATH}/sft_best.pth")
        print(f"üåü New best model saved!")

print(f"\n Training complete! Best loss: {best_loss:.4f}")
print(f"üìÅ Model saved to: {SAVE_PATH}/sft_best.pth")