In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, SubsetRandomSampler, Dataset
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import numpy as np
import h5py
import chess
import os
from tqdm import tqdm
from datetime import datetime

In [None]:
class Config:
    def __init__(self):
        self.LEARNING_RATE = 0.02
        self.CONV_FILTERS = 256
        self.NUM_RESIDUAL = 20
        self.INPUT_SHAPE = (19, 8, 8)
        self.OUTPUT_SHAPE = (4672, 1)
        self.DEVICE = torch.device(
            "cuda"
            if torch.cuda.is_available()
            else "mps" if torch.backends.mps.is_available() else "cpu"
        )
        self.SIMULATIONS_PER_MOVE = 200
        self.DIRICHLET_NOISE = 0.25


config = Config()
print(f"Using device: {config.DEVICE}")

In [None]:
def conv_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )


def residual_block(channels):
    return nn.Sequential(
        conv_block(channels, channels),
        nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(channels),
    )


class PolicyHead(nn.Module):
    def __init__(self, input_channels, output_size):
        super().__init__()
        self.head = nn.Sequential(
            nn.Conv2d(
                input_channels, 73, kernel_size=1
            ),  # Matches saved shape (73, 256, 1, 1)
            nn.BatchNorm2d(73),
            nn.ReLU(inplace=True),
            nn.Flatten(),
            nn.Dropout(0.3),
            nn.Linear(73 * config.INPUT_SHAPE[1] * config.INPUT_SHAPE[2], output_size),
        )

    def forward(self, x):
        return self.head(x)


class ValueHead(nn.Module):
    def __init__(self, input_channels, output_size):
        super().__init__()
        self.head = nn.Sequential(
            nn.Conv2d(input_channels, 1, kernel_size=1),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True),
            nn.Flatten(),
            nn.Linear(
                config.INPUT_SHAPE[1] * config.INPUT_SHAPE[2], 512
            ),  # Matches saved shape (512, 64)
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, output_size),  # Extra layer (head.7 in state_dict)
            nn.Tanh(),
        )

    def forward(self, x):
        return self.head(x)


class RLModel(nn.Module):
    def __init__(self, input_shape, output_shape):
        super().__init__()
        self.input_shape = input_shape
        self.output_shape = output_shape

        c, _, _ = input_shape
        self.conv1 = conv_block(c, config.CONV_FILTERS)
        self.residuals = nn.Sequential(
            *[residual_block(config.CONV_FILTERS) for _ in range(config.NUM_RESIDUAL)]
        )
        self.policy_head = PolicyHead(config.CONV_FILTERS, output_shape[0])
        self.value_head = ValueHead(config.CONV_FILTERS, output_shape[1])

    def forward(self, x):
        x = self.conv1(x)
        for res in self.residuals:
            x = F.relu(res(x) + x)
        return self.policy_head(x), self.value_head(x)

In [None]:
import h5py
import torch
from torch.utils.data import Dataset

class ChessDataset(Dataset):
    def __init__(self, hdf5_path):
        self.hdf5_path = hdf5_path
        self.file = None  # lazy loading

        # Use a temporary handle to fetch length info
        with h5py.File(hdf5_path, 'r') as f:
            self.num_samples = f['inputs'].shape[0]

    def _lazy_init(self):
        if self.file is None:
            self.file = h5py.File(self.hdf5_path, 'r')
            self.inputs = self.file['inputs']
            self.policies = self.file['policies']
            self.values = self.file['values']

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        self._lazy_init()
        input_tensor = torch.tensor(self.inputs[idx], dtype=torch.float32)     # (19, 8, 8)
        policy_tensor = torch.tensor(self.policies[idx], dtype=torch.float32)  # (73, 8, 8)
        value_tensor = torch.tensor(self.values[idx], dtype=torch.float32)     # (1,)
        policy_target = policy_tensor.flatten().argmax()  # class index
        return input_tensor, policy_target, value_tensor

    def __del__(self):
        if self.file is not None:
            self.file.close()


In [None]:
BATCH_SIZE = 256
LEARNING_RATE = 0.002
NUM_EPOCHS = 6
VALIDATION_SPLIT = 0.1
LOG_INTERVAL = 100
SAVE_PATH = "latest-trained.pth"
NUM_WORKERS = 4

