# Inital Supervised Learning performed on the model

### Imports

In [1]:
import os
import numpy as np # type: ignore
import time
import torch
import torch.nn as nn # type: ignore
import torch.optim as optim # type: ignore
import helper_functions as helper
from torch.utils.data import DataLoader, random_split # type: ignore
from chess import pgn # type: ignore
from tqdm import tqdm # type: ignore
from dataset import ChessPGNDataset
from model import ChessNet, ResBlock

#### Device Agnostic Code

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

Using device: cuda


In [3]:
print(torch.__version__)

2.10.0+cu126


# Processing the PGNs

### Creating Move Map

In [4]:
MOVE_MAP = helper.create_move_map()

In [5]:
MOVE_MAP_LENGTH = len(MOVE_MAP)
print(MOVE_MAP_LENGTH)

8192


### Loading the data

In [6]:
pgn_path = "../lichess_db_standard_rated_2026-01.pgn/lichess_db_standard_rated_2026-01.pgn"

In [7]:
pgn_dataset_v1 = ChessPGNDataset(pgn_file_path=pgn_path, move_map=MOVE_MAP)

In [8]:
print(len(pgn_dataset_v1))

159226


In [9]:
total_size = len(pgn_dataset_v1)
train_size = int(0.8 * total_size) 
val_size = total_size - train_size

In [10]:
train_dataset, val_dataset = random_split(pgn_dataset_v1, [train_size, val_size])

In [11]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [12]:
print(len(next(iter(train_loader)))) 

3


#### Stuff to do before training begins

In [13]:
train_data_iter = iter(train_loader)
images, labels, games = next(train_data_iter)

print(f"Batch shape (Images): {images.shape}") 
print(f"Batch shape (Labels): {labels.shape}") 
print(f"Batch shape (Labels): {games.shape}")
print(f"First label: {labels[0].item()}")      

Batch shape (Images): torch.Size([64, 12, 8, 8])
Batch shape (Labels): torch.Size([64])
Batch shape (Labels): torch.Size([64, 1])
First label: 3158


In [14]:
val_data_iter = iter(val_loader)
images, labels, games = next(val_data_iter)

print(f"Batch shape (Images): {images.shape}") 
print(f"Batch shape (Labels): {labels.shape}") 
print(f"Batch shape (Labels): {games.shape}")
print(f"First label: {labels[0].item()}")      

Batch shape (Images): torch.Size([64, 12, 8, 8])
Batch shape (Labels): torch.Size([64])
Batch shape (Labels): torch.Size([64, 1])
First label: 2360


# Making the model and training loop

#### Making the model

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

Using device: cuda


In [16]:
model = ChessNet(MOVE_MAP_LENGTH)
model.to(device)
policy_criterion = nn.CrossEntropyLoss()
value_criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = 0.001, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)

#### Making the training loop

In [17]:
test_batch = next(iter(train_loader))
test_images, test_labels, test_values = [x.to(device) for x in test_batch]
print(f"Test images: min={test_images.min():.4f}, max={test_images.max():.4f}")
print(f"Test values: min={test_values.min():.4f}, max={test_values.max():.4f}, mean={test_values.mean():.4f}")

with torch.no_grad():
    pred_pol, pred_val = model(test_images)
    print(f"Pred values: min={pred_val.min():.4f}, max={pred_val.max():.4f}")
    test_loss = torch.nn.functional.mse_loss(pred_val, test_values.float())
    print(f"Test MSE loss: {test_loss.item():.6f}")


Test images: min=0.0000, max=1.0000
Test values: min=-1.0000, max=1.0000, mean=-0.0312
Pred values: min=-0.4547, max=0.2546
Test MSE loss: 1.013683


In [18]:
print(f"Move map size: {len(MOVE_MAP)}")
print(f"Policy head output size: {MOVE_MAP_LENGTH}")
print(f"Move indices range: 0 to {max(MOVE_MAP.values())}")

out_of_range = [idx for idx in MOVE_MAP.values() if idx >= MOVE_MAP_LENGTH]
if out_of_range:
    print(f"WARNING: {len(out_of_range)} move indices are >= {MOVE_MAP_LENGTH}!")
else:
    print(f"All move indices are within valid range [0, {MOVE_MAP_LENGTH-1}]")

test_batch = next(iter(train_loader))
test_images, test_labels, test_values = test_batch
print(f"\nBatch shapes:")
print(f"  Images: {test_images.shape}")
print(f"  Labels (move indices): {test_labels.shape}, min={test_labels.min()}, max={test_labels.max()}")
print(f"  Values: {test_values.shape}")


