In [None]:
%pip install torch
%pip install transformers
%pip install ipywidgets

In [None]:
import csv
import random
from typing import Iterator, Tuple
from tqdm import tqdm

def process_game(game: str) -> Iterator[Tuple[str, str, str, str]]:
    moves = game.split()
    outcome = moves[-1]
    moves = moves[:-1]  # Remove the outcome from the move list
    
    for i in range(len(moves)):
        context = " ".join(moves[:i])
        next_move = moves[i]
        is_checkmate = "1" if next_move.endswith("#") else "0"
        
        # For the last move, we know the outcome
        if i == len(moves) - 1:
            yield context, next_move, is_checkmate, outcome
        else:
            yield context, next_move, is_checkmate, ""

def prepare_training_data(input_file: str, train_file: str, val_file: str, max_context_length: int = 50, val_split: float = 0.1):
    with open(train_file, 'w', newline='') as train_outfile, open(val_file, 'w', newline='') as val_outfile:
        train_writer = csv.writer(train_outfile)
        val_writer = csv.writer(val_outfile)
        
        headers = ['context', 'next_move', 'is_checkmate', 'outcome']
        train_writer.writerow(headers)
        val_writer.writerow(headers)
        
        # Count total lines for progress bar
        total_lines = sum(1 for _ in open(input_file, 'r'))
        
        with open(input_file, 'r') as infile:
            for line in tqdm(infile, total=total_lines, desc="Processing games"):
                game = line.strip()
                for context, next_move, is_checkmate, outcome in process_game(game):
                    # Limit context to last `max_context_length` moves
                    context_moves = context.split()[-max_context_length:]
                    limited_context = " ".join(context_moves)
                    
                    # Decide whether to write to train or val file
                    if random.random() < val_split:
                        val_writer.writerow([limited_context, next_move, is_checkmate, outcome])
                    else:
                        train_writer.writerow([limited_context, next_move, is_checkmate, outcome])

In [None]:
prepare_training_data("out/grandmaster.txt", "out/training-data.csv", "out/validation-data.csv")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Config, GPT2Model

class ChessTransformer(nn.Module):
    # The defaults here are a relatively small and easy-to-train model
    def __init__(self, vocab_size, n_positions=50, n_embd=128, n_layer=2, n_head=2):
        super(ChessTransformer, self).__init__()
        
        self.config = GPT2Config(
            vocab_size=vocab_size,
            n_positions=n_positions,
            n_embd=n_embd,
            n_layer=n_layer,
            n_head=n_head
        )
        
        self.transformer = GPT2Model(self.config)
        self.move_head = nn.Linear(n_embd, vocab_size)
        self.checkmate_head = nn.Linear(n_embd, 1)
        self.outcome_head = nn.Linear(n_embd, 3)  # Win, Loss, Draw
        
    def forward(self, input_ids, attention_mask=None):
        outputs = self.transformer(input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state
        prediction_hidden_state = hidden_states[:, -1, :]
        
        move_logits = self.move_head(prediction_hidden_state)
        checkmate_logits = self.checkmate_head(prediction_hidden_state)
        outcome_logits = self.outcome_head(prediction_hidden_state)
        
        return move_logits, checkmate_logits, outcome_logits
        
class ChessTokenizer:
    def __init__(self):
        self.move_to_id = {"[PAD]": 0, "[UNK]": 1}
        self.id_to_move = {0: "[PAD]", 1: "[UNK]"}
        self.vocab_size = 2  # Start with PAD and UNK tokens

    def fit(self, moves):
        for move in moves:
            if move not in self.move_to_id:
                self.move_to_id[move] = self.vocab_size
                self.id_to_move[self.vocab_size] = move
                self.vocab_size += 1

    def encode(self, moves):
        return [self.move_to_id.get(move, self.move_to_id["[UNK]"]) for move in moves]

    def decode(self, ids):
        return [self.id_to_move.get(id, "[UNK]") for id in ids]

In [None]:
import torch
from torch.utils.data import Dataset
import csv
import mmap

class ChessDataset(Dataset):
    def __init__(self, csv_file, tokenizer, max_length=50):
        self.csv_file = csv_file
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.line_offsets = []

        # Open the file and keep it open
        self.file = open(self.csv_file, 'r')
        
        # Create an index of line offsets for random access
        with open(self.csv_file, 'rb') as f:
            mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
            self.line_offsets.append(0)
            while mm.readline():
                self.line_offsets.append(mm.tell())
            mm.close()

        # Remove the last offset (empty line at the end of file)
        self.line_offsets.pop()

    def __len__(self):
        return len(self.line_offsets) - 1  # Subtract 1 to account for header

    def __getitem__(self, idx):
        # Add 1 to idx to skip the header
        self.file.seek(self.line_offsets[idx + 1])
        line = self.file.readline().strip()

        # Parse the CSV line
        row = next(csv.reader([line]))
        context, next_move, is_checkmate, outcome = row

        context = context.split() if context else []
        is_checkmate = float(is_checkmate)

        # Tokenize input (context)
        input_ids = self.tokenizer.encode(context)
        input_ids = input_ids[-self.max_length:]  # Keep only the last max_length tokens
        input_ids = [0] * (self.max_length - len(input_ids)) + input_ids  # Pad from the left

        # Create labels (next_move)
        labels = self.tokenizer.encode([next_move])[0]
        
        # Convert outcome to one-hot encoding (as float)
        outcome_label = torch.zeros(3, dtype=torch.float)
        if outcome == '1-0':
            outcome_label[0] = 1.0
        elif outcome == '0-1':
            outcome_label[1] = 1.0
        elif outcome == '1/2-1/2':
            outcome_label[2] = 1.0

        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.long),
            'is_checkmate': torch.tensor(is_checkmate, dtype=torch.float),
            'outcome': outcome_label
        }

    def __del__(self):
        # Close the file when the dataset object is destroyed
        if hasattr(self, 'file'):
            self.file.close()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import random

