In [3]:
import sys
sys.path.append('../../src/')
from dataset import TegameDataset
from models import TegameScoreModel
 
import torch
from torch.utils.data import DataLoader

%load_ext autoreload
%autoreload 2

In [4]:
groups = torch.load("training_set.pt")
dataset = TegameDataset(groups)
loader = DataLoader(dataset, batch_size=1, shuffle=True)

In [5]:
# Take the first group from the dataset
X_group, y_group = dataset[0]

# The feature dimension is the second dimension of X_group
INPUT_DIM = X_group.shape[1]

# 2. Model instance
# hidden_dim=128 is a reasonable default: large enough to capture patterns,
# small enough to avoid overfitting on limited gameplay data.
model = TegameScoreModel(input_dim=INPUT_DIM, hidden_dim=128)

# 3. Device selection
# Use GPU if available; otherwise fall back to CPU.
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# 4. Optimizer
# Adam is chosen because it converges faster and more stably than SGD
# for this kind of supervised policy-learning setup.
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [6]:
def listwise_softmax_loss(logits, target_index):
    """
    Computes a listwise softmax loss for a single decision group.

    Parameters
    ----------
    logits : tensor of shape [num_moves]
        The unnormalized scores assigned by the model to each possible move.

    target_index : int
        Index of the move actually taken (the correct action).

    Notes
    -----
    This loss corresponds to the negative log-likelihood of the chosen action
    under a softmax distribution over all legal moves. It is the standard
    objective for training a policy model in a listwise classification setting,
    where each group has a variable number of actions.
    """
    log_probs = torch.log_softmax(logits, dim=0)
    return -log_probs[target_index]

In [7]:
model.train()

epochs = 20

for epoch in range(epochs):
    total_loss = 0.0

    for X_group, y_group in loader:
        # X_group: [1, num_moves, feature_dim]
        # y_group: [1]

        # Remove the batch dimension (always 1) and move to device
        X_group = X_group.squeeze(0).to(device)   # [num_moves, feature_dim]
        y_group = y_group.item()                  # int (index of the chosen move)

        # Forward pass
        logits = model(X_group).squeeze(-1)       # [num_moves]

        # Skip groups with only one possible move (typically only NOOP)
        # These provide no learning signal, since the model has no choice to make.
        if logits.ndim == 0:
            continue

        # Loss: negative log-likelihood of the chosen move
        # Using log_softmax ensures numerical stability.
        log_probs = torch.log_softmax(logits, dim=0)
        loss = -log_probs[y_group]

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs} - Loss: {total_loss:.4f}")

Epoch 1/20 - Loss: 4127.9447
Epoch 2/20 - Loss: 3044.9432
Epoch 3/20 - Loss: 2825.9132
Epoch 4/20 - Loss: 2712.5898
Epoch 5/20 - Loss: 2819.7274
Epoch 6/20 - Loss: 2664.1553
Epoch 7/20 - Loss: 2611.4773
Epoch 8/20 - Loss: 2667.2960
Epoch 9/20 - Loss: 2507.7724
Epoch 10/20 - Loss: 2495.1049
Epoch 11/20 - Loss: 2490.9058
Epoch 12/20 - Loss: 2688.3812
Epoch 13/20 - Loss: 2597.6471
Epoch 14/20 - Loss: 2478.3551
Epoch 15/20 - Loss: 2411.5257
Epoch 16/20 - Loss: 2483.6621
Epoch 17/20 - Loss: 2428.2955
Epoch 18/20 - Loss: 2430.6318
Epoch 19/20 - Loss: 2473.6905
Epoch 20/20 - Loss: 2513.8300


> **Note:** The loss appears very large because it is the *sum* of the negative log‑likelihood over thousands of decision groups, not an average. Its absolute value is therefore not meaningful on its own — what matters is that it decreases over epochs.  
> Now that the model is trained, we can evaluate something more interpretable: its accuracy in reproducing the moves from the training set.


In [8]:
# test the model

model.eval()

correct_groups = 0
total_groups = 0

with torch.no_grad():
    for X_group, y_group in dataset:
        # X_group: [num_moves, feature_dim]
        # y_group: int (index of the correct move)

        logits = model(X_group).squeeze(-1)   # [num_moves]

        # Skip groups with only one possible move (typically NOOP)
        # These provide no meaningful accuracy signal.
        if logits.ndim == 0:
            continue

        pred_idx = logits.argmax().item()

        if pred_idx == y_group:
            correct_groups += 1

        total_groups += 1

# This accuracy measures how often the model reproduces
# the same decisions found in the training set.
acc = correct_groups / total_groups
print(f"Group accuracy: {acc*100:.1f}%")

Group accuracy: 77.2%


> **Comment:** Accuracy is a much more interpretable metric than the raw loss.  
> While the loss only tells us that the model is assigning higher probability to the correct moves over time, the accuracy directly measures how often the model reproduces the same decisions found in the training logs.  
> This gives us a clearer sense of how well the learned policy matches the behavior encoded in the dataset.

In [9]:
# save
torch.save(model, "tegame_model.pth")