In [24]:
import torch
from chess_model.model import ChessTransformer

# TRAINING_DATA_FILE = "out/training-data.csv"
# MODEL_OUTPUT_FILE = "trained_models/len_25_embd_128_layers_2_heads_2/model.pth"
# TOKENIZER_OUTPUT_FILE = "trained_models/len_25_embd_128_layers_2_heads_2/tokenizer.json"
MODEL_OUTPUT_FILE = "out/chess_transformer_model.pth"
TOKENIZER_OUTPUT_FILE = "out/chess_tokenizer.json"
MAX_LEN = 10
N_EMBD = 1024
N_LAYER = 4
N_HEAD = 4

# Change these to whatever the model was trained with!
# These values match the defaults in ChessTransformer:
# MAX_LEN = 10
# N_EMBD = 256
# N_LAYER = 4
# N_HEAD = 4

# MODEL_OUTPUT_FILE = "trained_models/len_25_embd_128_layers_2_heads_2/model.pth"
# TOKENIZER_OUTPUT_FILE = "trained_models/len_25_embd_128_layers_2_heads_2/tokenizer.json"
# MAX_LEN = 25
# N_EMBD = 128
# N_LAYER = 2
# N_HEAD = 2

In [25]:
from chess_model.model import ChessTokenizer

tokenizer = ChessTokenizer.load(f"../../{TOKENIZER_OUTPUT_FILE}")
print(f'Tokenizer initialized with vocab_size={tokenizer.vocab_size}')

Tokenizer initialized with vocab_size=4350


In [26]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

model = ChessTransformer(vocab_size=tokenizer.vocab_size, n_positions=MAX_LEN, n_embd=N_EMBD, n_layer=N_LAYER, n_head=N_HEAD)
model.load_state_dict(torch.load(f"../../{MODEL_OUTPUT_FILE}", map_location=device))

model.to(device)
print("Loaded model!")

Loaded model!


  model.load_state_dict(torch.load(f"../../{MODEL_OUTPUT_FILE}", map_location=device))


In [27]:
import torch
import torch.nn.functional as F
import random

def preprocess_input(move_sequence, tokenizer, max_length):
    # Convert move sequence to token IDs
    input_ids = tokenizer.encode(move_sequence)
    
    # Truncate or pad the sequence to max_length
    if len(input_ids) > max_length:
        input_ids = input_ids[-max_length:]
    else:
        input_ids = [0] * (max_length - len(input_ids)) + input_ids
    
    return torch.tensor(input_ids).unsqueeze(0)  # Add batch dimension

def predict_next_move(model, tokenizer, move_sequence, device, temperature=1.0, top_k=5):
    model.eval()  # Set the model to evaluation mode
    
    # Preprocess the input
    input_ids = preprocess_input(move_sequence, tokenizer, model.config.n_positions).to(device)
    
    with torch.no_grad():
        move_logits = model(input_ids)
    
    # Get the logits for the last position in the sequence
    last_move_logits = move_logits[0, -1, :]  # Shape: [vocab_size]
    
    # Apply temperature to logits
    last_move_logits = last_move_logits / temperature
    
    # Get probabilities
    move_probs = F.softmax(last_move_logits, dim=-1)
    
    # Normalize probabilities
    move_probs = move_probs / move_probs.sum()
    
    # Get top-k moves
    top_k_probs, top_k_indices = torch.topk(move_probs, top_k)
    
    # Sample from top-k moves
    sampled_index = torch.multinomial(top_k_probs, 1).item()
    predicted_move_id = top_k_indices[sampled_index].item()
    predicted_move = tokenizer.decode([predicted_move_id])[0]  # Decode returns a list, so we take the first item

    return predicted_move, move_probs

def interpret_prediction(predicted_move, move_probs, tokenizer):
    print(f"Predicted next move: {predicted_move}")
    
    # Debugging information
    print("\nDebugging Information:")
    print(f"Vocabulary size: {len(tokenizer.move_to_id)}")
    print("Top 5 predicted moves:")
    top_moves = torch.topk(move_probs.squeeze(), 5)
    for i, (prob, idx) in enumerate(zip(top_moves.values, top_moves.indices)):
        move = tokenizer.decode([idx.item()])
        print(f"  {i+1}. {move} (probability: {prob.item():.4f})")

In [30]:
move_sequence = "e4 e5 Nc3 Nf6".split(' ')
print(len(move_sequence))
predicted_move, move_probs = predict_next_move(model, tokenizer, move_sequence, device, temperature=0.5, top_k=5)
interpret_prediction(predicted_move, move_probs, tokenizer)

4
Predicted next move: Nf3

Debugging Information:
Vocabulary size: 4350
Top 5 predicted moves:
  1. ['Nf3'] (probability: 0.5933)
  2. ['Nc3'] (probability: 0.2029)
  3. ['d3'] (probability: 0.0951)
  4. ['d4'] (probability: 0.0180)
  5. ['Bxd8'] (probability: 0.0167)
