In [1]:
import pandas as pd
import numpy as np
import torch
import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
from torch.nn import functional as F
import pickle
from torch.utils.data import DataLoader, TensorDataset
import ast

from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer, PorterStemmer
import gensim.downloader

from torch.distributions import Categorical

from tokenizers import Tokenizer

type_dict={
    0:'Creature',
    1:'Sorcery',
    2:'Artifact',
    3:'Enchantment',
    4:'Instant',
    5:'Land',
}

In [2]:
# For ZY no csv read
#with open('mtgdata.pickle', 'rb') as file:
#    mtg_df=pickle.load(file)
mtg_df=pd.read_csv('mtg_data.csv', index_col=0)
mtg_df=mtg_df.dropna(subset=['text_prompt', 'card_description'])
mtg_df.head()

Unnamed: 0,name,mana_cost,cmc,type_line,oracle_text,power,toughness,colors,color_identity,keywords,rarity,flavor_text,text,text_prompt,card_description
0,Fury Sliver,{5}{R},6.0,Creature — Sliver,All Sliver creatures have double strike.,3.0,3.0,['R'],['R'],[],uncommon,"""A rift opened, and our arrows were abruptly s...",Fury Sliver: [SEP] {5}{R} [SEP] Creature — Sli...,Fury Sliver: [SEP] {5}{R},Creature — Sliver [SEP] All Sliver creatures h...
1,Kor Outfitter,{W}{W},2.0,Creature — Kor Soldier,"When ~ enters the battlefield, you may attach ...",2.0,2.0,['W'],['W'],[],common,"""We take only what we need to survive. Believe...",Kor Outfitter: [SEP] {W}{W} [SEP] Creature — K...,Kor Outfitter: [SEP] {W}{W},Creature — Kor Soldier [SEP] When ~ enters the...
2,Spirit,,0.0,Token Creature — Spirit,Flying,1.0,1.0,['W'],['W'],[Flying],common,,Spirit: [SEP] [SEP] Token Creature — Spirit [...,Spirit: [SEP],Token Creature — Spirit [SEP] Flying
3,Siren Lookout,{2}{U},3.0,Creature — Siren Pirate,"Flying\nWhen ~ enters the battlefield, it expl...",1.0,2.0,['U'],['U'],"[Flying, Explore]",common,,Siren Lookout: [SEP] {2}{U} [SEP] Creature — S...,Siren Lookout: [SEP] {2}{U},Creature — Siren Pirate [SEP] Flying\nWhen ~ e...
4,Web,{G},1.0,Enchantment — Aura,Enchant creature (Target a creature as you cas...,,,['G'],['G'],[Enchant],rare,,Web: [SEP] {G} [SEP] Enchantment — Aura [SEP] ...,Web: [SEP] {G},Enchantment — Aura [SEP] Enchant creature (Tar...


In [3]:
#load pretrained embeddings
wiki_vectors = gensim.downloader.load('glove-wiki-gigaword-50')

In [4]:
def generate_embedding(phrase):
    words = word_tokenize(phrase)
    words = [word.lower() for word in words]
    lemmatizer=WordNetLemmatizer()
    #stemmer=PorterStemmer()
    words = [lemmatizer.lemmatize(word) for word in words]
    #words = [stemmer.stem(word) for word in words]
    total_vector=[]
    for word in words:
        try:
            total_vector.append(wiki_vectors.word_vec(word))
        except KeyError:
            pass
    if len(total_vector)!=0:
        out = np.mean(total_vector, axis=0)
    else:
        out = np.zeros(50)
    return out

In [5]:
types=['Creature', 'Sorcery', 'Artifact', 'Enchantment', 'Instant', 'Land']
type_list={'Name':[],
           'Type':[]}
for idx, row in mtg_df.iterrows():
    for word in row['type_line'].split():
        if word in types:
            type_list['Name'].append(row['name'])
            type_list['Type'].append(word)
type_data=pd.DataFrame(type_list).drop_duplicates()
for type in types:
    type_data[type]=type_data['Type']==type
type_data

Unnamed: 0,Name,Type,Creature,Sorcery,Artifact,Enchantment,Instant,Land
0,Fury Sliver,Creature,True,False,False,False,False,False
1,Kor Outfitter,Creature,True,False,False,False,False,False
2,Spirit,Creature,True,False,False,False,False,False
3,Siren Lookout,Creature,True,False,False,False,False,False
4,Web,Enchantment,False,False,False,True,False,False
...,...,...,...,...,...,...,...,...
83968,Born to Drive,Enchantment,False,False,False,True,False,False
83978,Stern Mentor,Creature,True,False,False,False,False,False
83981,Youthful Valkyrie,Creature,True,False,False,False,False,False
83982,Fallaji Vanguard,Creature,True,False,False,False,False,False


In [6]:
total_embeddings=[]
for idx, row in type_data.iterrows():
    total_embeddings.append(generate_embedding(row['Name']))

tensor_list = [torch.tensor(nameembedding) for nameembedding in list(total_embeddings)]
name_embeddings = torch.stack(tensor_list, dim=0)
name_embeddings

  total_vector.append(wiki_vectors.word_vec(word))


tensor([[ 0.1918, -0.0154,  0.1995,  ..., -0.2442, -0.3086, -0.6314],
        [-0.4758,  1.0394, -0.2369,  ...,  0.2503,  0.1288,  0.0442],
        [-0.0175,  0.7970, -1.3675,  ..., -0.3775, -0.2226, -0.3518],
        ...,
        [ 0.4284,  0.2326, -0.1057,  ..., -0.2599,  0.0209,  0.2470],
        [ 0.4418, -0.1677,  0.5107,  ...,  0.3453, -0.3805,  0.7088],
        [ 0.3030,  0.8218, -0.1229,  ..., -0.3748, -0.0187, -0.1198]],
       dtype=torch.float64)

In [7]:
card_type=torch.tensor(type_data.drop(['Name', 'Type'], axis=1).values, dtype=torch.long)
n_train = int(0.9*type_data.shape[0])
train_X = name_embeddings[:n_train]
val_X = name_embeddings[n_train:]
train_Y = card_type[:n_train]
val_Y = card_type[n_train:]

In [8]:
class MLPClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MLPClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.softmax(out)
        return out

# Define the input size, hidden layer size, and number of classes
input_size = 50
hidden_size = 128
num_classes = 6

# Create an instance of the MLPClassifier
model = MLPClassifier(input_size, hidden_size, num_classes)

# You can print the model to see its architecture
print(model)

MLPClassifier(
  (fc1): Linear(in_features=50, out_features=128, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=128, out_features=6, bias=True)
  (softmax): Softmax(dim=-1)
)


In [9]:
num_epochs=100

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_X=train_X.float()
train_Y=train_Y.float()
dataset=TensorDataset(train_X, train_Y)
train_loader = DataLoader(dataset=dataset, batch_size=64, shuffle=True)

for epoch in range(num_epochs):
    total_loss = 0.0

    for inputs, labels in train_loader:
        # Forward pass
        outputs = model(inputs)

        # Compute the loss
        loss = criterion(outputs, labels)

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Print the average loss for this epoch
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader)}')

print('Training finished.')
torch.save(model.state_dict(), 'type.pt')

Epoch [1/100], Loss: 1.5129064028033146
Epoch [2/100], Loss: 1.455908931148839
Epoch [3/100], Loss: 1.4481157242949239
Epoch [4/100], Loss: 1.4431171862001952
Epoch [5/100], Loss: 1.4401456831675497
Epoch [6/100], Loss: 1.437087743718007
Epoch [7/100], Loss: 1.4348347311092513
Epoch [8/100], Loss: 1.4297137027464542
Epoch [9/100], Loss: 1.4215940817963653
Epoch [10/100], Loss: 1.4176938524706109
Epoch [11/100], Loss: 1.414722086209331
Epoch [12/100], Loss: 1.4125382443369947
Epoch [13/100], Loss: 1.4100610156954847
Epoch [14/100], Loss: 1.408063372379632
Epoch [15/100], Loss: 1.4061916158889152
Epoch [16/100], Loss: 1.4039863735286113
Epoch [17/100], Loss: 1.4021400785083094
Epoch [18/100], Loss: 1.4001241610740043
Epoch [19/100], Loss: 1.3983870681167254
Epoch [20/100], Loss: 1.3969136951538512
Epoch [21/100], Loss: 1.395307086143397
Epoch [22/100], Loss: 1.393453420721335
Epoch [23/100], Loss: 1.3919726120033844
Epoch [24/100], Loss: 1.3906046660418439
Epoch [25/100], Loss: 1.3891908

In [10]:
def generate_type(cardname):
    probs = model(torch.tensor(generate_embedding(cardname)))
    distribution = Categorical(probs)
    sampled_index = distribution.sample()
    for key, type in type_dict.items():
        print(f'{type}:{probs[key]}')
    return type_dict[int(sampled_index)]

In [11]:
generate_type('sword')

Creature:4.3889525013961355e-17
Sorcery:3.0426184471538017e-19
Artifact:1.0
Enchantment:5.477421901461943e-14
Instant:3.1649355358709386e-10
Land:2.7021122496589227e-30


  total_vector.append(wiki_vectors.word_vec(word))


'Artifact'

In [12]:
colors_wheel=['W', 'U', 'R', 'G', 'B', 'E']
color_list={'Name':[],
           'color':[]}
for idx, row in mtg_df.iterrows():
    colors = ast.literal_eval(row['color_identity'])
    if len(colors)!=0:
        for color in colors:
            color_list['Name'].append(row['name'])
            color_list['color'].append(color)
    else:
            color_list['Name'].append(row['name'])
            color_list['color'].append('E')
color_data=pd.DataFrame(color_list).drop_duplicates()
for color in colors_wheel:
     color_data[color]=color_data['color']==color
color_data

Unnamed: 0,Name,color,W,U,R,G,B,E
0,Fury Sliver,R,False,False,True,False,False,False
1,Kor Outfitter,W,True,False,False,False,False,False
2,Spirit,W,True,False,False,False,False,False
3,Siren Lookout,U,False,True,False,False,False,False
4,Web,G,False,False,False,True,False,False
...,...,...,...,...,...,...,...,...
99030,Youthful Valkyrie,W,True,False,False,False,False,False
99031,Fallaji Vanguard,R,False,False,True,False,False,False
99032,Fallaji Vanguard,W,True,False,False,False,False,False
99035,Hold at Bay,W,True,False,False,False,False,False


In [13]:
total_embeddings=[]
for idx, row in color_data.iterrows():
    total_embeddings.append(generate_embedding(row['Name']))
tensor_list = [torch.tensor(nameembedding) for nameembedding in list(total_embeddings)]
name_embeddings = torch.stack(tensor_list, dim=0)

card_color=torch.tensor(color_data.drop(['Name', 'color'], axis=1).values, dtype=torch.long)
n_train = int(0.9*color_data.shape[0])
color_train_X = name_embeddings[:n_train]
color_val_X = name_embeddings[n_train:]
color_train_Y = card_color[:n_train]
color_val_Y = card_color[n_train:]

  total_vector.append(wiki_vectors.word_vec(word))


In [14]:
# Define the input size, hidden layer size, and number of classes
input_size = 50
hidden_size = 128
num_classes = 6

class MLPColorClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MLPColorClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size, num_classes)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.relu2(out)
        out = self.fc3(out)
        out = self.softmax(out)
        return out

