In [1]:
import pandas as pd
import numpy as np
import torch
import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
from torch.nn import functional as F
import pickle
from torch.utils.data import DataLoader, TensorDataset
import ast

from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer, PorterStemmer
import gensim.downloader

from torch.distributions import Categorical

from tokenizers import Tokenizer

type_dict={
    0:'Creature',
    1:'Sorcery',
    2:'Artifact',
    3:'Enchantment',
    4:'Instant',
    5:'Land',
}

In [2]:
# For ZY no csv read
#with open('mtgdata.pickle', 'rb') as file:
#    mtg_df=pickle.load(file)
mtg_df=pd.read_csv('mtg_data.csv', index_col=0)
mtg_df=mtg_df.dropna(subset=['text_prompt', 'card_description'])
mtg_df.head()

Unnamed: 0,name,mana_cost,cmc,type_line,oracle_text,power,toughness,colors,color_identity,keywords,rarity,flavor_text,text,text_prompt,card_description
0,Fury Sliver,{5}{R},6.0,Creature — Sliver,All Sliver creatures have double strike.,3.0,3.0,['R'],['R'],[],uncommon,"""A rift opened, and our arrows were abruptly s...",Fury Sliver: [SEP] {5}{R} [SEP] Creature — Sli...,Fury Sliver: [SEP] {5}{R},Creature — Sliver [SEP] All Sliver creatures h...
1,Kor Outfitter,{W}{W},2.0,Creature — Kor Soldier,"When ~ enters the battlefield, you may attach ...",2.0,2.0,['W'],['W'],[],common,"""We take only what we need to survive. Believe...",Kor Outfitter: [SEP] {W}{W} [SEP] Creature — K...,Kor Outfitter: [SEP] {W}{W},Creature — Kor Soldier [SEP] When ~ enters the...
2,Spirit,,0.0,Token Creature — Spirit,Flying,1.0,1.0,['W'],['W'],[Flying],common,,Spirit: [SEP] [SEP] Token Creature — Spirit [...,Spirit: [SEP],Token Creature — Spirit [SEP] Flying
3,Siren Lookout,{2}{U},3.0,Creature — Siren Pirate,"Flying\nWhen ~ enters the battlefield, it expl...",1.0,2.0,['U'],['U'],"[Flying, Explore]",common,,Siren Lookout: [SEP] {2}{U} [SEP] Creature — S...,Siren Lookout: [SEP] {2}{U},Creature — Siren Pirate [SEP] Flying\nWhen ~ e...
4,Web,{G},1.0,Enchantment — Aura,Enchant creature (Target a creature as you cas...,,,['G'],['G'],[Enchant],rare,,Web: [SEP] {G} [SEP] Enchantment — Aura [SEP] ...,Web: [SEP] {G},Enchantment — Aura [SEP] Enchant creature (Tar...


In [3]:
#load pretrained embeddings
wiki_vectors = gensim.downloader.load('glove-wiki-gigaword-50')

In [4]:
def generate_embedding(phrase):
    words = word_tokenize(phrase)
    words = [word.lower() for word in words]
    lemmatizer=WordNetLemmatizer()
    #stemmer=PorterStemmer()
    words = [lemmatizer.lemmatize(word) for word in words]
    #words = [stemmer.stem(word) for word in words]
    total_vector=[]
    for word in words:
        try:
            total_vector.append(wiki_vectors.word_vec(word))
        except KeyError:
            pass
    if len(total_vector)!=0:
        out = np.mean(total_vector, axis=0)
    else:
        out = np.zeros(50)
    return out

In [5]:
types=['Creature', 'Sorcery', 'Artifact', 'Enchantment', 'Instant', 'Land']
type_list={'Name':[],
           'Type':[]}
for idx, row in mtg_df.iterrows():
    for word in row['type_line'].split():
        if word in types:
            type_list['Name'].append(row['name'])
            type_list['Type'].append(word)
type_data=pd.DataFrame(type_list).drop_duplicates()
for type in types:
    type_data[type]=type_data['Type']==type
type_data

Unnamed: 0,Name,Type,Creature,Sorcery,Artifact,Enchantment,Instant,Land
0,Fury Sliver,Creature,True,False,False,False,False,False
1,Kor Outfitter,Creature,True,False,False,False,False,False
2,Spirit,Creature,True,False,False,False,False,False
3,Siren Lookout,Creature,True,False,False,False,False,False
4,Web,Enchantment,False,False,False,True,False,False
...,...,...,...,...,...,...,...,...
83968,Born to Drive,Enchantment,False,False,False,True,False,False
83978,Stern Mentor,Creature,True,False,False,False,False,False
83981,Youthful Valkyrie,Creature,True,False,False,False,False,False
83982,Fallaji Vanguard,Creature,True,False,False,False,False,False


In [6]:
total_embeddings=[]
for idx, row in type_data.iterrows():
    total_embeddings.append(generate_embedding(row['Name']))

tensor_list = [torch.tensor(nameembedding) for nameembedding in list(total_embeddings)]
name_embeddings = torch.stack(tensor_list, dim=0)
name_embeddings

  total_vector.append(wiki_vectors.word_vec(word))


tensor([[ 0.1918, -0.0154,  0.1995,  ..., -0.2442, -0.3086, -0.6314],
        [-0.4758,  1.0394, -0.2369,  ...,  0.2503,  0.1288,  0.0442],
        [-0.0175,  0.7970, -1.3675,  ..., -0.3775, -0.2226, -0.3518],
        ...,
        [ 0.4284,  0.2326, -0.1057,  ..., -0.2599,  0.0209,  0.2470],
        [ 0.4418, -0.1677,  0.5107,  ...,  0.3453, -0.3805,  0.7088],
        [ 0.3030,  0.8218, -0.1229,  ..., -0.3748, -0.0187, -0.1198]],
       dtype=torch.float64)

In [7]:
card_type=torch.tensor(type_data.drop(['Name', 'Type'], axis=1).values, dtype=torch.long)
n_train = int(0.9*type_data.shape[0])
train_X = name_embeddings[:n_train]
val_X = name_embeddings[n_train:]
train_Y = card_type[:n_train]
val_Y = card_type[n_train:]

In [8]:
class MLPClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MLPClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.softmax(out)
        return out

# Define the input size, hidden layer size, and number of classes
input_size = 50
hidden_size = 128
num_classes = 6

# Create an instance of the MLPClassifier
model = MLPClassifier(input_size, hidden_size, num_classes)

# You can print the model to see its architecture
print(model)

MLPClassifier(
  (fc1): Linear(in_features=50, out_features=128, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=128, out_features=6, bias=True)
  (softmax): Softmax(dim=-1)
)


In [9]:
num_epochs=100

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_X=train_X.float()
train_Y=train_Y.float()
dataset=TensorDataset(train_X, train_Y)
train_loader = DataLoader(dataset=dataset, batch_size=64, shuffle=True)

for epoch in range(num_epochs):
    total_loss = 0.0

    for inputs, labels in train_loader:
        # Forward pass
        outputs = model(inputs)

        # Compute the loss
        loss = criterion(outputs, labels)

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Print the average loss for this epoch
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader)}')

print('Training finished.')
torch.save(model.state_dict(), 'type.pt')

Epoch [1/100], Loss: 1.5129064028033146
Epoch [2/100], Loss: 1.455908931148839
Epoch [3/100], Loss: 1.4481157242949239
Epoch [4/100], Loss: 1.4431171862001952
Epoch [5/100], Loss: 1.4401456831675497
Epoch [6/100], Loss: 1.437087743718007
Epoch [7/100], Loss: 1.4348347311092513
Epoch [8/100], Loss: 1.4297137027464542
Epoch [9/100], Loss: 1.4215940817963653
Epoch [10/100], Loss: 1.4176938524706109
Epoch [11/100], Loss: 1.414722086209331
Epoch [12/100], Loss: 1.4125382443369947
Epoch [13/100], Loss: 1.4100610156954847
Epoch [14/100], Loss: 1.408063372379632
Epoch [15/100], Loss: 1.4061916158889152
Epoch [16/100], Loss: 1.4039863735286113
Epoch [17/100], Loss: 1.4021400785083094
Epoch [18/100], Loss: 1.4001241610740043
Epoch [19/100], Loss: 1.3983870681167254
Epoch [20/100], Loss: 1.3969136951538512
Epoch [21/100], Loss: 1.395307086143397
Epoch [22/100], Loss: 1.393453420721335
Epoch [23/100], Loss: 1.3919726120033844
Epoch [24/100], Loss: 1.3906046660418439
Epoch [25/100], Loss: 1.3891908

In [10]:
def generate_type(cardname):
    probs = model(torch.tensor(generate_embedding(cardname)))
    distribution = Categorical(probs)
    sampled_index = distribution.sample()
    for key, type in type_dict.items():
        print(f'{type}:{probs[key]}')
    return type_dict[int(sampled_index)]

In [11]:
generate_type('sword')

Creature:4.3889525013961355e-17
Sorcery:3.0426184471538017e-19
Artifact:1.0
Enchantment:5.477421901461943e-14
Instant:3.1649355358709386e-10
Land:2.7021122496589227e-30


  total_vector.append(wiki_vectors.word_vec(word))


'Artifact'

In [12]:
colors_wheel=['W', 'U', 'R', 'G', 'B', 'E']
color_list={'Name':[],
           'color':[]}
for idx, row in mtg_df.iterrows():
    colors = ast.literal_eval(row['color_identity'])
    if len(colors)!=0:
        for color in colors:
            color_list['Name'].append(row['name'])
            color_list['color'].append(color)
    else:
            color_list['Name'].append(row['name'])
            color_list['color'].append('E')
color_data=pd.DataFrame(color_list).drop_duplicates()
for color in colors_wheel:
     color_data[color]=color_data['color']==color
color_data

Unnamed: 0,Name,color,W,U,R,G,B,E
0,Fury Sliver,R,False,False,True,False,False,False
1,Kor Outfitter,W,True,False,False,False,False,False
2,Spirit,W,True,False,False,False,False,False
3,Siren Lookout,U,False,True,False,False,False,False
4,Web,G,False,False,False,True,False,False
...,...,...,...,...,...,...,...,...
99030,Youthful Valkyrie,W,True,False,False,False,False,False
99031,Fallaji Vanguard,R,False,False,True,False,False,False
99032,Fallaji Vanguard,W,True,False,False,False,False,False
99035,Hold at Bay,W,True,False,False,False,False,False


In [13]:
total_embeddings=[]
for idx, row in color_data.iterrows():
    total_embeddings.append(generate_embedding(row['Name']))
tensor_list = [torch.tensor(nameembedding) for nameembedding in list(total_embeddings)]
name_embeddings = torch.stack(tensor_list, dim=0)

card_color=torch.tensor(color_data.drop(['Name', 'color'], axis=1).values, dtype=torch.long)
n_train = int(0.9*color_data.shape[0])
color_train_X = name_embeddings[:n_train]
color_val_X = name_embeddings[n_train:]
color_train_Y = card_color[:n_train]
color_val_Y = card_color[n_train:]

  total_vector.append(wiki_vectors.word_vec(word))


In [14]:
# Define the input size, hidden layer size, and number of classes
input_size = 50
hidden_size = 128
num_classes = 6

class MLPColorClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MLPColorClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size, num_classes)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.relu2(out)
        out = self.fc3(out)
        out = self.softmax(out)
        return out

