In [1]:
from datasets import load_dataset

ds = load_dataset("noor-zalouk/tournament-chess-games-modified")

In [2]:
import torch
from torch.utils.data import Dataset

class ChessDataset(Dataset):
    def __init__(self, ds, tokenizer, label_to_id, max_length):
        self.ds = ds
        self.tokenizer = tokenizer
        self.label_to_id = label_to_id
        self.max_length = max_length

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        text = '[CLS]' + ' ' + self.ds[idx]['positions']
        input_ids = [self.tokenizer[word] for word in text.split()]
        input_ids = input_ids[:self.max_length]

        if self.ds[idx]['moves'] in ['e1g1', 'e8g8', 'e1c1', 'e8c8']:
            start_labels = self.label_to_id[self.ds['moves'][idx]]
            end_labels = start_labels
        else:
            start_labels = self.label_to_id[self.ds[idx]['moves'][0:2]]
            end_labels = self.label_to_id[self.ds[idx]['moves'][2:4]]

        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'start_labels': torch.tensor(start_labels, dtype=torch.long),
            'end_labels': torch.tensor(end_labels, dtype=torch.long)
        }

In [3]:
import architecture as arch

config = arch.Config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = arch.tokenizer
label_to_id = arch.label_to_id

In [6]:
from transformers import get_linear_schedule_with_warmup
from torch.optim import AdamW
from torch.utils.data import DataLoader
import torch.nn as nn


model = arch.ChessMoveClassifier(config, device)
model.to(device)

batch_size = 128
gradient_accumulation_steps = 1
epochs = 15
lr = 8e-4

criterion = nn.CrossEntropyLoss()

train_set = ChessDataset(ds['train'], tokenizer, label_to_id, config.max_position_embeddings)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
valid_set = ChessDataset(ds['valid'], tokenizer, label_to_id, config.max_position_embeddings)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=0)

# Define optimizer
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=0.01)

# Scheduler setup
num_update_steps_per_epoch = len(train_loader) // gradient_accumulation_steps
num_training_steps = num_update_steps_per_epoch * epochs
num_warmup_steps = int(0.1 * num_training_steps)  # 10% warmup

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

In [7]:
def evaluate(model, dataloader, device):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    total_loss = 0
    total_batches = 0

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            start_labels = batch['start_labels'].to(device)
            end_labels = batch['end_labels'].to(device)

            start_logits, end_logits = model(input_ids=input_ids)
            loss_start = criterion(start_logits, start_labels)
            loss_end = criterion(end_logits, end_labels)
            loss = (loss_start + loss_end) / 2
            
            total_loss += loss.item()
            total_batches += 1

    avg_val_loss = total_loss / total_batches
    return avg_val_loss

In [8]:
import mlflow
import mlflow.pytorch
from tqdm import tqdm

# Start an MLflow run
with mlflow.start_run():
    # Log hyperparameters
    mlflow.log_params({
        "epochs": epochs,
        "batch_size": batch_size,
        "initial_lr": lr,
        "gradient_accumulation_steps": gradient_accumulation_steps,
        "num_warmup_steps": num_warmup_steps,
        "config": config.to_dict()
    })

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        total_batches = 0

        loop = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False)
        for step, batch in enumerate(loop):
            input_ids = batch['input_ids'].to(device)
            start_labels = batch['start_labels'].to(device)
            end_labels = batch['end_labels'].to(device)

            start_logits, end_logits = model(input_ids=input_ids)
            loss_start = criterion(start_logits, start_labels)
            loss_end = criterion(end_logits, end_labels)
            loss = (loss_start + loss_end) / 2

            loss = loss / gradient_accumulation_steps
            loss.backward()

            if (step + 1) % gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            total_loss += loss.item()
            total_batches += 1
            loop.set_postfix(loss=loss.item(), lr=scheduler.get_last_lr()[0])

        train_loss = total_loss / total_batches
        valid_loss = evaluate(model, valid_loader, device)

        # Log metrics
        mlflow.log_metrics({
            "train_loss": train_loss,
            "val_loss": valid_loss,
            "lr": scheduler.get_last_lr()[0]
        }, step=epoch)

        # Save model checkpoint with MLflow
        mlflow.pytorch.log_model(model, f"chess-epoch-{epoch}")
        print(f"Epoch {epoch}/{epochs} - Train Loss: {train_loss:.4f}, Validation Loss: {valid_loss:.4f}")



Epoch 1/15 - Train Loss: 3.0261, Validation Loss: 2.4235




Epoch 2/15 - Train Loss: 2.4564, Validation Loss: 2.2156




Epoch 3/15 - Train Loss: 2.3035, Validation Loss: 2.1000




Epoch 4/15 - Train Loss: 2.2260, Validation Loss: 2.0515