Move map size: 8192
Policy head output size: 8192
Move indices range: 0 to 8191
All move indices are within valid range [0, 8191]

Batch shapes:
  Images: torch.Size([64, 12, 8, 8])
  Labels (move indices): torch.Size([64]), min=48, max=8136
  Values: torch.Size([64, 1])


In [19]:
def train_step(model: torch.nn.Module,
               data_loader: torch.utils.data.DataLoader,
               optimizer: torch.optim.Optimizer,
               policy_criterion,
               value_criterion,
               device: torch.device = device):
        
    running_policy_loss = 0.0
    running_value_loss = 0.0
    model.train()
    model.to(device)

    for batch, (image, move_labels, game_values) in enumerate(data_loader):

        image, move_labels, game_values = image.to(device), move_labels.to(device), game_values.to(device).float()
        
        # Validate inputs
        if move_labels.min() < 0 or move_labels.max() >= MOVE_MAP_LENGTH:
            print(f"WARNING: Invalid move indices in batch {batch}: min={move_labels.min()}, max={move_labels.max()}")
            continue
            
        optimizer.zero_grad()

        pred_policy, pred_value = model(image)
        
        # Check for NaN/Inf
        if torch.isnan(pred_policy).any() or torch.isinf(pred_policy).any():
            print(f"WARNING: NaN/Inf in policy predictions at batch {batch}")
            continue
        if torch.isnan(pred_value).any() or torch.isinf(pred_value).any():
            print(f"WARNING: NaN/Inf in value predictions at batch {batch}")
            continue

        policy_loss = policy_criterion(pred_policy, move_labels)
        value_loss = torch.nn.functional.mse_loss(pred_value, game_values)

        total_loss = policy_loss + value_loss
        total_loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        running_policy_loss += policy_loss.item()
        running_value_loss += value_loss.item()

    print(f"Train policy loss: {running_policy_loss/len(data_loader):.5f} | Train value loss: {running_value_loss/len(data_loader):.5f}\n")


In [20]:
def test_step(model: torch.nn.Module,
               data_loader: torch.utils.data.DataLoader,
               optimizer: torch.optim.Optimizer,
               policy_criterion,
               value_criterion,
               device: torch.device = device):
    running_policy_loss = 0.0
    running_value_loss = 0.0
    model.eval()
    model.to(device)

    with torch.inference_mode(): 
        for batch, (image, move_labels, game_values) in enumerate(data_loader):
            image, move_labels, game_values = image.to(device), move_labels.to(device), game_values.to(device).float()
            
            # Validate inputs
            if move_labels.min() < 0 or move_labels.max() >= MOVE_MAP_LENGTH:
                print(f"WARNING: Invalid move indices in batch {batch}: min={move_labels.min()}, max={move_labels.max()}")
                continue
            
            pred_policy, pred_value = model(image)

            policy_loss = policy_criterion(pred_policy, move_labels)
            value_loss = torch.nn.functional.mse_loss(pred_value, game_values)

            running_policy_loss += policy_loss.item()
            running_value_loss += value_loss.item()

            avg_test_policy_loss = running_policy_loss / len(data_loader)
    print(f"Test policy loss: {avg_test_policy_loss:.5f} | Test value loss: {running_value_loss/len(data_loader):.5f}\n")

    return avg_test_policy_loss


In [21]:
from pathlib import Path

best_test_loss = float('inf')

MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents=True, 
                 exist_ok=True 
)

MODEL_NAME = "supervised_learning_chess_model_1.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME
epochs = 200
for epoch in tqdm(range(epochs)):
    tqdm.write(f"Epoch: {epoch}\n---------")
    train_step(model,
               train_loader,
               optimizer,
               policy_criterion,
               value_criterion,
               device
               )
    current_test_loss = test_step(model,
               val_loader,
               optimizer,
               policy_criterion,
               value_criterion,
               device
               )
    scheduler.step(current_test_loss)
    if current_test_loss < best_test_loss:
        best_test_loss = current_test_loss
        print(f"⭐ New Best Model! Loss improved to: {best_test_loss:.5f}. Saving...")
        torch.save(obj=model.state_dict(), f=MODEL_SAVE_PATH)
    else:
        print(f"No improvement in Test Loss ({current_test_loss:.5f} vs best {best_test_loss:.5f}).")

    current_lr = optimizer.param_groups[0]['lr']
    print(f"Current Learning Rate: {current_lr}")


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 0
---------
Train policy loss: 4.11416 | Train value loss: 0.94101

Test policy loss: 3.28409 | Test value loss: 0.95004

⭐ New Best Model! Loss improved to: 3.28409. Saving...


  0%|          | 1/200 [01:16<4:14:16, 76.67s/it]

