In [1]:
# %pip install stockfish

Collecting stockfish
  Downloading stockfish-3.28.0-py3-none-any.whl.metadata (12 kB)
Downloading stockfish-3.28.0-py3-none-any.whl (13 kB)
Installing collected packages: stockfish
Successfully installed stockfish-3.28.0
Note: you may need to restart the kernel to use updated packages.


In [7]:
from stockfish import Stockfish
sf = Stockfish(path="stockfish\stockfish-windows-x86-64-avx2.exe")

In [1]:
import pandas as pd
import numpy as np
import torch
import torch_geometric
import torch_geometric.utils as pyg_utils
from torch.utils.data import Dataset
import chess
import networkx as nx
import numpy as np
from torch_geometric.data import Data
from torch_geometric.data import Batch
from torch_geometric.nn import GCNConv,SAGEConv,GATConv,GINConv,TransformerConv,global_add_pool, global_mean_pool,global_max_pool,max_pool_neighbor_x
from torch_geometric.loader import DataListLoader,DataLoader
from sklearn.metrics import recall_score
# from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from info_nce import InfoNCE, info_nce
from pytorch_metric_learning import losses
import warnings
import chess.pgn
import chess
from tqdm import tqdm
import io
import random

warnings.filterwarnings("ignore")

In [2]:
def process_pgn_file2(pgn_file_path,_len):
    pgn = open(pgn_file_path)

    posdict = {
        "fen": [],
        "capture": [],
        "first_five": [],
        "last_five": [],
        "label": []
    }

    for i in tqdm(range(_len)):  # Assuming 33255 games to process
        game = chess.pgn.read_game(pgn)
        board = chess.Board()

        result_tag = game.headers["Result"]
        if result_tag == "1/2-1/2":
            continue  # Skip draws

        white_wins = (result_tag == "1-0")
        moves = list(game.mainline_moves())

        # Filter positions: not a capture and not within the first five moves
        eligible_positions = []

        for idx, move in enumerate(moves):

            capture = board.is_capture(move)
            board.push(move)
            first_five = idx < 9

            if not capture and not first_five:
                eligible_positions.append((board.fen(), idx, len(moves)))

        # If there are eligible positions, select a random one
        if eligible_positions:
            fen, idx, total_moves = random.choice(eligible_positions)
            last_five = idx > (total_moves - 5)

            posdict["fen"].append(fen)
            posdict["capture"].append(False)  # By definition, not a capture
            posdict["first_five"].append(False)  # By definition, not in the first five
            posdict["last_five"].append(last_five)
            posdict["label"].append(white_wins)

    return posdict

In [3]:
dict_lst = []
test_pgn_path = 'test_data/2024-02.bare.[28747].pgn'
dict_lst.append(process_pgn_file2(test_pgn_path,28747))
test_pgn_path = 'test_data/2024-03.bare.[28641].pgn'
dict_lst.append(process_pgn_file2(test_pgn_path,28641))
test_pgn_path = 'test_data/2024-04.bare.[24382].pgn'
dict_lst.append(process_pgn_file2(test_pgn_path,24382))
test_pgn_path = 'test_data/2024-05.bare.[15540].pgn'
dict_lst.append(process_pgn_file2(test_pgn_path,15540))
test_pgn_path = 'test_data/2024-06.bare.[13317].pgn'
dict_lst.append(process_pgn_file2(test_pgn_path,13317))
posdict = {
        "fen": [],
        "capture": [],
        "first_five": [],
        "last_five": [],
        "label": []
    }
for dct in dict_lst:
    for key in posdict.keys():
        posdict[key] = posdict[key]+dct[key]

test = pd.DataFrame(posdict)
print(test.info())

  0%|          | 0/28747 [00:00<?, ?it/s]