# Create an instance of the MLPClassifier
color_model = MLPColorClassifier(input_size, hidden_size, num_classes)

In [15]:
num_epochs=300

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(color_model.parameters(), lr=1e-3)
color_train_X=color_train_X.float()
color_train_Y=color_train_Y.float()
dataset=TensorDataset(color_train_X, color_train_Y)
train_loader = DataLoader(dataset=dataset, batch_size=64, shuffle=True)

for epoch in range(num_epochs):
    total_loss = 0.0

    for inputs, labels in train_loader:
        # Forward pass
        outputs = color_model(inputs)

        # Compute the loss
        loss = criterion(outputs, labels)

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Print the average loss for this epoch
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader)}')

print('Training finished.')
torch.save(color_model.state_dict(), 'color.pt')

Epoch [1/300], Loss: 1.7166238084916146
Epoch [2/300], Loss: 1.6834184023641772
Epoch [3/300], Loss: 1.6731772938082294
Epoch [4/300], Loss: 1.6654338854615407
Epoch [5/300], Loss: 1.6564930051885625
Epoch [6/300], Loss: 1.6491090407935522
Epoch [7/300], Loss: 1.6428221238556728
Epoch [8/300], Loss: 1.638009016744552
Epoch [9/300], Loss: 1.6309334080706361
Epoch [10/300], Loss: 1.6257848526841852
Epoch [11/300], Loss: 1.6213571458734493
Epoch [12/300], Loss: 1.6172445356204945
Epoch [13/300], Loss: 1.6130547361989176
Epoch [14/300], Loss: 1.6090626652522753
Epoch [15/300], Loss: 1.606383599260802
Epoch [16/300], Loss: 1.6025322673141316
Epoch [17/300], Loss: 1.5992948552613617
Epoch [18/300], Loss: 1.5959225703311224
Epoch [19/300], Loss: 1.5946246552210983
Epoch [20/300], Loss: 1.591557885754493
Epoch [21/300], Loss: 1.5890918293306904
Epoch [22/300], Loss: 1.5864567325961205
Epoch [23/300], Loss: 1.5858366553501417
Epoch [24/300], Loss: 1.5821018931686237
Epoch [25/300], Loss: 1.5798

