In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import chess
import numpy as np
from tqdm import tqdm
import os
import random
import time
from sklearn.model_selection import train_test_split
from graph_encode import move_to_index, index_to_move, encode_node_features
import matplotlib.pyplot as plt 





CSV_FILE_PATH = 'kingbase_processed_all.csv'
BATCH_SIZE = 128
LEARNING_RATE = 0.001
EPOCHS = 40
MODEL_SAVE_PATH = 'simple_chess_cnn_v2.pth'
TEST_SIZE = 0.1
VAL_SIZE = 0.1


CHECKPOINT_DIR = 'checkpoints_cnn'

NUM_POSSIBLE_MOVES = len(index_to_move)

def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

def uci_to_index(uci_move):
    try:

        move = chess.Move.from_uci(uci_move)
        move_without_promotion = chess.Move(move.from_square, move.to_square)
        return move_to_index[move_without_promotion]
    except Exception as e:
        print(f"Error encoding UCI move '{uci_move}': {e}")
        return -1 

def state_to_tensor(board: chess.Board):
    tensor = encode_node_features(board)
    # reshapes a flat (1344,) tensor to (21, 8, 8).
    return tensor.reshape((21, 8, 8))


def result_to_value(result: str):
    if result == '1-0': return 1.0
    elif result == '0-1': return -1.0
    return 0.0


class ChessDataset(Dataset):
    def __init__(self, dataframe):
        self.df = dataframe

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        moves_uci = row['Moves_UCI'].split()
        result = row['Result']
        if len(moves_uci) < 2:
            return self.__getitem__(random.randint(0, len(self) - 1))
        
        move_idx_to_play = random.randint(0, len(moves_uci) - 1)
        board = chess.Board()
        
        for move_uci in moves_uci[:move_idx_to_play]:
            try:
                board.push_uci(move_uci)
            except:
                return self.__getitem__(random.randint(0, len(self) - 1))

        state_tensor = state_to_tensor(board)
        target_move_uci = moves_uci[move_idx_to_play]
        target_move_index = uci_to_index(target_move_uci)

        if target_move_index == -1:
            return self.__getitem__(random.randint(0, len(self) - 1))

        game_value = result_to_value(result)
        if board.turn == chess.BLACK:
            game_value = -game_value
            
        return state_tensor, target_move_index, game_value

class ResidualBlock(nn.Module):
    def __init__(self, num_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        
        out += identity
        out = self.relu(out)
        return out

class AlphaZeroLikeCNN(nn.Module):
    def __init__(self, num_input_channels=21, num_residual_blocks=24, num_filters=128):
        super(AlphaZeroLikeCNN, self).__init__()
        num_possible_moves = NUM_POSSIBLE_MOVES

        self.initial_conv = nn.Sequential(
            nn.Conv2d(num_input_channels, num_filters, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(num_filters),
            nn.ReLU(inplace=True)
        )
        
        self.residual_tower = nn.Sequential(
            *[ResidualBlock(num_filters) for _ in range(num_residual_blocks)]
        )

        self.value_head = nn.Sequential(
            nn.Conv2d(num_filters, 1, kernel_size=1, bias=False),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True),
            nn.Flatten(),
            nn.Linear(8*8, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1),
            nn.Tanh() 
        )

        self.policy_head = nn.Sequential(
            nn.Conv2d(num_filters, 2, kernel_size=1, bias=False),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True),
            nn.Flatten(),
            nn.Linear(2 * 8 * 8, num_possible_moves)
        )

    def forward(self, x):
        # x shape: (batch, 21, 8, 8)
        x = self.initial_conv(x)
        x = self.residual_tower(x)
        
        value = self.value_head(x)
        policy_logits = self.policy_head(x)
        
        return value, policy_logits


