In [1]:
import chess
import chess_utils
import numpy as np
import torch
import torch.nn as nn 
import torch.optim as optim 
from torch.utils.data import DataLoader
import time
import tqdm

## Load the data

In [2]:
data = []
data += chess_utils.extract_training_data("games/jalba20-black.pgn", my_color="black")
data += chess_utils.extract_training_data("games/Jeedy20-black.pgn", my_color="black")
data += chess_utils.extract_training_data("games/jalba20-white.pgn", my_color="white")
data += chess_utils.extract_training_data("games/Jeedy20-black.pgn", my_color="white")


## Create tensors

In [3]:
# Convert list of tuples to separate arrays
X = np.array([item[0] for item in data], dtype=np.float32)  # Extract tensors
y = np.array([item[1] for item in data], dtype=np.long)     # Extract policy indices

X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

## Set up

In [4]:
from dataset import ChessDataset
from model import ChessModel

In [5]:
dataset = ChessDataset(X, y)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# check for cuda on device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

model = ChessModel(input_channels=19).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

Using device: cpu


## Train the model

### Model Architecture Hyperparameters:

input_channels: 19\
conv_filters: [64, 128, 256]\
kernel_size: 3\
padding: 1\
hidden_size: 1024\
output_size: 4288\
dropout_rate: 0.3

### Training Hyperparameters:

batch_size: 64\
learning_rate: 0.0001\
num_epochs: 50\
optimizer: Adam\
loss_function: CrossEntropyLoss\
gradient_clipping: max_norm=1.0\
shuffle: True

