# Deckbuild Model Training

This notebook trains deck completion models using 17lands game data.

### Steps:
1. Download **game data** (not draft data) from [17lands](https://www.17lands.com/public_datasets) into `statistical-drafting/data/17lands/`
2. Ensure card data exists in `statistical-drafting/data/cards/`
3. Run this notebook to train the deck completion model

In [None]:
# Install dependencies (restart kernel after running this cell)
%pip install torch numpy pandas scikit-learn matplotlib -q

In [None]:
# Add parent directory to path for local imports
import sys
sys.path.insert(0, '..')

import os
import torch
from torch.utils.data import DataLoader

import statisticaldeckbuild as sdb

## Configuration

In [None]:
# Set configuration
SET_ABBREVIATION = "FDN"  # Change to your target set
DRAFT_MODE = "Premier"     # "Premier", "Trad", etc.
N_HOLDOUT = 1              # Number of cards to hold out for prediction
OVERWRITE_DATASET = True   # Set to False to reuse existing dataset

## Create Dataset

In [None]:
# Create training and validation datasets from game data
train_path, val_path = sdb.create_deckbuild_dataset(
    set_abbreviation=SET_ABBREVIATION,
    draft_mode=DRAFT_MODE,
    overwrite=OVERWRITE_DATASET,
    n_holdout=N_HOLDOUT,
)

In [None]:
# Load datasets
train_dataset = torch.load(train_path, weights_only=False)
val_dataset = torch.load(val_path, weights_only=False)

print(f"Training examples: {len(train_dataset)}")
print(f"Validation examples: {len(val_dataset)}")
print(f"Number of cards: {len(train_dataset.cardnames)}")

In [None]:
# Create data loaders
train_dataloader = DataLoader(train_dataset, batch_size=10000, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=10000, shuffle=False)

## Inspect Sample Data

In [None]:
# Look at a single training example
partial_deck, available, label = train_dataset[0]

print("Partial deck cards:")
for i, count in enumerate(partial_deck):
    if count > 0:
        print(f"  {train_dataset.cardnames[i]}: {int(count)}")

print(f"\nTotal cards in partial deck: {int(partial_deck.sum())}")
print(f"Available cards to choose from: {int(available.sum())}")
print(f"Held out card(s): {[train_dataset.cardnames[i] for i in range(len(label)) if label[i] > 0]}")

## Train Model

In [None]:
# Create the network
network = sdb.DeckbuildNet(cardnames=train_dataset.cardnames)
print(f"Model parameters: {sum(p.numel() for p in network.parameters()):,}")

In [None]:
# Train the model
network, training_info = sdb.train_deckbuild_model(
    train_dataloader,
    val_dataloader,
    network,
    experiment_name=f"{SET_ABBREVIATION}_{DRAFT_MODE}_deckbuild",
    learning_rate=0.03,
)

In [None]:
# Display training results
print("Training Results:")
for key, value in training_info.items():
    print(f"  {key}: {value}")

## Evaluate Model

In [None]:
# Load best model weights
model_path = f"../data/models/{SET_ABBREVIATION}_{DRAFT_MODE}_deckbuild.pt"
network.load_state_dict(torch.load(model_path, weights_only=True))
network.eval()

# Final evaluation
accuracy = sdb.evaluate_deckbuild_model(val_dataloader, network)

## Example Predictions

In [None]:
# Show predictions for a few examples
import torch

network.eval()
with torch.no_grad():
    for i in range(5):
        partial_deck, available, label = val_dataset[i]
        
        # Get prediction
        pred = network(partial_deck.unsqueeze(0).float(), available.unsqueeze(0).float())
        pred = pred.squeeze(0)
        
        # Mask unavailable cards
        pred[available == 0] = float('-inf')
        
        # Get top 5 predictions
        top_indices = torch.argsort(pred, descending=True)[:5]
        actual_card = [val_dataset.cardnames[j] for j in range(len(label)) if label[j] > 0][0]
        
        print(f"\nExample {i+1}:")
        print(f"  Actual held-out card: {actual_card}")
        print(f"  Top 5 predictions:")
        for rank, idx in enumerate(top_indices, 1):
            card_name = val_dataset.cardnames[idx]
            marker = " <-- CORRECT" if card_name == actual_card else ""
            print(f"    {rank}. {card_name}{marker}")

## Quick Training Pipeline

For convenience, you can also use the default pipeline function:

In [None]:
# Uncomment to run the full pipeline in one call
# training_info = sdb.default_deckbuild_pipeline(
#     set_abbreviation="FDN",
#     draft_mode="Premier",
#     overwrite_dataset=True,
#     n_holdout=1,
# )