In [1]:
import chess
import chess_utils
import numpy as np
import torch
import torch.nn as nn 
import torch.optim as optim 
from torch.utils.data import DataLoader
import time
import tqdm

## Load the data

In [2]:
data = []
data += chess_utils.extract_training_data("games/jalba20-black.pgn", my_color="black")
data += chess_utils.extract_training_data("games/Jeedy20-black.pgn", my_color="black")
data += chess_utils.extract_training_data("games/jalba20-white.pgn", my_color="white")
data += chess_utils.extract_training_data("games/Jeedy20-white.pgn", my_color="white")

## Create tensors

In [3]:
# Convert list of tuples to separate arrays
X = np.array([item[0] for item in data], dtype=np.float32)  # Extract tensors
y = np.array([item[1] for item in data], dtype=np.long)     # Extract policy indices

X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

## Set up

In [4]:
from dataset import ChessDataset
from ResNet_model import ResNetChessModel

In [5]:
dataset = ChessDataset(X, y)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# check for cuda on device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

model = ResNetChessModel(input_channels=19, num_blocks=19).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

Using device: cuda


## Train the model

### Model Architecture Hyperparameters:

input_channels: 19\
conv_filters: [64, 128, 256]\
kernel_size: 3\
padding: 1\
hidden_size: 1024\
output_size: 4288\
dropout_rate: 0.3

### Training Hyperparameters:

batch_size: 64\
learning_rate: 0.0001\
num_epochs: 250\
optimizer: Adam\
loss_function: CrossEntropyLoss\
gradient_clipping: max_norm=1.0\
shuffle: True

In [6]:
num_epochs = 250
for epoch in range(num_epochs):
    start_time = time.time()
    model.train()
    running_loss = 0.0
    
    for inputs, labels in tqdm.tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        
        policy_logits, _ = model(inputs)
        loss = criterion(policy_logits, labels)
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        running_loss += loss.item()
    
    end_time = time.time()
    epoch_time = end_time - start_time
    minutes = int(epoch_time // 60)
    seconds = int(epoch_time) - minutes * 60
    
    avg_loss = running_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Time: {minutes}m {seconds}s")

Epoch 1/250: 100%|██████████| 3484/3484 [01:45<00:00, 32.89it/s]


Epoch [1/250], Loss: 5.3813, Time: 1m 45s


Epoch 2/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [2/250], Loss: 4.2346, Time: 1m 48s


Epoch 3/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.14it/s]


Epoch [3/250], Loss: 3.5024, Time: 1m 48s


Epoch 4/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.18it/s]


Epoch [4/250], Loss: 2.9996, Time: 1m 48s


Epoch 5/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.16it/s]


Epoch [5/250], Loss: 2.6066, Time: 1m 48s


Epoch 6/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.13it/s]


Epoch [6/250], Loss: 2.2528, Time: 1m 48s


Epoch 7/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.13it/s]


Epoch [7/250], Loss: 1.9203, Time: 1m 48s


Epoch 8/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.11it/s]


Epoch [8/250], Loss: 1.5886, Time: 1m 48s


Epoch 9/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.14it/s]


Epoch [9/250], Loss: 1.2790, Time: 1m 48s


Epoch 10/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.14it/s]


Epoch [10/250], Loss: 1.0029, Time: 1m 48s


Epoch 11/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.18it/s]


Epoch [11/250], Loss: 0.7855, Time: 1m 48s


Epoch 12/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.18it/s]


Epoch [12/250], Loss: 0.6294, Time: 1m 48s


Epoch 13/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.18it/s]


Epoch [13/250], Loss: 0.5260, Time: 1m 48s


Epoch 14/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.15it/s]


Epoch [14/250], Loss: 0.4587, Time: 1m 48s


Epoch 15/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.18it/s]


Epoch [15/250], Loss: 0.4126, Time: 1m 48s


Epoch 16/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [16/250], Loss: 0.3777, Time: 1m 48s


Epoch 17/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.20it/s]


Epoch [17/250], Loss: 0.3525, Time: 1m 48s


Epoch 18/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [18/250], Loss: 0.3318, Time: 1m 48s


Epoch 19/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.24it/s]


Epoch [19/250], Loss: 0.3104, Time: 1m 48s


Epoch 20/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.22it/s]


Epoch [20/250], Loss: 0.2994, Time: 1m 48s


