In [3]:
import torch
from chess_model import ChessTransformer

# TRAINING_DATA_FILE = "out/training-data.csv"
MODEL_OUTPUT_FILE = "out/chess_len_25_embd_128_layers_2_heads_2.pth"
TOKENIZER_OUTPUT_FILE = "out/chess_tokenizer_bigger.json"

# Change these to whatever the model was trained with!
# These values match the defaults in ChessTransformer:
# MAX_LEN = 10
# N_EMBD = 256
# N_LAYER = 4
# N_HEAD = 4
MAX_LEN = 25
N_EMBD = 128
N_LAYER = 2
N_HEAD = 2

In [4]:
from chess_model import ChessTokenizer

tokenizer = ChessTokenizer.load(f"../../{TOKENIZER_OUTPUT_FILE}")
print(f'Tokenizer initialized with vocab_size={tokenizer.vocab_size}')

Tokenizer initialized with vocab_size=3257


In [6]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

model = ChessTransformer(vocab_size=tokenizer.vocab_size, n_positions=MAX_LEN, n_embd=N_EMBD, n_layer=N_LAYER, n_head=N_HEAD)
model.load_state_dict(torch.load(f"../../{MODEL_OUTPUT_FILE}", map_location=device))

model.to(device)
print("Loaded model!")

Loaded model!


  model.load_state_dict(torch.load(f"../../{MODEL_OUTPUT_FILE}", map_location=device))


In [7]:
import torch
import torch.nn.functional as F
import random

def preprocess_input(move_sequence, tokenizer, max_length=MAX_LEN):
    # Tokenize the input sequence
    input_ids = tokenizer.encode(move_sequence)
    
    # Truncate or pad the sequence to max_length
    if len(input_ids) > max_length:
        input_ids = input_ids[-max_length:]
    else:
        input_ids = [0] * (max_length - len(input_ids)) + input_ids

    decoded = tokenizer.decode(input_ids)

    # For debugging
    # print(f'Move sequence: {move_sequence}')
    # print(f'Encoded:       {input_ids}')
    # print(f'Decoded:       {decoded}')
    
    return torch.tensor(input_ids).unsqueeze(0)  # Add batch dimension


def predict_next_move(model, tokenizer, move_sequence, device, temperature=1.0, top_k=5):
    model.eval()  # Set the model to evaluation mode
    
    # Preprocess the input
    input_ids = preprocess_input(move_sequence, tokenizer).to(device)
    
    with torch.no_grad():
        move_logits = model(input_ids)
    
    # Apply temperature to logits
    move_logits = move_logits / temperature
    
    # Get probabilities
    move_probs = F.softmax(move_logits, dim=-1)
    
    # Normalize probabilities after zeroing out the last move
    move_probs = move_probs / move_probs.sum()
    
    # Get top-k moves
    top_k_probs, top_k_indices = torch.topk(move_probs, top_k)
    
    # Sample from top-k moves
    sampled_index = torch.multinomial(top_k_probs.squeeze(), 1).item()
    predicted_move_id = top_k_indices.squeeze()[sampled_index].item()
    predicted_move = tokenizer.decode([predicted_move_id])

    return predicted_move, move_probs

def interpret_prediction(predicted_move, move_probs, tokenizer):
    print(f"Predicted next move: {predicted_move}")
    
    # Debugging information
    print("\nDebugging Information:")
    print(f"Vocabulary size: {len(tokenizer.move_to_id)}")
    print("Top 5 predicted moves:")
    top_moves = torch.topk(move_probs.squeeze(), 5)
    for i, (prob, idx) in enumerate(zip(top_moves.values, top_moves.indices)):
        move = tokenizer.decode([idx.item()])
        print(f"  {i+1}. {move} (probability: {prob.item():.4f})")

In [53]:
move_sequence = "a4 e6 a5 d5 a6 Nf6 axb7 Bxb7 e3 Bd6 f4 O-O fxe5 Bxe5 Nf3 Bd6 d4 Ne4 Bd3 f5 Nbd2 Nd7 Ra5 c5 c3 c4 Rxd5".split(' ')
print(len(move_sequence))
predicted_move, move_probs = predict_next_move(model, tokenizer, move_sequence, device, temperature=0.5, top_k=5)
interpret_prediction(predicted_move, move_probs, tokenizer)

27
Predicted next move: ['b5']

Debugging Information:
Vocabulary size: 3257
Top 5 predicted moves:
  1. ['b5'] (probability: 0.7944)
  2. ['Nb6'] (probability: 0.0654)
  3. ['Qc7'] (probability: 0.0644)
  4. ['e6'] (probability: 0.0277)
  5. ['Nc5'] (probability: 0.0081)
