**IMPORTS**

In [2]:
import aux_functions
import importlib

importlib.reload(aux_functions)
from aux_functions import *


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torch
#from torchvision import datasets, transforms
from torch.utils.data import DataLoader, ConcatDataset, SubsetRandomSampler
# For masking
from torch.masked import masked_tensor, as_masked_tensor

import numpy as np
import chess
from datetime import datetime
from sklearn.model_selection import KFold

**DATA PROCESSING**

- Importing the pgn data
- Transforming the data to sparce tensors 
- Splitting the data into training and testing

In [3]:
TEST_PERCENT = 0.25

# Load pgn paths
pgns = import_data(5)

# Convert pgns to tensors
board_tensors, next_moves = parse_pgn_to_tensors(pgns)

# Converting the dataset into a custom pytorch one
dataset = ChessDataset(board_tensors, next_moves)

torch.manual_seed(0)
# Splitting the data into train and test
train_dataset, test_data = torch.utils.data.random_split(dataset, [1-TEST_PERCENT, TEST_PERCENT])

print(len(test_data))  # Number of states
print(train_dataset.indices)

845
[414, 235, 1355, 1009, 437, 1064, 2105, 260, 2175, 81, 291, 79, 728, 2653, 2508, 1510, 2518, 2965, 2428, 1621, 2124, 763, 1293, 797, 1454, 1299, 443, 686, 967, 2616, 231, 283, 2045, 49, 2573, 717, 2412, 2931, 1205, 376, 678, 3326, 2258, 986, 1175, 861, 929, 1622, 3254, 3260, 3038, 408, 705, 1465, 949, 1752, 3331, 3113, 222, 365, 879, 128, 3065, 1309, 1241, 3233, 284, 2849, 2422, 2786, 1131, 1511, 92, 1568, 114, 2332, 2917, 1498, 1287, 1329, 126, 2037, 1326, 1383, 995, 1641, 3007, 3150, 293, 671, 2797, 206, 974, 3096, 2249, 229, 2462, 4, 2707, 1108, 35, 1747, 3098, 3055, 442, 1006, 2830, 2053, 393, 2667, 2633, 883, 1056, 1983, 1274, 1399, 3317, 100, 2969, 1744, 3337, 62, 1575, 2861, 1433, 715, 719, 2141, 2795, 2809, 446, 2584, 2959, 2831, 470, 1093, 2034, 381, 2264, 1448, 683, 1388, 1432, 1123, 1047, 941, 2257, 976, 959, 699, 790, 767, 2064, 2994, 2568, 2367, 833, 3168, 1400, 1325, 2240, 277, 2949, 1449, 2513, 3061, 1336, 1401, 3335, 290, 3339, 1367, 2817, 3114, 1415, 787, 2087, 230

**NEURAL NETWORK DESIGN**
- 2 Convolutional layers
- 2 Fully connected hidden layers

In [4]:
# Whether to do the operations on the cpu or gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class PieceToMoveNet(nn.Module):
    def __init__(self):
        super().__init__()

        # Takes as input a tensor of 14 channels (8x8 board)
        self.conv1 = nn.Conv2d(14, 6, 3)  # 6 filters, 3x3 kernel
        self.pool = nn.MaxPool2d(2, 2)    # Max pooling with 2x2 window
        self.conv2 = nn.Conv2d(6, 16, 3)  # 16 filters, 3x3 kernel

        # Using droput to reduce overfitting
        self.dropout = nn.Dropout(p=0.3)
        # Using batch normalization to make training faster and more stable
        self.bn1 = nn.BatchNorm1d(120)  # For the 1st layer
        self.bn2 = nn.BatchNorm1d(84)   # For the 2nd layer
        
        # Output from conv2 will be (16 channels, 1x1 feature maps)
        self.fc1 = nn.Linear(16 * 1 * 1, 120)
        self.fc2 = nn.Linear(120, 84)
        # Predicts the tile to move the piece from (64 possible tiles on the board)
        self.fc3 = nn.Linear(84, 64)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # Apply first conv + pooling
        x = F.relu(self.conv2(x))             # Apply second conv to get (16 x 1 x 1)
        x = torch.flatten(x, 1)               # Flatten all dimensions except batch size
        x = F.relu(self.bn1(self.fc1(x)))     # Fully connected layer 1 and batch normalization
        x = self.dropout(x)                   # Dropout of some first layer neurons
        x = F.relu(self.bn2(self.fc2(x)))     # Fully connected layer 2
        x = self.fc3(x)                       # Output layer (no activation, logits for classification)
        return x

**TRAINING LOOP**

In [7]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")


def train_epoch(model, optimizer, train_loader, loss_fn, train_sampler_size): 
    """
    Trains the model for one epoch and returns the average training loss and accuracy.
    """

    running_loss = 0.  
    running_correct = 0.

    # Looping through all samples in a batch
    for i, data in enumerate(train_loader):
        # Extracting the board tensor
        inputs = data[0]
        # Extracting the tile of the piece to move
        labels = data[1][:, 0]
        
        # Moving inputs and labels to the gpu/cpu
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Resetting the gradients
        optimizer.zero_grad()
        # Calculating model's output
        outputs = model(inputs)
        # Calculating the sample loss
        loss = loss_fn(outputs, labels)

        # Calculating the gradient
        loss.backward()
        # Updating model parameters
        optimizer.step()

        # Adding the last loss to the running loss
        running_loss += loss.item()

        # Calculate number of correct predictions
        _, predictions = torch.max(outputs.data, 1)
        running_correct += (predictions == labels).sum().item()

    # Averaging the loss for all samples in the batch
    running_loss /= (i + 1)

    # Calculate accuracy based on the total samples in the fold (train_sampler_size)
    train_accuracy = running_correct / train_sampler_size

    return running_loss, train_accuracy


def validation_epoch(model, validation_loader, loss_fn, val_sampler_size):
    """
    Validates the model for one epoch and returns the average validation loss and accuracy.
    """

    running_vloss = 0.
    running_vcorrect = 0.

    # Set model to evaluation mode
    model.eval()

    # Disable gradient calculations for validation set
    with torch.no_grad():
        # Looping through all batches in the validation set
        for i, v_data in enumerate(validation_loader):
            # Getting the tensors of the validation data
            vinputs = v_data[0]
            vlabels = v_data[1][:, 0]

            # Moving inputs and labels to the gpu/cpu
            vinputs = vinputs.to(device)
            vlabels = vlabels.to(device)

            # Calculating the output of the model
            voutputs = model(vinputs)
            # Calculating the loss of the model in the validation sample
            vloss = loss_fn(voutputs, vlabels)
            # Adding this sample's loss to the total loss
            running_vloss += vloss.item()

            # Calculate number of correct predictions
            _, predictions = torch.max(voutputs.data, 1)
            running_vcorrect += (predictions == vlabels).sum().item()

    # Averaging the loss for all samples in the validation set
    running_vloss /= (i + 1)

    # Calculate accuracy based on the total samples in the fold (val_sampler_size)
    validation_accuracy = running_vcorrect / val_sampler_size

    return running_vloss, validation_accuracy


def train_multiple_folds(n_epochs, batch_size, splits, writer, loss_fn):

    best_vloss = 1_000

    epoch_tloss = [0. for _ in range(n_epochs)]
    epoch_tacc = [0. for _ in range(n_epochs)]
    epoch_vloss = [0. for _ in range(n_epochs)]
    epoch_vacc = [0. for _ in range(n_epochs)]

    for fold, (train_idx, val_idx) in enumerate(splits.split(np.arange(len(train_dataset)))):
        print(f"FOLD {fold+1}")

        train_sampler = SubsetRandomSampler(train_idx)
        val_sampler = SubsetRandomSampler(val_idx)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
        val_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=val_sampler)

        model = PieceToMoveNet()
        model.to(device)
        optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

        train_sampler_size = len(train_sampler)
        val_sampler_size = len(val_sampler)

        avg_vloss = 0.

        for epoch in range(n_epochs): 
            train_loss, train_correct = train_epoch(model, optimizer, train_loader, loss_fn, train_sampler_size)
            val_loss, val_correct = validation_epoch(model, val_loader, loss_fn, val_sampler_size)

            epoch_tloss[epoch] += train_loss
            epoch_tacc[epoch] += train_correct
            epoch_vloss[epoch] += val_loss
            epoch_vacc[epoch] += val_correct

            avg_vloss += val_loss

            print(f"Epoch: {epoch+1} Train Loss: {train_loss}, Valid Loss: {val_loss} | Train Acc: {train_correct}, Valid Acc: {val_correct}")

        avg_vloss = avg_vloss / (epoch + 1)

        # Saving the model if the loss on the validation is lower than the best one
        if avg_vloss < best_vloss:
            best_vloss = avg_vloss
            model_path = f"models/piece_to_move_net_{timestamp}_{fold}"
            torch.save(model.state_dict(), model_path)


    for epoch in range(n_epochs):
        epoch_tloss[epoch] = epoch_tloss[epoch] / (fold + 1)
        epoch_tacc[epoch] = epoch_tacc[epoch] / (fold + 1)
        epoch_vloss[epoch] = epoch_vloss[epoch] / (fold + 1)
        epoch_vacc[epoch] = epoch_vacc[epoch] / (fold + 1)


        # Adding insights
        writer.add_scalars("Loss", {"Training": epoch_tloss[epoch], "Validation": epoch_vloss[epoch]}, epoch + 1)
        writer.add_scalars("Accuracy", {"Training": epoch_tacc[epoch], "Validation": epoch_vacc[epoch]}, epoch + 1)
        writer.flush()



