<a href="https://colab.research.google.com/github/kaballas/chess_llm_interpretability/blob/main/chess_trainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install chess

Collecting chess
  Downloading chess-1.10.0-py3-none-any.whl (154 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/154.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━[0m [32m143.4/154.4 kB[0m [31m4.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.4/154.4 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: chess
Successfully installed chess-1.10.0


In [5]:
import chess
import chess.pgn
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os
import time
import io

# Set random seed for reproducibility
torch.manual_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

def read_stafford_gambit_games(pgn_file):
    """Read PGN file and extract Stafford Gambit games."""
    stafford_games = []
    with open(pgn_file) as f:
        while True:
            game = chess.pgn.read_game(f)
            if game is None:
                break
            if game.headers.get("Opening", "").startswith("Petrov's Defense: Stafford Gambit"):
                stafford_games.append(game)
    return stafford_games

def board_to_tensor(board):
    """Convert chess board to tensor representation."""
    piece_types = [chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN, chess.KING]
    tensor = torch.zeros(12, 8, 8)
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            color = int(piece.color)
            piece_type = piece_types.index(piece.piece_type)
            rank, file = divmod(square, 8)
            tensor[piece_type + 6*color][7-rank][file] = 1
    return tensor

class StaffordGambitDataset(Dataset):
    """Dataset class for Stafford Gambit positions."""
    def __init__(self, games):
        self.positions = []
        self.moves = []
        for game in games:
            board = game.board()
            for move in game.mainline_moves():
                if board.turn == chess.BLACK:  # We only want Black's moves
                    self.positions.append(board_to_tensor(board))
                    self.moves.append(move_to_output(move))
                board.push(move)
                if board.fullmove_number > 40:
                    break

    def __len__(self):
        return len(self.positions)

    def __getitem__(self, idx):
        return self.positions[idx], self.moves[idx]

def move_to_output(move):
    """Convert moves to model output format."""
    return move.from_square * 64 + move.to_square

class SimplifiedStaffordGambitModel(nn.Module):
    """Simplified Model definition."""
    def __init__(self):
        super(SimplifiedStaffordGambitModel, self).__init__()
        self.conv1 = nn.Conv2d(12, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 64 * 64)  # Output for all possible moves
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = torch.relu(self.bn2(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

def fine_tune_model(model, train_loader, criterion, optimizer, num_epochs):
    """Fine-tuning function."""
    model.train()
    for epoch in range(num_epochs):
        start_time = time.time()
        running_loss = 0.0
        for i, (boards, moves) in enumerate(train_loader):
            boards = boards.to(device)
            moves = moves.to(device)

            optimizer.zero_grad()
            outputs = model(boards.float())
            loss = criterion(outputs, moves)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 10 == 9:  # print every 10 mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 10:.3f}')
                running_loss = 0.0

        print(f'Epoch {epoch + 1} completed in {time.time() - start_time:.2f} seconds')

    print('Finished Fine-tuning')

def get_model_move(board, model):
    """Generate a move from the model based on the current board state."""
    # Convert the board to a tensor
    board_tensor = board_to_tensor(board).unsqueeze(0).float().to(device)

    # Forward pass through the model to get predictions
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():
        output = model(board_tensor)
        predicted_move_code = output.max(1)[1].item()  # Get the index of the max log-probability

    # Decode the move
    from_square = predicted_move_code // 64
    to_square = predicted_move_code % 64
    model_move = chess.Move(from_square, to_square)

    return model_move, output



def main():
    """Main execution."""
    model = SimplifiedStaffordGambitModel().to(device)

    model_path = 'stafford_gambit_model_finetuned.pth'

    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path))
        print("Best model loaded successfully.")
    else:
        print("No pre-trained model found. Initializing a new model.")
        torch.save(model.state_dict(), model_path)
        print("New model initialized and saved.")

    # Read games from stafford.txt
    with open('stafford.txt', 'r') as file:
        pgn_text = file.read()

    # Create a PGN reader
    pgn_io = io.StringIO(pgn_text)

    # Identify poorly performing games (accuracy < 80%)
    poorly_performing_games = []
    game_count = 0

    while True:
        game = chess.pgn.read_game(pgn_io)
        if game is None:
            break

        game_count += 1
        board = game.board()
        correct_moves = 0
        total_moves = 0

        for move in game.mainline_moves():
          if board.turn == chess.BLACK:
              total_moves += 1
              model_move, _ = get_model_move(board, model)  # Pass the model as an argument
              if move == model_move:
                  correct_moves += 1
          board.push(move)

        accuracy = correct_moves / total_moves * 100 if total_moves > 0 else 0
        if accuracy < 80:
            poorly_performing_games.append(game)

    print(f"Found {len(poorly_performing_games)} poorly performing games out of {game_count} total games.")

    # Create dataset from poorly performing games
    dataset = StaffordGambitDataset(poorly_performing_games)
    train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

    print(f"Fine-tuning dataset size: {len(dataset)}")

    # Fine-tune the model
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)  # Lower learning rate for fine-tuning

    num_epochs = 100  # Adjust as needed
    fine_tune_model(model, train_loader, criterion, optimizer, num_epochs)

    # Save the fine-tuned model
    torch.save(model.state_dict(), model_path)
    print("Fine-tuned model saved.")

    # Basic testing routine
    model.eval()
    test_board = chess.Board("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1")  # Example starting position
    test_tensor = board_to_tensor(test_board).unsqueeze(0).float().to(device)

    with torch.no_grad():
        start_time = time.time()
        output = model(test_tensor)
        inference_time = time.time() - start_time
        _, predicted_move = output.max(1)
        from_square = predicted_move.item() // 64
        to_square = predicted_move.item() % 64
        move = chess.Move(from_square, to_square)
        print(f"After fine-tuning, for the given position, the model suggests the move: {move.uci()}")
        print(f"Inference time: {inference_time:.4f} seconds")

    print("Fine-tuning completed and model saved.")

if __name__ == "__main__":
    main()



Using device: cuda
Best model loaded successfully.
Found 238 poorly performing games out of 238 total games.
Fine-tuning dataset size: 2268
[1,    10] loss: 7.540
[1,    20] loss: 5.884
[1,    30] loss: 4.110
[1,    40] loss: 2.944
[1,    50] loss: 2.439
[1,    60] loss: 1.954
[1,    70] loss: 1.623
Epoch 1 completed in 0.98 seconds
[2,    10] loss: 1.293
[2,    20] loss: 0.973
[2,    30] loss: 0.938
[2,    40] loss: 0.620
[2,    50] loss: 0.562
[2,    60] loss: 0.521
[2,    70] loss: 0.382
Epoch 2 completed in 0.27 seconds
[3,    10] loss: 0.327
[3,    20] loss: 0.291
[3,    30] loss: 0.220
[3,    40] loss: 0.208
[3,    50] loss: 0.187
[3,    60] loss: 0.156
[3,    70] loss: 0.146
Epoch 3 completed in 0.26 seconds
[4,    10] loss: 0.108
[4,    20] loss: 0.102
[4,    30] loss: 0.086
[4,    40] loss: 0.088
[4,    50] loss: 0.085
[4,    60] loss: 0.074
[4,    70] loss: 0.088
Epoch 4 completed in 0.26 seconds
[5,    10] loss: 0.067
[5,    20] loss: 0.054
[5,    30] loss: 0.062
[5,    40] 

In [84]:
import chess
import chess.pgn
import torch
import torch.nn as nn
import torch.optim as optim
import io
import time

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Function to convert chess board to tensor representation
def board_to_tensor(board):
    piece_types = [chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN, chess.KING]
    tensor = torch.zeros(12, 8, 8)
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            color = int(piece.color)
            piece_type = piece_types.index(piece.piece_type)
            rank, file = divmod(square, 8)
            tensor[piece_type + 6*color][7-rank][file] = 1
    return tensor

# Simplified Model definition
class SimplifiedStaffordGambitModel(nn.Module):
    def __init__(self):
        super(SimplifiedStaffordGambitModel, self).__init__()
        self.conv1 = nn.Conv2d(12, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 64 * 64)  # Output for all possible moves
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = torch.relu(self.bn2(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Load the fine-tuned model
model = SimplifiedStaffordGambitModel().to(device)
model.load_state_dict(torch.load('stafford_gambit_model_finetuned_worst.pth'))
model.eval()
print("Fine-tuned model loaded successfully.")

# Function to get model's move prediction
def get_model_move(board):
    tensor = board_to_tensor(board).unsqueeze(0).float().to(device)
    with torch.no_grad():
        start_time = time.time()
        output = model(tensor)
        inference_time = time.time() - start_time
        _, predicted_move = output.max(1)
        from_square = predicted_move.item() // 64
        to_square = predicted_move.item() % 64
        return chess.Move(from_square, to_square), inference_time

# Function to test a single game
def test_game(game):
    board = game.board()
    total_moves = 0
    correct_moves = 0
    incorrect_moves = []
    total_inference_time = 0

    for move_number, move in enumerate(game.mainline_moves(), start=1):
        if board.turn == chess.BLACK:  # We're only interested in Black's moves
            model_move, inference_time = get_model_move(board)
            total_inference_time += inference_time
            total_moves += 1
            if move == model_move:
                correct_moves += 1
            else:
                incorrect_moves.append({
                    'move_number': move_number // 2,
                    'actual_move': move.uci(),
                    'model_move': model_move.uci(),
                    'fen': board.fen(),
                    'inference_time': inference_time
                })
        board.push(move)

    return total_moves, correct_moves, incorrect_moves, total_inference_time

# Read games from stafford.txt
with open('stafford.txt', 'r') as file:
    pgn_text = file.read()

# Create a PGN reader
pgn_io = io.StringIO(pgn_text)

game_results = []
games = []
game_count = 0

while True:
    game = chess.pgn.read_game(pgn_io)
    if game is None:
        break

    games.append(game)
    game_count += 1
    print(f"\nAnalyzing Game {game_count}")
    total_moves, correct_moves, incorrect_moves, total_inference_time = test_game(game)

    accuracy = correct_moves / total_moves * 100 if total_moves > 0 else 0
    avg_inference_time = total_inference_time / total_moves if total_moves > 0 else 0

    game_results.append({
        'game_number': game_count,
        'total_moves': total_moves,
        'correct_moves': correct_moves,
        'accuracy': accuracy,
        'avg_inference_time': avg_inference_time,
        'incorrect_moves': incorrect_moves
    })

    print(f"Total moves: {total_moves}")
    print(f"Correct moves: {correct_moves}")
    print(f"Accuracy: {accuracy:.2f}%")
    print(f"Average inference time: {avg_inference_time:.4f} seconds")

print("\nOverall Analysis:")
print(f"Total games analyzed: {game_count}")
total_accuracy = sum(game['accuracy'] for game in game_results) / game_count
total_avg_inference_time = sum(game['avg_inference_time'] for game in game_results) / game_count
print(f"Average accuracy across all games: {total_accuracy:.2f}%")
print(f"Average inference time across all games: {total_avg_inference_time:.4f} seconds")

# Sort games by accuracy
sorted_games = sorted(zip(games, game_results), key=lambda x: x[1]['accuracy'])

# Select the worst 20% of games for fine-tuning
num_worst_games = max(1, int(0.2 * len(games)))
worst_games = [game for game, _ in sorted_games[:num_worst_games]]

print(f"\nFine-tuning on the worst {num_worst_games} games")

# Fine-tuning function
def fine_tune_model(model, games, epochs=200):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        total_loss = 0
        for game in games:
            board = game.board()
            for move in game.mainline_moves():
                if board.turn == chess.BLACK:
                    tensor = board_to_tensor(board).unsqueeze(0).float().to(device)
                    true_move = move.from_square * 64 + move.to_square
                    optimizer.zero_grad()
                    outputs = model(tensor)
                    loss = criterion(outputs, torch.tensor([true_move], device=device))
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()
                board.push(move)
        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

# Fine-tune the model
fine_tune_model(model, worst_games)

# Save the fine-tuned model
torch.save(model.state_dict(), 'stafford_gambit_model_finetuned_worst.pth')
print("Model fine-tuned on worst games and saved.")

# Re-test on all games
model.eval()
new_game_results = []

for game in games:
    total_moves, correct_moves, incorrect_moves, total_inference_time = test_game(game)
    accuracy = correct_moves / total_moves * 100 if total_moves > 0 else 0
    new_game_results.append({
        'total_moves': total_moves,
        'correct_moves': correct_moves,
        'accuracy': accuracy,
    })

new_total_accuracy = sum(game['accuracy'] for game in new_game_results) / len(new_game_results)
print(f"\nNew average accuracy across all games after fine-tuning: {new_total_accuracy:.2f}%")

print("\nFine-tuning and re-testing complete.")


Using device: cuda
Fine-tuned model loaded successfully.

Analyzing Game 1
Total moves: 4
Correct moves: 4
Accuracy: 100.00%
Average inference time: 0.0005 seconds

Analyzing Game 2
Total moves: 8
Correct moves: 7
Accuracy: 87.50%
Average inference time: 0.0004 seconds

Analyzing Game 3
Total moves: 9
Correct moves: 9
Accuracy: 100.00%
Average inference time: 0.0004 seconds

Analyzing Game 4
Total moves: 10
Correct moves: 10
Accuracy: 100.00%
Average inference time: 0.0004 seconds

Analyzing Game 5
Total moves: 8
Correct moves: 7
Accuracy: 87.50%
Average inference time: 0.0004 seconds

Analyzing Game 6
Total moves: 9
Correct moves: 8
Accuracy: 88.89%
Average inference time: 0.0004 seconds

Analyzing Game 7
Total moves: 10
Correct moves: 10
Accuracy: 100.00%
Average inference time: 0.0004 seconds

Analyzing Game 8
Total moves: 15
Correct moves: 11
Accuracy: 73.33%
Average inference time: 0.0004 seconds

Analyzing Game 9
Total moves: 8
Correct moves: 7
Accuracy: 87.50%
Average inference 

In [85]:
import chess
import chess.pgn
import torch
import torch.nn as nn
import io
import time

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Function to convert chess board to tensor representation
def board_to_tensor(board):
    piece_types = [chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN, chess.KING]
    tensor = torch.zeros(12, 8, 8)
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            color = int(piece.color)
            piece_type = piece_types.index(piece.piece_type)
            rank, file = divmod(square, 8)
            tensor[piece_type + 6*color][7-rank][file] = 1
    return tensor

# Simplified Model definition
class SimplifiedStaffordGambitModel(nn.Module):
    def __init__(self):
        super(SimplifiedStaffordGambitModel, self).__init__()
        self.conv1 = nn.Conv2d(12, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 64 * 64)  # Output for all possible moves
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = torch.relu(self.bn2(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Load the fine-tuned model
model = SimplifiedStaffordGambitModel().to(device)
model.load_state_dict(torch.load('stafford_gambit_model_finetuned_worst.pth'))
model.eval()
print("Fine-tuned model loaded successfully.")

# Function to get model's move prediction
def get_model_move(board):
    tensor = board_to_tensor(board).unsqueeze(0).float().to(device)
    with torch.no_grad():
        start_time = time.time()
        output = model(tensor)
        inference_time = time.time() - start_time
        _, predicted_move = output.max(1)
        from_square = predicted_move.item() // 64
        to_square = predicted_move.item() % 64
        return chess.Move(from_square, to_square), inference_time

# Function to test a single game
def test_game(game):
    board = game.board()
    total_moves = 0
    correct_moves = 0
    incorrect_moves = []
    total_inference_time = 0

    for move_number, move in enumerate(game.mainline_moves(), start=1):
        if board.turn == chess.BLACK:  # We're only interested in Black's moves
            model_move, inference_time = get_model_move(board)
            total_inference_time += inference_time
            total_moves += 1
            if move == model_move:
                correct_moves += 1
            else:
                incorrect_moves.append({
                    'move_number': move_number // 2,
                    'actual_move': move.uci(),
                    'model_move': model_move.uci(),
                    'fen': board.fen(),
                    'inference_time': inference_time
                })
        board.push(move)

    return total_moves, correct_moves, incorrect_moves, total_inference_time

# Read games from stafford.txt
with open('stafford.txt', 'r') as file:
    pgn_text = file.read()

# Create a PGN reader
pgn_io = io.StringIO(pgn_text)

game_results = []
game_count = 0

while True:
    game = chess.pgn.read_game(pgn_io)
    if game is None:
        break

    game_count += 1
    print(f"\nAnalyzing Game {game_count}")
    total_moves, correct_moves, incorrect_moves, total_inference_time = test_game(game)

    accuracy = correct_moves / total_moves * 100 if total_moves > 0 else 0
    avg_inference_time = total_inference_time / total_moves if total_moves > 0 else 0

    game_results.append({
        'game_number': game_count,
        'total_moves': total_moves,
        'correct_moves': correct_moves,
        'accuracy': accuracy,
        'avg_inference_time': avg_inference_time,
        'incorrect_moves': incorrect_moves
    })

    print(f"Total moves: {total_moves}")
    print(f"Correct moves: {correct_moves}")
    print(f"Accuracy: {accuracy:.2f}%")
    print(f"Average inference time: {avg_inference_time:.4f} seconds")

    if incorrect_moves:
        print("\nIncorrect Moves:")
        for incorrect in incorrect_moves:
            print(f"Move {incorrect['move_number']} (Black)")
            print(f"Actual move: {incorrect['actual_move']}")
            print(f"Model's suggestion: {incorrect['model_move']}")
            print(f"Board position (FEN): {incorrect['fen']}")
            print(f"Inference time: {incorrect['inference_time']:.4f} seconds")
            print("-" * 40)

print("\nOverall Analysis:")
print(f"Total games analyzed: {game_count}")
total_accuracy = sum(game['accuracy'] for game in game_results) / game_count
total_avg_inference_time = sum(game['avg_inference_time'] for game in game_results) / game_count
print(f"Average accuracy across all games: {total_accuracy:.2f}%")
print(f"Average inference time across all games: {total_avg_inference_time:.4f} seconds")

# Find best and worst performing games
best_game = max(game_results, key=lambda x: x['accuracy'])
worst_game = min(game_results, key=lambda x: x['accuracy'])

print(f"\nBest performing game: Game {best_game['game_number']} with accuracy {best_game['accuracy']:.2f}%")
print(f"Worst performing game: Game {worst_game['game_number']} with accuracy {worst_game['accuracy']:.2f}%")

print("\nAnalysis complete.")


Using device: cuda
Fine-tuned model loaded successfully.

Analyzing Game 1
Total moves: 4
Correct moves: 4
Accuracy: 100.00%
Average inference time: 0.0005 seconds

Analyzing Game 2
Total moves: 8
Correct moves: 8
Accuracy: 100.00%
Average inference time: 0.0005 seconds

Analyzing Game 3
Total moves: 9
Correct moves: 9
Accuracy: 100.00%
Average inference time: 0.0004 seconds

Analyzing Game 4
Total moves: 10
Correct moves: 10
Accuracy: 100.00%
Average inference time: 0.0004 seconds

Analyzing Game 5
Total moves: 8
Correct moves: 7
Accuracy: 87.50%
Average inference time: 0.0005 seconds

Incorrect Moves:
Move 8 (Black)
Actual move: d8d1
Model's suggestion: f6e4
Board position (FEN): r1bqk2r/ppp2ppp/2p5/4P3/4P3/8/PPP2KPP/RNBQ1B1R b kq - 0 8
Inference time: 0.0006 seconds
----------------------------------------

Analyzing Game 6
Total moves: 9
Correct moves: 8
Accuracy: 88.89%
Average inference time: 0.0006 seconds

Incorrect Moves:
Move 9 (Black)
Actual move: e4h1
Model's suggestion: e8