Epoch 21/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [21/250], Loss: 0.2843, Time: 1m 48s


Epoch 22/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.15it/s]


Epoch [22/250], Loss: 0.2681, Time: 1m 48s


Epoch 23/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [23/250], Loss: 0.2587, Time: 1m 48s


Epoch 24/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.18it/s]


Epoch [24/250], Loss: 0.2503, Time: 1m 48s


Epoch 25/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.23it/s]


Epoch [25/250], Loss: 0.2399, Time: 1m 48s


Epoch 26/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.27it/s]


Epoch [26/250], Loss: 0.2335, Time: 1m 47s


Epoch 27/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.26it/s]


Epoch [27/250], Loss: 0.2264, Time: 1m 47s


Epoch 28/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.18it/s]


Epoch [28/250], Loss: 0.2187, Time: 1m 48s


Epoch 29/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.13it/s]


Epoch [29/250], Loss: 0.2120, Time: 1m 48s


Epoch 30/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.19it/s]


Epoch [30/250], Loss: 0.2081, Time: 1m 48s


Epoch 31/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.22it/s]


Epoch [31/250], Loss: 0.2007, Time: 1m 48s


Epoch 32/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.23it/s]


Epoch [32/250], Loss: 0.1942, Time: 1m 48s


Epoch 33/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.26it/s]


Epoch [33/250], Loss: 0.1959, Time: 1m 47s


Epoch 34/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.25it/s]


Epoch [34/250], Loss: 0.1869, Time: 1m 48s


Epoch 35/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.16it/s]


Epoch [35/250], Loss: 0.1855, Time: 1m 48s


Epoch 36/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.16it/s]


Epoch [36/250], Loss: 0.1801, Time: 1m 48s


Epoch 37/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.18it/s]


Epoch [37/250], Loss: 0.1766, Time: 1m 48s


Epoch 38/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [38/250], Loss: 0.1731, Time: 1m 48s


Epoch 39/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.24it/s]


Epoch [39/250], Loss: 0.1695, Time: 1m 48s


Epoch 40/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.27it/s]


Epoch [40/250], Loss: 0.1668, Time: 1m 47s


Epoch 41/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [41/250], Loss: 0.1641, Time: 1m 48s


Epoch 42/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.14it/s]


Epoch [42/250], Loss: 0.1611, Time: 1m 48s


Epoch 43/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.19it/s]


Epoch [43/250], Loss: 0.1593, Time: 1m 48s


Epoch 44/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.15it/s]


Epoch [44/250], Loss: 0.1585, Time: 1m 48s


Epoch 45/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [45/250], Loss: 0.1531, Time: 1m 48s


Epoch 46/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.26it/s]


Epoch [46/250], Loss: 0.1532, Time: 1m 48s


Epoch 47/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.27it/s]


Epoch [47/250], Loss: 0.1540, Time: 1m 47s


Epoch 48/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.26it/s]


Epoch [48/250], Loss: 0.1493, Time: 1m 48s


Epoch 49/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.23it/s]


Epoch [49/250], Loss: 0.1514, Time: 1m 48s


Epoch 50/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [50/250], Loss: 0.1432, Time: 1m 48s


Epoch 51/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.18it/s]


Epoch [51/250], Loss: 0.1418, Time: 1m 48s


Epoch 52/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [52/250], Loss: 0.1404, Time: 1m 48s


Epoch 53/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.27it/s]


Epoch [53/250], Loss: 0.1403, Time: 1m 47s


Epoch 54/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.25it/s]


Epoch [54/250], Loss: 0.1371, Time: 1m 48s


Epoch 55/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.22it/s]


Epoch [55/250], Loss: 0.1364, Time: 1m 48s


Epoch 56/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.19it/s]


Epoch [56/250], Loss: 0.1355, Time: 1m 48s


Epoch 57/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [57/250], Loss: 0.1356, Time: 1m 48s


Epoch 58/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [58/250], Loss: 0.1324, Time: 1m 48s


Epoch 59/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.24it/s]


Epoch [59/250], Loss: 0.1323, Time: 1m 48s


Epoch 60/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.25it/s]


Epoch [60/250], Loss: 0.1314, Time: 1m 48s


Epoch 61/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.28it/s]


Epoch [61/250], Loss: 0.1318, Time: 1m 47s


Epoch 62/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.24it/s]


Epoch [62/250], Loss: 0.1276, Time: 1m 48s


