# Draft Assistant

This notebook trains drafting models based on 17lands data

Currently hacked together

In [1]:
from collections import Counter
import pandas as pd
import torch

import warnings
warnings.filterwarnings("ignore")

import statisticaldrafting as sd

In [2]:
target_set = "FDN"

# Pull data. 
model_path = f"../data/models/{target_set}.pt"
card_path = f"../data/cards/{target_set}.csv"
pick_table = pd.read_csv(card_path)
pick_table["distance"] = [1] * len(pick_table)


# Load model. 
mlp_network = sd.DraftMLP(cardnames=pick_table["name"].tolist(), hidden_dims=[300, 300, 300])
mlp_network.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [3]:
def get_card_distances(collection_list, cur_network):
    """ Get card distances for current network. Used for visualization """
    
    # Cardnames - for validation. 
    cardnames = pick_table["name"].tolist()
    ss = len(cardnames)

    # Get collection vector
    collection_vector = torch.zeros([1, ss])
    cnt = Counter(collection_list)
    for card in cnt:
        
        # Validate cardname. 
        if card not in cardnames:
            raise Exception(f"{card} not in set. Please correct cardname.")

        # Add to collection vector. 
        card_index = cardnames.index(card)
        collection_vector[0, card_index] = cnt[card]
        
    # Get card and collection embeddings. 
    cur_network.eval()
    with torch.no_grad():
        card_distances = cur_network(collection_vector, torch.ones(ss))
    return card_distances

def get_percentile(card_distances, top_score=150):
    # TODO: omit basic lands. 
    card_distances = card_distances.reshape(-1) # Ensure correct shape. 
    min_distance = min(card_distances).item()
    max_distance = max(card_distances).item()
    percentiles = [top_score * (cd - min_distance) / (max_distance - min_distance) for cd in card_distances.tolist()]
    return [round(p, 1) for p in percentiles]

In [5]:
# Initialize an empty collection DataFrame
collection = pd.DataFrame(columns=pick_table.columns)

In [None]:
# THIS IS THE BROKEN UI. 
from IPython.display import display, clear_output
import ipywidgets as widgets

# State variables for filters
rarity_options = ["All", "common", "uncommon", "common+uncommon", "rare", "mythic"]
color_options = ["All", "W", "G", "U", "R", "B", "Multicolor", "Colorless"]

rarity_filter = widgets.Dropdown(
    options=rarity_options,
    value="rare",
    description="Rarity:",
)

color_filter = widgets.Dropdown(
    options=color_options,
    value="All",
    description="Color:",
)

def update_table():
    """Re-render the pick table and collection."""
    clear_output(wait=True)
    display_tables()

def make_pick(card):
    """Add card to the collection and update tables."""
    global collection
    collection = pd.concat([collection, pd.DataFrame([card])], ignore_index=True)
    update_table()

def remove_card(index):
    """Remove a card from the collection by index and update tables."""
    global collection
    collection = collection.drop(index).reset_index(drop=True)
    update_table()
    
# Function to reset the collection
def reset_collection(change=None):
    """Reset the collection (clear all cards)."""
    global collection
    collection = pd.DataFrame(columns=pick_table.columns)  # Empty collection
    update_table()

