In [1]:
import os
import json

import torch
import numpy as np
import pandas as pd

from mtgc.ai.data import load_cards_data, load_draft_data
from mtgc.ai.preprocessing import CardPreprocessor, filter_draft_data
from mtgc.ai.model import DraftPicker
from mtgc.ai.inference import DraftPickerInference

In [2]:
draft_data_dtypes_path = "../data/mkm/17lands/draft_data_public.MKM.PremierDraft.columns.json"
draft_data_path = "../data/mkm/17lands/draft_data_public.MKM.PremierDraft.csv"
card_folder = "../data/mkm/cards"
model_folder = "../models/draft_picker/mkm"

In [3]:
with open(os.path.join(model_folder, "hyper_parameters.json"), "r") as f:
    hyper_parameters = json.load(f)

In [4]:
draft_data_df = load_draft_data(draft_data_path, draft_data_dtypes_path, nrows=100_000)

Load draft data from '../data/mkm/17lands/draft_data_public.MKM.PremierDraft.csv'


In [5]:
cards_data_dict = load_cards_data(card_folder)

Load card data from '../data/mkm/cards'


In [6]:
draft_data_df = filter_draft_data(draft_data_df, cards_data_dict)

Filter draft data (current shape: (100000, 666))
Filtering done (new shape: (6438, 666))


In [7]:
card_preprocessor = CardPreprocessor(
    card_type_vocabulary = [
        "Land",
        "Creature",
        "Artifact",
        "Enchantment",
        "Planeswalker",
        "Battle",
        "Instant",
        "Sorcery"
    ],
    keyword_vocabulary = [
        "Attach",
        "Counter",
        "Exile",
        "Fight",
        "Mill",
        "Sacrifice",
        "Scry",
        "Tap",
        "Untap",
        "Deathtouch",
        "Defender",
        "Double strike",
        "Enchant",
        "Equip",
        "First strike",
        "Flash",
        "Flying",
        "Haste",
        "Hexproof",
        "Indestructible",
        "Lifelink",
        "Menace",
        "Protection",
        "Prowess",
        "Reach",
        "Trample",
        "Vigilance"
    ]
)

In [8]:
model = DraftPicker(**hyper_parameters)
model.load_state_dict(torch.load(os.path.join(model_folder, "model.pt")))
model.eval()



DraftPicker(
  (input_mlp): Sequential(
    (0): Linear(in_features=45, out_features=128, bias=True)
    (1): ReLU()
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-9): 10 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (output_mlp): Sequential(
    (0): Linear(in_features=128, out_features=1, bias=True)
  )
)

In [9]:
inference_engine = DraftPickerInference(model, cards_data_dict, card_preprocessor)

In [10]:
index = np.random.randint(len(draft_data_df))

row_dict = draft_data_df.iloc[index].to_dict()
picked_card = row_dict["pick"]
pool_cards = {
    column.replace("pool_", "").replace("/", " - "): value
    for column, value in row_dict.items()
    if column.startswith("pool_") and value > 0 and column.replace("pool_", "").replace("/", " - ") in cards_data_dict
}
pack_cards = {
    column.replace("pack_card_", "").replace("/", " - "): value
    for column, value in row_dict.items()
    if column.startswith("pack_card_") and value > 0 and column.replace("pack_card_", "").replace("/", " - ") in cards_data_dict
}

inference_engine.run(pool_cards, pack_cards, picked_card, explain=True)


===== Situation =====

# Pool

1 Deadly Complication
1 Dog Walker
1 Expose the Culprit
1 Frantic Scapegoat
1 Galvanize
1 Gearbane Orangutan
1 Glint Weaver
1 Gravestone Strider
1 Greenbelt Radical
1 Harried Dronesmith
1 Innocent Bystander
1 Jaded Analyst
1 Leering Onlooker
1 Murder
1 Nervous Gardener
1 Offender at Large
1 Pick Your Poison
1 Public Thoroughfare
1 Push // Pull
1 Reckless Detective
3 Red Herring
1 Riftburst Hellion
1 Rubblebelt Braggart
2 Shock
1 Slice from the Shadows
2 Torch the Witness

# Pack

1 A Killer Among Us
1 Connecting the Dots
1 Crowd-Control Warden
1 Fanatical Strength
1 Hustle // Bustle
1 Lumbering Laundry
1 Out Cold
1 Person of Interest
1 They Went This Way


===== Predictions =====

Pack - A Killer Among Us                 ##########(0.11)
Pack - Connecting the Dots               ###################(0.20)
Pack - Crowd-Control Warden              #####################(0.22)
Pack - Fanatical Strength                ################(0.17)
Pack - Hustle // Bus