In [16]:
import torch
import torch.nn as nn
import math
import re

In [17]:
def preprocess_game(game_str):
    moves = game_str.split()
    input_output_pairs = []
    for i in range(1, len(moves)):
        input_seq = ' '.join(moves[:i])
        output_move = moves[i]
        input_output_pairs.append((input_seq, output_move))
    return input_output_pairs

def preprocess_file(file_path):
    all_pairs = []
    with open(file_path, 'r') as file:
        for line in file:
            game = line.strip()
            all_pairs.extend(preprocess_game(game))
    return all_pairs

file_path = 'out/grandmaster.txt'
training_data = preprocess_file(file_path)
len(training_data)

6775496

In [18]:
vocab = set()

files = [
    'out/beginner.txt',
    'out/intermediate.txt',
    'out/master.txt',
    'out/grandmaster.txt',
]

for file_path in files:
    with open(file_path, 'r') as file:
        for line in file:
            for word in line.split(' '):
                vocab.add(word.strip())
vocab = sorted(list(vocab))

In [19]:
print(len(vocab))
vocab[:30]

14038


['0-1',
 '1-0',
 '1/2-1/2',
 'B1a3',
 'B1f2',
 'B1f3',
 'B2b3',
 'B2e3',
 'B2e4',
 'B3b4',
 'B3d4',
 'B3e2',
 'B3f4',
 'B3g2',
 'B4e5',
 'B5c4',
 'B5c6',
 'B5d6',
 'B5g4',
 'B6c7',
 'B6d5',
 'B6e7',
 'B7d4',
 'B7d6',
 'B7e6',
 'B7g6',
 'B7h6',
 'B8d6',
 'B8d7',
 'B8e6']

In [20]:
# Assuming 'vocab' is your list of moves as created in your code snippet
move_to_index = {move: idx for idx, move in enumerate(vocab)}
index_to_move = {idx: move for idx, move in enumerate(vocab)}

# Add special tokens
UNK_TOKEN = '<UNK>'
PAD_TOKEN = '<PAD>'
move_to_index[UNK_TOKEN] = len(move_to_index)
move_to_index[PAD_TOKEN] = len(move_to_index)
index_to_move[len(index_to_move)] = UNK_TOKEN
index_to_move[len(index_to_move)] = PAD_TOKEN

def encode_move(move):
    return move_to_index.get(move, move_to_index[UNK_TOKEN])

def decode_move(index):
    return index_to_move.get(index, UNK_TOKEN)

In [21]:
encode_move('Qe4#')

6645

In [22]:
decode_move(6645)

'Qe4#'

In [23]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        # Create a matrix of shape (max_len, d_model) filled with zeros
        # This will store the positional encodings for each position and dimension
        pe = torch.zeros(max_len, d_model)
        
        # Create a vector of positions from 0 to max_len-1
        # Unsqueeze to shape (max_len, 1) for broadcasting
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # Create the division term for the sinusoidal function
        # This creates a vector of values that increase exponentially
        # We use log(10000.0) as it's a common choice that works well in practice
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        # Apply sine to even indices in the positional encoding
        # This creates a sinusoidal pattern that varies at different frequencies
        pe[:, 0::2] = torch.sin(position * div_term)
        
        # Apply cosine to odd indices in the positional encoding
        # This creates a cosinusoidal pattern that varies at different frequencies
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Unsqueeze and transpose to shape (1, max_len, d_model)
        # This allows for easy addition to the input embeddings later
        pe = pe.unsqueeze(0).transpose(0, 1)
        
        # Register the positional encoding as a buffer
        # This means it won't be considered a model parameter (won't be updated during training)
        # but will be saved and loaded with the model
        self.register_buffer('pe', pe)

    def forward(self, x):
        # Add the positional encoding to the input
        # x is expected to have shape (seq_len, batch_size, d_model)
        # We slice the positional encoding to match the input sequence length
        return x + self.pe[:x.size(0), :]

In [24]:
class ChessTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_encoder_layers):
        super().__init__()
        
        # Create an embedding layer to convert input tokens to vectors
        # vocab_size is the number of unique tokens in our vocabulary
        # d_model is the dimensionality of the embedding space
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        # Create a positional encoding layer
        # This adds information about the position of each token in the sequence
        self.pos_encoder = PositionalEncoding(d_model)
        
        # Create a single transformer encoder layer
        # This includes self-attention and feedforward neural network
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead)
        
        # Create the full transformer encoder by stacking multiple encoder layers
        # num_encoder_layers determines the depth of the network
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers)
        
        # Store d_model for use in the forward pass
        self.d_model = d_model
        
        # Create a linear layer for the final output
        # This projects the transformer output back to vocabulary space
        self.linear = nn.Linear(d_model, vocab_size)

    def forward(self, src):
        # Convert input tokens to embeddings
        # Multiply by sqrt(d_model) to scale the embeddings
        # This scaling helps maintain the variance of the forward pass
        src = self.embedding(src) * math.sqrt(self.d_model)
        
        # Add positional encoding to the embeddings
        src = self.pos_encoder(src)
        
        # Pass the encoded input through the transformer encoder
        output = self.transformer_encoder(src)
        
        # Project the transformer output to vocabulary space
        output = self.linear(output)
        
        return output

