**IMPORTS**

In [1]:
from aux_functions import *
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


**DATA PROCESSING**

- Importing the pgn data
- Transforming the data to sparce tensors 
- Splitting the data into training and testing

In [2]:
TEST_PERCENT = 0.25

# Load pgn paths
pgns = import_data(1)

# Convert pgns to tensors
board_tensors, next_moves = parse_pgn_to_tensors(pgns)

# Converting the dataset into a custom pytorch one
dataset = ChessDataset(board_tensors, next_moves)

# Splitting the data into train and test
train_data, test_data = torch.utils.data.random_split(dataset, [1-TEST_PERCENT, TEST_PERCENT])

print(len(test_data))  # Number of states

147


**NEURAL NETWORK DESIGN**

In [4]:
class PieceToMoveNet(nn.Module):
    def __init__(self):
        super().__init__()

        # Takes as input a tensor of 14 channels (8x8 board)
        self.conv1 = nn.Conv2d(14, 6, 3)  # 6 filters, 3x3 kernel
        self.pool = nn.MaxPool2d(2, 2)    # Max pooling with 2x2 window
        self.conv2 = nn.Conv2d(6, 16, 3)  # 16 filters, 3x3 kernel
        
        # If starting with 8x8, after two pool layers it becomes 2x2.
        # Output from conv2 will be (16 channels, 2x2 feature maps), flattened to 16 * 2 * 2 = 64
        self.fc1 = nn.Linear(16 * 2 * 2, 120)
        self.fc2 = nn.Linear(120, 84)
        # Predicts the tile to move the piece from (64 possible tiles on the board)
        self.fc3 = nn.Linear(84, 64)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # Apply first conv + pooling
        x = self.pool(F.relu(self.conv2(x)))  # Apply second conv + pooling
        x = torch.flatten(x, 1)  # Flatten all dimensions except batch size
        x = F.relu(self.fc1(x))  # Fully connected layer 1
        x = F.relu(self.fc2(x))  # Fully connected layer 2
        x = self.fc3(x)          # Output layer (no activation, logits for classification)
        return x

piece_to_move_net = PieceToMoveNet()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(piece_to_move_net.parameters(), lr=1e-3, weight_decay=1e-4)