EPOCHS = 50
BATCH_SIZE = 32
K = 5

loss_fn = torch.nn.CrossEntropyLoss()

# Logs training statistics for TensorBoard visualization
writer = SummaryWriter(f"runs/piece_to_move_{timestamp}")  
splits = KFold(n_splits=K, shuffle=True, random_state=42)

train_multiple_folds(EPOCHS, BATCH_SIZE, splits, writer, loss_fn)


FOLD 1
Epoch: 1 Train Loss: 4.079975537955761, Valid Loss: 3.9653835147619247 | Train Acc: 0.04189255791030064, Valid Acc: 0.09448818897637795
Epoch: 2 Train Loss: 3.5440512634813786, Valid Loss: 2.8456623554229736 | Train Acc: 0.20601281419418432, Valid Acc: 0.265748031496063
Epoch: 3 Train Loss: 2.1823711451143026, Valid Loss: 1.8272282183170319 | Train Acc: 0.26761951700345, Valid Acc: 0.297244094488189
Epoch: 4 Train Loss: 1.7911275923252106, Valid Loss: 1.7463356256484985 | Train Acc: 0.2932479053721045, Valid Acc: 0.31299212598425197
Epoch: 5 Train Loss: 1.7417344618588686, Valid Loss: 1.7153668031096458 | Train Acc: 0.29620502710694924, Valid Acc: 0.32086614173228345
Epoch: 6 Train Loss: 1.7144139222800732, Valid Loss: 1.7047735154628754 | Train Acc: 0.3035978314440611, Valid Acc: 0.3090551181102362
Epoch: 7 Train Loss: 1.6950831431895494, Valid Loss: 1.6889970526099205 | Train Acc: 0.30655495317890585, Valid Acc: 0.30118110236220474
Epoch: 8 Train Loss: 1.6770307272672653, Vali

