In [1]:
import torch


def get_device() -> torch.device:
    if torch.backends.mps.is_available():
        return torch.device("mps")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")


device = get_device()
print(f"Using device: {device}")

Using device: mps


In [None]:
from chet import Chet42, tokenize_board

model = Chet42()
model.to(device)
print(model.get_n_params())

tokenizer = lambda board: tokenize_board(board).to(device)

42169964


In [16]:
from typing import Callable
import chess
import csv
from torch.utils.data import Dataset
from tqdm import tqdm


def load_dataset(
    csv_file: str,
    tokenizer: Callable[[chess.Board], torch.Tensor],
    *,
    limit: int | None = None,
    skip_header: bool = True,
) -> "ChessDataset":
    """
    Load a dataset from a CSV file.

    Expected format:
    ```
    board_fen,move_uci
    rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1,e2e4
    ...
    ```
    """

    boards = []
    moves = []

    with open(csv_file, "r") as f:
        reader = csv.reader(f)

        if skip_header:
            next(reader)

        for row in tqdm(reader, desc="Loading dataset", total=limit):
            board_fen, move_uci = row
            board = board_fen
            boards.append(board)

            move = move_uci
            moves.append(move)

            if limit and len(moves) == limit:
                break

    return ChessDataset(boards, moves, tokenizer)


