# Training

This notebook trains drafting models based on 17lands data

### Prerequisites
First download data from 17lands:
1. Create a folder statistical-drafting/data/17lands/ 
2. Download draft data from [17lands](https://www.17lands.com/public_datasets) into the above folder
3. Download an updated statistical-drafting/data/cards.csv file from [17lands](https://www.17lands.com/public_datasets) to support new sets
4. Run the rest of this notebook to perform model training

In [9]:
!mkdir -p ../data/17lands/

In [1]:
# If import fails, run the following command in project root. 
# pip install -e .
# Reinstall package. 
%pip install .. -q

import os
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset

import statisticaldrafting as sd

Note: you may need to restart the kernel to use updated packages.


In [4]:
# List all available sets for training. 
set_abbreviations, draft_modes = [], []
for fn in os.listdir("../data/17lands/"):
    if len(fn) > 20:
        sa, dm = fn.split(".")[1], fn.split(".")[2][:-5]
        set_abbreviations.append(sa)
        draft_modes.append(dm)
        # print(sa, dm) # Print sets here. 

In [6]:
# Train all models.
for set_abbreviation, draft_mode in zip(set_abbreviations, draft_modes):
        try:
                print(f"Starting training for {set_abbreviation}, {draft_mode}")
                sd.default_training_pipeline(set_abbreviation, draft_mode, overwrite_dataset=True)
        except:
                print(f"Error for: {set_abbreviation}, {draft_mode}")

Starting training for STX, Trad
Using input file ../data/17lands/draft_data_public.STX.TradDraft.csv.gz
Completed initialization.
Filtering by match wins >= 3
Loaded 0 picks, t= 2.1 s
Filtering by match wins >= 3
Filtering by match wins >= 3
Filtering by match wins >= 3
Filtering by match wins >= 3
Filtering by match wins >= 3
Filtering by match wins >= 3
Filtering by match wins >= 3
Filtering by match wins >= 3
Loaded all draft data.
Saved training set to ../data/training_sets/STX_Trad_train.pth
Saved validation set to ../data/training_sets/STX_Trad_val.pth
Using existing cardname file, ../data/cards/STX.csv
Starting to train model
Validation set pick accuracy = 19.5%

Starting epoch 0
Training loss: 4.6263

Starting epoch 1
Training loss: 3.1535

Starting epoch 2
Training loss: 2.5951
Validation set pick accuracy = 57.5%
Saving model weights to ../data/models/STX_Trad.pt

Starting epoch 3
Training loss: 2.0446

Starting epoch 4
Training loss: 1.5638
Validation set pick accuracy = 63.

In [45]:
model_path = "../data/models"

draft_models = [dm.split(".")[0].split("_") for dm in os.listdir(model_path) if ".pt" in dm]
sets = set([dm[0] for dm in draft_models])

df = pd.DataFrame(draft_models, columns=["set", "draft_mode"])
x = df.groupby(["set", "draft_mode"]).size().unstack(fill_value=0)
x

draft_mode,Premier,Trad
set,Unnamed: 1_level_1,Unnamed: 2_level_1
AFR,1,0
BLB,1,1
BRO,1,1
DMU,1,1
DSK,1,1
FDN,1,1
HBG,1,1
KTK,1,1
LCI,1,1
LTR,1,1


In [43]:
x.columns

Index(['set', 'Premier', 'Trad'], dtype='object', name='draft_mode')