In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import pandas as pd
import chess
import torch.nn.functional as F


In [None]:
def evaluate_puzzle_accuracy(
    model: ChessGPT_cls,
    puzzle_data: pd.DataFrame,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
) -> float:
    """
    Evaluates the model's accuracy on chess puzzles.
    A puzzle is considered solved if the intended first move is in the model's top 3 predictions.

    Args:
        model (ChessGPT): The trained model to evaluate
        puzzle_data (pd.DataFrame): DataFrame containing puzzles with FEN positions and moves
        tokenizer: Tokenizer object to convert FEN strings to board tokens
        device (str): Device to run evaluation on

    Returns:
        float: Accuracy score between 0 and 1
    """
    model = model.to(device)
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0

    # Create progress bar
    pbar = tqdm(
        puzzle_data.iterrows(), total=len(puzzle_data), desc="Evaluating puzzles"
    )

    with torch.no_grad():
        for _, row in pbar:
            # Get the FEN and first move from puzzle
            fen = row["FEN"]
            moves = row["Moves"].split()[0]  # Get first move only
            intended_move = chess.Move.from_uci(moves)

            # Create board from FEN
            board = chess.Board(fen)

            # Tokenize board
            board_tokens = tokenize_board_v1(board)
            board_tokens = torch.tensor(board_tokens, device=device).unsqueeze(
                0
            )  # Add batch dimension

            # Get model's top moves
            top_moves = model.get_top_moves(board_tokens, board, n=5)
            predicted_moves = [move for move, _ in top_moves]

            # Check if intended move is in top 3
            if intended_move in predicted_moves:
                correct += 1
            total += 1

            # Update progress bar with current accuracy
            pbar.set_postfix({"accuracy": f"{(correct/total):.2%}"})

    accuracy = correct / total if total > 0 else 0
    print(f"\nFinal puzzle accuracy: {accuracy:.2%}")
    return accuracy


def train(
    model: ChessGPT_cls,
    train_dataset: ChessDataset,
    val_dataset: ChessDataset | None = None,
    *,
    batch_size: int = 64,
    epochs: int = 100,
    learning_rate: float = 1e-4,
    warmup_steps: int | None = None,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
) -> None:
    """
    Train the ChessGPT model on a dataset of chess positions and moves.

    Args:
        model (ChessGPT): The model to train
        train_dataset (ChessDataset): Dataset containing training examples
        val_dataset (ChessDataset, optional): Dataset for validation
        batch_size (int, optional): Training batch size. Defaults to 32.
        epochs (int, optional): Number of training epochs. Defaults to 10.
        learning_rate (float, optional): Learning rate. Defaults to 1e-4.
        warmup_steps (int, optional): Number of warmup steps. If provided, learning rate
            will linearly increase from 0 to learning_rate over this many steps.
        device (str, optional): Device to train on. Defaults to CUDA if available.
    """
    model = model.to(device)
    model.train()

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        pin_memory=True,
    )

    val_loader = None
    if val_dataset is not None:
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            pin_memory=True,
        )

    # Setup optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    criterion = torch.nn.KLDivLoss(reduction="batchmean")

    # Setup learning rate scheduler for warmup
    scheduler = None
    if warmup_steps is not None:
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lambda step: min(1.0, step / warmup_steps) if warmup_steps > 0 else 1.0,
        )

    # Training loop
    global_step = 0
    last_val_loss = float("inf")
    for epoch in range(epochs):
        # Training phase
        model.train()
        running_loss = 0.0

        # Create progress bar for this epoch
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=True)

        i = 0
        for boards, targets in pbar:
            boards = boards.to(device)
            targets = targets.to(device)

            # Forward pass
            logits = model(boards)

            # Compute loss (KL divergence between predicted and target distributions)
            loss = criterion(F.log_softmax(logits, dim=-1), targets)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update learning rate for warmup
            if scheduler is not None:
                scheduler.step()

            # Update running loss and progress bar
            running_loss += loss.item()
            i += 1
            pbar.set_postfix({"train_loss": f"{running_loss / i:.4f}"})

            global_step += 1

            # Print current learning rate during warmup
            if (
                warmup_steps is not None
                and global_step <= warmup_steps
                and global_step % 100 == 0
            ):
                current_lr = optimizer.param_groups[0]["lr"]
                print(f"Step {global_step}, LR: {current_lr:.6f}")

        # Validation phase
        val_loss = None
        if val_loader is not None:
            model.eval()
            val_loss = 0.0
            val_steps = 0

            with torch.no_grad():
                for boards, targets in val_loader:
                    boards = boards.to(device)
                    targets = targets.to(device)

                    logits = model(boards)
                    loss = criterion(F.log_softmax(logits, dim=-1), targets)

                    val_loss += loss.item()
                    val_steps += 1

            val_loss /= val_steps
            print(f"Validation loss: {val_loss:.4f}")

            with open(f"info_{epoch}.txt", "w") as fl:
                fl.write(f"train_loss: {running_loss/i}, val_loss: {val_loss}")

        if val_loss is None or last_val_loss is None or val_loss < last_val_loss:
            last_val_loss = val_loss
            print(f"Saving model with validation loss {val_loss:.4f} as best model")
            torch.save(model.state_dict(), f"model_best.pth")