In [1]:
!pip install torch pytorch-lightning python-chess



In [2]:
import math
import random
import logging
import datetime
from typing import List, Tuple, Dict
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import chess
import chess.pgn

# Configuration

logging.basicConfig(level=logging.DEBUG, format='%(asctime)s [%(levelname)s] %(message)s')

# Hyperparameters
NUM_SELFPLAY_GAMES = 2         # Number of games per training epoch
MAX_MOVES_PER_GAME = 500       # Safety limit on moves
DISCOUNT_FACTOR = 0.99
LR = 1e-3
PPO_EPOCHS = 3
BATCH_SIZE = 32
EPS_CLIP = 0.2
ENTROPY_BONUS = 0.01
MCTS_SIMULATIONS = 32
REWARD_WIN = 1.0
REWARD_LOSS = -1.0
REWARD_DRAW = 0.0

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {DEVICE}")

# Transformer parameters
BOARD_FEATURES = 13   # Channels for board encoding
EMBED_DIM = 128       # Transformer embedding dimension
NUM_HEADS = 4         # Transformer multi-head attention
NUM_LAYERS = 2        # Number of Transformer encoder layers
FFN_HIDDEN = 256      # Hidden size in feed-forward network
DROPOUT = 0.1

# Action space: 8×8×8×8 = 4096
ACTION_SPACE_SIZE = 64 * 64

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f95ea4e9ad0>>
Traceback (most recent call last):
  File "/home/mmhfn1/anaconda3/envs/py311/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 
2025-01-03 18:47:16,465 [INFO] Using device: cuda


In [3]:
# Board to tensor encoding and decoding

