In [None]:
import torch
from chess_model import fit_tokenizer, ChessTransformer

# Change these to whatever the model was trained with!
# Default length is 10, embeddings is 128.
MAX_LEN = 50
N_EMBD = 64
TRAINING_DATA_FILE = "out/training-data-trunc.csv"
MODEL_OUTPUT_FILE = "out/chess_transformer_model.pth"

In [None]:
tokenizer = fit_tokenizer(f"../../{TRAINING_DATA_FILE}")
print(f'Tokenizer initialized with vocab_size={tokenizer.vocab_size}')

In [None]:
model = ChessTransformer(vocab_size=tokenizer.vocab_size, n_positions=MAX_LEN, n_embd=N_EMBD) # use defaults for small model
model.load_state_dict(torch.load(f"../../{MODEL_OUTPUT_FILE}"))

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print("Loaded model!")

In [None]:
import torch
import torch.nn.functional as F
import random

def preprocess_input(move_sequence, tokenizer, max_length=MAX_LEN):
    # Tokenize the input sequence
    input_ids = tokenizer.encode(move_sequence)
    
    # Truncate or pad the sequence to max_length
    if len(input_ids) > max_length:
        input_ids = input_ids[-max_length:]
    else:
        input_ids = [0] * (max_length - len(input_ids)) + input_ids

    print(f"with padding {input_ids}")
    
    return torch.tensor(input_ids).unsqueeze(0)  # Add batch dimension

def predict_next_move(model, tokenizer, move_sequence, device, temperature=1.0, top_k=5):
    model.eval()  # Set the model to evaluation mode
    
    # Preprocess the input
    input_ids = preprocess_input(move_sequence, tokenizer).to(device)
    
    with torch.no_grad():
        move_logits = model(input_ids)
    
    # Apply temperature to logits
    move_logits = move_logits / temperature
    
    # Get probabilities
    move_probs = F.softmax(move_logits, dim=-1)
    
    # Normalize probabilities after zeroing out the last move
    move_probs = move_probs / move_probs.sum()
    
    # Get top-k moves
    top_k_probs, top_k_indices = torch.topk(move_probs, top_k)
    
    # Sample from top-k moves
    sampled_index = torch.multinomial(top_k_probs.squeeze(), 1).item()
    predicted_move_id = top_k_indices.squeeze()[sampled_index].item()
    predicted_move = tokenizer.decode([predicted_move_id])

    return predicted_move, move_probs

def interpret_prediction(predicted_move, move_probs, tokenizer):
    print(f"Predicted next move: {predicted_move}")
    
    # Debugging information
    print("\nDebugging Information:")
    print(f"Vocabulary size: {len(tokenizer.move_to_id)}")
    print("Top 5 predicted moves:")
    top_moves = torch.topk(move_probs.squeeze(), 5)
    for i, (prob, idx) in enumerate(zip(top_moves.values, top_moves.indices)):
        move = tokenizer.decode([idx.item()])
        print(f"  {i+1}. {move} (probability: {prob.item():.4f})")

In [None]:
# Example move sequence
move_sequence = ["e4"]

# Make multiple predictions
for _ in range(5):
    predicted_move, move_probs = predict_next_move(model, tokenizer, move_sequence, device, temperature=0.8, top_k=5)
    print("\n--- New Prediction ---")
    interpret_prediction(predicted_move, move_probs, tokenizer)