def display_tables():
    """Display pick table and collection with interactive buttons."""
    global pick_table

    # Update distances pick table.  
    collection_list = [n for n in collection["name"]]
    cur_distances = get_card_distances(collection_list, mlp_network)
    percentiles = get_percentile(cur_distances) 
    pick_table["distance"] = percentiles # Use percentiles for now. 
    
    if "p1p1_distance" not in pick_table.columns:
        p1p1_distances = get_card_distances([], mlp_network)
        p1p1_percentiles = get_percentile(p1p1_distances)
        pick_table["p1p1_distance"] = p1p1_percentiles
        
    pick_table["synergy"] = (pick_table["distance"] - pick_table["p1p1_distance"]).round(1)
        
    
    # Hide distances in collection table. 
    collection["distance"] = [""] * len(collection)
    
    # Apply filters to the pick table
    filtered_table = pick_table.copy()

    # If the rarity filter is "All", exclude cards with "Basic" rarity
    if rarity_filter.value == "All":
        filtered_table = filtered_table[filtered_table['rarity'] != "basic"]
    elif rarity_filter.value == "common+uncommon":
        filtered_table = filtered_table[filtered_table['rarity'].isin(["common", "uncommon"])]
    else:
        filtered_table = filtered_table[filtered_table['rarity'] == rarity_filter.value]
    
    if color_filter.value != "All":
        filtered_table = filtered_table[filtered_table['color_identity'] == color_filter.value]

    # Sort the filtered pick table by distance (ascending order)
    filtered_table = filtered_table.sort_values(by="distance", ascending=False)
    
    # Add the "New Draft" button to reset the collection
    new_draft_button = widgets.Button(description="New Draft", button_style="warning")
    new_draft_button.on_click(reset_collection)
    display(new_draft_button)
    
    # Display the filters
    filter_box = widgets.HBox([rarity_filter, color_filter])
    display(filter_box)

    # Get the maximum length of card names to align them
    max_name_length = filtered_table['name'].apply(len).max()
    max_name_length = max(max_name_length, 15)  # Minimum width for the name column is 12
    
    # Formatting function to align columns and display as text
    def format_row(row):
        return f"{row['name']:<{max_name_length}} | {row['rarity']:<9} | {row['color_identity']:<12} | {row['synergy']:>+7}| {row['distance']:>6}"

    # Display the filtered pick table with buttons
    print(f'{" Card Name":<{max_name_length}} | {"Rarity":<9} | {"Color":<12} | {"Synergy":>7}| {"Rating":>6}')
    for row_count, (index, row) in enumerate(filtered_table.iterrows()):
        # print("This is called once", row["name"])
        row_widget = widgets.Output()
        with row_widget:
            print(format_row(row))
        pick_button = widgets.Button(description=f"Pick: {row['name']}", button_style="success")
        pick_button.on_click(lambda btn, card=row: make_pick(card.to_dict()))
        display(widgets.HBox([row_widget, pick_button]))

        MAX_CARDS_TO_SHOW = 10
        if row_count >= MAX_CARDS_TO_SHOW:
          break

    # Display the collection with remove buttons (same format as pick table)
    print("\nCollection:")
    if not collection.empty:
        print("in collection")
        collection_widget = widgets.Output()
        with collection_widget:
            # Use the same format_row for collection as for pick table
            for _, row in collection.iterrows():
                print(format_row(row))

        remove_buttons = []
        for idx, row in collection.iterrows():
            remove_button = widgets.Button(description=f"Remove: {row['name']}", button_style="danger")
            remove_button.on_click(lambda btn, index=idx: remove_card(index))

            # Align text and remove button together in the same layout
            row_widget = widgets.Output()
            with row_widget:
                print(format_row(row))
            
            remove_buttons.append(widgets.HBox([row_widget, remove_button]))

        remove_buttons_box = widgets.VBox(remove_buttons)
        display(remove_buttons_box)
    else:
        print("Collection is empty.")

# Add observers to filters to trigger table updates
rarity_filter.observe(lambda change: update_table(), names='value')
color_filter.observe(lambda change: update_table(), names='value')

# Note: restarting the notebook fixes most visual bugs.

display_tables()



HBox(children=(Dropdown(description='Rarity:', index=4, options=('All', 'common', 'uncommon', 'common+uncommon…

 Card Name                  | Rarity    | Color        | Synergy| Rating


HBox(children=(Output(), Button(button_style='success', description='Pick: Exemplar of Light', style=ButtonSty…

HBox(children=(Output(), Button(button_style='success', description='Pick: Celestial Armor', style=ButtonStyle…

HBox(children=(Output(), Button(button_style='success', description='Pick: High-Society Hunter', style=ButtonS…

HBox(children=(Output(), Button(button_style='success', description='Pick: Curator of Destinies', style=Button…

HBox(children=(Output(), Button(button_style='success', description='Pick: Kiora, the Rising Tide', style=Butt…

HBox(children=(Output(), Button(button_style='success', description='Pick: Arahbo, the First Fang', style=Butt…

HBox(children=(Output(), Button(button_style='success', description='Pick: Alesha, Who Laughs at Fate', style=…

HBox(children=(Output(), Button(button_style='success', description='Pick: Sylvan Scavenging', style=ButtonSty…

HBox(children=(Output(), Button(button_style='success', description='Pick: Elenda, Saint of Dusk', style=Butto…

HBox(children=(Output(), Button(button_style='success', description='Pick: Spinner of Souls', style=ButtonStyl…

HBox(children=(Output(), Button(button_style='success', description='Pick: Skyknight Squire', style=ButtonStyl…


Collection:
Collection is empty.