Epoch 5/15 - Train Loss: 2.1766, Validation Loss: 2.0087




Epoch 6/15 - Train Loss: 2.1399, Validation Loss: 1.9765




Epoch 7/15 - Train Loss: 2.1106, Validation Loss: 1.9494




Epoch 8/15 - Train Loss: 2.0859, Validation Loss: 1.9264




Epoch 9/15 - Train Loss: 2.0626, Validation Loss: 1.9006




Epoch 10/15 - Train Loss: 2.0411, Validation Loss: 1.8879




Epoch 11/15 - Train Loss: 2.0205, Validation Loss: 1.8645




Epoch 12/15 - Train Loss: 1.9997, Validation Loss: 1.8507




Epoch 13/15 - Train Loss: 1.9801, Validation Loss: 1.8303




Epoch 14/15 - Train Loss: 1.9601, Validation Loss: 1.8110




Epoch 15/15 - Train Loss: 1.9432, Validation Loss: 1.8023


In [9]:
# Replace with the actual run_id from MLflow
run_id = "0e8284ddafe4494099b88c7c85725df7"
artifact_path = "chess-epoch-15"  # path you logged under

model_loaded = mlflow.pytorch.load_model(f"runs:/{run_id}/{artifact_path}")
model_loaded.to(device)
model_loaded.eval()

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

ChessMoveClassifier(
  (encoder): TransformerEncoder(
    (embeddings): Embeddings(
      (token_embeddings): Embedding(20, 128)
      (position_embeddings): Embedding(68, 128)
      (layer_norm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (layer_norm_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (layer_norm_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attention): MultiHeadAttention(
          (heads): ModuleList(
            (0-3): 4 x AttentionHead(
              (q): Linear(in_features=128, out_features=32, bias=True)
              (k): Linear(in_features=128, out_features=32, bias=True)
              (v): Linear(in_features=128, out_features=32, bias=True)
            )
          )
          (output_linear): Linear(in_features=128, out_features=128, bias=True)
          (dropout): Dropout(p=0.1, inplace

In [12]:
torch.save(model_loaded.state_dict(), "chess_model.pth")

In [23]:
import pipe
import chess

board = chess.Board()
input = pipe.prepare_input(board)
input = input.unsqueeze(0).to(device)

In [24]:
input

tensor([[ 0,  1,  4,  6, 11,  9, 10, 12, 13, 10,  9, 11,  8,  8,  8,  8,  8,  8,
          8,  8,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,
          7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7, 14, 14,
         14, 14, 14, 14, 14, 14, 17, 15, 16, 18, 19, 16, 15, 17]],
       device='cuda:0')

In [25]:
output = model_loaded(input)
output

(tensor([[ -0.2780,   0.9939,  -2.5925,   0.0333,  -6.8306,  -5.6342,  -6.0418,
           -5.6006,   3.9611,   5.6951,  -2.5400,  -4.5057,  -1.8292,  -4.9075,
           -4.5986,  -6.2393,   3.0092,   9.1952,  -1.0495,  -3.6861,  -2.1646,
           -5.9491,  -4.2649,  -8.5509,   0.9781,  10.6924,   2.9767,  -7.2969,
           -0.1877,  -7.8628,  -3.8501,  -5.3964,  -0.6315,  10.9902,   1.4613,
           -6.5034,  -0.5225,  -7.2955,  -4.8630,  -6.2635,   2.6128,   2.2747,
           -1.1401,  -2.1958,   1.1575,  -4.4839,  -5.8524,  -8.3616,   9.3948,
            5.2283,  -3.1537,  -2.1141,  -5.0036,  -4.0353,  -6.4315,  -4.5495,
           -4.6077,   1.6003,  -1.5516,  -2.3726,  -1.7077, -11.0678,  -4.8503,
           -9.0453,  -2.0666,  -2.7445]], device='cuda:0',
        grad_fn=<AddmmBackward0>),
 tensor([[-5.9376, -2.6215,  0.3279, -0.5883, -2.5371, -4.7404, -5.1698, -4.6630,
          -2.0081, -2.0776,  4.5738,  1.2005, -0.7996, -5.6397, -3.4416, -6.6084,
          -2.7797, -1.

In [37]:
torch.topk(output[0], k=output[0].size(1) ,dim=-1)[1]

tensor([[33, 25, 48, 17,  9, 49,  8, 16, 26, 40, 41, 57, 34, 44,  1, 24,  3, 28,
          0, 36, 32, 18, 42, 58, 60, 12, 64, 51, 20, 43, 59, 10,  2, 65, 50, 19,
         30, 53, 22, 45, 11, 55, 14, 56, 62, 38, 13, 52, 31,  7,  5, 46, 21,  6,
         15, 39, 54, 35,  4, 37, 27, 29, 47, 23, 63, 61]], device='cuda:0')