# RL Training - Reinforcement Learning with MCTS Self-Play

**Self-contained notebook** - All classes defined inline!

**Features:**
- ‚úÖ Loads SFT checkpoint automatically
- ‚úÖ MCTS-guided self-play
- ‚úÖ Complete implementation in single notebook
- ‚úÖ Google Colab ready

## 1. Install Dependencies

In [None]:
!pip install torch numpy python-chess tqdm -q

# Mount Google Drive
try:
    from google.colab import drive
    drive.mount('/content/drive')
except:
    print("Not in Colab")

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import chess
import os
import random
import pickle
from collections import deque
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f" Device: {device}")

## 2. Define All Classes (Copy from SFT + Add MCTS)

In [None]:
# ============================================================
# CANONICAL MOVE ENCODER
# ============================================================
class CanonicalMoveEncoder:
    def __init__(self):
        self.move_to_idx, self.idx_to_move = self._build_canonical_map()
    
    def _build_canonical_map(self):
        move_to_idx, idx_to_move = {}, {}
        idx = 0
        
        # Queen moves
        directions = [(1,0),(1,1),(0,1),(-1,1),(-1,0),(-1,-1),(0,-1),(1,-1)]
        for from_sq in range(64):
            for d_r, d_f in directions:
                for dist in range(1, 8):
                    to_rank = (from_sq // 8) + d_r * dist
                    to_file = (from_sq % 8) + d_f * dist
                    if 0 <= to_rank < 8 and 0 <= to_file < 8:
                        dest = to_rank * 8 + to_file
                        uci = chess.Move(from_sq, dest).uci()
                        move_to_idx[uci] = idx
                        idx_to_move[idx] = uci
                    idx += 1
        
        # Knight moves
        knight_moves = [(2,1),(1,2),(-1,2),(-2,1),(-2,-1),(-1,-2),(1,-2),(2,-1)]
        for from_sq in range(64):
            for d_r, d_f in knight_moves:
                to_rank = (from_sq // 8) + d_r
                to_file = (from_sq % 8) + d_f
                if 0 <= to_rank < 8 and 0 <= to_file < 8:
                    dest = to_rank * 8 + to_file
                    uci = chess.Move(from_sq, dest).uci()
                    move_to_idx[uci] = idx
                    idx_to_move[idx] = uci
                idx += 1
        
        # Underpromotions
        for from_sq in range(64):
            rank = from_sq // 8
            file = from_sq % 8
            rank_step = 1 if rank == 6 else (-1 if rank == 1 else 0)
            for f_step in [0, -1, 1]:
                for p in ['r', 'b', 'n']:
                    if rank_step != 0:
                        to_rank = rank + rank_step
                        to_file = file + f_step
                        if 0 <= to_file < 8:
                            dest = to_rank * 8 + to_file
                            uci = chess.Move(from_sq, dest, promotion=chess.Piece.from_symbol(p).piece_type).uci()
                            move_to_idx[uci] = idx
                            idx_to_move[idx] = uci
                    idx += 1
        return move_to_idx, idx_to_move
    
    def encode_move(self, move_uci):
        idx = self.move_to_idx.get(move_uci, None)
        if idx is not None:
            return idx
        if isinstance(move_uci, str) and len(move_uci) == 5 and move_uci[-1] == 'q':
            return self.move_to_idx.get(move_uci[:4], None)
        return None


# ============================================================
# BOARD ENCODER
# ============================================================
class BoardEncoder:
    def __init__(self):
        self.piece_map = {'P':0,'N':1,'B':2,'R':3,'Q':4,'K':5,'p':6,'n':7,'b':8,'r':9,'q':10,'k':11}
    
    def encode(self, board, prev_board=None):
        tensor = np.zeros((32, 8, 8), dtype=np.float32)
        for square, piece in board.piece_map().items():
            channel = self.piece_map[piece.symbol()]
            rank, file = chess.square_rank(square), chess.square_file(square)
            tensor[channel][rank][file] = 1.0
        if prev_board:
            for square, piece in prev_board.piece_map().items():
                channel = self.piece_map[piece.symbol()] + 12
                rank, file = chess.square_rank(square), chess.square_file(square)
                tensor[channel][rank][file] = 1.0
        if board.turn == chess.WHITE: tensor[24, :, :] = 1.0
        if board.has_kingside_castling_rights(chess.WHITE): tensor[25, :, :] = 1.0
        if board.has_queenside_castling_rights(chess.WHITE): tensor[26, :, :] = 1.0
        if board.has_kingside_castling_rights(chess.BLACK): tensor[27, :, :] = 1.0
        if board.has_queenside_castling_rights(chess.BLACK): tensor[28, :, :] = 1.0
        if board.ep_square:
            rank, file = chess.square_rank(board.ep_square), chess.square_file(board.ep_square)
            tensor[29][rank][file] = 1.0
        if board.is_repetition(1): tensor[30, :, :] = 1.0
        if board.is_repetition(2): tensor[31, :, :] = 1.0
        return tensor


# ============================================================
# MODEL
# ============================================================
class ResidualBlock(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(num_channels, num_channels, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.conv2 = nn.Conv2d(num_channels, num_channels, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(num_channels)
    
    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return F.relu(out + residual)

class SmallResNet(nn.Module):
    def __init__(self, num_res_blocks=6, num_channels=64, action_size=4672):
        super().__init__()
        self.conv_input = nn.Conv2d(32, num_channels, 3, padding=1, bias=False)
        self.bn_input = nn.BatchNorm2d(num_channels)
        self.res_blocks = nn.ModuleList([ResidualBlock(num_channels) for _ in range(num_res_blocks)])
        self.policy_conv = nn.Conv2d(num_channels, 32, 1, bias=False)
        self.policy_bn = nn.BatchNorm2d(32)
        self.policy_fc = nn.Linear(32 * 8 * 8, action_size)
        self.value_conv = nn.Conv2d(num_channels, 3, 1, bias=False)
        self.value_bn = nn.BatchNorm2d(3)
        self.value_fc1 = nn.Linear(3 * 8 * 8, 64)
        self.value_fc2 = nn.Linear(64, 1)
    
    def forward(self, x):
        x = F.relu(self.bn_input(self.conv_input(x)))
        for block in self.res_blocks:
            x = block(x)
        p = F.relu(self.policy_bn(self.policy_conv(x)))
        p = p.view(-1, 32 * 8 * 8)
        policy_logits = self.policy_fc(p)
        v = F.relu(self.value_bn(self.value_conv(x)))
        v = v.view(-1, 3 * 8 * 8)
        v = F.relu(self.value_fc1(v))
        value = torch.tanh(self.value_fc2(v))
        return policy_logits, value


# ============================================================
# MCTS
# ============================================================
class MCTSNode:
    def __init__(self, parent=None, move=None, prior=0.0):
        self.parent = parent
        self.move = move
        self.prior = prior
        self.children = {}
        self.visit_count = 0
        self.value_sum = 0.0
    
    def value(self):
        return self.value_sum / self.visit_count if self.visit_count > 0 else 0.0
    
    def select_child(self, c_puct=1.5):
        best_score, best_child = -float('inf'), None
        for child in self.children.values():
            q_value = -child.value() if child.visit_count > 0 else 0.0
            u_value = c_puct * child.prior * np.sqrt(max(1, self.visit_count)) / (1 + child.visit_count)
            score = q_value + u_value
            if score > best_score:
                best_score, best_child = score, child
        return best_child
    
    def expand(self, moves_and_priors):
        for move, prior in moves_and_priors.items():
            if move not in self.children:
                self.children[move] = MCTSNode(parent=self, move=move, prior=prior)
    
    def backpropagate(self, value):
        node = self
        while node:
            node.visit_count += 1
            node.value_sum += value
            value = -value
            node = node.parent

class MCTS:
    def __init__(self, model, move_encoder, board_encoder, c_puct=1.5, device='cpu'):
        self.model = model
        self.move_encoder = move_encoder
        self.board_encoder = board_encoder
        self.c_puct = c_puct
        self.device = device
    
    def search(self, board, num_simulations, temperature=1.0, prev_board=None):
        root = MCTSNode()
        self._expand_node(root, board, prev_board)
        
        for _ in range(num_simulations):
            node = root
            search_board = board.copy()
            search_prev_board = prev_board
            
            while len(node.children) > 0 and not search_board.is_game_over():
                node = node.select_child(self.c_puct)
                if node.move:
                    search_prev_board = search_board.copy()
                    search_board.push(node.move)
            
            if not search_board.is_game_over():
                value = self._expand_node(node, search_board, search_prev_board)
            else:
                result = search_board.result()
                if result == "1-0":
                    value = -1.0 if search_board.turn == chess.BLACK else 1.0
                elif result == "0-1":
                    value = -1.0 if search_board.turn == chess.WHITE else 1.0
                else:
                    value = 0.0
            
            node.backpropagate(value)
        
        return self._get_action_probs(root, temperature)
    
    def _expand_node(self, node, board, prev_board):
        state = self.board_encoder.encode(board, prev_board)
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            policy_logits, value = self.model(state_tensor)
        
        policy = torch.softmax(policy_logits, dim=1).cpu().numpy()[0]
        value = value.cpu().item()
        
        legal_moves = list(board.legal_moves)
        moves_and_priors = {}
        total_prior = 0.0
        
        for move in legal_moves:
            move_idx = self.move_encoder.encode_move(move.uci())
            prior = policy[move_idx] if move_idx is not None else 0.0
            moves_and_priors[move] = prior
            total_prior += prior
        
        if total_prior > 0:
            moves_and_priors = {m: p/total_prior for m, p in moves_and_priors.items()}
        else:
            uniform = 1.0 / len(legal_moves)
            moves_and_priors = {m: uniform for m in legal_moves}
        
        node.expand(moves_and_priors)
        return value
    
    def _get_action_probs(self, root, temperature):
        action_probs = np.zeros(4672, dtype=np.float32)
        if len(root.children) == 0:
            return action_probs
        
        moves = list(root.children.keys())
        visits = np.array([root.children[m].visit_count for m in moves], dtype=np.float32)
        
        if temperature == 0:
            probs = np.zeros(len(visits))
            probs[np.argmax(visits)] = 1.0
        else:
            visits = visits ** (1.0 / temperature)
            probs = visits / visits.sum()
        
        for move, prob in zip(moves, probs):
            move_idx = self.move_encoder.encode_move(move.uci())
            if move_idx is not None:
                action_probs[move_idx] = prob
        
        return action_probs


# ============================================================
# REPLAY BUFFER
# ============================================================
class ReplayBuffer:
    def __init__(self, max_size=100000):
        self.buffer = deque(maxlen=max_size)
    
    def add_batch(self, samples):
        self.buffer.extend(samples)
    
    def sample(self, batch_size):
        return random.sample(self.buffer, min(batch_size, len(self.buffer)))
    
    def __len__(self):
        return len(self.buffer)


# ============================================================
# LOSS FUNCTION
# ============================================================
class AlphaZeroLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
    
    def forward(self, policy_pred, value_pred, policy_target, value_target):
        log_policy = torch.log_softmax(policy_pred, dim=1)
        policy_loss = -(policy_target * log_policy).sum(dim=1).mean()
        value_loss = self.mse(value_pred, value_target)
        return policy_loss + value_loss, policy_loss, value_loss

print(" All classes defined!")

## 3. Configuration

In [None]:
# ============================================================
# CONFIGURATION
# ============================================================
SFT_CHECKPOINT = '/content/drive/MyDrive/models/sft_best.pth'  # ‚Üê UPDATE!
SAVE_PATH = '/content/drive/MyDrive/models'

# RL Hyperparameters
RL_ITERATIONS = 100
GAMES_PER_ITER = 30
NUM_SIMULATIONS = 100
BATCH_SIZE = 64
LEARNING_RATE = 5e-5
EVAL_INTERVAL = 5

print(f"üìä RL Configuration:")
print(f"   MCTS simulations: {NUM_SIMULATIONS}")
print(f"   Games per iteration: {GAMES_PER_ITER}")
print(f"   Total iterations: {RL_ITERATIONS}")

## 4. Load Model & SFT Checkpoint

In [None]:
model = SmallResNet(num_res_blocks=6, num_channels=64, action_size=4672).to(device)

if os.path.exists(SFT_CHECKPOINT):
    print(f"üîÑ Loading SFT checkpoint: {SFT_CHECKPOINT}")
    checkpoint = torch.load(SFT_CHECKPOINT, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f" SFT checkpoint loaded!")
else:
    print(f" No SFT checkpoint found, training from scratch")

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = AlphaZeroLoss()

move_encoder = CanonicalMoveEncoder()
board_encoder = BoardEncoder()
mcts = MCTS(model, move_encoder, board_encoder, c_puct=1.5, device=device)
replay_buffer = ReplayBuffer(max_size=100000)

print(f" RL trainer initialized")

## 5. Self-Play Function

In [None]:
def generate_self_play_game(mcts, num_simulations=100):
    board = chess.Board()
    samples = []
    prev_board = None
    move_count = 0
    
    while not board.is_game_over() and move_count < 200:
        state = board_encoder.encode(board, prev_board)
        temperature = 1.0 if move_count < 30 else 0.0
        
        policy_target = mcts.search(board, num_simulations, temperature, prev_board)
        samples.append((state, policy_target, None, board.turn == chess.WHITE))
        
        # Select move
        legal_moves = list(board.legal_moves)
        legal_probs = []
        for move in legal_moves:
            move_idx = move_encoder.encode_move(move.uci())
            legal_probs.append(policy_target[move_idx] if move_idx is not None else 0.0)
        
        if sum(legal_probs) > 0:
            legal_probs = np.array(legal_probs, dtype=np.float64)
            legal_probs = np.clip(legal_probs, 0.0, None)
            legal_probs /= legal_probs.sum()
            move_idx = np.random.choice(len(legal_moves), p=legal_probs)
            selected_move = legal_moves[move_idx]
        else:
            selected_move = random.choice(legal_moves)
        
        prev_board = board.copy()
        board.push(selected_move)
        move_count += 1
    
    # Determine outcome
    if move_count >= 200:
        game_outcome = 0.0
    else:
        result = board.result()
        game_outcome = 1.0 if result == "1-0" else (-1.0 if result == "0-1" else 0.0)
    
    # Fill in values
    final_samples = []
    for state, policy, _, was_white in samples:
        value = game_outcome if was_white else -game_outcome
        final_samples.append((state, policy, value))
    
    return final_samples, game_outcome

print(" Self-play function defined")

## 6. RL Training Loop

In [None]:
os.makedirs(SAVE_PATH, exist_ok=True)
best_winrate = 0.0

for iteration in range(1, RL_ITERATIONS + 1):
    print(f"\n{'='*60}")
    print(f"Iteration {iteration}/{RL_ITERATIONS}")
    print(f"{'='*60}")
    
    # Self-play
    print(f"üéÆ Self-play: {GAMES_PER_ITER} games...")
    model.eval()
    for _ in tqdm(range(GAMES_PER_ITER), desc="Self-play"):
        samples, _ = generate_self_play_game(mcts, NUM_SIMULATIONS)
        replay_buffer.add_batch(samples)
    
    print(f"üìä Buffer size: {len(replay_buffer)} samples")
    
    # Training
    print(f"üèãÔ∏è Training on replay buffer...")
    model.train()
    
    for step in range(500):  # Train for 500 steps
        batch = replay_buffer.sample(BATCH_SIZE)
        states, policies, values = zip(*batch)
        
        states = torch.stack([torch.FloatTensor(s) for s in states]).to(device)
        policies = torch.stack([torch.FloatTensor(p) for p in policies]).to(device)
        values = torch.FloatTensor(values).unsqueeze(1).to(device)
        
        policy_pred, value_pred = model(states)
        loss, p_loss, v_loss = criterion(policy_pred, value_pred, policies, values)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        if step % 100 == 0:
            print(f"  Step {step}/500: Loss={loss.item():.4f}")
    
    # Save checkpoint
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'architecture': {'action_size': 4672, 'num_res_blocks': 6, 'num_channels': 64},
        'metadata': {'iteration': iteration, 'buffer_size': len(replay_buffer), 'stage': 'rl'}
    }, f"{SAVE_PATH}/rl_iter_{iteration}.pth")
    
    print(f"üíæ Checkpoint saved: rl_iter_{iteration}.pth")

print(f"\n RL Training complete!")
print(f"üìÅ Final model: {SAVE_PATH}/rl_iter_{RL_ITERATIONS}.pth")