class ChessDataset(Dataset):
    boards: list[str]
    moves: list[str]
    tokenizer: Callable[[chess.Board], torch.Tensor]

    def __init__(
        self,
        boards: list[str],
        moves: list[str],
        tokenizer: Callable[[chess.Board], torch.Tensor],
    ) -> None:
        super().__init__()
        self.boards = boards
        self.moves = moves
        self.tokenizer = tokenizer

    def __len__(self) -> int:
        return len(self.moves)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Get a training example from the dataset.

        Args:
            idx (int): Index of the example to get

        Returns:
            tuple[torch.Tensor, torch.Tensor]: Tuple containing:
                - Board tokens tensor of shape [65]
                - Target move probabilities tensor of shape [4096]
                - Legal move mask tensor of shape [4096]. 1.0 if the move is legal, 0.0 otherwise.
        """

        board = chess.Board(self.boards[idx])
        move = chess.Move.from_uci(self.moves[idx])

        board_tokens = self.tokenizer(board)
        target = torch.zeros(4096)

        target[move.from_square * 64 + move.to_square] = 1.0

        legal_move_mask = torch.zeros(4096, dtype=torch.bool)
        for move in board.legal_moves:
            legal_move_mask[move.from_square * 64 + move.to_square] = True

        assert legal_move_mask.sum() > 0, "No legal moves found"
        assert target.sum() == 1.0, "Target is not a one-hot vector"
        assert (
            legal_move_mask * target
        ).sum() > 0, f"Target is not in the legal move mask. FEN: {board.fen()} MOVE: {self.moves[idx]}"

        return board_tokens, target, legal_move_mask


def split_dataset(
    dataset: ChessDataset, val_split: float
) -> tuple[ChessDataset, ChessDataset]:
    n_val = int(len(dataset) * val_split)
    train_boards = dataset.boards[:-n_val]
    val_boards = dataset.boards[-n_val:]

    train_moves = dataset.moves[:-n_val]
    val_moves = dataset.moves[-n_val:]

    return (
        ChessDataset(train_boards, train_moves, dataset.tokenizer),
        ChessDataset(val_boards, val_moves, dataset.tokenizer),
    )

In [17]:
all_dataset = load_dataset("data_processing/training_data_2.csv", tokenizer)

train_dataset, val_dataset = split_dataset(all_dataset, 0.1)

Loading dataset: 100%|█████████▉| 999/1000 [00:00<00:00, 767841.25it/s]


In [30]:
from dataclasses import dataclass
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn import functional as F


@dataclass
class TrainingConfig:
    batch_size: int
    lr: float
    weight_decay: float
    warmup_steps: int

    checkpoint_path: str
    checkpoint_every: int

    training_loss_last_n_batches: int
    metrics_path: str

    device: torch.device


@dataclass
class TrainingMetrics:
    loss: float
    accuracy: float


@dataclass
class TrainingHistoryPoint:
    train: TrainingMetrics
    val: TrainingMetrics
    step: int
    p_epoch: float


class TrainingHistory:
    data: list[TrainingHistoryPoint]

    def __init__(self):
        self.data = []

    def add_point(
        self, train: TrainingMetrics, val: TrainingMetrics, step: int, p_epoch: float
    ):
        self.data.append(TrainingHistoryPoint(train, val, step, p_epoch))

    def as_csv(self, path: str) -> None:
        with open(path, "w") as f:
            f.write(
                "train_loss,train_accuracy,val_loss,val_accuracy,global_step,p_epoch\n"
            )
            for point in self.data:
                f.write(
                    f"{point.train.loss},{point.train.accuracy},{point.val.loss},{point.val.accuracy},{point.step},{point.p_epoch}\n"
                )


class SlidingAverageTrainingMetrics:
    n: int
    _metrics: list[TrainingMetrics]

    def __init__(self, n: int):
        self.n = n
        self._metrics = []

    def add_metric(self, metric: TrainingMetrics):
        self._metrics.append(metric)

        if len(self._metrics) > self.n:
            self._metrics.pop(0)

    def get_average(self) -> TrainingMetrics:
        return TrainingMetrics(
            sum(m.loss for m in self._metrics) / len(self._metrics),
            sum(m.accuracy for m in self._metrics) / len(self._metrics),
        )


class MaskedCrossEntropyLoss(nn.Module):
    """
    A custom loss function for scenarios where some classes are invalid for specific examples.

    The loss ignores predictions made for invalid classes and normalizes the remaining
    valid logits before computing cross-entropy.

    Args:
        reduction (str): Specifies the reduction to apply to the output:
            'none' | 'mean' | 'sum'. Default: 'mean'
    """

    def __init__(self, reduction="mean"):
        super(MaskedCrossEntropyLoss, self).__init__()
        self.reduction = reduction

    def forward(
        self, logits: torch.Tensor, targets: torch.Tensor, valid_mask: torch.Tensor
    ):
        """
        Args:
            logits (torch.Tensor): Raw model output of shape [batch_size, num_classes]
            targets (torch.Tensor): Ground truth labels of shape [batch_size]
            valid_mask (torch.Tensor): Boolean mask of shape [batch_size, num_classes] where
                                      True indicates a valid class and False indicates an invalid class

        Returns:
            torch.Tensor: The computed loss
        """

        logits = torch.masked_fill(logits, ~valid_mask, 1e-9)

        return F.cross_entropy(logits, targets, reduction=self.reduction)


class Trainer:
    model: nn.Module
    train_loader: DataLoader
    val_loader: DataLoader
    config: TrainingConfig

    optimizer: torch.optim.Optimizer
    scheduler: torch.optim.lr_scheduler.LambdaLR
    criterion: MaskedCrossEntropyLoss

    train_metrics: SlidingAverageTrainingMetrics

    def __init__(
        self,
        *,
        model: nn.Module,
        train_dataset: ChessDataset,
        val_dataset: ChessDataset,
        config: TrainingConfig,
    ):
        self.model = model
        self.train_loader = DataLoader(
            train_dataset,
            batch_size=config.batch_size,
            shuffle=True,
        )
        self.val_loader = DataLoader(
            val_dataset,
            batch_size=config.batch_size,
            shuffle=False,
        )
        self.config = config

        self.optimizer = torch.optim.AdamW(
            model.parameters(), lr=config.lr, weight_decay=config.weight_decay
        )

        self.scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer,
            lambda step: (
                min(1.0, step / config.warmup_steps) if config.warmup_steps > 0 else 1.0
            ),
        )

        self.criterion = MaskedCrossEntropyLoss()

        self.train_metrics = SlidingAverageTrainingMetrics(
            config.training_loss_last_n_batches
        )

    def train(self) -> None:
        best_val_loss = None
        best_val_acc = None
        global_step = 0
        epoch = 0
        history = TrainingHistory()

        while True:
            n_batches = len(self.train_loader)
            pbar = tqdm(self.train_loader, leave=True)

            for boards, targets, legal_move_masks in pbar:
                boards = boards.to(self.config.device)
                targets = targets.to(self.config.device)
                legal_move_masks = legal_move_masks.to(self.config.device)

                metrics = self.train_batch(boards, targets, legal_move_masks)
                self.train_metrics.add_metric(metrics)

                # after `config.checkpoint_every` steps, compute validation metrics,
                # save the best model, and add the training metrics to the history
                if global_step > 0 and global_step % self.config.checkpoint_every == 0:
                    val_metrics = self.compute_validation()
                    history.add_point(
                        self.train_metrics.get_average(),
                        val_metrics,
                        global_step,
                        global_step / n_batches,
                    )

                    if best_val_loss is None or val_metrics.loss < best_val_loss:
                        best_val_loss = val_metrics.loss
                        best_val_acc = val_metrics.accuracy
                        self.save_checkpoint()
                    history.as_csv(self.config.metrics_path)

                pbar.set_postfix(
                    {
                        "train_loss": f"{self.train_metrics.get_average().loss:.4f}",
                        "train_acc": f"{self.train_metrics.get_average().accuracy:.4f}",
                        "best_val_loss": (
                            f"{best_val_loss:.4f}"
                            if best_val_loss is not None
                            else "N/A"
                        ),
                        "best_val_acc": (
                            f"{best_val_acc:.4f}" if best_val_acc is not None else "N/A"
                        ),
                        "lr": f"{self.optimizer.param_groups[0]['lr']:.6f}",
                        "global_step": f"{global_step:,}",
                    }
                )

                global_step += 1

            epoch += 1

    def save_checkpoint(self) -> None:
        with open(self.config.checkpoint_path, "wb") as f:
            torch.save(self.model.state_dict(), f)

    def train_batch(
        self,
        boards: torch.Tensor,
        targets: torch.Tensor,
        legal_move_masks: torch.Tensor,
    ) -> TrainingMetrics:
        self.model.train()
        self.optimizer.zero_grad()

        logits = self.model(boards)
        loss = self.criterion(logits, targets, legal_move_masks)
        loss.backward()
        self.optimizer.step()
        self.scheduler.step()

        with torch.no_grad():
            masked_logits = logits.masked_fill(legal_move_masks == 0, float("-inf"))
            _, predicted = torch.max(masked_logits, 1)
            _, target_moves = torch.max(targets, 1)
            correct = (predicted == target_moves).sum().item()
            total = targets.size(0)
            accuracy = correct / total

        return TrainingMetrics(loss.item(), accuracy)

    def compute_validation(self) -> TrainingMetrics:
        with torch.no_grad():
            self.model.eval()
            correct = 0
            total = 0
            val_loss = 0.0

            for boards, targets, legal_move_masks in self.val_loader:
                boards = boards.to(self.config.device)
                targets = targets.to(self.config.device)
                legal_move_masks = legal_move_masks.to(self.config.device)

                logits = self.model(boards)
                loss = self.criterion(logits, targets, legal_move_masks)
                val_loss += loss.item()

                logits = logits.masked_fill(legal_move_masks == 0, float("-inf"))
                _, predicted = torch.max(logits, 1)
                _, target_moves = torch.max(targets, 1)
                total += targets.size(0)
                correct += (predicted == target_moves).sum().item()

            val_loss /= len(self.val_loader)
            val_acc = correct / total

            return TrainingMetrics(val_loss, val_acc)

In [None]:
training_config = TrainingConfig(
    batch_size=256,
    lr=1e-4,
    weight_decay=1e-4,
    warmup_steps=10_000,
    checkpoint_path="checkpoint.pth",
    metrics_path="metrics.csv",
    checkpoint_every=15_000,

    training_loss_last_n_batches=20,
    device=get_device(),
)

trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    config=training_config,
)

trainer.train()