In [None]:
from our_model import  ChessGNN
from graph_encode import base_graph_edges,move_to_index,index_to_move,adjacency_list, node_feature_matrix,edge_feature_matrix, encode_node_features, create_batch_from_boards, static_edge_index
import pandas as pd
from tqdm import tqdm
import os

created 64 nodes
created 1792 edge
current state (FEN): r1bqkbnr/pppp1ppp/2n5/1B2p3/4P3/5N2/PPPP1PPP/RNBQK2R b KQkq - 3 3

encode result:
node matrix shape: (64, 21)
edge matrix shape: (1792, 11)
--- Static Graph Components ---
static_edge_index shape: torch.Size([2, 1792])
static_edge_map shape: torch.Size([1792])
------------------------------


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import random
import numpy as np
import chess


NODE_IN_FEATURES = 21   
EDGE_IN_FEATURES = 11  
NODE_OUT_FEATURES = 1   

LEARNING_RATE = 0.001
EPOCHS = 100
BATCH_SIZE = 64

In [3]:
def uci_to_index(uci_move):
    try:
        move = chess.Move.from_uci(uci_move)
        return move_to_index[chess.Move(move.from_square, move.to_square)]
    except:
        print("error encode!!", chess.Move.from_uci(uci_move)) 
        return -1

def state_to_tensor(board: chess.Board):
    tensor = encode_node_features(board)
    return tensor.T.reshape((21, 8, 8))


def result_to_value(result: str):
    if result == '1-0': return 1.0
    elif result == '0-1': return -1.0
    return 0.0

In [4]:
class GraphChessDataset(Dataset):

    def __init__(self, dataframe):
        self.df = dataframe

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        moves_uci = row['Moves_UCI'].split()
        result = row['Result']

        if len(moves_uci) < 2:
            return self.__getitem__(random.randint(0, len(self) - 1))

        move_idx_to_play = random.randint(0, len(moves_uci) - 1)
        board = chess.Board()

        for move_uci in moves_uci[:move_idx_to_play]:
            try:
                board.push_uci(move_uci)
            except chess.IllegalMoveError:
                return self.__getitem__(random.randint(0, len(self) - 1))

        target_move_uci = moves_uci[move_idx_to_play]
        policy_target = uci_to_index(target_move_uci)
        
        if policy_target == -1:
             return self.__getitem__(random.randint(0, len(self) - 1))

        value_target = result_to_value(result)
        if board.turn == chess.BLACK:
            value_target = -value_target
        
        return board, policy_target, value_target

In [5]:
def collate_graph_data(batch):

    boards, policy_targets, value_targets = zip(*batch)

    batched_graph_data = create_batch_from_boards(list(boards))

    policy_targets = torch.tensor(policy_targets, dtype=torch.long)
    value_targets = torch.tensor(value_targets, dtype=torch.float32).unsqueeze(1) 

    return batched_graph_data, policy_targets, value_targets

In [6]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params


In [7]:
CSV_FILE_PATH = 'C:\\Users\\tan04\\Documents\\codeplace\\AI\\advance ML\\Advance_Machine_Learning_Project\\kingbase_processed_all.csv'
# df = pd.read_csv(CSV_FILE_PATH)

df = pd.read_csv('kingbase_processed_all.csv')

chess_dataset = GraphChessDataset(df)

BATCH_SIZE = 4
data_loader = DataLoader(
    dataset=chess_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_graph_data
)



In [10]:
from sklearn.model_selection import train_test_split