Epoch 63/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.19it/s]


Epoch [63/250], Loss: 0.1262, Time: 1m 48s


Epoch 64/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.18it/s]


Epoch [64/250], Loss: 0.1243, Time: 1m 48s


Epoch 65/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.20it/s]


Epoch [65/250], Loss: 0.1234, Time: 1m 48s


Epoch 66/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.20it/s]


Epoch [66/250], Loss: 0.1237, Time: 1m 48s


Epoch 67/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.26it/s]


Epoch [67/250], Loss: 0.1213, Time: 1m 47s


Epoch 68/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.24it/s]


Epoch [68/250], Loss: 0.1216, Time: 1m 48s


Epoch 69/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.26it/s]


Epoch [69/250], Loss: 0.1203, Time: 1m 47s


Epoch 70/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [70/250], Loss: 0.1196, Time: 1m 48s


Epoch 71/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [71/250], Loss: 0.1177, Time: 1m 48s


Epoch 72/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.22it/s]


Epoch [72/250], Loss: 0.1166, Time: 1m 48s


Epoch 73/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.24it/s]


Epoch [73/250], Loss: 0.1170, Time: 1m 48s


Epoch 74/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.27it/s]


Epoch [74/250], Loss: 0.1161, Time: 1m 47s


Epoch 75/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.30it/s]


Epoch [75/250], Loss: 0.1148, Time: 1m 47s


Epoch 76/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [76/250], Loss: 0.1146, Time: 1m 48s


Epoch 77/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.13it/s]


Epoch [77/250], Loss: 0.1133, Time: 1m 48s


Epoch 78/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.19it/s]


Epoch [78/250], Loss: 0.1117, Time: 1m 48s


Epoch 79/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [79/250], Loss: 0.1117, Time: 1m 48s


Epoch 80/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [80/250], Loss: 0.1115, Time: 1m 48s


Epoch 81/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.24it/s]


Epoch [81/250], Loss: 0.1113, Time: 1m 48s


Epoch 82/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.27it/s]


Epoch [82/250], Loss: 0.1099, Time: 1m 47s


Epoch 83/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.24it/s]


Epoch [83/250], Loss: 0.1098, Time: 1m 48s


Epoch 84/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.22it/s]


Epoch [84/250], Loss: 0.1095, Time: 1m 48s


Epoch 85/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.24it/s]


Epoch [85/250], Loss: 0.1086, Time: 1m 48s


Epoch 86/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.19it/s]


Epoch [86/250], Loss: 0.1081, Time: 1m 48s


Epoch 87/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.23it/s]


Epoch [87/250], Loss: 0.1070, Time: 1m 48s


Epoch 88/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.31it/s]


Epoch [88/250], Loss: 0.1066, Time: 1m 47s


Epoch 89/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.31it/s]


Epoch [89/250], Loss: 0.1085, Time: 1m 47s


Epoch 90/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.23it/s]


Epoch [90/250], Loss: 0.1069, Time: 1m 48s


Epoch 91/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.20it/s]


Epoch [91/250], Loss: 0.1047, Time: 1m 48s


Epoch 92/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.20it/s]


Epoch [92/250], Loss: 0.1051, Time: 1m 48s


Epoch 93/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [93/250], Loss: 0.1035, Time: 1m 48s


Epoch 94/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.24it/s]


Epoch [94/250], Loss: 0.1065, Time: 1m 48s


Epoch 95/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.33it/s]


Epoch [95/250], Loss: 0.1049, Time: 1m 47s


Epoch 96/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.30it/s]


Epoch [96/250], Loss: 0.1028, Time: 1m 47s


Epoch 97/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.26it/s]


Epoch [97/250], Loss: 0.1021, Time: 1m 47s


Epoch 98/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.24it/s]


Epoch [98/250], Loss: 0.1027, Time: 1m 48s


Epoch 99/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [99/250], Loss: 0.1014, Time: 1m 48s


Epoch 100/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.20it/s]


Epoch [100/250], Loss: 0.1032, Time: 1m 48s


Epoch 101/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.26it/s]


Epoch [101/250], Loss: 0.1004, Time: 1m 47s


Epoch 102/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.28it/s]


Epoch [102/250], Loss: 0.1010, Time: 1m 47s


Epoch 103/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.33it/s]


Epoch [103/250], Loss: 0.0996, Time: 1m 47s