In [None]:
def generate_mask(board: chess.Board, outputs: np.arry, labels: np.array) -> np.array:
    """Creates mask with the position of the pieces that are able to move"""

    mask = np.zeros((8,8)) # 8x8 mask for the chessboard

    # Obtaining legal moves from the board
    legal_moves = list(board.legal_moves)

    # Indicating with 1s the valid squares
    for move in legal_moves:
        to_square = move.to_square
        to_row, to_col = divmod(to_square, 8)
        mask[to_row, to_col] = 1 # A valid square

    # Reshaping mask to match output and labels
    move_mask = mask.flatten() # Converts 8*8 2D array to a 1D array with 64 elements

    masked_outputs = outputs * move_mask
    masked_labels = labels * move_mask

    return masked_outputs, masked_labels

def update_board( board: chess.Board, move: chess.Move ) -> chess.Board:
    """This function is responsible for updating the board everytime a move is made"""

    board.push(move) # Add move to the board
    
    return move

**CROSS VALIDATION**


In [None]:
def reset_weights(model):
    """Resets the weights of the model, so the model is trained with randomly initalized weights"""

    # List of layers containing reset parameters
    layer_types = [nn.Conv2d, nn.Linear, nn.BatchNorm2d]

    # Iterating through all layers of the model
    for layer in model.modules():
        # Check layers with reset parameters
        if type(layer) in layer_types:
            layer.reset_parameters()

    