100%|██████████| 28747/28747 [03:11<00:00, 150.23it/s]
100%|██████████| 28641/28641 [03:12<00:00, 148.82it/s]
100%|██████████| 24382/24382 [02:37<00:00, 154.37it/s]
100%|██████████| 15540/15540 [01:35<00:00, 162.51it/s]
100%|██████████| 13317/13317 [01:15<00:00, 177.17it/s]


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 40125 entries, 0 to 40124
Data columns (total 5 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   fen         40125 non-null  object
 1   capture     40125 non-null  bool  
 2   first_five  40125 non-null  bool  
 3   last_five   40125 non-null  bool  
 4   label       40125 non-null  bool  
dtypes: bool(4), object(1)
memory usage: 470.3+ KB
None


In [4]:
import networkx as nx
import matplotlib.pyplot as plt

def chess_position_to_graph(board):
    G = nx.DiGraph()  # Directed graph

    # Ensure all squares are included as nodes
    for square in chess.SQUARES:
        G.add_node(chess.square_name(square))
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            # Get all squares attacked by the piece on the current square
            attacked_squares = board.attacks(square)
            for target_square in attacked_squares:
                # Add a directed edge from the attacking square to the attacked square
                from_square_name = chess.square_name(square)
                to_square_name = chess.square_name(target_square)
                G.add_edge(from_square_name, to_square_name)
    # Get the legal moves for the current player
    current_moves = list(board.legal_moves)


    # Add edges for the current player's moves
    for move in current_moves:
        from_square_name = chess.square_name(move.from_square)
        to_square_name = chess.square_name(move.to_square)
        G.add_edge(from_square_name, to_square_name)

    # Temporarily switch turns to the opponent
    board.push(chess.Move.null())

    # Get the legal moves for the opponent
    opponent_moves = list(board.legal_moves)


    # Add edges for the opponent's moves
    for move in opponent_moves:
        from_square_name = chess.square_name(move.from_square)
        to_square_name = chess.square_name(move.to_square)
        G.add_edge(from_square_name, to_square_name)

    G.add_node("global")

    # Connect the global node to all squares
    for square in chess.SQUARES:
        square_name = chess.square_name(square)
        G.add_edge("global", square_name)  # Edge from global node to square
        G.add_edge(square_name, "global")  # Edge from square to global node

    # Restore the board to the original state
    # G = G.to_undirected()
    board.pop()

    return G

In [5]:
def square_to_coordinates(square):
    """Convert a square index to board coordinates."""
    row = square // 8 + 1
    col = square % 8 + 1
    return [col, row]

def distance_to_center(square):
    """
    Calculate the distance of a square to the center of the board.
    """
    row, col = divmod(square, 8)
    center_row, center_col = 3.5, 3.5  # Center of the board is at (3.5, 3.5)
    return np.sqrt((row - center_row) ** 2 + (col - center_col) ** 2) / 4.95  # Normalize distance


def piece_to_one_hot(piece):
    """Convert a chess piece to a one-hot encoding including color."""
    pieces = {
        'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
        'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
    }
    one_hot = [0] * 12
    if piece:
        one_hot[pieces[piece.symbol()]] = 1
    return one_hot
def piece_to_value(piece):
    pieces = {
        'P': 100, 'N': 320, 'B': 330, 'R': 500, 'Q': 900, 'K': 20000,
        'p': -100, 'n': -320, 'b': -330, 'r': -500, 'q': -900, 'k': -20000
    }
    if piece:
        return [pieces[piece.symbol()]]
    return [0]

def create_node_embeddings(board:chess.Board):
    embeddings = []
    for square in chess.SQUARES:
        coordinates = square_to_coordinates(square)
        piece = board.piece_at(square)
        piece_one_hot = piece_to_one_hot(piece)
        piece_value = piece_to_value(piece)
        center_distance = distance_to_center(square)
        is_attacked_white = int(board.is_attacked_by(True, square))
        is_attacked_black = int(board.is_attacked_by(False, square))
        is_check = int(board.is_check())
        embeddings.append(coordinates+piece_one_hot+piece_value+[center_distance]+[-1**(1-int(board.turn)),is_attacked_white,is_attacked_black,is_check])
    embeddings.append(torch.zeros(20))
    return torch.tensor(embeddings, dtype=torch.float)

def chess_position_to_torch_geometric_data(board):
    # Create the graph using the previous function
    G = chess_position_to_graph(board)

    # Convert the NetworkX graph to edge_index format for torch_geometric
    edge_index = pyg_utils.from_networkx(G).edge_index

    # Create node embeddings
    node_embeddings = create_node_embeddings(board)

    # Create the torch_geometric data object
    data = Data(x=node_embeddings, edge_index=edge_index)

    return data

In [6]:
class ChessPositionDataset(Dataset):
    def __init__(self, df):
        """
        Initialize the dataset by filtering the DataFrame to exclude
        capture moves and positions marked as first_five.

        :param df: DataFrame containing chess positions with columns:
                   'fen', 'capture', 'first_five', 'last_five', 'label'
        """
        # Filter the DataFrame to exclude capture moves and first five moves
        self.df = df[(~df['capture']) & (~df['first_five'])].reset_index(drop=True)
        self.df = self.df.groupby('fen').apply(self.determine_majority_label).reset_index().reset_index(drop=True)
        print(self.df.info())
        self.df.columns = ['fen', 'label']


    def determine_majority_label(self,group):
    # Determine the majority label (mode); if tied, select the first mode
        return group['label'].mode()[0]

    def __len__(self):
        """Returns the number of samples in the dataset."""
        return len(self.df)

    def get_graph(self, fen):
        """
        Converts a FEN string into a graph representation using the provided
        function chess_position_to_torch_geometric_data.

        :param fen: FEN string representing the chess board state
        :return: Graph representation of the position
        """
        # Create a chess board from the FEN string
        board = chess.Board(fen=fen)

        # Generate the graph data using the chess_position_to_torch_geometric_data function
        graph_data = chess_position_to_torch_geometric_data(board)

        return graph_data

    def __getitem__(self, idx):
        """
        Returns a single sample from the dataset.

        :param idx: Index of the sample
        :return: Tuple (graph_data, label) where graph_data is the graph representation
                 of the position and label is its target label
        """
        # Get the FEN string and the label
        fen = self.df.loc[idx, 'fen']
        label = self.df.loc[idx, 'label']


        # Generate the graph representation of the position
        graph_data = self.get_graph(fen)

        # Convert label to a tensor
        label_tensor = torch.tensor(label, dtype=torch.float)

        return graph_data, label_tensor

# Example usage:
fen_dataset = test[["fen","label"]]
dataset = ChessPositionDataset(test)
print(len(dataset))
print(len(fen_dataset))

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 39965 entries, 0 to 39964
Data columns (total 2 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   fen     39965 non-null  object
 1   0       39965 non-null  bool  
dtypes: bool(1), object(1)
memory usage: 351.4+ KB
None
39965
40125


In [7]:
def label_fen_dataframe_with_stockfish(fen_df, stockfish_path):
    labeled_data = []

    # Initialize Stockfish
    with chess.engine.SimpleEngine.popen_uci(stockfish_path) as engine:
        for index, row in tqdm(fen_df.iterrows(), total=len(fen_df)):
            board = chess.Board(row["fen"])

            # Evaluate position with Stockfish
            info = engine.analyse(board, chess.engine.Limit(time=0.1))
            centipawns = info["score"].white().score(mate_score=100000)  # Arbitrary high for mate

            # Determine label based on centipawn evaluation
            if centipawns is not None:
                if centipawns > 1:  # Winning position
                    label = 1
                elif centipawns < -1:  # Losing position
                    label = 0
                else:  # Drawish position
                    label = -1
            else:
                label = 10

            labeled_data.append(label)

    # Add the labels to the DataFrame
    fen_df["stock_fish_label"] = labeled_data
    return fen_df

stockfish_path="stockfish\stockfish-windows-x86-64-avx2.exe"
# labeled_fen_dataset = label_fen_dataframe_with_stockfish(fen_dataset, stockfish_path)
# labeled_fen_dataset.to_csv("stockfish_labels.csv")
labeled_fen_dataset = pd.read_csv("stockfish_labels.csv")

In [18]:
device = "cuda:0"
seed = 42 # age of SIPL
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# torch.use_deterministic_algorithms(True)
# torch.backends.cuda.matmul.allow_tf32 = False
# torch.backends.cudnn.benchmark=False

dataloader=DataLoader(dataset, batch_size=16, shuffle=False,num_workers = 0)

In [11]:
for i, data in enumerate(dataloader):
    # Process the data here
    print(data)  # Example processing; replace with your logic

    # Stop after the first 5 batches
    if i >= 4:  # 4 because i starts from 0
        break
# # Test accessing individual items in the dataset
# for i in range(5):
#     item = dataset[i]
#     print(f"Loaded item {i}: {item}")

[DataBatch(x=[1040, 20], edge_index=[2, 3670], batch=[1040], ptr=[17]), tensor([0., 1., 0., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 0.])]
[DataBatch(x=[1040, 20], edge_index=[2, 3936], batch=[1040], ptr=[17]), tensor([1., 1., 1., 1., 0., 0., 0., 0., 1., 0., 1., 1., 0., 1., 1., 1.])]
[DataBatch(x=[1040, 20], edge_index=[2, 3666], batch=[1040], ptr=[17]), tensor([1., 1., 0., 1., 0., 1., 0., 0., 1., 1., 1., 0., 1., 1., 1., 1.])]
[DataBatch(x=[1040, 20], edge_index=[2, 3717], batch=[1040], ptr=[17]), tensor([0., 1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 0., 0., 1.])]
[DataBatch(x=[1040, 20], edge_index=[2, 3744], batch=[1040], ptr=[17]), tensor([0., 0., 1., 1., 1., 1., 0., 0., 0., 1., 1., 1., 1., 0., 1., 1.])]


In [11]:
class GNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=2,aggr="mean"):
        super(GNN, self).__init__()

        # First GNN layer
        self.conv1 = SAGEConv(input_dim, hidden_dim,aggr=aggr)

        # Additional GNN layers
        self.convs = nn.ModuleList()
        for _ in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden_dim, hidden_dim,aggr=aggr))
        # Fully connected layer to produce graph embedding
        self.fc1 = nn.Linear(hidden_dim, 128)
        self.fc2 = nn.Linear(128,1)

    def forward(self, data):
        # Extract relevant data from the Data object
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # Apply the first GCN layer
        x = F.tanh(self.conv1(x, edge_index))

        # Apply the remaining GCN layers
        for conv in self.convs:
            x = F.tanh(conv(x, edge_index))

        # Pooling to get the graph-level embedding
        x = global_add_pool(x, batch)


        # Final fully connected layer
        x = F.leaky_relu(self.fc1(x),negative_slope=0.2)
        x= self.fc2(x)

        return x