Epoch 104/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.22it/s]


Epoch [104/250], Loss: 0.0992, Time: 1m 48s


Epoch 105/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [105/250], Loss: 0.0973, Time: 1m 48s


Epoch 106/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.19it/s]


Epoch [106/250], Loss: 0.1028, Time: 1m 48s


Epoch 107/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.25it/s]


Epoch [107/250], Loss: 0.0971, Time: 1m 48s


Epoch 108/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.29it/s]


Epoch [108/250], Loss: 0.1000, Time: 1m 47s


Epoch 109/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.26it/s]


Epoch [109/250], Loss: 0.0975, Time: 1m 47s


Epoch 110/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.31it/s]


Epoch [110/250], Loss: 0.0968, Time: 1m 47s


Epoch 111/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.24it/s]


Epoch [111/250], Loss: 0.1039, Time: 1m 48s


Epoch 112/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.19it/s]


Epoch [112/250], Loss: 0.0958, Time: 1m 48s


Epoch 113/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.23it/s]


Epoch [113/250], Loss: 0.0963, Time: 1m 48s


Epoch 114/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.24it/s]


Epoch [114/250], Loss: 0.0955, Time: 1m 48s


Epoch 115/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.29it/s]


Epoch [115/250], Loss: 0.0969, Time: 1m 47s


Epoch 116/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.25it/s]


Epoch [116/250], Loss: 0.0972, Time: 1m 48s


Epoch 117/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.28it/s]


Epoch [117/250], Loss: 0.0982, Time: 1m 47s


Epoch 118/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [118/250], Loss: 0.0942, Time: 1m 48s


Epoch 119/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.18it/s]


Epoch [119/250], Loss: 0.0946, Time: 1m 48s


Epoch 120/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.18it/s]


Epoch [120/250], Loss: 0.0931, Time: 1m 48s


Epoch 121/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.22it/s]


Epoch [121/250], Loss: 0.0929, Time: 1m 48s


Epoch 122/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.28it/s]


Epoch [122/250], Loss: 0.0936, Time: 1m 47s


Epoch 123/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.31it/s]


Epoch [123/250], Loss: 0.0935, Time: 1m 47s


Epoch 124/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.32it/s]


Epoch [124/250], Loss: 0.0930, Time: 1m 47s


Epoch 125/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.19it/s]


Epoch [125/250], Loss: 0.0933, Time: 1m 48s


Epoch 126/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.22it/s]


Epoch [126/250], Loss: 0.0926, Time: 1m 48s


Epoch 127/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.19it/s]


Epoch [127/250], Loss: 0.0923, Time: 1m 48s


Epoch 128/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.20it/s]


Epoch [128/250], Loss: 0.0969, Time: 1m 48s


Epoch 129/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.25it/s]


Epoch [129/250], Loss: 0.0933, Time: 1m 48s


Epoch 130/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.30it/s]


Epoch [130/250], Loss: 0.0924, Time: 1m 47s


Epoch 131/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.32it/s]


Epoch [131/250], Loss: 0.0905, Time: 1m 47s


Epoch 132/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.25it/s]


Epoch [132/250], Loss: 0.0908, Time: 1m 48s


Epoch 133/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.22it/s]


Epoch [133/250], Loss: 0.0919, Time: 1m 48s


Epoch 134/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.22it/s]


Epoch [134/250], Loss: 0.0913, Time: 1m 48s


Epoch 135/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.19it/s]


Epoch [135/250], Loss: 0.0917, Time: 1m 48s


Epoch 136/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [136/250], Loss: 0.0905, Time: 1m 48s


Epoch 137/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.26it/s]


Epoch [137/250], Loss: 0.0903, Time: 1m 48s


Epoch 138/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.30it/s]


Epoch [138/250], Loss: 0.0900, Time: 1m 47s


Epoch 139/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.28it/s]


Epoch [139/250], Loss: 0.0900, Time: 1m 47s


Epoch 140/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [140/250], Loss: 0.0889, Time: 1m 48s


Epoch 141/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.22it/s]


Epoch [141/250], Loss: 0.0887, Time: 1m 48s


Epoch 142/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.18it/s]


Epoch [142/250], Loss: 0.0887, Time: 1m 48s


Epoch 143/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.22it/s]


Epoch [143/250], Loss: 0.0885, Time: 1m 48s


Epoch 144/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.28it/s]


