In [53]:
import numpy as np
import torch
import torch
import torch.nn as nn

from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
import gensim.downloader

from torch.distributions import Categorical

type_dict={
    0:'Creature',
    1:'Sorcery',
    2:'Artifact',
    3:'Enchantment',
    4:'Instant',
    5:'Land',
}

color_dict={
    0:'W',
    1:'U',
    2:'R',
    3:'G',
    4:'B',
    5:'Colorless',
}

wiki_vectors = gensim.downloader.load('glove-wiki-gigaword-50')

In [54]:
class MLPClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MLPClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.softmax(out)
        return out


class MLPColorClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MLPColorClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size, num_classes)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.relu2(out)
        out = self.fc3(out)
        out = self.softmax(out)
        return out


# Define the input size, hidden layer size, and number of classes
input_size = 50
hidden_size = 128
num_classes = 6

# Create an instance of the MLPClassifier
model = MLPClassifier(input_size, hidden_size, num_classes)
model.load_state_dict(torch.load('type.pt'))
model.eval()

# Create an instance of the MLPClassifier
color_model = MLPColorClassifier(input_size, hidden_size, num_classes)
color_model.load_state_dict(torch.load('color.pt'))
color_model.eval()

MLPColorClassifier(
  (fc1): Linear(in_features=50, out_features=128, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (relu2): ReLU()
  (fc3): Linear(in_features=128, out_features=6, bias=True)
  (softmax): Softmax(dim=-1)
)

In [55]:
def generate_embedding(phrase):
    words = word_tokenize(phrase)
    words = [word.lower() for word in words]
    lemmatizer=WordNetLemmatizer()
    #stemmer=PorterStemmer()
    words = [lemmatizer.lemmatize(word) for word in words]
    #words = [stemmer.stem(word) for word in words]
    total_vector=[]
    for word in words:
        try:
            total_vector.append(wiki_vectors.word_vec(word))
        except KeyError:
            pass
    if len(total_vector)!=0:
        out = np.mean(total_vector, axis=0)
    else:
        out = np.zeros(50)
    return out

def generate_type(cardname):
    probs = model(torch.tensor(generate_embedding(cardname)))
    distribution = Categorical(probs)
    sampled_index = distribution.sample()
    for key, type in type_dict.items():
        print(f'{type}:{probs[key]}')
    return type_dict[int(sampled_index)]

def generate_color(cardname):
    probs = color_model(torch.tensor(generate_embedding(cardname)))
    distribution = Categorical(probs)
    sampled_index = distribution.sample()
    for key, type in color_dict.items():
        print(f'{type}:{probs[key]}')
    return color_dict[int(sampled_index)]

In [56]:
generate_type('black hole')

Creature:0.9998857975006104
Sorcery:8.379576684092171e-06
Artifact:0.00010513646702747792
Enchantment:4.258950525581895e-07
Instant:2.9796856537700478e-08
Land:2.688680353912787e-07


  total_vector.append(wiki_vectors.word_vec(word))


'Creature'

In [57]:
def generate(cardname):
    return generate_type(cardname=cardname), generate_color(cardname=cardname)

In [73]:
generate('infinity sea')

Creature:0.9999287128448486
Sorcery:1.4220736943570955e-07
Artifact:4.2257350287400186e-05
Enchantment:2.8071557608200237e-05
Instant:2.376560948036399e-09
Land:8.712710837244231e-07
W:2.067809327854775e-05
U:0.9999793767929077
R:5.584361948092092e-28
G:8.457973875733299e-18
B:3.7592337046359875e-19
Colorless:0.0


  total_vector.append(wiki_vectors.word_vec(word))


('Creature', 'U')