# Chess BPT toy model

Downloaded [Lichess elite database](https://database.nikonoel.fr/)

In [24]:
import json
from typing import Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time
from schemas import GameState

def read_games_from_json(json_file_path: str) -> list[GameState]:
    with open(json_file_path, 'r', encoding='utf-8') as f:
        games_json = json.load(f)
    
    # Convert each JSON object back to a GameState
    games = [GameState.model_validate_json(game_json) if isinstance(game_json, str) 
             else GameState.model_validate(game_json) 
             for game_json in games_json]
    
    return games

    # All possible pieces (6 types × 2 colors + empty square)
pieces = {
    'white pawn', 'white rook', 'white knight', 'white bishop', 'white queen', 'white king',
    'black pawn', 'black rook', 'black knight', 'black bishop', 'black queen', 'black king',
    '.'  # empty square
}

def read_possible_moves_from_txt(txt_file_path):
    """
    Read possible moves from a text file
    """
    with open(txt_file_path, 'r', encoding='utf-8') as f:
        content = f.read()
    
    # Split by space to get individual moves
    possible_moves = set(content.split())
    return possible_moves

# Example usage
moves = read_possible_moves_from_txt("data/processed/possible_moves.txt")
    

# Get the elements
print(f"Pieces ({len(pieces)}): {sorted(pieces)}")
print(f"Number of basic move patterns: {len(moves)} {moves}")

EMPTY = '.'

# Read the processed game file
processed_games = read_games_from_json('data/processed/games.json')

# Create character-to-index mapping from the processed games
piece_tokens = list(pieces)
move_tokens = list(moves)
tokens = piece_tokens + move_tokens
stoi = {ch: i for i, ch in enumerate(set(tokens))}
itos = {i: ch for ch, i in stoi.items()}

encode: Callable[[str], list[int]] = lambda s: [stoi[c] for c in s]
decode: Callable[[list[int]], str] = lambda l: ''.join([itos[i] for i in l])
vocab_size = len(stoi)
board_context_length = len(pieces)
move_context_length = 9
print(f'{board_context_length=}')
print(f'{move_context_length=}')
print(f'{len(stoi)=}')
stoi


Pieces (13): ['.', 'black bishop', 'black king', 'black knight', 'black pawn', 'black queen', 'black rook', 'white bishop', 'white king', 'white knight', 'white pawn', 'white queen', 'white rook']
Number of basic move patterns: 2007 {'Bxe3', 'R3d5', 'Ncb1', 'c8=Q+', 'Nxd1', 'Rd1#', 'Qxg3', 'Bf1+', 'Nge7', 'Ng2', 'Nfe4+', 'Rbe8', 'Rab5', 'N2b3', 'Ree4', 'a7', 'Nb7', 'Bc6+', 'a7+', 'Bh7+', 'Re6', 'Rxb8', 'Nxg2', 'Nbd2', 'Bf8+', 'Rxc2', 'e5+', 'Bxa2+', 'Rxe7+', 'Qb7#', 'Rdxa5', 'Qh6', 'Rgh8', 'cxd4', 'N6xc7', 'Nbxc4', 'Qxc8+', 'Qb6#', 'f5+', 'Kg8', 'Bxe4', 'Rxe7', 'Nf3+', 'Bxb4+', 'Ree1', 'Rg6', 'Ree7', 'Rd4', 'Qxd1+', 'Rdc1', 'N5c6', 'Rexf4', 'Qxf7', 'R5c3', 'Kf1', 'Rge3', 'Rhf8', 'Qe6#', 'N3c4', 'Rxe3', 'bxa4', 'Rxf3', 'Kxd2', 'Raxa7', 'Rxa8+', 'Re2', 'Bd2+', 'Nef6', 'Qxf4+', 'Qa5+', 'Qc1#', 'Qh2+', 'Qe6+', 'Bxf4+', 'e1=Q', 'Rb4+', 'b1=Q', 'Be1', 'd4', 'Bxg4', 'Rf1', 'Bxe5+', 'Rxd2', 'N7e5', 'Qxg1+', 'Nxd2+', 'Qg8+', 'Rcd3', 'Qe3+', 'Rc5#', 'Rec2', 'Qf3#', 'Rh4', 'Qd2', 'Kxe6', 'R1a7', 

{'Bxe3': 0,
 'R3d5': 1,
 'Ncb1': 2,
 'c8=Q+': 3,
 'Nxd1': 4,
 'Rd1#': 5,
 'Qxg3': 6,
 'Bf1+': 7,
 'Nge7': 8,
 'Ng2': 9,
 'Nfe4+': 10,
 'Rbe8': 11,
 'Rab5': 12,
 'N2b3': 13,
 'Ree4': 14,
 'a7': 15,
 'Nb7': 16,
 'Bc6+': 17,
 'a7+': 18,
 'Bh7+': 19,
 'Re6': 20,
 'Rxb8': 21,
 'Nxg2': 22,
 'Nbd2': 23,
 'Bf8+': 24,
 'Rxc2': 25,
 'e5+': 26,
 'Bxa2+': 27,
 'Rxe7+': 28,
 'black pawn': 29,
 'Qb7#': 30,
 'Rdxa5': 31,
 'Qh6': 32,
 'Rgh8': 33,
 'cxd4': 34,
 'N6xc7': 35,
 'Nbxc4': 36,
 'Qxc8+': 37,
 'Qb6#': 38,
 'f5+': 39,
 'Kg8': 40,
 'Bxe4': 41,
 'Rxe7': 42,
 'Nf3+': 43,
 'Bxb4+': 44,
 'Ree1': 45,
 'Rg6': 46,
 'Ree7': 47,
 'Rd4': 48,
 'Qxd1+': 49,
 'Rdc1': 50,
 'N5c6': 51,
 'Rexf4': 52,
 'Qxf7': 53,
 'R5c3': 54,
 'Kf1': 55,
 'Rge3': 56,
 'Rhf8': 57,
 'Qe6#': 58,
 'N3c4': 59,
 'Rxe3': 60,
 'bxa4': 61,
 'Rxf3': 62,
 'Kxd2': 63,
 'Raxa7': 64,
 'Rxa8+': 65,
 'Re2': 66,
 'Bd2+': 67,
 'Nef6': 68,
 'Qxf4+': 69,
 'Qa5+': 70,
 'Qc1#': 71,
 'Qh2+': 72,
 'Qe6+': 73,
 'Bxf4+': 74,
 'e1=Q': 75,
 'Rb4+': 76,
 '

In [81]:
def tokenize_san_move(move: str):
    """
    Tokenize a SAN (Standard Algebraic Notation) chess move into a vector.
    Input format expected: 'Nf3:w', 'dxc4:b', etc.
    Returns a list of integers representing the move components.
    """
    
    if move == ".":
        return [0, 0, 0, 0, 0, 0, 0, 0, 0]  # Updated to return 9 elements
    
    # Component vocabularies
    piece_vocab = {"": 0, "P": 1, "N": 2, "B": 3, "R": 4, "Q": 5, "K": 6}
    file_vocab = {"a": 0, "b": 1, "c": 2, "d": 3, "e": 4, "f": 5, "g": 6, "h": 7}
    rank_vocab = {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4, "6": 5, "7": 6, "8": 7}
    action_vocab = {"": 0, "x": 1, "+": 2, "#": 3}
    color_vocab = {"w": 0, "b": 1}
    
    # Parse move
    move_part, color = move.split(":")
    
    # Handle castling
    if move_part == "O-O":
        # Return 9 elements to match regular moves
        return [7, 0, 0, 0, 0, 0, 0, 0, color_vocab[color]]
    elif move_part == "O-O-O":
        # Return 9 elements to match regular moves
        return [8, 0, 0, 0, 0, 0, 0, 0, color_vocab[color]]
    
    
    # Extract check/checkmate
    check = ""
    if move_part.endswith("+"):
        check = "+"
        move_part = move_part[:-1]
    elif move_part.endswith("#"):
        check = "#"
        move_part = move_part[:-1]
    
    # Extract piece (default to pawn if not specified)
    if move_part[0].isupper() and move_part[0] in "PNBRQK":
        piece = move_part[0]
        move_part = move_part[1:]
    else:
        piece = ""  # Pawn is implied
    
    # Extract capture
    if "x" in move_part:
        action = "x"
        move_part = move_part.replace("x", "")
    else:
        action = ""
    
    # Extract disambiguation if present (like Nbd7)
    disambig_file = 8  # Default value for no disambiguation
    disambig_rank = 8
    if len(move_part) > 2:
        if move_part[0] in "abcdefgh":
            disambig_file = file_vocab[move_part[0]]
            move_part = move_part[1:]
        elif move_part[0] in "12345678":
            disambig_rank = rank_vocab[move_part[0]]
            move_part = move_part[1:]
    
    # Extract destination square
    dest_file = move_part[-2] if len(move_part) >= 2 else None
    dest_rank = move_part[-1] if len(move_part) >= 1 else None
    
    # Handle promotion if present
    promotion = 0  # 0 = no promotion
    if "=" in move_part:
        promotion_piece = move_part.split("=")[1]
        promotion = piece_vocab.get(promotion_piece, 0)
        # Remove promotion part from destination
        dest_file = move_part.split("=")[0][-2] if len(move_part.split("=")[0]) >= 2 else None
        dest_rank = move_part.split("=")[0][-1] if len(move_part.split("=")[0]) >= 1 else None
    
    # Build vector
    vector = [
        piece_vocab.get(piece, 0),
        disambig_file,
        disambig_rank,
        action_vocab.get(action, 0),
        file_vocab.get(dest_file, 8) if dest_file else 8,
        rank_vocab.get(dest_rank, 8) if dest_rank else 8,
        promotion,
        action_vocab.get(check, 0),
        color_vocab.get(color, 0)
    ]
    
    return vector

def detokenize_san_move(vector: list[float]):
    """
    Convert a tokenized chess move vector back to SAN (Standard Algebraic Notation).
    Takes a list of 9 integers and returns a string in format like 'Nf3:w' or 'dxc4:b'.
    """
    
    vector = [int(v) for v in vector]

    if len(vector) != 9:
        raise ValueError("Vector must have exactly 9 elements")
    
    # If all zeros, it's an empty move
    if all(v == 0 for v in vector):
        return "."
    
    # Component vocabularies (inverted from the tokenizer)
    piece_vocab = {0: "", 1: "P", 2: "N", 3: "B", 4: "R", 5: "Q", 6: "K"}
    file_vocab = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e", 5: "f", 6: "g", 7: "h", 8: ""}
    rank_vocab = {0: "1", 1: "2", 2: "3", 3: "4", 4: "5", 5: "6", 6: "7", 7: "8", 8: ""}
    action_vocab = {0: "", 1: "x", 2: "+", 3: "#"}
    color_vocab = {0: "w", 1: "b"}
    
    # Extract components from vector
    piece_id, disambig_file_id, disambig_rank_id, action_id, dest_file_id, dest_rank_id, promotion_id, check_id, color_id = vector
    
    # Handle castling
    if piece_id == 7:
        return f"O-O:{color_vocab[color_id]}"
    elif piece_id == 8:
        return f"O-O-O:{color_vocab[color_id]}"
    
    # Start building the move
    move = ""
    
    # Add piece (except for pawns which are implied)
    if piece_id > 0:
        move += piece_vocab[piece_id]
        
    if disambig_rank_id > 8:
        x=1
    # Add disambiguation if present
    if disambig_file_id != 8:
        move += file_vocab[disambig_file_id]
    if disambig_rank_id != 8:
        move += rank_vocab[disambig_rank_id]
    
    # Add capture
    if action_id == 1:
        move += "x"
    
    # Add destination square
    if dest_file_id != 8 and dest_rank_id != 8:
        move += file_vocab[dest_file_id] + rank_vocab[dest_rank_id]
    
    # Add promotion
    if promotion_id > 0:
        move += f"={piece_vocab[promotion_id]}"
    
    # Add check/checkmate
    if check_id > 0:
        move += action_vocab[check_id]
    
    # Add color
    move += f":{color_vocab[color_id]}"
    
    return move

for move in processed_games[10].moves:
    encoded_move = tokenize_san_move(move)
    decoded_move = detokenize_san_move(encoded_move)
    
    assert move == decoded_move
    print(move)
    print(tokenize_san_move(move))
    


a3:w
[0, 8, 8, 0, 0, 2, 0, 0, 0]
c5:b
[0, 8, 8, 0, 2, 4, 0, 0, 1]
Bg2:w
[3, 8, 8, 0, 6, 1, 0, 0, 0]
O-O:b
[7, 0, 0, 0, 0, 0, 0, 0, 1]
Nd2:w
[2, 8, 8, 0, 3, 1, 0, 0, 0]
Bb4+:b
[3, 8, 8, 0, 1, 3, 0, 2, 1]
g3:w
[0, 8, 8, 0, 6, 2, 0, 0, 0]
e6:b
[0, 8, 8, 0, 4, 5, 0, 0, 1]
c4:w
[0, 8, 8, 0, 2, 3, 0, 0, 0]
Nf6:b
[2, 8, 8, 0, 5, 5, 0, 0, 1]
d4:w
[0, 8, 8, 0, 3, 3, 0, 0, 0]


In [56]:
def process_all_games_to_tensors(games):
    board_data = []
    moves_data = []
    target_data = []
    
    for game in games:
        board_data.append([[stoi[piece] for piece in row] for row in game.board])
        game_moves = []
        for move in game.moves[:-2]:
            encoded_move = tokenize_san_move(move)
            game_moves.append(encoded_move)
        moves_data.append(game_moves)

        target_move = game.moves[-1]
        target_data.append(tokenize_san_move(target_move))
    
    # Convert all to tensors
    board_tensors = torch.tensor(board_data, dtype=torch.long).view(-1, 64)  # [num_examples, 64]
    move_tensors = torch.tensor(moves_data, dtype=torch.long).view(-1, 81) # [num_examples, 81]
    target_tensors = torch.tensor(target_data, dtype=torch.long) # [num_examples, 9]
    
    return board_tensors, target_tensors, move_tensors

board_tensors, target_data, move_data = process_all_games_to_tensors(processed_games)
board_tensors.shape, move_data.shape, target_data.shape

(torch.Size([163674, 64]), torch.Size([163674, 81]), torch.Size([163674, 9]))

In [61]:
from torch.utils.data import Dataset, DataLoader

class ChessDataset(Dataset):
    def __init__(self, board_tensors, move_data, target_data):
        # Convert all tensors to float32
        self.board_tensors = board_tensors.float()
        self.move_data = move_data.float()
        self.target_data = target_data.float()
        
        
    def __len__(self):
        return len(self.board_tensors)
    
    def __getitem__(self, idx):
        return {
            'board': self.board_tensors[idx],
            'move': self.move_data[idx],
            'target': self.target_data[idx]
        }
        
# Create dataset
chess_dataset = ChessDataset(board_tensors, move_data, target_data)

# Create data loaders
batch_size = 64  # Adjust based on your GPU memory
train_loader = DataLoader(
    chess_dataset, 
    batch_size=batch_size, 
    shuffle=True,
)

In [62]:
import torch.nn as nn
import torch.nn.functional as F

class ChessModel(nn.Module):
    def __init__(self):
        super(ChessModel, self).__init__()
        # Process board state
        self.board_encoder = nn.Sequential(
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU()
        )
        
        # Process move data
        self.move_encoder = nn.Sequential(
            nn.Linear(81, 256),
            nn.ReLU()
        )
        
        # Combined processing
        self.combined = nn.Sequential(
            nn.Linear(512 + 256, 512),
            nn.ReLU(),
            nn.Linear(512, 9)  # Output matches target_data dimension
        )
        
    def forward(self, board, move):
        board_feat = self.board_encoder(board)
        move_feat = self.move_encoder(move)
        combined = torch.cat([board_feat, move_feat], dim=1)
        return self.combined(combined)

In [77]:
model = ChessModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()  # Or another loss function appropriate for your task

num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for batch in train_loader:
        board = batch['board'].to(device)
        move = batch['move'].to(device)
        target = batch['target'].to(device)
        
        # Forward pass
        outputs = model(board, move)
        loss = criterion(outputs, target)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")

Epoch 1, Loss: 4.02970052685842
Epoch 2, Loss: 1.5368297605937677
Epoch 3, Loss: 1.508978387450128
Epoch 4, Loss: 1.4926161624380534
Epoch 5, Loss: 1.4750401753993403
Epoch 6, Loss: 1.4582338454762505
Epoch 7, Loss: 1.4398113368683443
Epoch 8, Loss: 1.4167798953982793
Epoch 9, Loss: 1.3942997719600296
Epoch 10, Loss: 1.3701070013504835


In [82]:
def generate_predictions(model, board_tensors, move_data, device=None, batch_size=64):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model.to(device)
    model.eval()  # Set model to evaluation mode
    
    # Convert inputs to float if they're not already
    if board_tensors.dtype != torch.float32:
        board_tensors = board_tensors.float()
    if move_data.dtype != torch.float32:
        move_data = move_data.float()
    
    # Create a DataLoader for batch processing
    dataset = torch.utils.data.TensorDataset(board_tensors, move_data)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
    
    all_predictions = []
    
    with torch.no_grad():  # Disable gradient calculation for inference
        for board_batch, move_batch in dataloader:
            board_batch = board_batch.to(device)
            move_batch = move_batch.to(device)
            
            # Forward pass
            predictions = model(board_batch, move_batch)
            all_predictions.append(predictions.cpu())  # Move back to CPU for concatenation
    
    # Concatenate all batches
    return torch.cat(all_predictions, dim=0)

# For the entire dataset
predictions = generate_predictions(model, board_tensors, move_data)
predicted_moves = torch.tensor([detokenize_san_move(pred.tolist()) for pred in predictions])

# For a single position (add batch dimension)
single_board = board_tensors[0].unsqueeze(0)
single_move = move_data[0].unsqueeze(0)
single_prediction = generate_predictions(model, single_board, single_move)

# Process the predictions (example - getting the highest probability move)
_, predicted_classes = torch.max(predictions, dim=1)

KeyError: 9

In [68]:
predictions.shape

torch.Size([163674, 9])