def fit_tokenizer(csv_file):
    unique_moves = set()
    with open(csv_file, 'r') as data:
        for row in data:
            context, _next_move, _is_checkmate, _outcome = row.split(',')
            context = context.strip().split()
            for move in context:
                unique_moves.add(move)
        
    tokenizer = ChessTokenizer()
    tokenizer.fit(list(unique_moves))
    return tokenizer

def train_model(model, train_dataloader, val_dataloader, num_epochs, learning_rate, device):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    move_criterion = nn.CrossEntropyLoss()
    checkmate_criterion = nn.BCEWithLogitsLoss()
    outcome_criterion = nn.BCEWithLogitsLoss()

    total_steps = num_epochs * len(train_dataloader)
    progress_bar = tqdm(total=total_steps, desc="Training Progress")

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        
        for batch in train_dataloader:
            input_ids = batch['input_ids'].to(device)
            move_labels = batch['labels'].to(device)
            checkmate_labels = batch['is_checkmate'].to(device)
            outcome_labels = batch['outcome'].to(device)

            optimizer.zero_grad()

            move_logits, checkmate_logits, outcome_logits = model(input_ids)

            move_loss = move_criterion(move_logits, move_labels)
            checkmate_loss = checkmate_criterion(checkmate_logits.squeeze(), checkmate_labels)
            outcome_loss = outcome_criterion(outcome_logits, outcome_labels)

            loss = move_loss + 0.1 * checkmate_loss + 0.1 * outcome_loss

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Update progress bar
            progress_bar.update(1)
            progress_bar.set_postfix({'epoch': epoch+1, 'loss': f'{loss.item():.4f}'})

        avg_loss = total_loss / len(train_dataloader)

        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_dataloader:
                input_ids = batch['input_ids'].to(device)
                move_labels = batch['labels'].to(device)
                checkmate_labels = batch['is_checkmate'].to(device)
                outcome_labels = batch['outcome'].to(device)

                move_logits, checkmate_logits, outcome_logits = model(input_ids)

                move_loss = move_criterion(move_logits, move_labels)
                checkmate_loss = checkmate_criterion(checkmate_logits.squeeze(), checkmate_labels)
                outcome_loss = outcome_criterion(outcome_logits, outcome_labels)

                loss = move_loss + 0.1 * checkmate_loss + 0.1 * outcome_loss
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_dataloader)
        
        print(f"\nEpoch {epoch+1}/{num_epochs}, Train Loss: {avg_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

    progress_bar.close()
    return model


def calculate_random_baseline(dataloader, vocab_size, device):
    total_loss = 0
    move_criterion = nn.CrossEntropyLoss()
    checkmate_criterion = nn.BCEWithLogitsLoss()
    outcome_criterion = nn.BCEWithLogitsLoss()

    for batch in tqdm(dataloader, desc="Calculating random baseline"):
        batch_size = batch['labels'].size(0)
        
        random_move_logits = torch.rand(batch_size, vocab_size).to(device)
        random_checkmate_logits = torch.rand(batch_size, 1).to(device)
        random_outcome_logits = torch.rand(batch_size, 3).to(device)

        move_loss = move_criterion(random_move_logits, batch['labels'].to(device))
        checkmate_loss = checkmate_criterion(random_checkmate_logits.squeeze(), batch['is_checkmate'].to(device))
        outcome_loss = outcome_criterion(random_outcome_logits, batch['outcome'].to(device))

        loss = move_loss + 0.1 * checkmate_loss + 0.1 * outcome_loss
        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    return avg_loss

In [None]:
def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")

print(f'Using device {get_device()}')

In [None]:
# Initialize tokenizer and model
print("Initializing tokenizer...")
tokenizer = fit_tokenizer('out/training-data.csv')
print(f'Tokenizer initialized with vocab_size={tokenizer.vocab_size}')

In [None]:
MAX_LEN=16

model = ChessTransformer(vocab_size=tokenizer.vocab_size, n_positions=MAX_LEN, n_embd=64, n_layer=4, n_head=4)
# model = ChessTransformer(vocab_size=tokenizer.vocab_size, n_positions=MAX_LEN) # use defaults for small model

# Load and prepare data
print("Loading training/validation data...")
train_dataset = ChessDataset('out/training-data.csv', tokenizer, max_length=MAX_LEN)
val_dataset = ChessDataset('out/validation-data.csv', tokenizer, max_length=MAX_LEN)

train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=256)