# Example usage:
input_dim = 20  # Example: 2 for location + 12 for one-hot encoding of pieces
hidden_dim = 4*256
output_dim = 128  # Size of the graph embedding

test_model = GNN(input_dim=input_dim, hidden_dim=hidden_dim,num_layers=3,aggr="mean").to(device)

In [12]:
path = "trained model\chess_gnn_model.pth"
test_model.load_state_dict(torch.load(path, map_location=device))  # Load pre-trained weights and map to the correct device
test_model.eval()  # Set the model to evaluation mode
output_logits = []  # List to store all processed logits
output_labels = []  # List to store all corresponding labels

# Disable gradient calculation for evaluation
with torch.no_grad():
    for batch in tqdm(dataloader):
        graphs, labels = batch  # Unpack graphs and labels from the batch
        
        # Move data to the correct device
        graphs = graphs.to(device)
        labels = labels.to(device)

        # Forward pass to get the logits
        outputs = test_model(graphs).squeeze()  # Squeeze if the output has extra dimensions
        
        # Process each logit and label
        for logit, label in zip(outputs, labels):
            # Clip logits based on the given rules
            if logit > 0:
                output_logits.append(1)
            elif logit < 0:
                output_logits.append(0)
            else:
                output_logits.append(-1)
            
            # Append the corresponding label
            output_labels.append(label.item())  # Convert label tensor to a Python scalar

100%|██████████| 2498/2498 [05:37<00:00,  7.41it/s]


In [19]:
avg = 0
for i in range(len(output_logits)):
    if output_logits[i]==output_labels[i]:
        avg+=1
print (avg/len(output_logits))
print(len(output_logits))

0.7156192907375428
39957


In [25]:
stock_fish_fen_list = labeled_fen_dataset["stock_fish_label"].to_list()
stock_fish__true_labels = labeled_fen_dataset["label"].to_list()
print(len(stock_fish_fen_list))
avg = 0
for i in range(len(stock_fish_fen_list)):
    if stock_fish_fen_list[i]==-1:
        avg+=1
print (avg)
print(len(output_logits))

40125
470
39957


In [21]:
avg=0
for i in range(len(output_logits)):
    if stock_fish_fen_list[i]==output_logits[i]:
        avg+=1
print(avg/len(output_logits))

0.5386790800110118