In [6]:
num_epochs = 50
for epoch in range(num_epochs):
    start_time = time.time()
    model.train()
    running_loss = 0.0
    
    for inputs, labels in tqdm.tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        
        policy_logits = model(inputs)
        loss = criterion(policy_logits, labels)
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        running_loss += loss.item()
    
    end_time = time.time()
    epoch_time = end_time - start_time
    minutes = int(epoch_time // 60)
    seconds = int(epoch_time) - minutes * 60
    
    avg_loss = running_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Time: {minutes}m {seconds}s")

Epoch 1/50:   0%|          | 0/3344 [00:00<?, ?it/s]

Epoch 1/50: 100%|██████████| 3344/3344 [16:18<00:00,  3.42it/s]   


Epoch [1/50], Loss: 6.2038, Time: 16m 18s


Epoch 2/50: 100%|██████████| 3344/3344 [33:06<00:00,  1.68it/s]  


Epoch [2/50], Loss: 5.9259, Time: 33m 6s


Epoch 3/50: 100%|██████████| 3344/3344 [08:02<00:00,  6.94it/s]


Epoch [3/50], Loss: 5.8391, Time: 8m 2s


Epoch 4/50: 100%|██████████| 3344/3344 [08:21<00:00,  6.67it/s]


Epoch [4/50], Loss: 5.7600, Time: 8m 21s


Epoch 5/50: 100%|██████████| 3344/3344 [06:27<00:00,  8.63it/s]


Epoch [5/50], Loss: 5.6727, Time: 6m 27s


Epoch 6/50: 100%|██████████| 3344/3344 [06:25<00:00,  8.68it/s]


Epoch [6/50], Loss: 5.5938, Time: 6m 25s


Epoch 7/50: 100%|██████████| 3344/3344 [06:34<00:00,  8.49it/s]


Epoch [7/50], Loss: 5.5292, Time: 6m 34s


Epoch 8/50: 100%|██████████| 3344/3344 [06:31<00:00,  8.54it/s]


Epoch [8/50], Loss: 5.4721, Time: 6m 31s


Epoch 9/50: 100%|██████████| 3344/3344 [08:10<00:00,  6.81it/s]


Epoch [9/50], Loss: 5.4162, Time: 8m 10s


Epoch 10/50: 100%|██████████| 3344/3344 [07:01<00:00,  7.93it/s]


Epoch [10/50], Loss: 5.3623, Time: 7m 1s


Epoch 11/50: 100%|██████████| 3344/3344 [06:41<00:00,  8.33it/s]


Epoch [11/50], Loss: 5.3123, Time: 6m 41s


Epoch 12/50: 100%|██████████| 3344/3344 [06:48<00:00,  8.18it/s]


Epoch [12/50], Loss: 5.2639, Time: 6m 49s


Epoch 13/50: 100%|██████████| 3344/3344 [07:04<00:00,  7.89it/s]


Epoch [13/50], Loss: 5.2138, Time: 7m 4s


Epoch 14/50: 100%|██████████| 3344/3344 [07:42<00:00,  7.23it/s]


Epoch [14/50], Loss: 5.1658, Time: 7m 42s


Epoch 15/50: 100%|██████████| 3344/3344 [07:21<00:00,  7.58it/s]


Epoch [15/50], Loss: 5.1168, Time: 7m 21s


Epoch 16/50: 100%|██████████| 3344/3344 [07:03<00:00,  7.90it/s]


Epoch [16/50], Loss: 5.0675, Time: 7m 3s


Epoch 17/50: 100%|██████████| 3344/3344 [07:06<00:00,  7.84it/s]


Epoch [17/50], Loss: 5.0213, Time: 7m 6s


Epoch 18/50: 100%|██████████| 3344/3344 [07:11<00:00,  7.75it/s]


Epoch [18/50], Loss: 4.9759, Time: 7m 11s


Epoch 19/50: 100%|██████████| 3344/3344 [07:13<00:00,  7.71it/s]


Epoch [19/50], Loss: 4.9309, Time: 7m 13s


Epoch 20/50: 100%|██████████| 3344/3344 [07:20<00:00,  7.59it/s]


Epoch [20/50], Loss: 4.8862, Time: 7m 20s


Epoch 21/50: 100%|██████████| 3344/3344 [07:43<00:00,  7.22it/s]


Epoch [21/50], Loss: 4.8420, Time: 7m 43s


Epoch 22/50: 100%|██████████| 3344/3344 [06:25<00:00,  8.67it/s]


Epoch [22/50], Loss: 4.8024, Time: 6m 25s


Epoch 23/50: 100%|██████████| 3344/3344 [06:28<00:00,  8.60it/s]


Epoch [23/50], Loss: 4.7610, Time: 6m 28s


Epoch 24/50: 100%|██████████| 3344/3344 [06:32<00:00,  8.53it/s]


Epoch [24/50], Loss: 4.7218, Time: 6m 32s


Epoch 25/50: 100%|██████████| 3344/3344 [06:51<00:00,  8.13it/s]


Epoch [25/50], Loss: 4.6812, Time: 6m 51s


Epoch 26/50: 100%|██████████| 3344/3344 [06:42<00:00,  8.30it/s]


Epoch [26/50], Loss: 4.6452, Time: 6m 42s


Epoch 27/50: 100%|██████████| 3344/3344 [07:36<00:00,  7.33it/s]


Epoch [27/50], Loss: 4.6059, Time: 7m 36s


Epoch 28/50: 100%|██████████| 3344/3344 [08:01<00:00,  6.94it/s]


Epoch [28/50], Loss: 4.5710, Time: 8m 1s


Epoch 29/50: 100%|██████████| 3344/3344 [08:24<00:00,  6.63it/s]


Epoch [29/50], Loss: 4.5347, Time: 8m 24s


Epoch 30/50: 100%|██████████| 3344/3344 [08:51<00:00,  6.29it/s]


Epoch [30/50], Loss: 4.4980, Time: 8m 51s


Epoch 31/50: 100%|██████████| 3344/3344 [08:40<00:00,  6.43it/s]


Epoch [31/50], Loss: 4.4659, Time: 8m 40s


Epoch 32/50: 100%|██████████| 3344/3344 [08:46<00:00,  6.36it/s]


Epoch [32/50], Loss: 4.4339, Time: 8m 46s


Epoch 33/50: 100%|██████████| 3344/3344 [09:18<00:00,  5.99it/s]


Epoch [33/50], Loss: 4.4025, Time: 9m 18s


Epoch 34/50: 100%|██████████| 3344/3344 [08:54<00:00,  6.25it/s]


Epoch [34/50], Loss: 4.3699, Time: 8m 54s


Epoch 35/50: 100%|██████████| 3344/3344 [09:02<00:00,  6.16it/s]


Epoch [35/50], Loss: 4.3380, Time: 9m 2s


Epoch 36/50: 100%|██████████| 3344/3344 [08:24<00:00,  6.63it/s]


Epoch [36/50], Loss: 4.3065, Time: 8m 24s


Epoch 37/50: 100%|██████████| 3344/3344 [08:43<00:00,  6.39it/s]


Epoch [37/50], Loss: 4.2827, Time: 8m 43s


Epoch 38/50: 100%|██████████| 3344/3344 [08:49<00:00,  6.32it/s]


Epoch [38/50], Loss: 4.2547, Time: 8m 49s


Epoch 39/50: 100%|██████████| 3344/3344 [08:16<00:00,  6.73it/s]


Epoch [39/50], Loss: 4.2271, Time: 8m 16s


Epoch 40/50: 100%|██████████| 3344/3344 [08:15<00:00,  6.75it/s]


Epoch [40/50], Loss: 4.2002, Time: 8m 15s


Epoch 41/50: 100%|██████████| 3344/3344 [07:54<00:00,  7.05it/s]


Epoch [41/50], Loss: 4.1738, Time: 7m 54s


Epoch 42/50: 100%|██████████| 3344/3344 [08:20<00:00,  6.68it/s]


Epoch [42/50], Loss: 4.1503, Time: 8m 20s


Epoch 43/50: 100%|██████████| 3344/3344 [08:15<00:00,  6.75it/s]


Epoch [43/50], Loss: 4.1250, Time: 8m 15s


Epoch 44/50: 100%|██████████| 3344/3344 [08:18<00:00,  6.71it/s]


Epoch [44/50], Loss: 4.1019, Time: 8m 18s


Epoch 45/50: 100%|██████████| 3344/3344 [08:21<00:00,  6.67it/s]


Epoch [45/50], Loss: 4.0780, Time: 8m 21s


Epoch 46/50: 100%|██████████| 3344/3344 [08:37<00:00,  6.46it/s]


Epoch [46/50], Loss: 4.0533, Time: 8m 37s


Epoch 47/50: 100%|██████████| 3344/3344 [08:12<00:00,  6.79it/s]


Epoch [47/50], Loss: 4.0338, Time: 8m 12s


Epoch 48/50: 100%|██████████| 3344/3344 [08:11<00:00,  6.81it/s]


Epoch [48/50], Loss: 4.0127, Time: 8m 11s


Epoch 49/50: 100%|██████████| 3344/3344 [08:11<00:00,  6.80it/s]


Epoch [49/50], Loss: 3.9934, Time: 8m 11s


Epoch 50/50: 100%|██████████| 3344/3344 [08:28<00:00,  6.58it/s]


Epoch [50/50], Loss: 3.9725, Time: 8m 28s


In [7]:
# Save the model
model_path = "models/TORCH_50EPOCH.pth"
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': num_epochs,
    'loss': avg_loss,
}, model_path)

## Convert to format compatible with C++

In [None]:
# An instance of your model.
model, device = chess_utils.load_model("models/TORCH_50EPOCH.pth")

# An example input you would normally provide to your model's forward() method.
board = chess.Board()
X = chess_utils.board_to_tensor(board)
example = torch.tensor(X, dtype=torch.float32).unsqueeze(0).to(device)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

In [10]:
traced_script_module.save("models/traced_50EPOCH_model.pt")