In [1]:
import numpy as np
import pandas as pd
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

from bots.mcts_cnn_bots.tensors import df_to_tensors

In [2]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
training_data = Path("./data/parquet/data.parquet")

In [4]:
df = pd.read_parquet(training_data)

In [5]:
def shuffle_data(df: pd.DataFrame, random_state=42):
    return df.sample(frac=1, random_state=random_state).reset_index(drop=True)

In [6]:
def split_data(state_tensors, policy_targets, value_targets, test_size=0.2, random_state=42):
    indices = np.arange(len(state_tensors))
    train_idx, test_idx = train_test_split(indices, test_size=test_size, random_state=random_state)

    train_dataset = TensorDataset(state_tensors[train_idx], policy_targets[train_idx], value_targets[train_idx])
    test_dataset = TensorDataset(state_tensors[test_idx], policy_targets[test_idx], value_targets[test_idx])

    return train_dataset, test_dataset

In [18]:
def train_model(model, train_loader, test_loader, epochs, lr, weight_decay):
    training_loss_value = []
    training_loss_policy = []
    validation_loss_value = []
    validation_loss_policy = []
    
    criterion_policy = nn.CrossEntropyLoss()
    criterion_value = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

    for epoch in range(epochs):
        # Training phase
        model.train()
        train_policy_loss = 0.0
        train_value_loss = 0.0

        for state_batch, policy_batch, value_batch in train_loader:
            state_batch = state_batch.to(device)
            policy_batch = torch.argmax(policy_batch, dim=1)  # Converts one-hot to class indices
            policy_batch = policy_batch.to(device)
            value_batch = value_batch.to(device)
            
            optimizer.zero_grad()

            # Forward pass
            policy_pred, value_pred = model(state_batch)
            
            # Compute the individual losses
            policy_loss = criterion_policy(policy_pred, policy_batch)
            value_loss = criterion_value(value_pred, value_batch)
            
            # Dynamically scale based on current loss values
            policy_weight = 1.0
            value_weight = policy_loss.item() / (value_loss.item() + 1e-8)
            loss = policy_weight * policy_loss + value_weight * value_loss

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_policy_loss += policy_loss.item()
            train_value_loss += value_loss.item()
            
        training_loss_value.append(train_value_loss / len(train_loader))
        training_loss_policy.append(train_policy_loss / len(train_loader))

        # Validation phase
        model.eval()
        val_policy_loss = 0.0
        val_value_loss = 0.0

        with torch.no_grad():
            for state_batch, policy_batch, value_batch in test_loader:
                policy_batch = torch.argmax(policy_batch, dim=1)  # Converts one-hot to class indices
                policy_pred, value_pred = model(state_batch)
                val_policy_loss += criterion_policy(policy_pred, policy_batch).item()
                val_value_loss += criterion_value(value_pred, value_batch).item()
                
        validation_loss_value.append(val_value_loss / len(test_loader))
        validation_loss_policy.append(val_policy_loss / len(test_loader))
        
        scheduler.step((val_value_loss + val_policy_loss) / 2)

        print(
            f"Epoch {epoch+1}/{epochs}, "
            f"Train Policy Loss: {training_loss_policy[-1]:.4f}, Train Value Loss: {training_loss_value[-1]:.4f}, "
            f"Val Policy Loss: {validation_loss_policy[-1]:.4f}, Val Value Loss: {validation_loss_value[-1]:.4f}"
        )

In [19]:
class JassCNN(nn.Module):
    def __init__(self, input_channels=19, num_actions=36):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, 256, kernel_size=(2, 3), stride=1, padding=1)
        self.conv2 = nn.Conv2d(256, 256, kernel_size=(2, 3), stride=2, padding=0)
        self.conv3 = nn.Conv2d(256, 256, kernel_size=(2, 2), stride=2, padding=0)
        self.conv4 = nn.Conv2d(256, 256, kernel_size=(1, 2), stride=1, padding=0)

        self.bn1 = nn.BatchNorm2d(128)
        self.bn2 = nn.BatchNorm2d(256)
        self.bn3 = nn.BatchNorm2d(256)
        self.bn4 = nn.BatchNorm2d(256)
        
        self.fc1 = nn.Linear(256, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, 128)
        
        self.fc_policy = nn.Linear(128, num_actions)
        self.fc_value = nn.Linear(128, 1)

    def forward(self, x):   
        x = torch.leaky_relu(self.bn1(self.conv1(x)), negative_slope=0.01)
        x = torch.leaky_relu(self.bn2(self.conv2(x)), negative_slope=0.01)
        x = torch.leaky_relu(self.bn3(self.conv3(x)), negative_slope=0.01)
        x = torch.leaky_relu(self.bn4(self.conv4(x)), negative_slope=0.01)

        x = torch.flatten(x, start_dim=1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        shared_features = torch.relu(self.fc4(x))

        policy = torch.softmax(self.fc_policy(shared_features), dim=-1)
        value = torch.sigmoid(self.fc_value(shared_features))
        
        return policy, value

In [9]:
df = shuffle_data(df, random_state=42)

In [10]:
state_tensors, policy_targets, value_targets = df_to_tensors(df)

In [11]:
train_state, test_state, train_policy, test_policy, train_value, test_value = train_test_split(
    state_tensors, policy_targets, value_targets, test_size=0.2, random_state=42
)

train_state = train_state.to(device)
train_policy = train_policy.to(device)
train_value = train_value.to(device)
test_state = test_state.to(device)
test_policy = test_policy.to(device)
test_value = test_value.to(device)

train_dataset = TensorDataset(train_state, train_policy, train_value)
test_dataset = TensorDataset(test_state, test_policy, test_value)

In [12]:
model = JassCNN().to(device)

In [15]:
BATCH_SIZE = 32
EPOCHS = 20
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4

In [20]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

train_model(model, train_loader, test_loader, epochs=EPOCHS, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

Epoch 1/20, Train Policy Loss: 3.5835, Train Value Loss: 0.0269, Val Policy Loss: 3.5835, Val Value Loss: 0.0231
Epoch 2/20, Train Policy Loss: 3.5835, Train Value Loss: 0.0229, Val Policy Loss: 3.5835, Val Value Loss: 0.0208
Epoch 3/20, Train Policy Loss: 3.5835, Train Value Loss: 0.0223, Val Policy Loss: 3.5835, Val Value Loss: 0.0209
Epoch 4/20, Train Policy Loss: 3.5835, Train Value Loss: 0.0222, Val Policy Loss: 3.5835, Val Value Loss: 0.0210
Epoch 5/20, Train Policy Loss: 3.5835, Train Value Loss: 0.0222, Val Policy Loss: 3.5835, Val Value Loss: 0.0210
Epoch 6/20, Train Policy Loss: 3.5835, Train Value Loss: 0.0222, Val Policy Loss: 3.5835, Val Value Loss: 0.0206
Epoch 7/20, Train Policy Loss: 3.5835, Train Value Loss: 0.0222, Val Policy Loss: 3.5835, Val Value Loss: 0.0211
Epoch 8/20, Train Policy Loss: 3.5835, Train Value Loss: 0.0222, Val Policy Loss: 3.5835, Val Value Loss: 0.0208
Epoch 9/20, Train Policy Loss: 3.5835, Train Value Loss: 0.0222, Val Policy Loss: 3.5835, Val Va