In [None]:
def train_model(model, hdf5_path, save_path=SAVE_PATH):
    # Device setup
    device = config.DEVICE
    print(f"Using device: {device}")
    print(f"Training for {NUM_EPOCHS} epochs")

    criterion_policy = nn.CrossEntropyLoss()
    criterion_value = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

    # Dataset and data loaders
    dataset = ChessDataset(hdf5_path)
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(VALIDATION_SPLIT * dataset_size))
    np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]

    train_loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        sampler=SubsetRandomSampler(train_indices),
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
    val_loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        sampler=SubsetRandomSampler(val_indices),
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )

    print(
        f"Training samples: {len(train_indices)}, Validation samples: {len(val_indices)}"
    )

    # Training history
    history = {
        "train_loss": [],
        "train_policy_loss": [],
        "train_value_loss": [],
        "train_policy_acc": [],
        "train_policy_top3_acc": [],
        "val_loss": [],
        "val_policy_loss": [],
        "val_value_loss": [],
        "val_policy_acc": [],
        "val_policy_top3_acc": [],
    }

    best_val_loss = float("inf")
    start_time = datetime.now()

    # Training loop
    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch+1}/{NUM_EPOCHS} [{'='*20}]")

        # Training phase
        model.train()
        train_loss, train_policy_loss, train_value_loss = 0, 0, 0
        train_correct, train_top3_correct, train_total = 0, 0, 0
        train_value_min, train_value_max = float("inf"), float("-inf")

        for batch_idx, (inputs, policy_targets, value_targets) in enumerate(
            train_loader
        ):
            inputs = inputs.to(device, non_blocking=True)  # (batch, 19, 8, 8)
            policy_targets = policy_targets.to(
                device, non_blocking=True, dtype=torch.long
            )  # (batch,)
            value_targets = value_targets.to(device, non_blocking=True)  # (batch, 1)

            optimizer.zero_grad()
            policy_pred, value_pred = model(inputs)  # (batch, 73*8*8), (batch, 1)

            policy_loss = criterion_policy(policy_pred, policy_targets)
            value_loss = criterion_value(value_pred.squeeze(), value_targets.squeeze())
            loss = 0.4 * policy_loss + 0.6 * value_loss

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_loss += loss.item()
            train_policy_loss += policy_loss.item()
            train_value_loss += value_loss.item()

            # Policy accuracy (top-1 and top-3)
            _, predicted = torch.max(policy_pred, 1)
            train_correct += (predicted == policy_targets).sum().item()
            _, top3_predicted = torch.topk(policy_pred, k=3, dim=1)
            train_top3_correct += sum(
                policy_targets[i] in top3_predicted[i]
                for i in range(policy_targets.size(0))
            )
            train_total += policy_targets.size(0)

            # Track value prediction range
            train_value_min = min(train_value_min, value_pred.min().item())
            train_value_max = max(train_value_max, value_pred.max().item())

            if (batch_idx + 1) % LOG_INTERVAL == 0:
                print(
                    f"Batch {batch_idx+1}/{len(train_loader)}: "
                    f"Loss={loss.item():.4f}, Policy Loss={policy_loss.item():.4f}, "
                    f"Value Loss={value_loss.item():.4f}, "
                    f"Top-1 Acc={100 * train_correct/train_total:.2f}%, "
                    f"Top-3 Acc={100 * train_top3_correct/train_total:.2f}%"
                )

        # Epoch metrics
        train_loss /= len(train_loader)
        train_policy_loss /= len(train_loader)
        train_value_loss /= len(train_loader)
        train_policy_acc = 100 * train_correct / train_total
        train_policy_top3_acc = 100 * train_top3_correct / train_total

        history["train_loss"].append(train_loss)
        history["train_policy_loss"].append(train_policy_loss)
        history["train_value_loss"].append(train_value_loss)
        history["train_policy_acc"].append(train_policy_acc)
        history["train_policy_top3_acc"].append(train_policy_top3_acc)

        print(
            f"Train Epoch Summary: Loss={train_loss:.4f}, Policy Loss={train_policy_loss:.4f}, "
            f"Value Loss={train_value_loss:.4f}, Top-1 Acc={train_policy_acc:.2f}%, "
            f"Top-3 Acc={train_policy_top3_acc:.2f}%, "
            f"Value Range=[{train_value_min:.4f}, {train_value_max:.4f}]"
        )

        # Validation phase
        model.eval()
        val_loss, val_policy_loss, val_value_loss = 0, 0, 0
        val_correct, val_top3_correct, val_total = 0, 0, 0
        val_value_min, val_value_max = float("inf"), float("-inf")

        with torch.no_grad():
            for inputs, policy_targets, value_targets in val_loader:
                inputs = inputs.to(device, non_blocking=True)
                policy_targets = policy_targets.to(
                    device, non_blocking=True, dtype=torch.long
                )
                value_targets = value_targets.to(device, non_blocking=True)

                policy_pred, value_pred = model(inputs)
                policy_loss = criterion_policy(policy_pred, policy_targets)
                value_loss = criterion_value(
                    value_pred.squeeze(), value_targets.squeeze()
                )
                loss = 0.4 * policy_loss + 0.6 * value_loss

                val_loss += loss.item()
                val_policy_loss += policy_loss.item()
                val_value_loss += value_loss.item()

                _, predicted = torch.max(policy_pred, 1)
                val_correct += (predicted == policy_targets).sum().item()
                _, top3_predicted = torch.topk(policy_pred, k=3, dim=1)
                val_top3_correct += sum(
                    policy_targets[i] in top3_predicted[i]
                    for i in range(policy_targets.size(0))
                )
                val_total += policy_targets.size(0)

                val_value_min = min(val_value_min, value_pred.min().item())
                val_value_max = max(val_value_max, value_pred.max().item())

        val_loss /= len(val_loader)
        val_policy_loss /= len(val_loader)
        val_value_loss /= len(val_loader)
        val_policy_acc = 100 * val_correct / val_total
        val_policy_top3_acc = 100 * val_top3_correct / val_total

        history["val_loss"].append(val_loss)
        history["val_policy_loss"].append(val_policy_loss)
        history["val_value_loss"].append(val_value_loss)
        history["val_policy_acc"].append(val_policy_acc)
        history["val_policy_top3_acc"].append(val_policy_top3_acc)

        print(
            f"Validation Summary: Loss={val_loss:.4f}, Policy Loss={val_policy_loss:.4f}, "
            f"Value Loss={val_value_loss:.4f}, Top-1 Acc={val_policy_acc:.2f}%, "
            f"Top-3 Acc={val_policy_top3_acc:.2f}%, "
            f"Value Range=[{val_value_min:.4f}, {val_value_max:.4f}]"
        )

        # Save model for each epoch
        epoch_save_path = f"checkpoints/epoch{epoch+1}-training2.pth"
        torch.save(model.state_dict(), epoch_save_path)
        print(f"Saved model for epoch {epoch+1} to {epoch_save_path}")

        # Track best model based on validation loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print(
                f"Saved best model with validation loss {val_loss:.4f} to {save_path}"
            )

        scheduler.step()
        print(f"Learning rate: {scheduler.get_last_lr()[0]:.6f}")

    training_time = (datetime.now() - start_time).total_seconds() / 60
    print(f"\nTraining completed in {training_time:.2f} minutes")

    # Plotting
    epochs = range(1, len(history["train_loss"]) + 1)
    plt.figure(figsize=(15, 12))

    plt.subplot(2, 3, 1)
    plt.plot(epochs, history["train_loss"], label="Train Loss")
    plt.plot(epochs, history["val_loss"], label="Val Loss")
    plt.title("Total Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)

    plt.subplot(2, 3, 2)
    plt.plot(epochs, history["train_policy_loss"], label="Train Policy Loss")
    plt.plot(epochs, history["val_policy_loss"], label="Val Policy Loss")
    plt.title("Policy Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)

    plt.subplot(2, 3, 3)
    plt.plot(epochs, history["train_value_loss"], label="Train Value Loss")
    plt.plot(epochs, history["val_value_loss"], label="Val Value Loss")
    plt.title("Value Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)

    plt.subplot(2, 3, 4)
    plt.plot(epochs, history["train_policy_acc"], label="Train Top-1 Acc")
    plt.plot(epochs, history["val_policy_acc"], label="Val Top-1 Acc")
    plt.title("Policy Top-1 Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.legend()
    plt.grid(True)

    plt.subplot(2, 3, 5)
    plt.plot(epochs, history["train_policy_top3_acc"], label="Train Top-3 Acc")
    plt.plot(epochs, history["val_policy_top3_acc"], label="Val Top-3 Acc")
    plt.title("Policy Top-3 Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig("training_plots-3-stockfish.png")
    print("Saved training plots to 'training_plots-3-stockfish.png'")

    return history

In [None]:
# Verify saved data shapes
with h5py.File("dataset-1-6M.h5", "r") as f:
    print(f"Inputs Shape: {f['inputs'].shape}")
    print(f"Policies Shape: {f['policies'].shape}")
    print(f"Values Shape: {f['values'].shape}")

In [None]:
# Model, loss, and optimizer
torch.cuda.empty_cache()
model = RLModel(config.INPUT_SHAPE, config.OUTPUT_SHAPE).to(config.DEVICE)
model.load_state_dict(torch.load('./latest.pth', map_location=config.DEVICE))

In [None]:
hdf5_path = "chess_dataset-1M.h5"
train_model(model, hdf5_path)