In [26]:
import torch
from torch.utils.data import TensorDataset

def create_dataset(data, max_seq_length=50):
    """
    Create a TensorDataset from preprocessed chess move data.
    
    Args:
    data (list): List of tuples, each containing (input_sequence, target_move)
    max_seq_length (int): Maximum sequence length to consider
    
    Returns:
    TensorDataset: Dataset containing input sequences and target moves
    """
    input_sequences = []
    target_moves = []
    
    for input_seq, target in data:
        # Split the input sequence into individual moves
        moves = input_seq.split()
        
        # Truncate or pad the sequence to max_seq_length
        if len(moves) > max_seq_length:
            moves = moves[-max_seq_length:]  # Take the last max_seq_length moves
        else:
            moves = [PAD_TOKEN] * (max_seq_length - len(moves)) + moves  # Pad with PAD_TOKEN
        
        # Encode the moves
        encoded_moves = [encode_move(move) for move in moves]
        
        # Encode the target move
        encoded_target = encode_move(target)
        
        input_sequences.append(encoded_moves)
        target_moves.append(encoded_target)
    
    # Convert to PyTorch tensors
    input_tensor = torch.LongTensor(input_sequences)
    target_tensor = torch.LongTensor(target_moves)
    
    # Create and return the TensorDataset
    return TensorDataset(input_tensor, target_tensor)

In [28]:
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Prepare the data
input_sequences, target_moves = zip(*training_data)
input_tensors = [torch.tensor([encode_move(m) for m in seq.split()]) for seq in input_sequences]
target_tensors = torch.tensor([encode_move(m) for m in target_moves])

# Pad sequences to the same length
max_len = max(len(seq) for seq in input_tensors)
input_tensors = [torch.nn.functional.pad(seq, (0, max_len - len(seq))) for seq in input_tensors]
input_tensors = torch.stack(input_tensors)

In [32]:
print(input_tensors)
print(target_tensors)
print(max_len)

tensor([[13504,     0,     0,  ...,     0,     0,     0],
        [13504, 13628,     0,  ...,     0,     0,     0],
        [13504, 13628,  2296,  ...,     0,     0,     0],
        ...,
        [13389, 13625, 13857,  ...,     0,     0,     0],
        [13389, 13625, 13857,  ...,     0,     0,     0],
        [13389, 13625, 13857,  ...,     0,     0,     0]])
tensor([13628,  2296, 13282,  ...,   753,  8403,     1])
593


In [34]:
vocab_size = len(vocab) + 500  # Slightly larger than the actual vocabulary size
d_model = 512
nhead = 8
num_encoder_layers = 6

# Create the model
model = ChessTransformer(vocab_size, d_model, nhead, num_encoder_layers)

In [None]:
# Create DataLoader
dataset = TensorDataset(input_tensors, target_tensors)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

# Training loop
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print(f"Device: {device}")
for epoch in range(num_epochs):
    print(f"Starting epoch {epoch+1}/{num_epochs}...")
    model.train()
    total_loss = 0
    for batch_input, batch_target in dataloader:
        batch_input, batch_target = batch_input.to(device), batch_target.to(device)
        
        optimizer.zero_grad()
        output = model(batch_input)
        loss = criterion(output[:, -1, :], batch_target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader):.4f}")

# Save the model
torch.save(model.state_dict(), 'chess_model.pth')

Device: cpu
Starting epoch 1/10...


In [None]:
# Inference function
def predict_next_move(model, move_sequence):
    model.eval()
    with torch.no_grad():
        input_tensor = torch.tensor([encode_move(m) for m in move_sequence.split()]).unsqueeze(0).to(device)
        output = model(input_tensor)
        predicted_move_index = output[0, -1, :].argmax().item()
        return decode_move(predicted_move_index)

# Example usage
game_so_far = "e4 e5 Nf3 Nc6 Bb5"
next_move = predict_next_move(model, game_so_far)
print(f"Predicted next move: {next_move}")