In [1]:
import chess
import chess_utils
import numpy as np
import torch
import torch.nn as nn 
import torch.optim as optim 
from torch.utils.data import DataLoader
import time
import tqdm

## Load the data

In [None]:
data = []
data += chess_utils.extract_training_data("games/jalba20-black.pgn", my_color="black")
data += chess_utils.extract_training_data("games/Jeedy20-black.pgn", my_color="black")
data += chess_utils.extract_training_data("games/jalba20-white.pgn", my_color="white")
data += chess_utils.extract_training_data("games/Jeedy20-white.pgn", my_color="white")


## Create tensors

In [3]:
# Convert list of tuples to separate arrays
X = np.array([item[0] for item in data], dtype=np.float32)  # Extract tensors
y = np.array([item[1] for item in data], dtype=np.long)     # Extract policy indices

X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

## Set up

In [4]:
from dataset import ChessDataset
from model import ChessModel

In [5]:
dataset = ChessDataset(X, y)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# check for cuda on device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

model = ChessModel(input_channels=19).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

Using device: cpu


## Train the model

### Model Architecture Hyperparameters:

input_channels: 19\
conv_filters: [64, 128, 256]\
kernel_size: 3\
padding: 1\
hidden_size: 1024\
output_size: 4288\
dropout_rate: 0.3

### Training Hyperparameters:

batch_size: 64\
learning_rate: 0.0001\
num_epochs: 250\
optimizer: Adam\
loss_function: CrossEntropyLoss\
gradient_clipping: max_norm=1.0\
shuffle: True

In [None]:
num_epochs = 250
for epoch in range(num_epochs):
    start_time = time.time()
    model.train()
    running_loss = 0.0
    
    for inputs, labels in tqdm.tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        
        policy_logits = model(inputs)
        loss = criterion(policy_logits, labels)
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        running_loss += loss.item()
    
    end_time = time.time()
    epoch_time = end_time - start_time
    minutes = int(epoch_time // 60)
    seconds = int(epoch_time) - minutes * 60
    
    avg_loss = running_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Time: {minutes}m {seconds}s")

Epoch 1/1:   0%|          | 0/3344 [00:00<?, ?it/s]

Epoch 1/1: 100%|██████████| 3344/3344 [07:29<00:00,  7.44it/s]  

Epoch [1/1], Loss: 5.9110, Time: 7m 29s





In [None]:
# Save the model
model_path = "models/TORCH_250EPOCH.pth"
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': num_epochs,
    'loss': avg_loss,
}, model_path)

## Convert to format compatible with C++

In [None]:
# An instance of your model.
model, device = chess_utils.load_model("models/TORCH_250EPOCH.pth")

# An example input you would normally provide to your model's forward() method.
board = chess.Board()
X = chess_utils.board_to_tensor(board)
example = torch.tensor(X, dtype=torch.float32).unsqueeze(0).to(device)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

### Check that the model is compatible with TorchScript

In [11]:
try:
    # Try scripting first (more comprehensive than tracing)
    scripted_module = torch.jit.script(model)
    print("Model is TorchScript compatible via scripting")
    scripted_module.save("models/scripted_1EPOCH_model.pt")
except Exception as e:
    print(f"Scripting failed: {e}")
    print("Falling back to tracing...")
    
    # If scripting fails, use tracing with warnings
    with torch.jit.optimized_execution(False):
        traced_script_module = torch.jit.trace(model, example, strict=False)
    traced_script_module.save("models/traced_1EPOCH_model.pt")

Model is TorchScript compatible via scripting


In [None]:
traced_script_module.save("models/traced_250EPOCH_model.pt")

### Test loading the model

In [None]:
try:
    loaded_model = torch.jit.load("models/traced_250EPOCH_model.pt")
    loaded_model.eval()
    
    # Test with the same input
    with torch.no_grad():
        output = loaded_model(example)
    print("Model loads and runs successfully in Python")
    print(f"Output shape: {output.shape}")
except Exception as e:
    print(f"Error loading traced model in Python: {e}")

Model loads and runs successfully in Python
Output shape: torch.Size([1, 4288])
