<a href="https://colab.research.google.com/github/nurgumus/Chess-Engine-AI/blob/main/chessCNN_Train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install chess



In [2]:
import os
import numpy as np
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from chess import pgn
from tqdm import tqdm

In [3]:
march_pgn_path = '/content/drive/MyDrive/chess/lichess_elite_2024-04.pgn'

In [4]:
import chess.pgn as pgn

def load_pgn(file_path, max_games=20000):
    games = []
    with open(file_path, 'r') as pgn_file:
        for _ in range(max_games):
            game = pgn.read_game(pgn_file)
            if game is None:
                break
            games.append(game)
    return games


In [5]:
games = []

In [6]:
 games.extend(load_pgn(march_pgn_path))

In [7]:
len(games)

20000

In [8]:
import numpy as np
from chess import Board


def board_to_matrix(board: Board):

    matrix = np.zeros((13, 8, 8))
    piece_map = board.piece_map()

    for square, piece in piece_map.items():
        row, col = divmod(square, 8)
        piece_type = piece.piece_type - 1
        piece_color = 0 if piece.color else 6
        matrix[piece_type + piece_color, row, col] = 1

    legal_moves = board.legal_moves
    for move in legal_moves:
        to_square = move.to_square
        row_to, col_to = divmod(to_square, 8)
        matrix[12, row_to, col_to] = 1

    return matrix


def create_input_for_nn(games):
    X = []
    y = []
    for game in games:
        board = game.board()
        for move in game.mainline_moves():
            X.append(board_to_matrix(board))
            y.append(move.uci())
            board.push(move)
    return np.array(X, dtype=np.float32), np.array(y)


def encode_moves(moves):
    move_to_int = {move: idx for idx, move in enumerate(set(moves))}
    return np.array([move_to_int[move] for move in moves], dtype=np.float32), move_to_int

In [9]:
X, y = create_input_for_nn(games)

In [10]:
X[8]

