## Install relevant dependencies
I make the assumption that Pytorch is already installed

In [None]:
!pip install python-chess
!pip install tqdm

## Download chess game data
Can skip this step if your own pgns to use

In [None]:
!curl https://database.lichess.org/standard/lichess_db_standard_rated_2013-01.pgn.zst --output games.pgn.zst
!zstd --decompress games.pgn.zst

# !mkdir pgns/ # uncomment if pgns dir does not exist
!mv games.pgn pgns/
!rm games.pgn.zst

## Create list of games

In [None]:
import chess.pgn

pgn = open("pgns/games.pgn", "r", encoding="utf-8")

all_games= []

# while True: 
for i in range(5000): # increase this limit for a better model
    game = chess.pgn.read_game(pgn)
    if game is None:
        break  # End of games
        
    all_games.append(game)

pgn.close()
print(f"{len(all_games)} games parsed")

## Create list of distinct chess positions
Goal of this is to create diverse set of chess FENs that can be used to create the training dataset

In [None]:
import random

all_positions = set()

for game in all_games:
    board = game.board()
    moves = list(game.mainline_moves())
    positions = []
    
    for move in moves:
        board.push(move)
        positions.append(board.fen())
    
    random_positions = random.sample(positions, min(10, len(moves)) // 7)
    all_positions.update(random_positions)

all_positions = list(all_positions)
print(f"{len(all_positions)} unique positions")

## Define functions to convert between tensor and FEN string
This will let us encode chess positions in a way the NNs can use

In [None]:
import torch

piece_to_idx = {'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
                'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11}

def pos_to_tensor(fen, device="cpu"):
    parts = fen.split(" ")
    wtm = parts[1] == "w"
    castling_rights = parts[2]

    board = chess.Board(fen)
    tensor = torch.zeros(15, 8, 8, device=device)

    for row in range(8): 
        for col in range(8):
            sqr = chess.square(col, 7 - row)
            piece = board.piece_at(sqr)
            if piece != None:
                p = piece.symbol()
                idx = piece_to_idx[p]
                tensor[idx, row, col] = 1 if p.isupper() else -1
                  
    # Encode castling rights
    if 'K' in castling_rights:
        tensor[12, 0, 0] = 1
    if 'Q' in castling_rights:
        tensor[12, 0, 7] = 1
    if 'k' in castling_rights:
        tensor[13, 7, 0] = -1  
    if 'q' in castling_rights:
        tensor[13, 7, 7] = -1
          
    # Encode side to move
    tensor[14] = 1 if wtm else -1
      
    return tensor

def tensor_to_pos(tensor):
    board = chess.Board(None)
    piece_symbols = list(piece_to_idx.keys())
    
    # Decode the board pieces
    for idx, piece_symbol in enumerate(piece_symbols[:12]):
        mask = tensor[idx].abs() > 0
        positions = mask.nonzero(as_tuple=True)
        for row, col in zip(*positions):
            square = chess.square(col, 7 - row)
            board.set_piece_at(square, chess.Piece.from_symbol(piece_symbol))
    
    # Decode castling rights
    castling_rights = ''
    if tensor[12, 0, 0] == 1:
        castling_rights += 'K'
    if tensor[12, 0, 7] == 1:
        castling_rights += 'Q'
    if tensor[13, 7, 0] == -1:
        castling_rights += 'k'
    if tensor[13, 7, 7] == -1:
        castling_rights += 'q'
    board.set_castling_fen(castling_rights)
    
    # Decode side to move
    side_to_move = 'w' if tensor[14].mean() > 0 else 'b'
    board.turn = True if side_to_move == 'w' else False
    
    return board.fen()
    
print(f"Shape of encoded chess position tensor: {pos_to_tensor(all_positions[0]).shape}")

## Set pytorch device to cuda if available

In [None]:
torch.cuda.memory._record_memory_history()
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device("cpu")
print(f"using {device}")

## Create datasets for training and testing

In [None]:
from torch.utils.data import Dataset, DataLoader

class PositionDataset(Dataset):
  def __init__(self, tensors):
        self.tensors = tensors

  def __len__(self):
      return len(self.tensors)

  def __getitem__(self, idx):
        return self.tensors[idx]

tensors = [pos_to_tensor(pos, device) for pos in all_positions]
random.shuffle(tensors)

# Calculate the indices for splitting
total_tensors = len(tensors)
train_end = int(total_tensors * 0.8)
val_end = int(total_tensors * 0.9)

# Split the tensors into train, validation, and test sets
train_tensors = tensors[:train_end]
val_tensors = tensors[train_end:val_end]
test_tensors = tensors[val_end:]

# Create datasets for each split
train_dataset = PositionDataset(train_tensors)
val_dataset = PositionDataset(val_tensors)
test_dataset = PositionDataset(test_tensors)

print(f"len training set: {len(train_dataset)}")
print(f"len validation set: {len(val_dataset)}")
print(f"len test set: {len(test_dataset)}")

## Define the structure of the NN
We are training an autoencoder that will learn to deconstruct, then reconstruct chess positions.\
Once trained, we can use the encoder to generate our embeddings

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, hyperparams):
        super(Encoder, self).__init__()

        channels = hyperparams["position_channels"]
        n_embed = hyperparams["n_embed"]
        filters = hyperparams["filters"]
        fc_size = hyperparams["fc_size"]
        
        self.conv1 = nn.Conv2d(channels, filters, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(filters, filters * 2, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(filters * 2, filters * 4, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(filters * 4 * 1 * 1, fc_size)
        self.fc2 = nn.Linear(fc_size, n_embed)  # Compressed representation

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))  
        x = self.pool(x)
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class Decoder(nn.Module):
    def __init__(self, hyperparams):        
        super(Decoder, self).__init__()

        channels = hyperparams["position_channels"]
        n_embed = hyperparams["n_embed"]
        filters = hyperparams["filters"]
        fc_size = hyperparams["fc_size"]

        
        self.fc1 = nn.Linear(n_embed, fc_size)
        self.fc2 = nn.Linear(fc_size, filters * 4 * 1 * 1)
        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(filters * 4 , 1, 1))
        self.deconv1 = nn.ConvTranspose2d(filters * 4, filters * 2, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv2 = nn.ConvTranspose2d(filters * 2, filters, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv3 = nn.ConvTranspose2d(filters, channels, kernel_size=3, stride=2, padding=1, output_padding=1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.unflatten(x)
        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2(x))
        x = self.deconv3(x)
        return x

class PositionAutoEncoder(nn.Module):
    def __init__(self, hyperparams):
        super(PositionAutoEncoder, self).__init__()
        self.encoder = Encoder(hyperparams)
        self.decoder = Decoder(hyperparams)

    
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)

        return x

    @torch.no_grad()
    def embed(self, x):
        code = self.encoder(x)
        return code

## Define model hyperparameter

In [None]:
hyperparams = {
    "batch_size": 32,
    "n_epochs": 50,
    "learning_rate": 17e-4,
    "dropout_rate": 0,
    "position_channels": 15,
    "n_embed": 128,
    "filters": 32,
    "fc_size": 256,
    "version": 6
}

batch_size = hyperparams["batch_size"]
n_epochs = hyperparams["n_epochs"]
learning_rate = hyperparams["learning_rate"]

## Initialize the model and optimizer

In [None]:
from torch.optim import AdamW, lr_scheduler

# init dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

#init model
model = PositionAutoEncoder(hyperparams)
model.to(device)

# init optimizer
optimizer = AdamW(model.parameters(), lr=learning_rate)
num_params = sum(p.numel() for p in model.parameters())/1e6
print(f"{num_params:.2f}M parameters")

# init lr scheduler
num_training_steps = n_epochs * len(train_loader)
scheduler = lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=learning_rate, total_steps=num_training_steps)
print(f"training iterations: {num_training_steps}")


## Run the training loop
We train by minimizing MSE loss on the reconstructed posittion encoding.\
Validation loss is calculated after each epoch to ensure learning

In [None]:
from tqdm.auto import tqdm

criterion = nn.MSELoss()
progress_bar = tqdm(range(num_training_steps))

for epoch in range(n_epochs):
    model.train() # switch model to training mode

    for batch in train_loader:
        batch = batch.to(device)
        outputs = model(batch)
        loss = criterion(outputs, batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        progress_bar.update(1)
    
    print(f"finished epcoh: {epoch}")
    with torch.no_grad():
        # evaluate validation loss
        model.eval() # switch model to evaluation mode
        losses = torch.zeros(len(val_loader), device=device)
        k = 0
        for batch in val_loader:
            batch = batch.to(device)
            outputs = model(batch)
            loss = criterion(outputs, batch)
                
            losses[k] = loss.item()
            k += 1

        avg_val_loss = losses.mean()
        # -----------------------------
        
        # evaluate training loss
        losses =  torch.zeros(len(train_loader), device=device)
        k = 0
        for batch in train_loader:
            batch = batch.to(device)
            outputs = model(batch)
            loss = criterion(outputs, batch)
                
            losses[k] = loss.item()
            k += 1
            
            if(k == len(train_loader)):
                break
        
        avg_train_loss = losses.mean()
        # ------------------------------
        print(f"learning rate: {optimizer.param_groups[0]['lr']}")
        print(f"val loss: {avg_val_loss}")
        print(f"train loss: {avg_train_loss}")
    

## Save the trained model weights and metadata

In [None]:
# !mkdir models  # uncomment to create models/ dir
checkpoint = {
    "model": model.state_dict(),
    "train_set": train_dataset,
    "val_set": val_dataset,
    "test_set": test_dataset,
    "hyperparameters": hyperparams
}
torch.save(checkpoint, f"models/v0.pt")

## Load a saved model

In [None]:
chkp = torch.load("models/v0.pt")
emb_model = PositionAutoEncoder(chkp["hyperparameters"]).to(device)
emb_model.eval()
emb_model.load_state_dict(chkp["model"])

train_dataset = chkp["train_set"]
val_dataset = chkp["val_set"]
test_dataset = chkp["test_set"]
embed_data = list(train_dataset + val_dataset + test_dataset)

## Embed collection of chess positions

In [None]:
batches = [embed_data[i:i + 256] for i in range(0, len(embed_data), 256)]
embeds = torch.cat([emb_model.embed(torch.stack(batch)) for batch in batches]).unsqueeze(1)
embeds.shape

## Seacrh similar positions

In [None]:
# embed a query position
query = emb_model.embed(test_dataset[0].unsqueeze(0)).unsqueeze(0)

# calculate similarities and find top matches
similarities = F.cosine_similarity(embeds, query, dim=2)
top_matches = torch.topk(similarities, 10, dim=0)

# print(top_matches.values)

# convert matches to FEN strings
top_tensors = torch.stack(list(embed_data))[top_matches.indices].squeeze(1)
top_tensors = list(torch.split(top_tensors, 1, dim=0))
positions = [tensor_to_pos(t.squeeze(0)) for t in top_tensors]
query_pos = tensor_to_pos(test_dataset[0])

print(f"query position: {query_pos}")
print(f"similar positions: {positions}")

## Inspect similar positions
Pass a FEN string to chess.Board() to view it

In [None]:
chess.Board(query_pos)

In [None]:
 chess.Board(positions[1])