In [16]:
color_dict={
    0:'W',
    1:'U',
    2:'R',
    3:'G',
    4:'B',
    5:'Colorless',
}

def generate_color(cardname):
    probs = color_model(torch.tensor(generate_embedding(cardname)))
    distribution = Categorical(probs)
    sampled_index = distribution.sample()
    for key, type in color_dict.items():
        print(f'{type}:{probs[key]}')
    return color_dict[int(sampled_index)]

In [17]:
mtg_df=mtg_df.dropna(subset='oracle_text')
mtg_df.head(5)

Unnamed: 0,name,mana_cost,cmc,type_line,oracle_text,power,toughness,colors,color_identity,keywords,rarity,flavor_text,text,text_prompt,card_description
0,Fury Sliver,{5}{R},6.0,Creature — Sliver,All Sliver creatures have double strike.,3.0,3.0,['R'],['R'],[],uncommon,"""A rift opened, and our arrows were abruptly s...",Fury Sliver: [SEP] {5}{R} [SEP] Creature — Sli...,Fury Sliver: [SEP] {5}{R},Creature — Sliver [SEP] All Sliver creatures h...
1,Kor Outfitter,{W}{W},2.0,Creature — Kor Soldier,"When ~ enters the battlefield, you may attach ...",2.0,2.0,['W'],['W'],[],common,"""We take only what we need to survive. Believe...",Kor Outfitter: [SEP] {W}{W} [SEP] Creature — K...,Kor Outfitter: [SEP] {W}{W},Creature — Kor Soldier [SEP] When ~ enters the...
2,Spirit,,0.0,Token Creature — Spirit,Flying,1.0,1.0,['W'],['W'],[Flying],common,,Spirit: [SEP] [SEP] Token Creature — Spirit [...,Spirit: [SEP],Token Creature — Spirit [SEP] Flying
3,Siren Lookout,{2}{U},3.0,Creature — Siren Pirate,"Flying\nWhen ~ enters the battlefield, it expl...",1.0,2.0,['U'],['U'],"[Flying, Explore]",common,,Siren Lookout: [SEP] {2}{U} [SEP] Creature — S...,Siren Lookout: [SEP] {2}{U},Creature — Siren Pirate [SEP] Flying\nWhen ~ e...
4,Web,{G},1.0,Enchantment — Aura,Enchant creature (Target a creature as you cas...,,,['G'],['G'],[Enchant],rare,,Web: [SEP] {G} [SEP] Enchantment — Aura [SEP] ...,Web: [SEP] {G},Enchantment — Aura [SEP] Enchant creature (Tar...