Epoch [144/250], Loss: 0.0912, Time: 1m 47s


Epoch 145/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.31it/s]


Epoch [145/250], Loss: 0.0910, Time: 1m 47s


Epoch 146/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.23it/s]


Epoch [146/250], Loss: 0.0877, Time: 1m 48s


Epoch 147/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.20it/s]


Epoch [147/250], Loss: 0.0885, Time: 1m 48s


Epoch 148/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.19it/s]


Epoch [148/250], Loss: 0.0885, Time: 1m 48s


Epoch 149/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [149/250], Loss: 0.0876, Time: 1m 48s


Epoch 150/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [150/250], Loss: 0.0879, Time: 1m 48s


Epoch 151/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.24it/s]


Epoch [151/250], Loss: 0.0927, Time: 1m 48s


Epoch 152/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.26it/s]


Epoch [152/250], Loss: 0.0874, Time: 1m 48s


Epoch 153/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.25it/s]


Epoch [153/250], Loss: 0.0874, Time: 1m 48s


Epoch 154/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.20it/s]


Epoch [154/250], Loss: 0.0876, Time: 1m 48s


Epoch 155/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.19it/s]


Epoch [155/250], Loss: 0.0875, Time: 1m 48s


Epoch 156/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.19it/s]


Epoch [156/250], Loss: 0.0868, Time: 1m 48s


Epoch 157/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.15it/s]


Epoch [157/250], Loss: 0.0856, Time: 1m 48s


Epoch 158/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.23it/s]


Epoch [158/250], Loss: 0.0869, Time: 1m 48s


Epoch 159/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.28it/s]


Epoch [159/250], Loss: 0.0863, Time: 1m 47s


Epoch 160/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.25it/s]


Epoch [160/250], Loss: 0.0859, Time: 1m 48s


Epoch 161/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.16it/s]


Epoch [161/250], Loss: 0.0866, Time: 1m 48s


Epoch 162/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.19it/s]


Epoch [162/250], Loss: 0.0871, Time: 1m 48s


Epoch 163/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.20it/s]


Epoch [163/250], Loss: 0.0908, Time: 1m 48s


Epoch 164/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.16it/s]


Epoch [164/250], Loss: 0.0849, Time: 1m 48s


Epoch 165/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [165/250], Loss: 0.0860, Time: 1m 48s


Epoch 166/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.29it/s]


Epoch [166/250], Loss: 0.0861, Time: 1m 47s


Epoch 167/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.28it/s]


Epoch [167/250], Loss: 0.0877, Time: 1m 47s


Epoch 168/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.24it/s]


Epoch [168/250], Loss: 0.0864, Time: 1m 48s


Epoch 169/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.19it/s]


Epoch [169/250], Loss: 0.0855, Time: 1m 48s


Epoch 170/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [170/250], Loss: 0.0852, Time: 1m 48s


Epoch 171/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.16it/s]


Epoch [171/250], Loss: 0.0850, Time: 1m 48s


Epoch 172/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [172/250], Loss: 0.0851, Time: 1m 48s


Epoch 173/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.24it/s]


Epoch [173/250], Loss: 0.0838, Time: 1m 48s


Epoch 174/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.30it/s]


Epoch [174/250], Loss: 0.0836, Time: 1m 47s


Epoch 175/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.25it/s]


Epoch [175/250], Loss: 0.0855, Time: 1m 48s


Epoch 176/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.14it/s]


Epoch [176/250], Loss: 0.0847, Time: 1m 48s


Epoch 177/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.20it/s]


Epoch [177/250], Loss: 0.0834, Time: 1m 48s


Epoch 178/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.18it/s]


Epoch [178/250], Loss: 0.0849, Time: 1m 48s


Epoch 179/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.18it/s]


Epoch [179/250], Loss: 0.0835, Time: 1m 48s


Epoch 180/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.20it/s]


Epoch [180/250], Loss: 0.0833, Time: 1m 48s


Epoch 181/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.27it/s]


Epoch [181/250], Loss: 0.0823, Time: 1m 47s


Epoch 182/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.26it/s]


Epoch [182/250], Loss: 0.0836, Time: 1m 48s


Epoch 183/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.16it/s]


Epoch [183/250], Loss: 0.0828, Time: 1m 48s


Epoch 184/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [184/250], Loss: 0.0830, Time: 1m 48s


