In [133]:
import re
import csv
import json
import chess
import numpy as np
import pandas as pd
from tqdm import tqdm
from itertools import islice
from datasets import load_dataset


import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, TensorDataset

# Data Arrangement

In [134]:
ds = load_dataset("Lichess/tournament-chess-games", split="train", streaming=True)

In [157]:
def extract_positions_and_evals(moves_str):
    move_san, fen_positions, evals = [], [], []
    board = chess.Board()
    moves = re.findall(r'. (o-o|o-o-o|O-O|O-O-O|\w{2,5})[+?!]*? { \[%eval (.*?)]', moves_str)
    for pose, val in moves[:100]:
        try:
            move_san.append(board.parse_san(pose).uci())
            board.push_san(pose)
            if '#' in val:
                val = 1000 if not val.startswith('-') else -1000
            else:
                val = float(val)
            
            fen_positions.append(board.fen())
            evals.append(val)
        except:
            return move_san, fen_positions, evals, 0
    
    return move_san, fen_positions, evals, 1

In [136]:
sample_iter = islice(ds, 10_000)
sample_list = list(sample_iter)

In [158]:
all_moves, positions, labels, game_ids = [], [], [], []
success = 0
for game_id, row in tqdm(enumerate(sample_list)):
    moves_list, fen_list, eval_list, count = extract_positions_and_evals(row["movetext"])
    success+=count
    positions.extend(fen_list)
    labels.extend(eval_list)
    game_ids.extend([game_id] * len(fen_list))
    all_moves.extend(moves_list[:len(fen_list)])

print(success/10000)
df = pd.DataFrame({"id":game_ids, "fen": positions, "move": all_moves, "eval": labels})
df.to_csv("chess_positions.csv", index=False)

10000it [00:19, 515.15it/s]

0.9838





# Fitting The Data To The Model

In [None]:
df = pd.read_csv("chess_positions.csv")

In [120]:
piece_map = {'P':0,'N':1,'B':2,'R':3,'Q':4,'K':5,
             'p':6,'n':7,'b':8,'r':9,'q':10,'k':11}

def fen_to_tensor(fen):
    board = chess.Board(fen)
    mat = np.zeros((8,8,12), dtype=np.float32)
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            row = 7 - chess.square_rank(square)
            col = chess.square_file(square)
            mat[row, col, piece_map[piece.symbol()]] = 1.0
    turn_channel = np.full((8,8,1), int(board.turn), dtype=np.float32)
    mat = np.concatenate([mat, turn_channel], axis=-1)
    return mat

In [121]:
all_moves = sorted(set(df["move"].to_list()))
move2idx = {m:i for i,m in enumerate(all_moves)}
idx2move = {i:m for m,i in move2idx.items()}

X_list = [fen_to_tensor(fen) for fen in df["fen"]]

X = torch.tensor(np.array(X_list), dtype=torch.float32).permute(0,3,1,2)
y_move = torch.tensor([move2idx[m] for m in df["move"]], dtype=torch.long)
y_eval = torch.tensor(df["eval"].values, dtype=torch.float32)

In [None]:
torch.save({"X": X, "y_move": y_move, "y_eval": y_eval}, "data.pt")

with open('my_dict.json', 'w') as f:
    json.dump(move2idx, f)

# Training

In [68]:
# simple model
import torch
import torch.nn as nn

class ChessMovePredictor(nn.Module):
    def __init__(self, num_moves):
        super().__init__()
        self.conv1 = nn.Conv2d(13, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64*4*4, 256)
        self.fc_move = nn.Linear(256, num_moves)  # ראש המהלך
        self.fc_eval = nn.Linear(256, 1)          # ראש הציון

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        return self.fc_move(x), self.fc_eval(x)

In [None]:
data = torch.load("data.pt")
X = data["X"]
y_move = data["y_move"]
y_eval = data["y_eval"]

with open('my_dict.json', 'r') as f:
    move2idx = json.load(f)

In [64]:
dataset = TensorDataset(X, y_move, y_eval)
loader = DataLoader(dataset, batch_size=16, shuffle=True)

# device = torch.device("cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ChessMovePredictor(num_moves=len(move2idx)).to(device)

criterion_move = nn.CrossEntropyLoss(ignore_index=-1)
criterion_eval = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
beta = 1.0
lambda_eval = 0.5

for epoch in range(5):
    model.train()
    total_loss = 0.0
    total_weight = 0.0

    for X_batch, y_move_batch, y_eval_batch in loader:
        X_batch = X_batch.to(device)
        y_move_batch = y_move_batch.to(device)
        y_eval_batch = y_eval_batch.to(device)

        optimizer.zero_grad()
        logits_move, pred_eval = model(X_batch)

        loss_per_example = criterion_move(logits_move, y_move_batch)

        # normalized
        y_eval_min = y_eval_batch.min()
        y_eval_max = y_eval_batch.max()
        if y_eval_max - y_eval_min < 1e-8:
            y_eval_norm = torch.zeros_like(y_eval_batch)
        else:
            y_eval_norm = (y_eval_batch - y_eval_min) / (y_eval_max - y_eval_min)

        weights = torch.exp(beta * y_eval_norm)
        weighted_loss = (loss_per_example * weights).mean()
        loss_eval = criterion_eval(pred_eval.squeeze(), y_eval_batch)
        total_batch_loss = weighted_loss + lambda_eval * loss_eval

        if torch.isnan(total_batch_loss) or torch.isinf(total_batch_loss):
            print("NaN or Inf loss detected, skipping batch")
            continue

        total_batch_loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        total_loss += (loss_per_example * weights).sum().item()
        total_weight += weights.sum().item()

    epoch_loss = total_loss / (total_weight + 1e-8)
    print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}")

In [None]:
torch.save(model.state_dict(), "chess_model.pth")