# Create an instance of the MLPClassifier
color_model = MLPColorClassifier(input_size, hidden_size, num_classes)

In [15]:
num_epochs=300

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(color_model.parameters(), lr=1e-3)
color_train_X=color_train_X.float()
color_train_Y=color_train_Y.float()
dataset=TensorDataset(color_train_X, color_train_Y)
train_loader = DataLoader(dataset=dataset, batch_size=64, shuffle=True)

for epoch in range(num_epochs):
    total_loss = 0.0

    for inputs, labels in train_loader:
        # Forward pass
        outputs = color_model(inputs)

        # Compute the loss
        loss = criterion(outputs, labels)

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Print the average loss for this epoch
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader)}')

print('Training finished.')
torch.save(color_model.state_dict(), 'color.pt')

Epoch [1/300], Loss: 1.7166238084916146
Epoch [2/300], Loss: 1.6834184023641772
Epoch [3/300], Loss: 1.6731772938082294
Epoch [4/300], Loss: 1.6654338854615407
Epoch [5/300], Loss: 1.6564930051885625
Epoch [6/300], Loss: 1.6491090407935522
Epoch [7/300], Loss: 1.6428221238556728
Epoch [8/300], Loss: 1.638009016744552
Epoch [9/300], Loss: 1.6309334080706361
Epoch [10/300], Loss: 1.6257848526841852
Epoch [11/300], Loss: 1.6213571458734493
Epoch [12/300], Loss: 1.6172445356204945
Epoch [13/300], Loss: 1.6130547361989176
Epoch [14/300], Loss: 1.6090626652522753
Epoch [15/300], Loss: 1.606383599260802
Epoch [16/300], Loss: 1.6025322673141316
Epoch [17/300], Loss: 1.5992948552613617
Epoch [18/300], Loss: 1.5959225703311224
Epoch [19/300], Loss: 1.5946246552210983
Epoch [20/300], Loss: 1.591557885754493
Epoch [21/300], Loss: 1.5890918293306904
Epoch [22/300], Loss: 1.5864567325961205
Epoch [23/300], Loss: 1.5858366553501417
Epoch [24/300], Loss: 1.5821018931686237
Epoch [25/300], Loss: 1.5798

