In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

import torchvision
import torchvision.transforms.v2 as transforms
import torch.nn.functional as F
from torch.utils.data import TensorDataset

from data_collection.data_collector import DataCollector

import chess
import numpy as np

from sklearn.model_selection import train_test_split

data_collector = DataCollector(username="Hikaru")  # Create an instance of DataCollector

In [2]:
data = data_collector.get_data()  

In [3]:
x, y = zip(*data)

In [4]:
len(x)

457599

In [5]:
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)

In [6]:
x_train_tensors = torch.tensor(x_train, dtype=torch.float32)
y_train_tensors = torch.tensor(y_train, dtype=torch.long)
x_test_tensors = torch.tensor(x_test, dtype=torch.float32)
y_test_tensors = torch.tensor(y_test, dtype=torch.long)

  x_train_tensors = torch.tensor(x_train, dtype=torch.float32)


In [7]:
n_classes = 8*8
input_size = 8*8*12

model = nn.Sequential(
    # Convolutional layers
    nn.Conv2d(in_channels=12, out_channels=32, kernel_size=3, stride=1, padding=1),  # (12,8,8) -> (32,8,8)
    nn.BatchNorm2d(32),
    nn.ReLU(),
    nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),  # (32,8,8) -> (64,8,8)
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.Dropout(.2),
    nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),  # (64,8,8) -> (128,8,8)
    nn.BatchNorm2d(128),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),  # Reduce spatial dimensions from 8x8 to 4x4

    # Flatten the output from conv layers
    nn.Flatten(),  # Output shape: (batch_size, 128 * 4 * 4) = (batch_size, 2048)

    # Fully connected layers
    nn.Linear(128 * 4 * 4, 512),
    nn.Dropout(.3),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, n_classes)
)

In [8]:
model = torch.compile(model)

In [9]:
loss_function = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters())

In [10]:
def get_batch_accuracy(output, y, N):
    pred = output.argmax(dim=1, keepdim=True)
    correct = pred.eq(y.view_as(pred)).sum().item()
    return correct / N

In [11]:
train_set = TensorDataset(x_train_tensors, y_train_tensors)
valid_set = TensorDataset(x_test_tensors, y_test_tensors)

In [12]:
batch_size = 16
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size)

In [13]:
train_N = len(train_loader.dataset)
valid_N = len(valid_loader.dataset)

In [14]:
def train():
    loss = 0
    accuracy = 0

    model.train()
    for x, y in train_loader:
        output = model(x)
        optimizer.zero_grad()
        batch_loss = loss_function(output, y)
        batch_loss.backward()
        optimizer.step()

        loss += batch_loss.item()
        accuracy += get_batch_accuracy(output, y, train_N)
    print('Train - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))

In [15]:
def validate():
    loss = 0
    accuracy = 0

    model.eval()
    with torch.no_grad():
        for x, y in valid_loader:
            output = model(x)

            loss += loss_function(output, y).item()
            accuracy += get_batch_accuracy(output, y, valid_N)
    print('Valid - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))

In [16]:
epochs = 5

for epoch in range(epochs):
    print('Epoch: {}'.format(epoch))
    train()
    validate()

Epoch: 0
Train - Loss: 57817.8011 Accuracy: 0.2939
Valid - Loss: 10931.7975 Accuracy: 0.4030
Epoch: 1
Train - Loss: 44682.8344 Accuracy: 0.3886
Valid - Loss: 9986.0411 Accuracy: 0.4183
Epoch: 2
Train - Loss: 42120.9007 Accuracy: 0.4034
Valid - Loss: 9698.5490 Accuracy: 0.4227
Epoch: 3
Train - Loss: 40895.7784 Accuracy: 0.4117
Valid - Loss: 9585.1376 Accuracy: 0.4258
Epoch: 4
Train - Loss: 40137.0994 Accuracy: 0.4175
Valid - Loss: 9440.3701 Accuracy: 0.4269


In [22]:
def fen_to_tensor(fen: str) -> torch.Tensor:
    """
    Convert a FEN position into an 8x8x12 torch tensor.
    
    The 12 channels represent:
      Channel 0: White Pawn
      Channel 1: White Knight
      Channel 2: White Bishop
      Channel 3: White Rook
      Channel 4: White Queen
      Channel 5: White King
      Channel 6: Black Pawn
      Channel 7: Black Knight
      Channel 8: Black Bishop
      Channel 9: Black Rook
      Channel 10: Black Queen
      Channel 11: Black King

    Empty squares are 0, and the presence of a piece is indicated by 1.
    """
    board = chess.Board(fen)
    # Initialize an 8x8x12 array of zeros.
    board_array = np.zeros((8, 8, 12), dtype=np.float32)
    
    # Iterate over all 64 squares.
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece is None:
            continue
        
        # chess.square_rank returns 0 for rank 1, ..., 7 for rank 8.
        # We want row 0 to represent rank 8 and row 7 to represent rank 1.
        row = 7 - chess.square_rank(square)
        col = chess.square_file(square)
        
        # Determine the channel.
        # piece.piece_type: Pawn=1, Knight=2, Bishop=3, Rook=4, Queen=5, King=6.
        # White pieces go to channels 0-5, black pieces to channels 6-11.
        piece_type = piece.piece_type
        if piece.color:  # White piece
            channel = piece_type - 1
        else:  # Black piece
            channel = piece_type - 1 + 6
        
        board_array[row, col, channel] = 1.0

    # Convert the numpy array to a torch tensor.
    tensor_board = torch.tensor(board_array, dtype=torch.float32)
    
    tensor_board = np.transpose(tensor_board, (2, 0, 1))
    return tensor_board


In [79]:
fen_str = "8/5p2/2Rp1pkp/4p3/4P3/2pr3P/5PP1/6K1 b - - 5 28"
tensor_board = fen_to_tensor(fen_str)

In [80]:
tensor_board = tensor_board.unsqueeze(0)
prediction = softmaxed_tensor = F.softmax(model(tensor_board), dim=1)
matrix = prediction.view(8, 8)  
torch.set_printoptions(precision=3, sci_mode=False)

# Flatten the tensor
flattened_matrix = matrix.flatten()

# Find the top 3 maximum values and their indices (flattened)
top_k = torch.topk(flattened_matrix, k=3)

# Get the values and indices of the top 3 maximum values
max_values = top_k.values
max_indices = top_k.indices

# Convert the flattened indices to 2D indices using torch.div and torch.remainder
rows = torch.div(max_indices, matrix.size(1), rounding_mode='floor')  # Integer division (row)
cols = torch.remainder(max_indices, matrix.size(1))  # Remainder (column)

# Print the top 3 maximum values and their locations in chess notation
for i in range(3):
    chess_column = chr(cols[i].item() + ord('a'))  # Map column index to chess notation (a-h)
    chess_row = 8 - rows[i].item()  # Adjust row to match chess notation (8 is at the bottom)
    print(f"Max value: {max_values[i].item()} at {chess_column}{chess_row}")

Max value: 0.6800079345703125 at c6
Max value: 0.1525385081768036 at g1
Max value: 0.04776259511709213 at g2