Epoch 185/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.15it/s]


Epoch [185/250], Loss: 0.0834, Time: 1m 48s


Epoch 186/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.22it/s]


Epoch [186/250], Loss: 0.0848, Time: 1m 48s


Epoch 187/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.26it/s]


Epoch [187/250], Loss: 0.0832, Time: 1m 47s


Epoch 188/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.25it/s]


Epoch [188/250], Loss: 0.0821, Time: 1m 48s


Epoch 189/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.23it/s]


Epoch [189/250], Loss: 0.0897, Time: 1m 48s


Epoch 190/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.22it/s]


Epoch [190/250], Loss: 0.0832, Time: 1m 48s


Epoch 191/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [191/250], Loss: 0.0834, Time: 1m 48s


Epoch 192/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [192/250], Loss: 0.0829, Time: 1m 48s


Epoch 193/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.20it/s]


Epoch [193/250], Loss: 0.0868, Time: 1m 48s


Epoch 194/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.18it/s]


Epoch [194/250], Loss: 0.0836, Time: 1m 48s


Epoch 195/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [195/250], Loss: 0.0831, Time: 1m 48s


Epoch 196/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.26it/s]


Epoch [196/250], Loss: 0.0848, Time: 1m 47s


Epoch 197/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [197/250], Loss: 0.0821, Time: 1m 48s


Epoch 198/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.20it/s]


Epoch [198/250], Loss: 0.0814, Time: 1m 48s


Epoch 199/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.19it/s]


Epoch [199/250], Loss: 0.0829, Time: 1m 48s


Epoch 200/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [200/250], Loss: 0.0828, Time: 1m 48s


Epoch 201/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [201/250], Loss: 0.0821, Time: 1m 48s


Epoch 202/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.23it/s]


Epoch [202/250], Loss: 0.0858, Time: 1m 48s


Epoch 203/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.25it/s]


Epoch [203/250], Loss: 0.0811, Time: 1m 48s


Epoch 204/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.26it/s]


Epoch [204/250], Loss: 0.0818, Time: 1m 48s


Epoch 205/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [205/250], Loss: 0.0814, Time: 1m 48s


Epoch 206/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [206/250], Loss: 0.0816, Time: 1m 48s


Epoch 207/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [207/250], Loss: 0.0824, Time: 1m 48s


Epoch 208/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.20it/s]


Epoch [208/250], Loss: 0.0813, Time: 1m 48s


Epoch 209/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.23it/s]


Epoch [209/250], Loss: 0.0808, Time: 1m 48s


Epoch 210/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.27it/s]


Epoch [210/250], Loss: 0.0812, Time: 1m 47s


Epoch 211/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.28it/s]


Epoch [211/250], Loss: 0.0830, Time: 1m 47s


Epoch 212/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [212/250], Loss: 0.0811, Time: 1m 48s


Epoch 213/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [213/250], Loss: 0.0805, Time: 1m 48s


Epoch 214/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.20it/s]


Epoch [214/250], Loss: 0.0813, Time: 1m 48s


Epoch 215/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [215/250], Loss: 0.0810, Time: 1m 48s


Epoch 216/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [216/250], Loss: 0.0806, Time: 1m 48s


Epoch 217/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.19it/s]


Epoch [217/250], Loss: 0.0825, Time: 1m 48s


Epoch 218/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.20it/s]


Epoch [218/250], Loss: 0.0795, Time: 1m 48s


Epoch 219/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [219/250], Loss: 0.0799, Time: 1m 48s


Epoch 220/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.13it/s]


Epoch [220/250], Loss: 0.0813, Time: 1m 48s


Epoch 221/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.13it/s]


Epoch [221/250], Loss: 0.0810, Time: 1m 48s


Epoch 222/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.12it/s]


Epoch [222/250], Loss: 0.0821, Time: 1m 48s


Epoch 223/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.12it/s]


Epoch [223/250], Loss: 0.0806, Time: 1m 48s


Epoch 224/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [224/250], Loss: 0.0789, Time: 1m 48s


Epoch 225/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.22it/s]


Epoch [225/250], Loss: 0.0800, Time: 1m 48s


Epoch 226/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [226/250], Loss: 0.0798, Time: 1m 48s


Epoch 227/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [227/250], Loss: 0.0801, Time: 1m 48s


Epoch 228/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.18it/s]