Current Learning Rate: 0.001
Epoch: 1
---------
Train policy loss: 3.09335 | Train value loss: 0.91678

Test policy loss: 3.09996 | Test value loss: 0.91846

⭐ New Best Model! Loss improved to: 3.09996. Saving...


  1%|          | 2/200 [02:35<4:17:28, 78.02s/it]

Current Learning Rate: 0.001
Epoch: 2
---------
Train policy loss: 2.78744 | Train value loss: 0.90686

Test policy loss: 3.01524 | Test value loss: 0.93473

⭐ New Best Model! Loss improved to: 3.01524. Saving...


  2%|▏         | 3/200 [03:56<4:19:56, 79.17s/it]

Current Learning Rate: 0.001
Epoch: 3
---------
Train policy loss: 2.60630 | Train value loss: 0.90046

Test policy loss: 2.99753 | Test value loss: 0.91145

⭐ New Best Model! Loss improved to: 2.99753. Saving...


  2%|▏         | 4/200 [05:15<4:19:19, 79.38s/it]

Current Learning Rate: 0.001
Epoch: 4
---------
Train policy loss: 2.47950 | Train value loss: 0.89534

Test policy loss: 2.96390 | Test value loss: 0.91956

⭐ New Best Model! Loss improved to: 2.96390. Saving...


  2%|▎         | 5/200 [06:35<4:17:45, 79.31s/it]

Current Learning Rate: 0.001
Epoch: 5
---------
Train policy loss: 2.39244 | Train value loss: 0.88952

Test policy loss: 2.94506 | Test value loss: 0.90553

⭐ New Best Model! Loss improved to: 2.94506. Saving...


  3%|▎         | 6/200 [07:54<4:16:52, 79.44s/it]

Current Learning Rate: 0.001
Epoch: 6
---------
Train policy loss: 2.33992 | Train value loss: 0.88338



  4%|▎         | 7/200 [09:13<4:15:04, 79.30s/it]

Test policy loss: 2.96346 | Test value loss: 0.90526

No improvement in Test Loss (2.96346 vs best 2.94506).
Current Learning Rate: 0.001
Epoch: 7
---------
Train policy loss: 2.29258 | Train value loss: 0.87817



  4%|▍         | 8/200 [10:32<4:12:49, 79.01s/it]

Test policy loss: 2.98807 | Test value loss: 0.91790

No improvement in Test Loss (2.98807 vs best 2.94506).
Current Learning Rate: 0.001
Epoch: 8
---------
Train policy loss: 2.25561 | Train value loss: 0.87382



  4%|▍         | 9/200 [11:50<4:10:40, 78.74s/it]

Test policy loss: 2.97497 | Test value loss: 0.90479

No improvement in Test Loss (2.97497 vs best 2.94506).
Current Learning Rate: 0.001
Epoch: 9
---------
Train policy loss: 2.22145 | Train value loss: 0.86934



  5%|▌         | 10/200 [13:10<4:10:24, 79.08s/it]

Test policy loss: 2.98932 | Test value loss: 0.90607

No improvement in Test Loss (2.98932 vs best 2.94506).
Current Learning Rate: 0.001
Epoch: 10
---------
Train policy loss: 2.20478 | Train value loss: 0.86332



  6%|▌         | 11/200 [14:28<4:08:27, 78.87s/it]

Test policy loss: 3.00801 | Test value loss: 0.92428

No improvement in Test Loss (3.00801 vs best 2.94506).
Current Learning Rate: 0.001
Epoch: 11
---------
Train policy loss: 2.18867 | Train value loss: 0.85946



  6%|▌         | 12/200 [15:54<4:13:44, 80.98s/it]

Test policy loss: 3.01970 | Test value loss: 0.90601

No improvement in Test Loss (3.01970 vs best 2.94506).
Current Learning Rate: 0.0001
Epoch: 12
---------
Train policy loss: 1.55602 | Train value loss: 0.81641

Test policy loss: 2.87047 | Test value loss: 0.90803

⭐ New Best Model! Loss improved to: 2.87047. Saving...


  6%|▋         | 13/200 [17:13<4:10:39, 80.43s/it]

Current Learning Rate: 0.0001
Epoch: 13
---------
Train policy loss: 1.41547 | Train value loss: 0.79426



  7%|▋         | 14/200 [18:32<4:07:45, 79.92s/it]

Test policy loss: 2.87796 | Test value loss: 0.91317

No improvement in Test Loss (2.87796 vs best 2.87047).
Current Learning Rate: 0.0001
Epoch: 14
---------
Train policy loss: 1.35318 | Train value loss: 0.77766



  8%|▊         | 15/200 [19:50<4:04:39, 79.35s/it]