def train_model(model, train_loader, val_loader, epochs, checkpoint_dir):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"train on{device}")
    model.to(device)


    os.makedirs(checkpoint_dir, exist_ok=True)
    print(f"Checkpoints will be saved in '{checkpoint_dir}'")

    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion_value = nn.MSELoss()
    criterion_policy = nn.CrossEntropyLoss()
    total, trainable = count_parameters(model)
    print(f"Total parameters: {total:,}")
    print(f"Trainable parameters: {trainable:,}\n\n")


    train_loss_history = []
    val_loss_history = []

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [train]")
        for states, target_moves, target_values in train_pbar:
            states, target_moves, target_values = states.to(device), target_moves.to(device), target_values.to(device).float()
            optimizer.zero_grad()
            pred_values, pred_policies = model(states)
            loss_v = criterion_value(pred_values.squeeze(), target_values)
            loss_p = criterion_policy(pred_policies, target_moves)
            loss = loss_v + loss_p
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            train_pbar.set_postfix({'train_loss': f'{loss.item():.4f}'})

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [val]")
            for states, target_moves, target_values in val_pbar:
                states, target_moves, target_values = states.to(device), target_moves.to(device), target_values.to(device).float()
                pred_values, pred_policies = model(states)
                loss_v = criterion_value(pred_values.squeeze(), target_values)
                loss_p = criterion_policy(pred_policies, target_moves)
                val_loss += (loss_v.item() + loss_p.item())

        avg_train_loss = running_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)

        train_loss_history.append(avg_train_loss)
        val_loss_history.append(avg_val_loss)
        
        print(f"Epoch {epoch+1} end | avg train loss: {avg_train_loss:.4f} | avg val loss: {avg_val_loss:.4f}")


        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss
        }, checkpoint_path)
        print(f"Checkpoint for epoch {epoch+1} saved to {checkpoint_path}")


    print("done!")
    torch.save(model.state_dict(), MODEL_SAVE_PATH)
    print(f"Final model saved to: {MODEL_SAVE_PATH}")
    

    return {'train_loss': train_loss_history, 'val_loss': val_loss_history}




def plot_loss_curves(history):
    """Plots training and validation loss curves."""
    epochs_range = range(1, len(history['train_loss']) + 1)
    
    plt.figure(figsize=(10, 5))
    plt.plot(epochs_range, history['train_loss'], 'b-o', label='Training Loss')
    plt.plot(epochs_range, history['val_loss'], 'r-o', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.show()


if __name__ == '__main__':
    if not os.path.exists(CSV_FILE_PATH):
        print(f"error {CSV_FILE_PATH} not exist")
    else:
        full_df = pd.read_csv(CSV_FILE_PATH)

        
        train_val_df, test_df = train_test_split(full_df, test_size=TEST_SIZE, random_state=42)
        val_split_ratio = VAL_SIZE / (1 - TEST_SIZE)
        train_df, val_df = train_test_split(train_val_df, test_size=val_split_ratio, random_state=42)

        print(f"dataset split:")
        print(f" - testing dataset: {len(train_df)} ")
        print(f" - validation dataset: {len(val_df)} ")

        train_dataset = ChessDataset(dataframe=train_df)
        val_dataset = ChessDataset(dataframe=val_df)
        test_dataset = ChessDataset(dataframe=test_df)

        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) # num_workers can speed up data loading
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

        cnn_model = AlphaZeroLikeCNN()
        

        training_history = train_model(cnn_model, train_loader, val_loader, epochs=EPOCHS, checkpoint_dir=CHECKPOINT_DIR)


        if training_history:
             plot_loss_curves(training_history)
        


dataset split:
 - testing dataset: 547713 
 - validation dataset: 68465 
train oncuda
Checkpoints will be saved in 'checkpoints_cnn'
Total parameters: 7,354,631
Trainable parameters: 7,354,631




Epoch 1/40 [train]:   0%|          | 13/4280 [00:04<24:57,  2.85it/s, train_loss=8.1985]


KeyboardInterrupt: 