def board_to_tensor(board: chess.Board) -> torch.Tensor:
    """
    Encode the chess.Board into a [13, 8, 8] tensor:
      - 6 channels for white pieces (P, N, B, R, Q, K)
      - 6 channels for black pieces (p, n, b, r, q, k)
      - 1 channel for side to move
    """
    t = torch.zeros((BOARD_FEATURES, 8, 8), dtype=torch.float32)
    for square, piece in board.piece_map().items():
        row = 7 - (square // 8)
        col = square % 8
        channel = None
        if piece.color == chess.WHITE:
            if piece.piece_type == chess.PAWN:
                channel = 0
            elif piece.piece_type == chess.KNIGHT:
                channel = 1
            elif piece.piece_type == chess.BISHOP:
                channel = 2
            elif piece.piece_type == chess.ROOK:
                channel = 3
            elif piece.piece_type == chess.QUEEN:
                channel = 4
            elif piece.piece_type == chess.KING:
                channel = 5
        else:
            if piece.piece_type == chess.PAWN:
                channel = 6
            elif piece.piece_type == chess.KNIGHT:
                channel = 7
            elif piece.piece_type == chess.BISHOP:
                channel = 8
            elif piece.piece_type == chess.ROOK:
                channel = 9
            elif piece.piece_type == chess.QUEEN:
                channel = 10
            elif piece.piece_type == chess.KING:
                channel = 11

        if channel is not None:
            t[channel, row, col] = 1.0

    if board.turn == chess.WHITE:
        t[12, :, :] = 1.0

    return t


def move_to_index(move: chess.Move) -> int:
    # Map a python-chess Move to an index in [0..4095].
    # from_square in [0..63], to_square in [0..63].
    # index = from_square * 64 + to_square
    return move.from_square * 64 + move.to_square


def index_to_move(index: int, board: chess.Board) -> chess.Move:
    # Inverse of move_to_index. We decode (from_square, to_square)
    # and construct a python-chess Move. For promotions, default to queen.
    from_sq = index // 64
    to_sq = index % 64
    move = chess.Move(from_sq, to_sq)
    # Try adding promotion if it's a pawn move to the back rank
    if move not in board.legal_moves:
        # If it's a pawn move that might require promotion:
        if (chess.square_rank(to_sq) in [0, 7] 
            and board.piece_at(from_sq) is not None
            and board.piece_at(from_sq).piece_type == chess.PAWN):
            move = chess.Move(from_sq, to_sq, promotion=chess.QUEEN)
    return move


def chess_outcome_to_reward(outcome: chess.Outcome) -> float:
    # Convert a python-chess outcome to a numeric reward from White's perspective.
    if outcome.winner is None:
        return REWARD_DRAW
    elif outcome.winner == chess.WHITE:
        return REWARD_WIN
    else:
        return REWARD_LOSS

In [4]:
# Transformer Based NN

class PositionalEncoding2D(nn.Module):
    # 2D positional encoding for an 8x8 board.
    # Producing an [embed_dim, height, width] embedding that can be added to the board embeddings.
    def __init__(self, embed_dim=128, height=8, width=8):
        super().__init__()
        self.embed_dim = embed_dim
        self.height = height
        self.width = width

        d2 = embed_dim // 2  # E.g., 128 -> 64
        pe = torch.zeros(embed_dim, height, width)  # shape: [embed_dim, 8, 8]

        # We only need d2/2 "frequencies" for sine and cosine
        num_freqs = d2 // 2  # e.g., 64//2=32
        div_term = torch.exp(torch.arange(0, num_freqs, 1) * -(math.log(10000.0) / num_freqs))

        # Positions along rows & cols
        pos_y = torch.arange(0, height).unsqueeze(1)  # shape: [8,1]
        pos_x = torch.arange(0, width).unsqueeze(1)   # shape: [8,1]

        pe_y = torch.zeros(d2, height)
        pe_x = torch.zeros(d2, width)

        # Fill in row-based sines/cosines
        for idx in range(num_freqs):
            i_sin = 2 * idx
            i_cos = 2 * idx + 1
            pe_y[i_sin, :] = torch.sin(pos_y[:, 0] * div_term[idx])
            pe_y[i_cos, :] = torch.cos(pos_y[:, 0] * div_term[idx])

        # Fill in col-based sines/cosines (occupies the second half of channels)
        for idx in range(num_freqs):
            i_sin = 2 * idx
            i_cos = 2 * idx + 1
            pe_x[i_sin, :] = torch.sin(pos_x[:, 0] * div_term[idx])
            pe_x[i_cos, :] = torch.cos(pos_x[:, 0] * div_term[idx])

        # Combine them: first half for rows, second half for cols
        for y in range(height):
            for x in range(width):
                pe[:d2, y, x] = pe_y[:, y]
                pe[d2:, y, x] = pe_x[:, x]

        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x shape: [batch_size, embed_dim, height, width]
        return x + self.pe.unsqueeze(0).to(x.device)


class TransformerChessNet(nn.Module):
    # A Transformer-based model for chess. We flatten the 8x8 to a sequence of 64 tokens,
    # each token having a dimension of 'embed_dim'. Then we pass it through a Transformer
    # encoder to produce a final representation used by the policy & value heads.
    # Policy dimension: 4096
    # Value dimension: 1
    def __init__(self, name: str):
        super().__init__()
        self.name = name

        # Initial encoding: Convert [13,8,8] to [embed_dim, 8,8] via a simple Conv
        self.initial_conv = nn.Conv2d(BOARD_FEATURES, EMBED_DIM, kernel_size=3, padding=1)
        self.pos_enc = PositionalEncoding2D(embed_dim=EMBED_DIM, height=8, width=8)

        # Flatten 8x8 into sequence of length 64
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=EMBED_DIM,
            nhead=NUM_HEADS,
            dim_feedforward=FFN_HIDDEN,
            dropout=DROPOUT,
            activation='relu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=NUM_LAYERS)

        # Policy & value heads
        self.policy_head = nn.Linear(EMBED_DIM, ACTION_SPACE_SIZE)
        self.value_head = nn.Linear(EMBED_DIM, 1)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # x shape: [batch_size, 13, 8, 8]
        # Returns policy_logits: [batch_size, 4096]
        # value: [batch_size, 1]
        out = self.initial_conv(x)  # [batch_size, embed_dim, 8, 8]
        out = self.pos_enc(out)     # add positional encoding

        # Flatten to [batch_size, 64, embed_dim]
        bsz, edim, h, w = out.shape
        out = out.permute(0, 2, 3, 1).contiguous()
        out = out.view(bsz, h*w, edim)  # [B, 64, embed_dim]

        out = self.transformer(out)     # [B, 64, embed_dim]

        # Pool (average) to get a single embedding
        pooled = out.mean(dim=1)  # [B, embed_dim]

        policy_logits = self.policy_head(pooled)  # [B, 4096]
        value = torch.tanh(self.value_head(pooled)) # [B, 1]

        return policy_logits, value

In [5]:
# MCTS

class MCTSNode:
    def __init__(self, board: chess.Board, parent=None):
        self.board = board
        self.parent = parent
        self.children: Dict[chess.Move, "MCTSNode"] = {}
        self.visit_count = 0
        self.value_sum = 0.0
        self.prior = 0.0

    def is_leaf(self):
        return len(self.children) == 0

    def expand(self, moves: List[chess.Move], priors: torch.Tensor):
        # Expand for each legal move
        for move in moves:
            child_board = self.board.copy()
            child_board.push(move)
            idx = move_to_index(move)
            p = priors[idx].item() if idx < ACTION_SPACE_SIZE else 0.0
            child_node = MCTSNode(child_board, parent=self)
            child_node.prior = p
            self.children[move] = child_node

    def value(self):
        if self.visit_count == 0:
            return 0.0
        return self.value_sum / self.visit_count

    def ucb_score(self, total_visit_count, c_puct=1.0):
        return self.value() + c_puct * self.prior * math.sqrt(total_visit_count) / (1 + self.visit_count)


def mcts_search(root: MCTSNode, net: nn.Module, simulations=MCTS_SIMULATIONS):
    # Basic MCTS: For each simulation, we traverse down, expand the leaf, evaluate,
    # and backpropagate the result.
    net.eval()
    for sim in range(simulations):
        node = root
        # Selection
        while not node.is_leaf() and not node.board.is_game_over():
            best_move, best_score, best_child = None, -float("inf"), None
            for move, child in node.children.items():
                score = child.ucb_score(node.visit_count)
                if score > best_score:
                    best_score = score
                    best_move = move
                    best_child = child
            node = best_child

        # Expansion
        if not node.board.is_game_over():
            board_tensor = board_to_tensor(node.board).unsqueeze(0).to(DEVICE)
            with torch.no_grad():
                policy_logits, value = net(board_tensor)
            policy_logits = F.softmax(policy_logits[0], dim=-1)  # [4096]
            legal_moves = list(node.board.legal_moves)
            node.expand(legal_moves, policy_logits)

        # Evaluation
        if node.board.is_game_over():
            outcome = node.board.outcome()
            current_value = chess_outcome_to_reward(outcome)
        else:
            current_value = value.item()

        # Backpropagation
        while node is not None:
            node.visit_count += 1
            node.value_sum += current_value
            node = node.parent
            current_value = -current_value

        logging.debug(f"MCTS simulation {sim+1}/{simulations} done")


def select_move_mcts(board: chess.Board, net: nn.Module, simulations=MCTS_SIMULATIONS, temperature=1.0) -> chess.Move:
    # Build a root node for the given board, expand once from the net, run MCTS,
    # then pick a move from children with probability proportional to visit counts.
    root = MCTSNode(board.copy())
    with torch.no_grad():
        board_tensor = board_to_tensor(board).unsqueeze(0).to(DEVICE)
        policy_logits, value = net(board_tensor)
    policy_logits = F.softmax(policy_logits[0], dim=-1)
    legal_moves = list(board.legal_moves)
    root.expand(legal_moves, policy_logits)

    # Run MCTS
    mcts_search(root, net, simulations=simulations)

    # pick child with probability = visit_count^(1/temperature)
    visit_counts = torch.tensor([child.visit_count for child in root.children.values()], dtype=torch.float32)
    if visit_counts.sum() == 0:
        # fallback
        logging.debug("No visit counts in MCTS root children, returning None.")
        return None

    if temperature < 1e-8:
        move_index = torch.argmax(visit_counts).item()
    else:
        probs = visit_counts ** (1.0 / temperature)
        move_index = torch.multinomial(probs, 1).item()

    selected_move = list(root.children.keys())[move_index]
    logging.debug(f"MCTS selected move: {selected_move}")
    return selected_move

In [6]:
# Replay Buffer

class ReplayBuffer:
    def __init__(self):
        self.states = []
        self.actions = []
        self.log_probs = []
        self.values = []
        self.rewards = []
        self.dones = []

    def add(self, state, action, log_prob, value, reward, done):
        self.states.append(state)
        self.actions.append(action)
        self.log_probs.append(log_prob)
        self.values.append(value)
        self.rewards.append(reward)
        self.dones.append(done)

    def clear(self):
        self.states.clear()
        self.actions.clear()
        self.log_probs.clear()
        self.values.clear()
        self.rewards.clear()
        self.dones.clear()

    def size(self):
        return len(self.states)

In [7]:
# PPO Update

def ppo_update(net: nn.Module, optimizer: optim.Optimizer,
               states: torch.Tensor, actions: torch.Tensor,
               old_log_probs: torch.Tensor,
               returns: torch.Tensor, advantages: torch.Tensor):
    # Perform one epoch of PPO updates on the given batch.
    
    # shuffle
    indices = torch.randperm(states.size(0)).to(DEVICE)
    for start in range(0, states.size(0), BATCH_SIZE):
        end = start + BATCH_SIZE
        batch_idx = indices[start:end]

        batch_states = states[batch_idx]
        batch_actions = actions[batch_idx]
        batch_old_log_probs = old_log_probs[batch_idx]
        batch_returns = returns[batch_idx]
        batch_advantages = advantages[batch_idx]

        policy_logits, values = net(batch_states)  # policy_logits: [bs, 4096], values: [bs,1]

        policy_log_probs = F.log_softmax(policy_logits, dim=-1)  # [bs, 4096]
        new_log_probs = policy_log_probs.gather(1, batch_actions.unsqueeze(-1)).squeeze(-1)  # [bs]

        ratio = torch.exp(new_log_probs - batch_old_log_probs)

        surr1 = ratio * batch_advantages
        surr2 = torch.clamp(ratio, 1.0 - EPS_CLIP, 1.0 + EPS_CLIP) * batch_advantages
        policy_loss = -torch.min(surr1, surr2).mean()

        value_loss = F.mse_loss(values.squeeze(-1), batch_returns)

        entropy = -(policy_log_probs * torch.exp(policy_log_probs)).sum(dim=-1).mean()

        loss = policy_loss + 0.5 * value_loss - ENTROPY_BONUS * entropy

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    logging.debug(f"PPO update done with {states.size(0)} samples.")

In [8]:
# CHessBot Class

class ChessBot:
    def __init__(self, name: str):
        self.name = name
        self.net = TransformerChessNet(name).to(DEVICE)
        logging.debug(f"{name}: Model initialized. First parameter device: {next(self.net.parameters()).device}")
        self.optimizer = optim.Adam(self.net.parameters(), lr=LR)
        self.replay = ReplayBuffer()

    def select_move(self, board: chess.Board) -> Tuple[chess.Move, float, float]:
        # Uses MCTS to select a move, returns (move, log_prob, value).
        self.net.eval()
        move = select_move_mcts(board, self.net, simulations=MCTS_SIMULATIONS, temperature=1.0)
        if move is None:
            return None, 0.0, 0.0

        # compute log_prob and value from the net
        board_tensor = board_to_tensor(board).unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            policy_logits, value = self.net(board_tensor)
        policy_log_probs = F.log_softmax(policy_logits, dim=-1)  # [1, 4096]

        move_idx = move_to_index(move)
        log_prob = policy_log_probs[0, move_idx].item()
        val = value[0].item()
        return move, log_prob, val

    def train_on_replay(self):
        size = self.replay.size()
        logging.debug(f"{self.name}: Replay buffer size = {size}")
        if size == 0:
            return

        states = torch.stack(self.replay.states).to(DEVICE)  # [N, 13, 8, 8]
        actions = torch.tensor(self.replay.actions, dtype=torch.long).to(DEVICE)  # [N]
        log_probs_old = torch.tensor(self.replay.log_probs, dtype=torch.float32).to(DEVICE)  # [N]
        values = torch.tensor(self.replay.values, dtype=torch.float32).to(DEVICE)  # [N]
        rewards = torch.tensor(self.replay.rewards, dtype=torch.float32).to(DEVICE)  # [N]
        dones = torch.tensor(self.replay.dones, dtype=torch.float32).to(DEVICE)  # [N]

        # Compute returns and advantages (simple GAE)
        returns = []
        advantages = []
        gae = 0.0
        for i in reversed(range(size)):
            if i == size - 1 or dones[i] == 1.0:
                delta = rewards[i] - values[i]
                gae = delta
            else:
                delta = rewards[i] + DISCOUNT_FACTOR * values[i+1] - values[i]
                gae = delta + DISCOUNT_FACTOR * gae
            advantages.insert(0, gae)
            returns.insert(0, gae + values[i])

        returns = torch.tensor(returns, dtype=torch.float32).to(DEVICE)
        advantages = torch.tensor(advantages, dtype=torch.float32).to(DEVICE)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        for _ in range(PPO_EPOCHS):
            ppo_update(self.net, self.optimizer, states, actions,
                       log_probs_old, returns, advantages)

        self.replay.clear()
        logging.debug(f"{self.name}: Completed training on replay.")

In [9]:
#Self Play

def play_selfplay_game(bot_white: ChessBot, bot_black: ChessBot, game_id: int):
    board = chess.Board()

    # Create a PGN game object with accurate headers
    pgn_game = chess.pgn.Game()
    pgn_game.headers["Event"] = f"Self-Play Game #{game_id}"
    pgn_game.headers["Site"] = "Local Machine"
    pgn_game.headers["Date"] = datetime.date.today().strftime("%Y.%m.%d")
    pgn_game.headers["Round"] = str(game_id)
    pgn_game.headers["White"] = bot_white.name
    pgn_game.headers["Black"] = bot_black.name
    pgn_game.headers["Result"] = "*"  # provisional

    node = pgn_game

    states_w, actions_w, log_probs_w, vals_w, rewards_w, dones_w = [], [], [], [], [], []
    states_b, actions_b, log_probs_b, vals_b, rewards_b, dones_b = [], [], [], [], [], []

    move_count = 0

    while not board.is_game_over() and move_count < MAX_MOVES_PER_GAME:
        move_count += 1
        if board.turn == chess.WHITE:
            move, log_p, val = bot_white.select_move(board)
            if move is None:
                logging.debug("White has no move. Breaking.")
                break
            states_w.append(board_to_tensor(board))
            a_idx = move_to_index(move)
            actions_w.append(a_idx)
            log_probs_w.append(log_p)
            vals_w.append(val)
            rewards_w.append(0.0)
            dones_w.append(0.0)

            board.push(move)
            node = node.add_variation(move)
        else:
            move, log_p, val = bot_black.select_move(board)
            if move is None:
                logging.debug("Black has no move. Breaking.")
                break
            states_b.append(board_to_tensor(board))
            a_idx = move_to_index(move)
            actions_b.append(a_idx)
            log_probs_b.append(log_p)
            vals_b.append(val)
            rewards_b.append(0.0)
            dones_b.append(0.0)

            board.push(move)
            node = node.add_variation(move)

    outcome = board.outcome()
    final_reward_white = 0.0
    if outcome is not None:
        final_reward_white = chess_outcome_to_reward(outcome)

    # White gets final_reward_white, Black gets -final_reward_white
    for i in range(len(rewards_w)):
        rewards_w[i] = final_reward_white
        if i == len(rewards_w) - 1:
            dones_w[i] = 1.0
    for i in range(len(rewards_b)):
        rewards_b[i] = -final_reward_white
        if i == len(rewards_b) - 1:
            dones_b[i] = 1.0

    # Add to replay buffers
    for i in range(len(states_w)):
        bot_white.replay.add(
            states_w[i], actions_w[i], log_probs_w[i],
            vals_w[i], rewards_w[i], dones_w[i]
        )
    for i in range(len(states_b)):
        bot_black.replay.add(
            states_b[i], actions_b[i], log_probs_b[i],
            vals_b[i], rewards_b[i], dones_b[i]
        )

    # Update the PGN "Result"
    if board.is_game_over() and outcome is not None:
        if outcome.winner == chess.WHITE:
            pgn_game.headers["Result"] = "1-0"
        elif outcome.winner == chess.BLACK:
            pgn_game.headers["Result"] = "0-1"
        else:
            pgn_game.headers["Result"] = "1/2-1/2"
    else:
        pgn_game.headers["Result"] = "*"

    fen_str = board.fen()
    pgn_str = str(pgn_game)
    logging.info(f"Game #{game_id} finished. FEN: {fen_str}")
    logging.info(f"PGN:\n{pgn_str}")

    # Save PGN and FEN
    with open(f"game_{game_id}.pgn", "w") as f:
        f.write(pgn_str + "\n")
    with open(f"game_{game_id}.fen", "w") as f:
        f.write(fen_str + "\n")

In [None]:
# Main Loop

def main_train_loop(num_epochs=1):
    logging.info("Starting main training loop.")
    logging.debug(f"Number of epochs: {num_epochs}. Games per epoch: {NUM_SELFPLAY_GAMES}")

    bot_white = ChessBot("Transformer-Bot-White")
    bot_black = ChessBot("Transformer-Bot-Black")

    for epoch in range(num_epochs):
        logging.info(f"===== EPOCH {epoch+1}/{num_epochs} =====")
        for game_id in range(epoch * NUM_SELFPLAY_GAMES, (epoch+1) * NUM_SELFPLAY_GAMES):
            play_selfplay_game(bot_white, bot_black, game_id)

        # Train each bot
        logging.info(f"Training {bot_white.name}")
        bot_white.train_on_replay()

        logging.info(f"Training {bot_black.name}")
        bot_black.train_on_replay()

        # Additional debug: We can quickly check net parameters or
        # run a quick net forward
        logging.debug(f"{bot_white.name} param sample: {next(bot_white.net.parameters()).flatten()[0]}")
        logging.debug(f"{bot_black.name} param sample: {next(bot_black.net.parameters()).flatten()[0]}")

    logging.info("Training complete.")


if __name__ == "__main__":
    main_train_loop(num_epochs=1)

2025-01-03 18:47:16,554 [INFO] Starting main training loop.
2025-01-03 18:47:16,555 [DEBUG] Number of epochs: 1. Games per epoch: 2
2025-01-03 18:47:16,814 [DEBUG] Transformer-Bot-White: Model initialized. First parameter device: cuda:0
2025-01-03 18:47:17,559 [DEBUG] Transformer-Bot-Black: Model initialized. First parameter device: cuda:0
2025-01-03 18:47:17,560 [INFO] ===== EPOCH 1/1 =====
2025-01-03 18:47:17,858 [DEBUG] MCTS simulation 1/32 done
2025-01-03 18:47:17,860 [DEBUG] MCTS simulation 2/32 done
2025-01-03 18:47:17,863 [DEBUG] MCTS simulation 3/32 done
2025-01-03 18:47:17,866 [DEBUG] MCTS simulation 4/32 done
2025-01-03 18:47:17,868 [DEBUG] MCTS simulation 5/32 done
2025-01-03 18:47:17,870 [DEBUG] MCTS simulation 6/32 done
2025-01-03 18:47:17,872 [DEBUG] MCTS simulation 7/32 done
2025-01-03 18:47:17,877 [DEBUG] MCTS simulation 8/32 done
2025-01-03 18:47:17,880 [DEBUG] MCTS simulation 9/32 done
2025-01-03 18:47:17,885 [DEBUG] MCTS simulation 10/32 done
2025-01-03 18:47:17,887 