# Get the appropriate device
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Calculate random baseline loss
random_baseline_loss = calculate_random_baseline(train_dataloader, model.config.vocab_size, device)
print(f"Random Baseline Loss: {random_baseline_loss:.4f}")

In [None]:
# Train the model
trained_model = train_model(model, train_dataloader, val_dataloader, num_epochs=5, learning_rate=5e-3, device=device)

# Save the trained model
torch.save(trained_model.state_dict(), 'chess_transformer_model.pth')

In [None]:
import torch
import torch.nn.functional as F
import random

def preprocess_input(move_sequence, tokenizer, max_length=25):
    # Tokenize the input sequence
    input_ids = tokenizer.encode(move_sequence)
    
    # Truncate or pad the sequence to max_length
    if len(input_ids) > max_length:
        input_ids = input_ids[-max_length:]
    else:
        input_ids = [0] * (max_length - len(input_ids)) + input_ids
    
    return torch.tensor(input_ids).unsqueeze(0)  # Add batch dimension

def predict_next_move(model, tokenizer, move_sequence, device, temperature=1.0, top_k=5):
    model.eval()  # Set the model to evaluation mode
    
    # Preprocess the input
    input_ids = preprocess_input(move_sequence, tokenizer).to(device)
    
    with torch.no_grad():
        move_logits, checkmate_logits, outcome_logits = model(input_ids)
    
    # Apply temperature to logits
    move_logits = move_logits / temperature
    
    # Get probabilities
    move_probs = F.softmax(move_logits, dim=-1)
    
    # Zero out the probability of the last move in the sequence
    last_move_id = tokenizer.encode([move_sequence[-1]])[0]
    move_probs[0, last_move_id] = 0
    
    # Normalize probabilities after zeroing out the last move
    move_probs = move_probs / move_probs.sum()
    
    # Get top-k moves
    top_k_probs, top_k_indices = torch.topk(move_probs, top_k)
    
    # Sample from top-k moves
    sampled_index = torch.multinomial(top_k_probs.squeeze(), 1).item()
    predicted_move_id = top_k_indices.squeeze()[sampled_index].item()
    predicted_move = tokenizer.decode([predicted_move_id])
    
    # Get the checkmate probability
    checkmate_prob = torch.sigmoid(checkmate_logits).item()
    
    # Get the game outcome probabilities
    outcome_probs = F.softmax(outcome_logits, dim=-1).squeeze()
    
    return predicted_move, checkmate_prob, outcome_probs, move_probs

def interpret_prediction(predicted_move, checkmate_prob, outcome_probs, move_probs, tokenizer):
    outcomes = ['Win', 'Loss', 'Draw']
    outcome_dict = {outcome: prob.item() for outcome, prob in zip(outcomes, outcome_probs)}
    most_likely_outcome = max(outcome_dict, key=outcome_dict.get)
    
    print(f"Predicted next move: {predicted_move}")
    print(f"Checkmate probability: {checkmate_prob:.2f}")
    print("Game outcome probabilities:")
    for outcome, prob in outcome_dict.items():
        print(f"  {outcome}: {prob:.2f}")
    print(f"Most likely outcome: {most_likely_outcome}")
    
    # Debugging information
    print("\nDebugging Information:")
    print(f"Vocabulary size: {len(tokenizer.move_to_id)}")
    print("Top 5 predicted moves:")
    top_moves = torch.topk(move_probs.squeeze(), 5)
    for i, (prob, idx) in enumerate(zip(top_moves.values, top_moves.indices)):
        move = tokenizer.decode([idx.item()])
        print(f"  {i+1}. {move} (probability: {prob.item():.4f})")

In [1]:
tokenizer = fit_tokenizer('out/training-data.csv')
print(f'Tokenizer initialized with vocab_size={tokenizer.vocab_size}')
model = ChessTransformer(vocab_size=tokenizer.vocab_size, n_positions=MAX_LEN, n_embd=64) # use defaults for small model
model.load_state_dict(torch.load('chess_transformer_model.pth'))

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
""

NameError: name 'fit_tokenizer' is not defined

In [None]:
# Example move sequence
move_sequence = ["e4", "c6", "d4", "e5", "d5", "e5"]

# Make multiple predictions
for _ in range(5):
    predicted_move, checkmate_prob, outcome_probs, move_probs = predict_next_move(model, tokenizer, move_sequence, device, temperature=0.8, top_k=5)
    print("\n--- New Prediction ---")
    interpret_prediction(predicted_move, checkmate_prob, outcome_probs, move_probs, tokenizer)