def train_chess_model(checkpoint_dir='checkpoints of our model'):

    NODE_IN_FEATURES = 21
    EDGE_IN_FEATURES = 11
    GNN_NODE_OUT_FEATURES = 56
    NUM_POSSIBLE_MOVES = 1792

    EPOCHS = 40
    BATCH_SIZE = 128
    LEARNING_RATE = 0.001
    TEST_SIZE = 0.1
    VAL_SIZE = 0.1 

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")



    full_df = pd.read_csv('kingbase_processed_all.csv')
    train_val_df, test_df = train_test_split(full_df, test_size=TEST_SIZE, random_state=42)
    train_df, val_df = train_test_split(train_val_df, test_size=(VAL_SIZE / (1 - TEST_SIZE)), random_state=42)


    dataset = GraphChessDataset(dataframe=train_df)
    data_loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_graph_data
    )

    model = ChessGNN(
        node_in_features=NODE_IN_FEATURES,
        edge_in_features=EDGE_IN_FEATURES,
        gnn_hidden_features=GNN_NODE_OUT_FEATURES,
        num_possible_moves=NUM_POSSIBLE_MOVES,
    ).to(device)

    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    policy_loss_fn = nn.CrossEntropyLoss()
    value_loss_fn = nn.MSELoss()

    start_epoch = 0
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    latest_checkpoint_path = None
    if os.listdir(checkpoint_dir):
        checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith('checkpoint_epoch_') and f.endswith('.pth')]
        if checkpoint_files:
            latest_epoch = max([int(f.split('_')[-1].split('.')[0]) for f in checkpoint_files])
            latest_checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{latest_epoch}.pth')

    if latest_checkpoint_path:
        print(f"Resuming training from checkpoint: {latest_checkpoint_path}")
        checkpoint = torch.load(latest_checkpoint_path, map_location=device)
        
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] 
        print(f"Loaded model state from epoch {checkpoint['epoch']}. Starting from epoch {start_epoch + 1}.")
    else:
        print("No checkpoint found. Starting training from scratch.")


    print("Starting training...")
    total, trainable = count_parameters(model)
    print(f"Total parameters: {total:,}")
    print(f"Trainable parameters: {trainable:,}\n\n")
    for epoch in range(start_epoch, EPOCHS):
        model.train()
        total_loss = 0.0

        progress_bar = tqdm(data_loader, desc=f"Epoch [{epoch+1}/{EPOCHS}]", leave=False)
        
        for batched_graph_data, policy_targets, value_targets in progress_bar:
            node_features = batched_graph_data["node_feature_matrix"].to(device)
            edge_features = batched_graph_data["edge_feature_matrix"].to(device)
            edge_index = batched_graph_data["edge_index"].to(device)
            edge_map = batched_graph_data["edge_map"].to(device)
            policy_targets = policy_targets.to(device)
            value_targets = value_targets.to(device).float() 
            
            current_batch_size = len(policy_targets)

            optimizer.zero_grad()

            policy_logits, value_pred = model(
                node_feature_matrix=node_features,
                edge_feature_matrix=edge_features,
                edge_index=edge_index,
                edge_map=edge_map,
                batch_size=current_batch_size
            )

            loss_policy = policy_loss_fn(policy_logits, policy_targets)
            loss_value = value_loss_fn(value_pred, value_targets)
            combined_loss = loss_policy + loss_value

            combined_loss.backward()
            optimizer.step()

            total_loss += combined_loss.item()
            progress_bar.set_postfix(loss=combined_loss.item())

        avg_loss = total_loss / len(data_loader)
        print(f"Epoch [{epoch+1}/{EPOCHS}], Average Loss: {avg_loss:.4f}")

        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")


    print("Training finished.")

if __name__ == '__main__':
    print("Training script is defined. Uncomment 'train_chess_model()' to run.")
    train_chess_model()

Training script is defined. Uncomment 'train_chess_model()' to run.
Using device: cuda


  checkpoint = torch.load(latest_checkpoint_path, map_location=device)


Resuming training from checkpoint: checkpoints of our model\checkpoint_epoch_30.pth
Loaded model state from epoch 30. Starting from epoch 31.
Starting training...
Total parameters: 7,234,433
Trainable parameters: 7,234,433




                                                                               

Epoch [31/40], Average Loss: 3.3841
Checkpoint saved to checkpoints of our model\checkpoint_epoch_31.pth


                                                                               

Epoch [32/40], Average Loss: 3.3778
Checkpoint saved to checkpoints of our model\checkpoint_epoch_32.pth


                                                                               

Epoch [33/40], Average Loss: 3.3707
Checkpoint saved to checkpoints of our model\checkpoint_epoch_33.pth


                                                                               

Epoch [34/40], Average Loss: 3.3611
Checkpoint saved to checkpoints of our model\checkpoint_epoch_34.pth


                                                                               

Epoch [35/40], Average Loss: 3.3518
Checkpoint saved to checkpoints of our model\checkpoint_epoch_35.pth


                                                                               

Epoch [36/40], Average Loss: 3.3493
Checkpoint saved to checkpoints of our model\checkpoint_epoch_36.pth


                                                                               

Epoch [37/40], Average Loss: 3.3413
Checkpoint saved to checkpoints of our model\checkpoint_epoch_37.pth


                                                                               

Epoch [38/40], Average Loss: 3.3342
Checkpoint saved to checkpoints of our model\checkpoint_epoch_38.pth


                                                                               

Epoch [39/40], Average Loss: 3.3271
Checkpoint saved to checkpoints of our model\checkpoint_epoch_39.pth


                                                                               

Epoch [40/40], Average Loss: 3.3239
Checkpoint saved to checkpoints of our model\checkpoint_epoch_40.pth
Training finished.


In [9]:
def load_trained_model(checkpoint_path, device):

    NODE_IN_FEATURES = 21   
    EDGE_IN_FEATURES = 11   
    GNN_NODE_OUT_FEATURES = 64 
    NUM_POSSIBLE_MOVES = 1792 
    
    model = ChessGNN(
        node_in_features=NODE_IN_FEATURES,
        edge_in_features=EDGE_IN_FEATURES,
        gnn_hidden_features=GNN_NODE_OUT_FEATURES,
        num_possible_moves=NUM_POSSIBLE_MOVES
    )
    

    try:
        checkpoint = torch.load(checkpoint_path, map_location=device)
    except FileNotFoundError:
        print(f"Error: Checkpoint file not found at {checkpoint_path}")
        return None
        
    model.load_state_dict(checkpoint['model_state_dict'])
    
    model.to(device)
    
    model.eval()
    
    print(f"Model loaded from {checkpoint_path} and is ready for inference.")
    return model