In [None]:
import re
import csv
import os
import json
import datetime
import typing as T
import functools
from copy import deepcopy

import torch
import plotly.express as px
import pytorch_lightning as pl
import mlflow
import pandas as pd

import mtgradient
from mtgradient import processing, datasets, models

In [None]:
%load_ext nb_black

In [None]:
# ! pip install mlflow
mlflow.pytorch.autolog(log_models=False, log_every_n_step=100)

In [None]:
###################################
# Draft config
# -- location of raw data
# -- location to cache dataset
# -- time split
###################################

# draft_csv_path = "data/draft_data_public.NEO.PremierDraft.csv"
draft_csv_path = "data/HBG/draft_data_public.HBG.PremierDraft.csv"
# draft_csv_path = "tests/testdata/test_premier_draft_hbg.csv"
# draft_csv_path = "tests/testdata/test_premier_draft.csv"

# cache_path = "data/cached/neo_premier_draft"
cache_path = "data/cached/hbg_premier_draft/"
# cache_path = "tests/testdata/"

# test_split_cutoff = pd.Timestamp(datetime.date(2022, 3, 11))
test_split_cutoff = pd.Timestamp(datetime.date(2022, 7, 28))
recent_game_cutoff = test_split_cutoff - datetime.timedelta(10)

LOAD_CACHED = True

In [None]:
if not LOAD_CACHED:
    parsed_data, card_ids = processing.parse_csv(draft_csv_path, verbose=True)
    os.makedirs(cache_path, exist_ok=True)
    processing.persist_processed_dataset(cache_path, parsed_data, card_ids)
else:
    parsed_data, card_ids = processing.load_processed_dataset(cache_path)

In [None]:
train_subset = {
    k: v
    for k, v in parsed_data.items()
    if v.get("draft_time", datetime.date(1999, 9, 9)) < test_split_cutoff
}
val_subset = {
    k: v
    for k, v in parsed_data.items()
    if v.get("draft_time", datetime.date(1999, 9, 9)) >= test_split_cutoff
}

train_draft_dataset = datasets.DraftDataset(train_subset, recent_game_cutoff)
val_draft_dataset = datasets.DraftDataset(val_subset, test_split_cutoff, use_all=True)

small = set(list(train_subset.keys())[:500])
check_draft_dataset = datasets.DraftDataset({k: v for k, v in train_subset.items() 
                                             if k in small}, recent_game_cutoff)


In [None]:
len(val_subset), len(train_subset)

In [None]:
len(parsed_data)

In [None]:
data_out = []
for draft_id, draft in train_subset.items():
    sub_data = {
        "rank": draft["rank"] if draft["rank"] else "NA",
        "event_match_wins": draft["event_match_wins"],
        "user_game_win_rate_bucket": draft["user_game_win_rate_bucket"],
        "draft_time": draft["draft_time"],
    }
    for round_num in range(len(draft["pick_data"])):
        weights = train_draft_dataset.get_weights(draft, round_num)
        cp = deepcopy(sub_data)
        cp["w1"] = weights[0]
        cp["w2"] = weights[1]
        cp["round"] = round_num
        data_out.append(cp)
    if len(data_out) > 10000:
        break
df_out = pd.DataFrame(data_out)

In [None]:
df_out.head()

In [None]:
# Generate scatterplot of example weights
# px.scatter(
#     df_out,
#     x="round",
#     y="w1",
#     color="rank",
#     hover_data=["user_game_win_rate_bucket", "event_match_wins"],
#     category_orders={"rank": ["mythic", "diamond", "platinum", "gold", "silver", "NA"]},
#     color_discrete_sequence=["orange", "teal", "green", "yellow", "silver", "black"],
# )

In [None]:
# Generate scatterplot of win-rate weights and plot time of draft
# px.scatter(
#     df_out,
#     x="round",
#     y="w2",
#     color="rank",
#     hover_data=["user_game_win_rate_bucket", "event_match_wins"],
#     category_orders={"rank": ["mythic", "diamond", "platinum", "gold", "silver", "NA"]},
#     color_discrete_sequence=["orange", "teal", "green", "yellow", "silver", "black"],
# )
# px.histogram(df_out, x="draft_time", cumulative=True, histnorm="percent")

In [None]:
# load the model and ensure its forward method works
params = {
    "n_cards": 500,
    "emb_dim": 512,
    "n_cards_in_pack": 16,
    "n_steps": 10000,
}
model = models.DraftTransformer(
    **params,
).to("cuda:0")
print(
    model(
        models.collate_batch([next(iter(train_draft_dataset))], device="cuda:0"),
    )
)

In [None]:
try:
    mlflow.end_run()
except Exception as e:
    pass
mlflow.start_run()
mlflow.log_params({k: str(v) for k, v in params.items()})

checkpoint = pl.callbacks.ModelCheckpoint(save_weights_only=True, filename="model")
trainer = pl.Trainer(
    gpus=[0],
    max_steps=model.n_steps,
    #     profiler="simple",
    precision=16,
    callbacks=[checkpoint],
    #     benchmark=True,
    check_val_every_n_epoch=2,
    limit_val_batches=0.2,
)


train_dl = torch.utils.data.DataLoader(
    train_draft_dataset,
    batch_size=200,
    shuffle=True,
    drop_last=True,
    collate_fn=models.collate_batch,
    num_workers=2,
    pin_memory=False,
)
val_dl = torch.utils.data.DataLoader(
    val_draft_dataset,
    batch_size=200,
    shuffle=False,
    drop_last=False,
    collate_fn=models.collate_batch,
    num_workers=2,
    pin_memory=False,
)


trainer.fit(model, train_dl, val_dl)

In [None]:
dp = checkpoint.dirpath
fn = os.listdir(trainer.checkpoint_callback.dirpath)[0]
checkpoint_path = os.path.join(dp, fn)

# log our model weights
mlflow.log_artifact(checkpoint_path)

# log the card ids
card_id_path = f"{dp}/card_ids.json"
with open(card_id_path, "w") as f:
    json.dump(card_ids, f)

mlflow.log_artifact(card_id_path)

In [None]:
model.metrics["train_accuracy_round_00"].

In [None]:
val_iter = iter(val_draft_dataset)
bt = [next(val_iter) for _ in range(10)]
collated = models.collate_batch(bt, device="cpu")
model.eval()
model.to("cpu")
pred_a, pred_b = model(collated)

In [None]:
collated["picks"].shape, pred_a.shape, collated["pick_weights"]

In [None]:
acc(pred_a, collated["picks"])

In [None]:
import numpy as np

# log the model params/config
model_config_path = f"{dp}/model_config.json"
with open(model_config_path, "w") as f:
    json.dump(params, f)

mlflow.log_artifact(model_config_path)