In [16]:
color_dict={
    0:'W',
    1:'U',
    2:'R',
    3:'G',
    4:'B',
    5:'Colorless',
}

def generate_color(cardname):
    probs = color_model(torch.tensor(generate_embedding(cardname)))
    distribution = Categorical(probs)
    sampled_index = distribution.sample()
    for key, type in color_dict.items():
        print(f'{type}:{probs[key]}')
    return color_dict[int(sampled_index)]

In [17]:
mtg_df=mtg_df.dropna(subset='oracle_text')
mtg_df.head(5)

Unnamed: 0,name,mana_cost,cmc,type_line,oracle_text,power,toughness,colors,color_identity,keywords,rarity,flavor_text,text,text_prompt,card_description
0,Fury Sliver,{5}{R},6.0,Creature — Sliver,All Sliver creatures have double strike.,3.0,3.0,['R'],['R'],[],uncommon,"""A rift opened, and our arrows were abruptly s...",Fury Sliver: [SEP] {5}{R} [SEP] Creature — Sli...,Fury Sliver: [SEP] {5}{R},Creature — Sliver [SEP] All Sliver creatures h...
1,Kor Outfitter,{W}{W},2.0,Creature — Kor Soldier,"When ~ enters the battlefield, you may attach ...",2.0,2.0,['W'],['W'],[],common,"""We take only what we need to survive. Believe...",Kor Outfitter: [SEP] {W}{W} [SEP] Creature — K...,Kor Outfitter: [SEP] {W}{W},Creature — Kor Soldier [SEP] When ~ enters the...
2,Spirit,,0.0,Token Creature — Spirit,Flying,1.0,1.0,['W'],['W'],[Flying],common,,Spirit: [SEP] [SEP] Token Creature — Spirit [...,Spirit: [SEP],Token Creature — Spirit [SEP] Flying
3,Siren Lookout,{2}{U},3.0,Creature — Siren Pirate,"Flying\nWhen ~ enters the battlefield, it expl...",1.0,2.0,['U'],['U'],"[Flying, Explore]",common,,Siren Lookout: [SEP] {2}{U} [SEP] Creature — S...,Siren Lookout: [SEP] {2}{U},Creature — Siren Pirate [SEP] Flying\nWhen ~ e...
4,Web,{G},1.0,Enchantment — Aura,Enchant creature (Target a creature as you cas...,,,['G'],['G'],[Enchant],rare,,Web: [SEP] {G} [SEP] Enchantment — Aura [SEP] ...,Web: [SEP] {G},Enchantment — Aura [SEP] Enchant creature (Tar...


In [19]:
types=['Creature', 'Sorcery', 'Artifact', 'Enchantment', 'Instant', 'Land']
colors_wheel=['W', 'U', 'R', 'G', 'B', 'E']
type_color_list={'Name':[],
                 'Type':[],
                 'Color':[],
                 'oracle_text':[]}
for idx, row in mtg_df.iterrows():
    colors = ast.literal_eval(row['color_identity'])
    for word in row['type_line'].split():
        if word in types:
            if len(colors)!=0:
                for color in colors:
                    type_color_list['Name'].append(row['name'])
                    type_color_list['Type'].append(word)
                    type_color_list['Color'].append(color)
                    type_color_list['oracle_text'].append(row['oracle_text'])
            else:
                type_color_list['Name'].append(row['name'])
                type_color_list['Type'].append(word)
                type_color_list['Color'].append('E')
                type_color_list['oracle_text'].append(row['oracle_text'])

type_color_data=pd.DataFrame(type_color_list)
type_color_data

Unnamed: 0,Name,Type,Color,oracle_text
0,Fury Sliver,Creature,R,All Sliver creatures have double strike.
1,Kor Outfitter,Creature,W,"When ~ enters the battlefield, you may attach ..."
2,Spirit,Creature,W,Flying
3,Siren Lookout,Creature,U,"Flying\nWhen ~ enters the battlefield, it expl..."
4,Web,Enchantment,G,Enchant creature (Target a creature as you cas...
...,...,...,...,...
96994,Angel's Tomb,Artifact,E,Whenever a creature enters the battlefield und...
96995,Horned Troll,Creature,G,{G}: Regenerate ~.
96996,Faerie Bladecrafter,Creature,B,Flying\nWhenever one or more Faeries you contr...
96997,Exultant Skymarcher,Creature,W,Flying


In [20]:
rare_char={
    '¡®°²½˝̶π’„•…™−∞☐œŠ':'',
    'Äàáâãä':'a',
    'Éèéêë':'e',
    'Ææ':'ae',
    'Óóö':'o',
    'úûü':'u',
    'íī':'i',
    'Ññ':'n'
}
for rarechar, target in rare_char.items():
    for char in [*rarechar]:
        type_color_data['oracle_text']=type_color_data['oracle_text'].str.replace(char, target)

In [21]:
tokenizer = Tokenizer.from_file("mtggenerator_tokenizer_v6.json")
vocab_size=tokenizer.get_vocab_size()
encode = lambda text: tokenizer.encode(text).ids
decode = lambda list: tokenizer.decode(list)

In [22]:
text_list=list(type_color_data['oracle_text'])
#convert data to 2d tensor
encoded_text_list=[torch.Tensor(encode(text)) for text in text_list]
max_len=max([len(item) for item in encoded_text_list])
padded_text_list=[torch.cat((item, torch.full((max_len - len(item),), 3))) for item in encoded_text_list] # the [PAD] token has id=3
padded_text_list_with_CLS = [torch.cat((torch.tensor([1]), item)) for item in padded_text_list]

data = pad_sequence(padded_text_list_with_CLS, batch_first=True).long()

In [23]:
for type in types:
    type_color_data[type]=type_color_data['Type']==type
for color in colors_wheel:
     type_color_data[color]=type_color_data['Color']==color
type_color_encodings = torch.tensor(type_color_data.drop(['Name', 'Type', 'Color', 'oracle_text'], axis=1).values, dtype=torch.long)

In [24]:
n_train=int(0.9*type_color_data.shape[0])
train_context=type_color_encodings[:n_train]
val_context=type_color_encodings[n_train:]
train_targets=data[:n_train]
val_targets=data[n_train:]

In [None]:
torch.manual_seed(69)
vector_len=12
batch_size=512
block_size=20
sampling_size=6
max_iters=8000
eval_interval=300
learning_rate=2e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 36
n_heads = 4
n_layers = 4
dropout=0.3

In [36]:
def get_batch(split):
    #generates a small batch of data input x and target y
    targets = train_targets if split == 'train' else val_targets
    contexts = train_context if split == 'train' else val_context
    ix = torch.stack([torch.randint(targets.shape[0], (batch_size, )), torch.randint(sampling_size, (batch_size, ))]).T
    x = torch.stack(tuple(targets[i[0]][i[1]:i[1] + block_size] for i in ix))
    y = torch.stack(tuple(targets[i[0]][i[1] + 1:i[1] + block_size + 1] for i in ix))
    x, y = x.to(device), y.to(device)
    prompts = torch.stack(tuple(contexts[i[0]] for i in ix))
    prompts = prompts.to(device)
    return x, y, prompts

In [41]:
def get_batch(split):
    #generates a small batch of data input x and target y
    targets = train_targets if split == 'train' else val_targets
    contexts = train_context if split == 'train' else val_context
    ix = torch.randint(targets.shape[0], (batch_size, )).T
    x = torch.stack(tuple(targets[i][:block_size] for i in ix))
    y = torch.stack(tuple(targets[i][1:block_size+1] for i in ix))
    x, y = x.to(device), y.to(device)
    prompts = torch.stack(tuple(contexts[i] for i in ix))
    prompts = prompts.to(device)
    return x, y, prompts

In [42]:
@torch.no_grad()
def estimate_loss():
    out={}
    mtg_model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y, P = get_batch(split)
            logits, loss = mtg_model(X, P, Y)
            losses[k] = loss.item()
        out[split]=losses.mean()
    mtg_model.train()
    return out

class Head(nn.Module):
    #one self attention head

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias= False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        # compute attention scores
        wei = q @ k.transpose(-2, -1) * C**0.5
        wei = wei.masked_fill(self.tril[:T, :T]==0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        v=self.value(x)
        out=wei @ v
        return out

class MultiHeadAttention(nn.Module):
    """multi head attention"""
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads=nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj=nn.Linear(head_size*num_heads, n_embd)
        self.dropout=nn.Dropout(dropout)
    
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return self.dropout(out)
    
class CrossAttentionHead(nn.Module):
    #one self attention head

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(vector_len, head_size, bias=False)
        self.value = nn.Linear(vector_len, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, context):
        B, T, C = x.shape
        k = self.key(context)
        q = self.query(x)
        # compute attention scores
        wei = q @ k.transpose(-2, -1) * C**0.5
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        v=self.value(context)
        out=wei @ v
        return out
    
class MultiHeadCrossAttention(nn.Module):
    """multi head cross attention"""
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads=nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj=nn.Linear(head_size*num_heads, n_embd)
        self.dropout=nn.Dropout(dropout)
    
    def forward(self, x, context):
        out = torch.cat([h(x, context) for h in self.heads], dim=-1)
        out = self.proj(out)
        return self.dropout(out)

'''   
class FeedForward(nn.Module):
    """simple feedforward perceptron layer"""
    def __init__(self, n_embd):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(n_embd+12, 4*(n_embd+12)),
            nn.ReLU(),
            nn.Linear(4*(n_embd+12), n_embd),
            nn.Dropout(dropout),
        )
    
    def forward(self, x, context):
        x = torch.cat([x, torch.cat(tuple(context.unsqueeze(-2) for i in range(x.shape[-2])), dim=-2)], dim=-1)
        return self.net(x)
'''
class FeedForward(nn.Module):
    """simple feedforward perceptron layer"""
    def __init__(self, n_embd):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(n_embd, 4*(n_embd)),
            nn.ReLU(),
            nn.Linear(4*(n_embd), n_embd),
            nn.Dropout(dropout),
        )
    
    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """Transformer block: multihead self attention followed by one Feedforward layer"""
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd//n_head
        self.sa=MultiHeadAttention(n_head, head_size)
        self.ca=MultiHeadCrossAttention(n_head, head_size)
        self.ffwd=FeedForward(n_embd)
        self.ln1=nn.LayerNorm(n_embd)
        self.ln2=nn.LayerNorm(n_embd)
        self.ln3=nn.LayerNorm(n_embd)
    
    def forward(self, x, context):
        # context is B * 12
        # x is B * T * C
        x = x+self.sa(self.ln1(x))
        x = x+self.ca(self.ln2(x), context)
        x = x+self.ffwd(self.ln3(x))

        return x, context


class MTGCardGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table=nn.Embedding(vocab_size, n_embd) #each token directly look up the logit of the next token from a lookup table
        self.lmhead=nn.Linear(n_embd, vocab_size)
        self.position_embedding_table=nn.Embedding(block_size, n_embd) #each token gets a position embeding of block_size, stores the relative position of token in the block
        self.block=nn.Sequential(*[Block(n_embd, n_head=n_heads) for _ in range(n_layers)])
        #self.block=Block(n_embd, n_head=n_heads)
    
    def forward(self, idx, context, targets=None):
        
        B, T = idx.shape

        #idx and targets are both (B,T) tensors of integers, where B=batch number, T=position in batch
        token_embeddings=self.token_embedding_table(idx) #look up value corresponding to own position in the token embedding table to form C (channel value)
        position_embeddings=self.position_embedding_table(torch.arange(T, device=device)) #add position embeddings to token embedding
        x= token_embeddings + position_embeddings

        for layer in self.block:
            x, context = layer(x, context)
            
        logits=self.lmhead(x)

        if targets is None:
            loss=None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            #logits are therefore values associated with each character
            loss=F.cross_entropy(logits, targets) #evaluate loss

        return logits, loss
    
    def generate(self, idx, context, max_new_tokens):
        for _ in range(max_new_tokens):
            #crop idx to max block size
            idx_cond=idx[:, -block_size:]
            #get the predictions
            logits, loss = self(idx_cond, context)
            #use logits only, focus only on last time step
            logits = logits[:, -1, :] #keep only last time step ---> (B, C)
            #apply softmax on logit to get distribution
            probs = F.softmax(logits, dim=-1) #get a (B, C) matrix of probabilities, sum(prob) of each B = 1
            #sample from the distribution
            idx_next=torch.multinomial(probs, num_samples=1) #get a (B, 1) array of predictions
            #append prediction to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) #now a (B, T+1) matrix of returned results
        return idx

In [43]:
mtg_model=MTGCardGenerator()
m=mtg_model.to(device)

In [44]:
#create new optimizer
optimizer=torch.optim.AdamW(mtg_model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    # every once in a while evaluate the loss of train and val
    if iter % eval_interval == 0:
        losses=estimate_loss()
        print(f"step {iter}: train loss: {losses['train']:.4f}, val loss: {losses['val']:.4f}")
    
    #sample a batch of data
    xb, yb, promptb = get_batch('train')

    #evaluate the loss
    logits, loss = mtg_model(xb, promptb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

torch.save(m.state_dict(), 'mtggenerator_v6.pt')

  ix = torch.randint(targets.shape[0], (batch_size, )).T


step 0: train loss: 9.4774, val loss: 9.4740
step 300: train loss: 4.0448, val loss: 4.0613
step 600: train loss: 3.2150, val loss: 3.2246
step 900: train loss: 2.7402, val loss: 2.7675
step 1200: train loss: 2.4517, val loss: 2.4780
step 1500: train loss: 2.2685, val loss: 2.2911
step 1800: train loss: 2.1210, val loss: 2.1516
step 2100: train loss: 2.0189, val loss: 2.0450
step 2400: train loss: 1.9293, val loss: 1.9605
step 2700: train loss: 1.8571, val loss: 1.8878
step 3000: train loss: 1.8022, val loss: 1.8257
step 3300: train loss: 1.7400, val loss: 1.7767
step 3600: train loss: 1.6935, val loss: 1.7252
step 3900: train loss: 1.6570, val loss: 1.6892
step 4200: train loss: 1.6117, val loss: 1.6463
step 4500: train loss: 1.5784, val loss: 1.6164
step 4800: train loss: 1.5514, val loss: 1.5828
step 5100: train loss: 1.5212, val loss: 1.5611
step 5400: train loss: 1.4995, val loss: 1.5381
step 5700: train loss: 1.4714, val loss: 1.5092
step 6000: train loss: 1.4579, val loss: 1.489

KeyboardInterrupt: 

In [593]:
torch.tensor(val_context[0]).unsqueeze(0)

  torch.tensor(val_context[0]).unsqueeze(0)


tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]])

In [33]:
def generate(context):
    start = torch.tensor([encode('[CLS]')], dtype=torch.long, device=device)
    response=m.generate(start, context=context.unsqueeze(0).to(device), max_new_tokens=250)[0].tolist()
    indices = [i for i, x in enumerate(response) if x == 2]
    slices = [response[i+1:j] for i, j in zip([0] + indices, indices + [None])]
    print(response)
    for slice in slices:
        print(decode(slice))

In [34]:
generate(val_context[0])

[1, 195, 96, 242, 201, 122, 149, 129, 405, 1129, 362, 122, 111, 155, 23, 246, 170, 113, 128, 111, 384, 1127, 21, 106, 23, 96, 408, 23, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
Whenever ~ deals damage to target of another Cleric equal to the battlefield . If that creature is the next alone , you . ~ shuffle .


In [564]:
torch.tensor(encode('[CLS]'), dtype=torch.long, device=device)

tensor([1], device='cuda:0')

In [598]:
m.generate(torch.tensor([encode('[CLS]')], dtype=torch.long, device=device), context=val_context[0].unsqueeze(0).to(device), max_new_tokens=250)

tensor([[   1,  285,  176,  653,  297, 3165,  113,  238,   17,   85,  170,  113,
         5167,  117,  279,   23,  246,  464,  231,  111,  155,   21,  125,   96,
           26,  106,   18,   21,  145,  185,  149,  113, 1575,  149,  113,   61,
          111,  384,    3,    3,    3,    3,    3,    3,    3,    3,    3,    3,
            3,    3,    3,    3,    3,    3,    3,    3,    3,    3,    3,    3,
            3,    3,    3,    3,  246,   95, 5175,  111,   23,    3,    3,    3,
            3,    3,    3,    3,    3,    3,    3,    3,    3,    3,    3,    3,
            3,    3,    3,    3,    3,    3,    3,    3,    3,    3,    3,    3,
            3,    3,    3,    3,    3,    3,    3,    3,    3,    3,    3,    3,
            3,    3,    3,    3,    3,    3,    3,    3,    3,    3,    3,    3,
            3,    3,    3,    3,    3,    3,    3,    3,    3,    3,    3,    3,
            3,    3,    3,    3,    3,    3,    3,    3,    3,    3,    3,    3,
            3,    3,    3,  