array([[[0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 1., 0., 1., 0., 1.],
        [0., 1., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]],

       [[0., 1., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0

In [11]:
X = X[0:4000000]
y = y[0:4000000]

In [12]:
y, move_to_int = encode_moves(y)
num_classes = len(move_to_int)

In [13]:
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

In [15]:
from torch.utils.data import Dataset


class ChessDataset(Dataset):

    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [16]:
import torch.nn as nn


class ChessModel(nn.Module):
    def __init__(self, num_classes):
        super(ChessModel, self).__init__()
        # conv1 -> relu -> conv2 -> relu -> flatten -> fc1 -> relu -> fc2
        self.conv1 = nn.Conv2d(13, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(8 * 8 * 128, 256)
        self.fc2 = nn.Linear(256, num_classes)
        self.relu = nn.ReLU()

        # Initialize weights
        nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu')
        nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu')
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)  # Output raw logits
        return x

In [17]:
dataset = ChessDataset(X, y)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# Model Initialization
model = ChessModel(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

Using device: cuda


In [18]:
num_epochs = 50
for epoch in range(num_epochs):
    start_time = time.time()
    model.train()
    running_loss = 0.0
    for inputs, labels in tqdm(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)  # Move data to GPU
        optimizer.zero_grad()

        outputs = model(inputs)  # Raw logits

        # Compute loss
        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        running_loss += loss.item()
    end_time = time.time()
    epoch_time = end_time - start_time
    minutes: int = int(epoch_time // 60)
    seconds: int = int(epoch_time) - minutes * 60
    print(f'Epoch {epoch + 1 + 50}/{num_epochs + 1 + 50}, Loss: {running_loss / len(dataloader):.4f}, Time: {minutes}m{seconds}s')

100%|██████████| 27433/27433 [01:40<00:00, 272.70it/s]


Epoch 51/101, Loss: 3.8633, Time: 1m40s


100%|██████████| 27433/27433 [01:38<00:00, 277.82it/s]


Epoch 52/101, Loss: 2.8899, Time: 1m38s


100%|██████████| 27433/27433 [01:39<00:00, 276.09it/s]


Epoch 53/101, Loss: 2.6236, Time: 1m39s


100%|██████████| 27433/27433 [01:38<00:00, 277.16it/s]


Epoch 54/101, Loss: 2.4717, Time: 1m38s


100%|██████████| 27433/27433 [01:39<00:00, 276.95it/s]


Epoch 55/101, Loss: 2.3648, Time: 1m39s


100%|██████████| 27433/27433 [01:39<00:00, 276.28it/s]


Epoch 56/101, Loss: 2.2815, Time: 1m39s


100%|██████████| 27433/27433 [01:38<00:00, 277.41it/s]


Epoch 57/101, Loss: 2.2117, Time: 1m38s


100%|██████████| 27433/27433 [01:38<00:00, 277.52it/s]


Epoch 58/101, Loss: 2.1528, Time: 1m38s


100%|██████████| 27433/27433 [01:38<00:00, 277.44it/s]


Epoch 59/101, Loss: 2.1005, Time: 1m38s


100%|██████████| 27433/27433 [01:38<00:00, 277.62it/s]


Epoch 60/101, Loss: 2.0536, Time: 1m38s


100%|██████████| 27433/27433 [01:38<00:00, 277.64it/s]


Epoch 61/101, Loss: 2.0112, Time: 1m38s


100%|██████████| 27433/27433 [01:39<00:00, 276.92it/s]


Epoch 62/101, Loss: 1.9716, Time: 1m39s


100%|██████████| 27433/27433 [01:39<00:00, 276.22it/s]


Epoch 63/101, Loss: 1.9356, Time: 1m39s


100%|██████████| 27433/27433 [01:38<00:00, 277.65it/s]


Epoch 64/101, Loss: 1.9021, Time: 1m38s


100%|██████████| 27433/27433 [01:38<00:00, 277.21it/s]


Epoch 65/101, Loss: 1.8712, Time: 1m38s


100%|██████████| 27433/27433 [01:38<00:00, 277.78it/s]


Epoch 66/101, Loss: 1.8422, Time: 1m38s


100%|██████████| 27433/27433 [01:38<00:00, 277.87it/s]


Epoch 67/101, Loss: 1.8148, Time: 1m38s


100%|██████████| 27433/27433 [01:39<00:00, 276.81it/s]


Epoch 68/101, Loss: 1.7890, Time: 1m39s


100%|██████████| 27433/27433 [01:39<00:00, 277.07it/s]


Epoch 69/101, Loss: 1.7659, Time: 1m39s


100%|██████████| 27433/27433 [01:38<00:00, 277.29it/s]


Epoch 70/101, Loss: 1.7435, Time: 1m38s


100%|██████████| 27433/27433 [01:39<00:00, 276.69it/s]


Epoch 71/101, Loss: 1.7217, Time: 1m39s


100%|██████████| 27433/27433 [01:39<00:00, 277.04it/s]


Epoch 72/101, Loss: 1.7016, Time: 1m39s


100%|██████████| 27433/27433 [01:39<00:00, 276.72it/s]


Epoch 73/101, Loss: 1.6828, Time: 1m39s


100%|██████████| 27433/27433 [01:39<00:00, 275.62it/s]


Epoch 74/101, Loss: 1.6649, Time: 1m39s


100%|██████████| 27433/27433 [01:39<00:00, 276.36it/s]


Epoch 75/101, Loss: 1.6480, Time: 1m39s


100%|██████████| 27433/27433 [01:39<00:00, 277.05it/s]


Epoch 76/101, Loss: 1.6317, Time: 1m39s


100%|██████████| 27433/27433 [01:39<00:00, 276.18it/s]


Epoch 77/101, Loss: 1.6164, Time: 1m39s


100%|██████████| 27433/27433 [01:39<00:00, 275.81it/s]


Epoch 78/101, Loss: 1.6010, Time: 1m39s


100%|██████████| 27433/27433 [01:39<00:00, 275.96it/s]


Epoch 79/101, Loss: 1.5869, Time: 1m39s


100%|██████████| 27433/27433 [01:39<00:00, 276.19it/s]


Epoch 80/101, Loss: 1.5736, Time: 1m39s


100%|██████████| 27433/27433 [01:39<00:00, 276.93it/s]


Epoch 81/101, Loss: 1.5607, Time: 1m39s


100%|██████████| 27433/27433 [01:38<00:00, 277.84it/s]


Epoch 82/101, Loss: 1.5480, Time: 1m38s


100%|██████████| 27433/27433 [01:39<00:00, 276.73it/s]


Epoch 83/101, Loss: 1.5361, Time: 1m39s


100%|██████████| 27433/27433 [01:39<00:00, 275.75it/s]


Epoch 84/101, Loss: 1.5241, Time: 1m39s


100%|██████████| 27433/27433 [01:38<00:00, 277.46it/s]


Epoch 85/101, Loss: 1.5131, Time: 1m38s


100%|██████████| 27433/27433 [01:38<00:00, 277.85it/s]


Epoch 86/101, Loss: 1.5026, Time: 1m38s


100%|██████████| 27433/27433 [01:39<00:00, 276.71it/s]


Epoch 87/101, Loss: 1.4918, Time: 1m39s


100%|██████████| 27433/27433 [01:39<00:00, 276.79it/s]


Epoch 88/101, Loss: 1.4816, Time: 1m39s


100%|██████████| 27433/27433 [01:39<00:00, 276.54it/s]


Epoch 89/101, Loss: 1.4718, Time: 1m39s


100%|██████████| 27433/27433 [01:39<00:00, 276.51it/s]


Epoch 90/101, Loss: 1.4617, Time: 1m39s


100%|██████████| 27433/27433 [01:38<00:00, 277.52it/s]


Epoch 91/101, Loss: 1.4528, Time: 1m38s


100%|██████████| 27433/27433 [01:38<00:00, 277.15it/s]


Epoch 92/101, Loss: 1.4434, Time: 1m38s


100%|██████████| 27433/27433 [01:38<00:00, 278.25it/s]


Epoch 93/101, Loss: 1.4347, Time: 1m38s


100%|██████████| 27433/27433 [01:39<00:00, 276.58it/s]


Epoch 94/101, Loss: 1.4261, Time: 1m39s


100%|██████████| 27433/27433 [01:38<00:00, 277.71it/s]


Epoch 95/101, Loss: 1.4177, Time: 1m38s


100%|██████████| 27433/27433 [01:38<00:00, 277.23it/s]


Epoch 96/101, Loss: 1.4094, Time: 1m38s


100%|██████████| 27433/27433 [01:38<00:00, 278.61it/s]


Epoch 97/101, Loss: 1.4014, Time: 1m38s


100%|██████████| 27433/27433 [01:38<00:00, 277.67it/s]


Epoch 98/101, Loss: 1.3932, Time: 1m38s


100%|██████████| 27433/27433 [01:38<00:00, 277.79it/s]


Epoch 99/101, Loss: 1.3861, Time: 1m38s


100%|██████████| 27433/27433 [01:38<00:00, 278.18it/s]

Epoch 100/101, Loss: 1.3785, Time: 1m38s





In [21]:
torch.save(model.state_dict(), "/content/drive/MyDrive/Colab Notebooks/model/model1.pth")

In [22]:
import pickle

with open("/content/drive/MyDrive/Colab Notebooks/model/heavy_move_to_int", "wb") as file:
    pickle.dump(move_to_int, file)