Epoch [228/250], Loss: 0.0789, Time: 1m 48s


Epoch 229/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.18it/s]


Epoch [229/250], Loss: 0.0806, Time: 1m 48s


Epoch 230/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.15it/s]


Epoch [230/250], Loss: 0.0792, Time: 1m 48s


Epoch 231/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.22it/s]


Epoch [231/250], Loss: 0.0794, Time: 1m 48s


Epoch 232/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.26it/s]


Epoch [232/250], Loss: 0.0802, Time: 1m 48s


Epoch 233/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.26it/s]


Epoch [233/250], Loss: 0.0798, Time: 1m 48s


Epoch 234/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [234/250], Loss: 0.0807, Time: 1m 48s


Epoch 235/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [235/250], Loss: 0.0796, Time: 1m 48s


Epoch 236/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.11it/s]


Epoch [236/250], Loss: 0.0796, Time: 1m 48s


Epoch 237/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.16it/s]


Epoch [237/250], Loss: 0.0811, Time: 1m 48s


Epoch 238/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.20it/s]


Epoch [238/250], Loss: 0.0792, Time: 1m 48s


Epoch 239/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.22it/s]


Epoch [239/250], Loss: 0.0791, Time: 1m 48s


Epoch 240/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.27it/s]


Epoch [240/250], Loss: 0.0785, Time: 1m 47s


Epoch 241/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.22it/s]


Epoch [241/250], Loss: 0.0793, Time: 1m 48s


Epoch 242/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.19it/s]


Epoch [242/250], Loss: 0.0796, Time: 1m 48s


Epoch 243/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.19it/s]


Epoch [243/250], Loss: 0.0795, Time: 1m 48s


Epoch 244/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.17it/s]


Epoch [244/250], Loss: 0.0791, Time: 1m 48s


Epoch 245/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.16it/s]


Epoch [245/250], Loss: 0.0805, Time: 1m 48s


Epoch 246/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [246/250], Loss: 0.0791, Time: 1m 48s


Epoch 247/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.28it/s]


Epoch [247/250], Loss: 0.0793, Time: 1m 47s


Epoch 248/250: 100%|██████████| 3484/3484 [01:47<00:00, 32.29it/s]


Epoch [248/250], Loss: 0.0830, Time: 1m 47s


Epoch 249/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.21it/s]


Epoch [249/250], Loss: 0.0783, Time: 1m 48s


Epoch 250/250: 100%|██████████| 3484/3484 [01:48<00:00, 32.20it/s]

Epoch [250/250], Loss: 0.0783, Time: 1m 48s





In [7]:
# Save the model
model_path = "models/CNN_ResNet.pth"
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': num_epochs,
    'loss': avg_loss,
}, model_path)

## Convert to format compatible with C++

In [10]:
from RL_utils import load_resnet_model

# An instance of your model.
model, device = load_resnet_model("models/CNN_ResNet.pth")
model.eval()

# An example input you would normally provide to your model's forward() method.
board = chess.Board()
X = chess_utils.board_to_tensor(board)
example = torch.tensor(X, dtype=torch.float32).unsqueeze(0).to(device)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

Model loaded from models/CNN_ResNet.pth


### Check that the model is compatible with TorchScript

In [11]:
try:
    # Try scripting first (more comprehensive than tracing)
    scripted_module = torch.jit.script(model)
    print("Model is TorchScript compatible via scripting")
    scripted_module.save("models/scripted_CNN_ResNet.pt")
except Exception as e:
    print(f"Scripting failed: {e}")
    print("Falling back to tracing...")
    
    # If scripting fails, use tracing with warnings
    with torch.jit.optimized_execution(False):
        traced_script_module = torch.jit.trace(model, example, strict=False)
    traced_script_module.save("models/traced_CNN_ResNet.pt")

Model is TorchScript compatible via scripting


In [12]:
traced_script_module.save("models/traced_CNN_ResNet.pt")

### Test loading the model

In [13]:
try:
    loaded_model = torch.jit.load("models/scripted_250EPOCH_model.pt")
    loaded_model.eval()
    
    # Test with the same input
    with torch.no_grad():
        output = loaded_model(example)
    print("Model loads and runs successfully in Python")
    print(f"Output shape: {output.shape}")
except Exception as e:
    print(f"Error loading traced model in Python: {e}")

Model loads and runs successfully in Python
Output shape: torch.Size([1, 4288])