Test policy loss: 2.88822 | Test value loss: 0.92302

No improvement in Test Loss (2.88822 vs best 2.87047).
Current Learning Rate: 0.0001
Epoch: 15
---------
Train policy loss: 1.30334 | Train value loss: 0.76108



  8%|▊         | 16/200 [21:12<4:06:05, 80.24s/it]

Test policy loss: 2.90200 | Test value loss: 0.93955

No improvement in Test Loss (2.90200 vs best 2.87047).
Current Learning Rate: 0.0001
Epoch: 16
---------
Train policy loss: 1.27518 | Train value loss: 0.74533



  8%|▊         | 17/200 [22:40<4:12:05, 82.65s/it]

Test policy loss: 2.92169 | Test value loss: 0.93511

No improvement in Test Loss (2.92169 vs best 2.87047).
Current Learning Rate: 0.0001
Epoch: 17
---------
Train policy loss: 1.24338 | Train value loss: 0.72847



  9%|▉         | 18/200 [24:08<4:14:59, 84.06s/it]

Test policy loss: 2.95954 | Test value loss: 0.95513

No improvement in Test Loss (2.95954 vs best 2.87047).
Current Learning Rate: 0.0001
Epoch: 18
---------
Train policy loss: 1.22726 | Train value loss: 0.71019



 10%|▉         | 19/200 [25:41<4:21:46, 86.77s/it]

Test policy loss: 2.96449 | Test value loss: 0.95657

No improvement in Test Loss (2.96449 vs best 2.87047).
Current Learning Rate: 1e-05
Epoch: 19
---------
Train policy loss: 1.12786 | Train value loss: 0.67110



 10%|█         | 20/200 [27:09<4:21:14, 87.08s/it]

Test policy loss: 2.95585 | Test value loss: 0.96139

No improvement in Test Loss (2.95585 vs best 2.87047).
Current Learning Rate: 1e-05
Epoch: 20
---------
Train policy loss: 1.11744 | Train value loss: 0.66516



 10%|█         | 21/200 [28:38<4:22:10, 87.88s/it]

Test policy loss: 2.95022 | Test value loss: 0.96577

No improvement in Test Loss (2.95022 vs best 2.87047).
Current Learning Rate: 1e-05
Epoch: 21
---------
Train policy loss: 1.12060 | Train value loss: 0.66046



 11%|█         | 22/200 [30:10<4:24:20, 89.10s/it]

Test policy loss: 2.96770 | Test value loss: 0.97285

No improvement in Test Loss (2.96770 vs best 2.87047).
Current Learning Rate: 1e-05
Epoch: 22
---------
Train policy loss: 1.11540 | Train value loss: 0.65583



 12%|█▏        | 23/200 [31:42<4:25:00, 89.83s/it]

Test policy loss: 2.96477 | Test value loss: 0.97512

No improvement in Test Loss (2.96477 vs best 2.87047).
Current Learning Rate: 1e-05
Epoch: 23
---------
Train policy loss: 1.11398 | Train value loss: 0.65276



 12%|█▏        | 24/200 [33:11<4:22:39, 89.54s/it]

Test policy loss: 2.95481 | Test value loss: 0.97687

No improvement in Test Loss (2.95481 vs best 2.87047).
Current Learning Rate: 1e-05
Epoch: 24
---------
Train policy loss: 1.11026 | Train value loss: 0.64878



 12%|█▎        | 25/200 [34:34<4:15:36, 87.64s/it]

Test policy loss: 2.95747 | Test value loss: 0.97993

No improvement in Test Loss (2.95747 vs best 2.87047).
Current Learning Rate: 1.0000000000000002e-06
Epoch: 25
---------
Train policy loss: 1.09816 | Train value loss: 0.64194



 13%|█▎        | 26/200 [35:57<4:10:25, 86.35s/it]

Test policy loss: 2.95573 | Test value loss: 0.97962

No improvement in Test Loss (2.95573 vs best 2.87047).
Current Learning Rate: 1.0000000000000002e-06
Epoch: 26
---------
Train policy loss: 1.10081 | Train value loss: 0.64210



 14%|█▎        | 27/200 [37:25<4:09:46, 86.63s/it]

Test policy loss: 2.97384 | Test value loss: 0.98018

No improvement in Test Loss (2.97384 vs best 2.87047).
Current Learning Rate: 1.0000000000000002e-06
Epoch: 27
---------


 14%|█▎        | 27/200 [37:37<4:01:07, 83.63s/it]


KeyboardInterrupt: 