In [7]:
import torch
from chess_model import fit_tokenizer, ChessTransformer

MAX_LEN=16

tokenizer = fit_tokenizer('../../out/training-data.csv')
print(f'Tokenizer initialized with vocab_size={tokenizer.vocab_size}')
model = ChessTransformer(vocab_size=tokenizer.vocab_size, n_positions=MAX_LEN, n_embd=64) # use defaults for small model
model.load_state_dict(torch.load('../../out/chess_transformer_model.pth'))

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
""

Tokenizer initialized with vocab_size=5717


  model.load_state_dict(torch.load('../../out/chess_transformer_model.pth'))


''

In [8]:
import torch
import torch.nn.functional as F
import random

def preprocess_input(move_sequence, tokenizer, max_length=25):
    # Tokenize the input sequence
    input_ids = tokenizer.encode(move_sequence)
    
    # Truncate or pad the sequence to max_length
    if len(input_ids) > max_length:
        input_ids = input_ids[-max_length:]
    else:
        input_ids = [0] * (max_length - len(input_ids)) + input_ids
    
    return torch.tensor(input_ids).unsqueeze(0)  # Add batch dimension

def predict_next_move(model, tokenizer, move_sequence, device, temperature=1.0, top_k=5):
    model.eval()  # Set the model to evaluation mode
    
    # Preprocess the input
    input_ids = preprocess_input(move_sequence, tokenizer).to(device)
    
    with torch.no_grad():
        move_logits, checkmate_logits, outcome_logits = model(input_ids)
    
    # Apply temperature to logits
    move_logits = move_logits / temperature
    
    # Get probabilities
    move_probs = F.softmax(move_logits, dim=-1)
    
    # Zero out the probability of the last move in the sequence
    last_move_id = tokenizer.encode([move_sequence[-1]])[0]
    move_probs[0, last_move_id] = 0
    
    # Normalize probabilities after zeroing out the last move
    move_probs = move_probs / move_probs.sum()
    
    # Get top-k moves
    top_k_probs, top_k_indices = torch.topk(move_probs, top_k)
    
    # Sample from top-k moves
    sampled_index = torch.multinomial(top_k_probs.squeeze(), 1).item()
    predicted_move_id = top_k_indices.squeeze()[sampled_index].item()
    predicted_move = tokenizer.decode([predicted_move_id])
    
    # Get the checkmate probability
    checkmate_prob = torch.sigmoid(checkmate_logits).item()
    
    # Get the game outcome probabilities
    outcome_probs = F.softmax(outcome_logits, dim=-1).squeeze()
    
    return predicted_move, checkmate_prob, outcome_probs, move_probs

def interpret_prediction(predicted_move, checkmate_prob, outcome_probs, move_probs, tokenizer):
    outcomes = ['Win', 'Loss', 'Draw']
    outcome_dict = {outcome: prob.item() for outcome, prob in zip(outcomes, outcome_probs)}
    most_likely_outcome = max(outcome_dict, key=outcome_dict.get)
    
    print(f"Predicted next move: {predicted_move}")
    print(f"Checkmate probability: {checkmate_prob:.2f}")
    print("Game outcome probabilities:")
    for outcome, prob in outcome_dict.items():
        print(f"  {outcome}: {prob:.2f}")
    print(f"Most likely outcome: {most_likely_outcome}")
    
    # Debugging information
    print("\nDebugging Information:")
    print(f"Vocabulary size: {len(tokenizer.move_to_id)}")
    print("Top 5 predicted moves:")
    top_moves = torch.topk(move_probs.squeeze(), 5)
    for i, (prob, idx) in enumerate(zip(top_moves.values, top_moves.indices)):
        move = tokenizer.decode([idx.item()])
        print(f"  {i+1}. {move} (probability: {prob.item():.4f})")

In [12]:
# Example move sequence
move_sequence = ["e4"]

# Make multiple predictions
for _ in range(5):
    predicted_move, checkmate_prob, outcome_probs, move_probs = predict_next_move(model, tokenizer, move_sequence, device, temperature=0.8, top_k=5)
    print("\n--- New Prediction ---")
    interpret_prediction(predicted_move, checkmate_prob, outcome_probs, move_probs, tokenizer)


--- New Prediction ---
Predicted next move: ['Qgf5']
Checkmate probability: 0.00
Game outcome probabilities:
  Win: 0.33
  Loss: 0.23
  Draw: 0.43
Most likely outcome: Draw

Debugging Information:
Vocabulary size: 5717
Top 5 predicted moves:
  1. ['Qgf5'] (probability: 0.3778)
  2. ['R7d4'] (probability: 0.1927)
  3. ['Nexg1'] (probability: 0.1263)
  4. ['Nxd2+'] (probability: 0.0942)
  5. ['Ngf3'] (probability: 0.0568)

--- New Prediction ---
Predicted next move: ['Qgf5']
Checkmate probability: 0.00
Game outcome probabilities:
  Win: 0.33
  Loss: 0.23
  Draw: 0.43
Most likely outcome: Draw

Debugging Information:
Vocabulary size: 5717
Top 5 predicted moves:
  1. ['Qgf5'] (probability: 0.3778)
  2. ['R7d4'] (probability: 0.1927)
  3. ['Nexg1'] (probability: 0.1263)
  4. ['Nxd2+'] (probability: 0.0942)
  5. ['Ngf3'] (probability: 0.0568)

--- New Prediction ---
Predicted next move: ['Qgf5']
Checkmate probability: 0.00
Game outcome probabilities:
  Win: 0.33
  Loss: 0.23
  Draw: 0.43
Mo