In [1]:
from typing import Any, Iterator, Type
from dataclasses import dataclass
from itertools import product
import numpy as np
import pandas as pd
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

from bots.mcts_cnn_bots.tensors import df_to_tensors

In [2]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
training_data = Path("./data/parquet/data.parquet")

In [4]:
df = pd.read_parquet(training_data)

In [5]:
def shuffle_data(df: pd.DataFrame, random_state=42):
    return df.sample(frac=1, random_state=random_state).reset_index(drop=True)

In [6]:
def split_data(state_tensors, policy_targets, value_targets, test_size=0.2, random_state=42):
    indices = np.arange(len(state_tensors))
    train_idx, test_idx = train_test_split(indices, test_size=test_size, random_state=random_state)

    train_dataset = TensorDataset(state_tensors[train_idx], policy_targets[train_idx], value_targets[train_idx])
    test_dataset = TensorDataset(state_tensors[test_idx], policy_targets[test_idx], value_targets[test_idx])

    return train_dataset, test_dataset

In [None]:
df = shuffle_data(df, random_state=42)

In [None]:
state_tensors, policy_targets, value_targets = df_to_tensors(df)

In [None]:
train_state, test_state, train_policy, test_policy, train_value, test_value = train_test_split(
    state_tensors, policy_targets, value_targets, test_size=0.2, random_state=42
)

train_state = train_state.to(device)
train_policy = train_policy.to(device)
train_value = train_value.to(device)
test_state = test_state.to(device)
test_policy = test_policy.to(device)
test_value = test_value.to(device)

train_dataset = TensorDataset(train_state, train_policy, train_value)
test_dataset = TensorDataset(test_state, test_policy, test_value)

In [7]:
def train_model(model, train_loader, test_loader, epochs, lr, weight_decay, optimizer, criterion):
    training_loss_value = []
    validation_loss_value = []
    
    criterion_value = criterion
    optimizer = optimizer(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

    for epoch in range(epochs):
        # Training phase
        model.train()
        train_policy_loss = 0.0
        train_value_loss = 0.0

        for state_batch, _, value_batch in train_loader:
            state_batch = state_batch.to(device)
            value_batch = value_batch.to(device)
            
            optimizer.zero_grad()

            # Forward pass
            value_pred = model(state_batch)
            
            # Compute the individual losses
            value_loss = criterion_value(value_pred, value_batch)

            # Backward pass
            value_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_value_loss += value_loss.item()
            
        training_loss_value.append(train_value_loss / len(train_loader))

        # Validation phase
        model.eval()
        val_value_loss = 0.0

        with torch.no_grad():
            for state_batch, _, value_batch in test_loader:
                state_batch = state_batch.to(device)
                value_batch = value_batch.to(device)
                value_pred = model(state_batch)
                val_value_loss += criterion_value(value_pred, value_batch).item()
                
        validation_loss_value.append(val_value_loss / len(test_loader))
        
        scheduler.step(val_value_loss)

        print(
            f"Epoch {epoch+1}/{epochs}, "
            f"Train Value Loss: {training_loss_value[-1]:.4f}, "
            f"Val Value Loss: {validation_loss_value[-1]:.4f}"
        )

In [None]:
@dataclass
class Parameters:
    criterion: Any
    optimizer: Any
    epochs: int
    learning_rate: float
    weight_decay: float
    batch_size: int
    
    def __str__(self):
        return (f"crit: {self.criterion}, optm: {self.optimizer}, epcs: {self.epochs}, "
                f"lrnt: {self.learning_rate}, wedc: {self.weight_decay}, basz: {self.batch_size}")

@dataclass
class ParamGrid:
    criterion: tuple[Any, ...]
    optimizer: tuple[Any, ...]
    epochs: tuple[int, ...]
    learning_rate: tuple[float, ...]
    weight_decay: tuple[float, ...]
    batch_size: tuple[int, ...]
    
    def iter_product(self) -> Iterator[Parameters]:
        for combination in product(self.criterion, self.optimizer, self.epochs, self.learning_rate, self.weight_decay, self.batch_size):
            yield Parameters(*combination)

In [None]:
def grid_search(t_loader, v_loader, param_grid: ParamGrid, model_class: Type[nn.Module]):
    best_params = None
    best_score = float('inf')
    
    for params in param_grid.iter_product():        
        # Initialize model, criterion, and optimizer
        m = model_class()

        # Train and evaluate the model
        score = train_model(m, t_loader, v_loader, params.epochs, params.learning_rate, params.weight_decay, params.optimizer, params.criterion)
        print(f"Params: {params}, Validation Loss: {score}")
        
        # Update best parameters if current score is better
        if score < best_score:
            best_score = score
            best_params = params
    
    return best_params, best_score

In [8]:
class JassCNN(nn.Module):
    def __init__(self, input_channels=19, num_actions=36):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, 256, kernel_size=(3, 3), stride=1, padding=1)
        self.conv2 = nn.Conv2d(256, 256, kernel_size=(2, 3), stride=2, padding=1)
        self.conv3 = nn.Conv2d(256, 256, kernel_size=(2, 2), stride=1, padding=1)
        self.conv4 = nn.Conv2d(256, 256, kernel_size=(2, 2), stride=2, padding=0)
        self.conv5 = nn.Conv2d(256, 256, kernel_size=(2, 2), stride=1, padding=0)

        self.bn1 = nn.BatchNorm2d(256)
        self.bn2 = nn.BatchNorm2d(256)
        self.bn3 = nn.BatchNorm2d(256)
        self.bn4 = nn.BatchNorm2d(256)
        self.bn5 = nn.BatchNorm2d(256)
        
        self.fc1 = nn.Linear(512, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc_value = nn.Linear(128, 1)

    def forward(self, x):   
        x = nn.functional.leaky_relu(self.bn1(self.conv1(x)), negative_slope=0.01)
        x = nn.functional.leaky_relu(self.bn2(self.conv2(x)), negative_slope=0.01)
        x = nn.functional.leaky_relu(self.bn3(self.conv3(x)), negative_slope=0.01)
        x = nn.functional.leaky_relu(self.bn4(self.conv4(x)), negative_slope=0.01)
        x = nn.functional.leaky_relu(self.bn5(self.conv5(x)), negative_slope=0.01)

        x = torch.flatten(x, start_dim=1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        value = torch.sigmoid(self.fc_value(x))
        
        return value

In [12]:
model = JassCNN().to(device)

In [13]:
BATCH_SIZE = 64
EPOCHS = 20
LEARNING_RATE = 5e-4
WEIGHT_DECAY = 1e-5

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
train_model(model, train_loader, test_loader, epochs=EPOCHS, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, optimizer=optim.Adam, criterion=nn.MSELoss)

In [None]:
param_grid = ParamGrid(
    criterion=(nn.HuberLoss(delta=1.35), nn.SmoothL1Loss(beta=1.35)),
    optimizer=(optim.Adam, optim.AdamW, optim.SGD),
    epochs=(7, ),
    learning_rate=(5e-3, 1e-3, 5e-4),
    weight_decay=(1e-4, 1e-5),
    batch_size=(32, 64, 128)
)

In [None]:
params = grid_search(train_loader, test_loader, param_grid, JassCNN)