In [3]:
import re
import csv
import json
import chess
import numpy as np
import pandas as pd
from tqdm import tqdm
from itertools import islice
from datasets import load_dataset


import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, TensorDataset

# Data Arrangement

In [134]:
ds = load_dataset("austindavis/lichess-uci", split="train", streaming=True)

In [157]:
def extract_positions_and_evals(moves_str):
    move_san, fen_positions = [], []
    board = chess.Board()
    moves = moves_str.split()
    for pose in moves:
        try:
            fen_positions.append(board.fen())
            board.push_san(pose)
        except:
            return moves, fen_positions

    return moves, fen_positions

In [136]:
sample_iter = islice(ds, 10_000)
sample_list = list(sample_iter)

In [158]:
all_moves, positions, labels, game_ids = [], [], [], []
for game_id, row in tqdm(enumerate(sample_list)):
    moves_list, fen_list = extract_positions_and_evals(row["Transcript"])
    positions.extend(fen_list)
    all_moves.extend(moves_list[:len(fen_list)])
df = pd.DataFrame({"fen": positions, "move": all_moves})
# df.to_csv("chess_positions.csv", index=False)

10000it [00:19, 515.15it/s]

0.9838





# Fitting The Data To The Model

In [None]:
df = pd.read_csv("chess_positions.csv")

In [120]:
piece_map = {'P':0,'N':1,'B':2,'R':3,'Q':4,'K':5,
             'p':6,'n':7,'b':8,'r':9,'q':10,'k':11}

def fen_to_tensor(fen):
    board = chess.Board(fen)
    mat = np.zeros((8,8,12), dtype=np.float32)
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            row = 7 - chess.square_rank(square)
            col = chess.square_file(square)
            mat[row, col, piece_map[piece.symbol()]] = 1.0
    turn_channel = np.full((8,8,1), int(board.turn), dtype=np.float32)
    mat = np.concatenate([mat, turn_channel], axis=-1)
    return mat

In [121]:
all_moves = sorted(set(df["move"].to_list()))
move2idx = {m:i for i,m in enumerate(all_moves)}
idx2move = {i:m for m,i in move2idx.items()}

X_list = [fen_to_tensor(fen) for fen in df["fen"]]

X = torch.tensor(np.array(X_list), dtype=torch.float32).permute(0,3,1,2)
y_move = torch.tensor([move2idx[m] for m in df["move"]], dtype=torch.long)


In [None]:
torch.save({"X": X, "y_move": y_move}, "data.pt")

with open('my_dict.json', 'w') as f:
    json.dump(move2idx, f)

# Training

In [68]:
import torch
import torch.nn as nn

class ChessMovePredictor(nn.Module):
    def __init__(self, num_moves):
        super().__init__()
        self.conv1 = nn.Conv2d(13, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 4 * 4, 256)
        self.fc_move = nn.Linear(256, num_moves)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        return self.fc_move(x)

In [13]:
data = torch.load("data.pt")
X = data["X"]
y_move = data["y_move"]

with open('move_idx.json', 'r') as f:
    move2idx = json.load(f)

In [64]:
dataset = TensorDataset(X, y_move)
loader = DataLoader(dataset, batch_size=16, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ChessMovePredictor(num_moves=len(move2idx)).to(device)

criterion_move = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
for epoch in range(5):
    model.train()
    total_loss = 0.0
    total_samples = 0

    for X_batch, y_move_batch in loader:
        X_batch = X_batch.to(device)
        y_move_batch = y_move_batch.to(device)

        optimizer.zero_grad()
        logits_move = model(X_batch)

        loss = criterion_move(logits_move, y_move_batch)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * X_batch.size(0)
        total_samples += X_batch.size(0)

    epoch_loss = total_loss / total_samples
    print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}")


In [None]:
torch.save(model.state_dict(), "chess_model.pth")

# extend move2idx

In [None]:
data = torch.load("data.pt")
fens = data["fens"]
moves = data["moves"]

In [None]:
all_moves = sorted(set(df_old["move"]) | set(df_new["move"]))
move2idx = {m: i for i, m in enumerate(all_moves)}
idx2move = {i: m for m, i in move2idx.items()}