In [7]:
import pandas as pd
import numpy as np
import re

In [8]:
mtg_df=pd.read_csv('mtg_data.csv', index_col=0)
mtg_df=mtg_df.dropna(subset='text')
mtg_df.head(5)

Unnamed: 0,name,mana_cost,cmc,type_line,oracle_text,power,toughness,colors,color_identity,keywords,rarity,flavor_text,text
0,Fury Sliver,{5}{R},6.0,Creature — Sliver,All Sliver creatures have double strike.,3.0,3.0,['R'],['R'],[],uncommon,"""A rift opened, and our arrows were abruptly s...",Fury Sliver:\nCreature — Sliver\nAll Sliver cr...
1,Kor Outfitter,{W}{W},2.0,Creature — Kor Soldier,"When Kor Outfitter enters the battlefield, you...",2.0,2.0,['W'],['W'],[],common,"""We take only what we need to survive. Believe...",Kor Outfitter:\nCreature — Kor Soldier\nWhen K...
2,Spirit,,0.0,Token Creature — Spirit,Flying,1.0,1.0,['W'],['W'],[Flying],common,,Spirit:\nToken Creature — Spirit\nFlying\n
3,Siren Lookout,{2}{U},3.0,Creature — Siren Pirate,Flying\nWhen Siren Lookout enters the battlefi...,1.0,2.0,['U'],['U'],"[Flying, Explore]",common,,Siren Lookout:\nCreature — Siren Pirate\nFlyin...
4,Web,{G},1.0,Enchantment — Aura,Enchant creature (Target a creature as you cas...,,,['G'],['G'],[Enchant],rare,,Web:\nEnchantment — Aura\nEnchant creature (Ta...


In [9]:
#pre-processing to get rid of unregonizable characters
rare_char={
    '¡®°²½˝̶π—―’„•…™−∞☐œŠ':'',
    'Äàáâãä':'a',
    'Éèéêë':'e',
    'Ææ':'ae',
    'Óóö':'o',
    'úûü':'u',
    'íī':'i',
    'Ññ':'n'
}
for rarechar, target in rare_char.items():
    for char in [*rarechar]:
        mtg_df['text']=mtg_df['text'].str.replace(char, target)

In [10]:
text_list=list(mtg_df['text'])
text_len=np.array([len(desc) for desc in text_list])
print(
    'total number of characters:\n'+
    str(np.sum(text_len))+'\n'+
    'average numbr of characters:\n'+
    str(np.mean(text_len))+'\n'+
    'total number of cards:\n'+
    str(len(text_len))
)

total number of characters:
18411812
average numbr of characters:
216.5178512629945
total number of cards:
85036


In [11]:
text_total='\n'.join(text_list)
chars=sorted(list(set(text_total)))
vocab_size=len(chars)
print(
    'vocab content:\n'+
    ''.join(chars)+'\n'+
    'vocab size:\n'+
    str(vocab_size)
)

vocab content:

 !"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_abcdefghijklmnopqrstuvwxyz{|}
vocab size:
94


In [12]:
#create the mapping from characters to integers
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[ch] for ch in s] #encode: take a string, output a list of integers
decode = lambda li: ''.join([itos[i] for i in li]) #decode: take a list of integers, output a string

#test
print(encode('Black Lotus'))
print(decode(encode('Black Lotus')))

[35, 76, 65, 67, 75, 1, 45, 79, 84, 85, 83]
Black Lotus


In [13]:
import torch
from torch.nn.utils.rnn import pad_sequence

#convert data to 2d tensor
encoded_text_list=[torch.Tensor(encode(text)) for text in text_list]
max_len=max([len(item) for item in encoded_text_list])
padded_text_list=[torch.cat((item, torch.ones(max_len-len(item)))) for item in encoded_text_list]

data = pad_sequence(padded_text_list, batch_first=True).long() # N_cards * Char_length
data.shape

torch.Size([85036, 794])

In [14]:
#train test split

n_train=int(0.9*data.shape[0])
train_data=data[:n_train]
val_data=data[n_train:]
print(train_data.shape)
print(val_data.shape)

torch.Size([76532, 794])
torch.Size([8504, 794])


In [15]:
block_size=8
train_data[0][:block_size+1]

tensor([39, 85, 82, 89,  1, 52, 76, 73, 86])

In [16]:
x = train_data[0][:block_size]
y = train_data[0][1:block_size+1]
for t in range(block_size):
    context=x[:t+1]
    target=y[t]
    print(f'when input is {context} the target is {target}')

when input is tensor([39]) the target is 85
when input is tensor([39, 85]) the target is 82
when input is tensor([39, 85, 82]) the target is 89
when input is tensor([39, 85, 82, 89]) the target is 1
when input is tensor([39, 85, 82, 89,  1]) the target is 52
when input is tensor([39, 85, 82, 89,  1, 52]) the target is 76
when input is tensor([39, 85, 82, 89,  1, 52, 76]) the target is 73
when input is tensor([39, 85, 82, 89,  1, 52, 76, 73]) the target is 86


In [17]:
torch.manual_seed(69)
batch_size=4
block_size=8

def get_batch(split):
    #generates a small batch of data input x and target y
    data = train_data if split == 'train' else val_data
    ix = torch.stack([torch.randint(data.shape[0], (batch_size, )), torch.randint(256 - block_size, (batch_size, ))]).T
    x = torch.stack(tuple(data[i[0]][i[1]:i[1] + block_size] for i in ix))
    y = torch.stack(tuple(data[i[0]][i[1] + 1:i[1] + block_size + 1] for i in ix))
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('-----------')

inputs:
torch.Size([4, 8])
tensor([[69, 86, 69, 82,  1, 65,  1, 67],
        [83, 13,  1, 69, 65, 67, 72,  1],
        [76, 76,  1, 84, 72, 73, 83,  1],
        [73, 78, 83,  1, 70, 65, 76, 76]])
targets:
torch.Size([4, 8])
tensor([[86, 69, 82,  1, 65,  1, 67, 82],
        [13,  1, 69, 65, 67, 72,  1, 79],
        [76,  1, 84, 72, 73, 83,  1, 84],
        [78, 83,  1, 70, 65, 76, 76, 73]])
-----------


In [18]:
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(69)

class BigranLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table=nn.Embedding(vocab_size, vocab_size) #each token directly look up the logit of the next token from a lookup table
    
    def forward(self, idx, targets=None):
        #idx and targets are both (B,T) tensors of integers, where B=batch number, T=position in batch
        logits=self.token_embedding_table(idx) #look up value corresponding to own position in the token embedding table to form C (channel value)

        if targets is None:
            loss=None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            #logits are therefore values associated with each character
            loss=F.cross_entropy(logits, targets) #evaluate loss

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            #get the predictions
            logits, loss = self(idx)
            #use logits only, focus only on last time step
            logits = logits[:, -1, :] #keep only last time step ---> (B, C)
            #apply softmax on logit to get distribution
            probs = F.softmax(logits, dim=-1) #get a (B, C) matrix of probabilities, sum(prob) of each B = 1
            #sample from the distribution
            idx_next=torch.multinomial(probs, num_samples=1) #get a (B, 1) array of predictions
            #append prediction to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) #now a (B, T+1) matrix of returned results
        return idx

m=BigranLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

torch.Size([32, 94])
tensor(4.9992, grad_fn=<NllLossBackward0>)


In [19]:
#testing baseline model
print(decode(m.generate(idx= torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


f)&f
nSa)O(]q3pXWx)lQbVt1M|i{13aqAI@rQlf\QY<%3JJe*FdY\@J9]+.SNn"-vUue_?bxOT+
e?yErcQJ*p$jT!E(wwb-I{Y


In [20]:
#create new optimizer
optimizer=torch.optim.AdamW(m.parameters(), lr=1e-3)

In [21]:
batch_size=32
for steps in range(10000):

    #sample a batch of data
    xb, yb = get_batch('train')
    #evaluate the loss
    logits, loss = m(xb,yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    print(loss.item())

5.096214771270752
5.13163423538208
5.243349075317383
5.176997661590576
4.864535808563232
5.183795928955078
4.979980945587158
5.068874835968018
4.967780590057373
4.86656379699707
4.96356725692749
5.119108200073242
5.031829833984375
5.121277809143066
5.016507148742676
5.0143866539001465
5.20923376083374
5.089817047119141
4.963018417358398


5.066132068634033
5.007952690124512
5.022974491119385
4.987492084503174
5.035606384277344
4.990057468414307
5.066300392150879
5.082245826721191
5.061469554901123
5.116850852966309
5.158682823181152
5.1605730056762695
5.02885103225708
4.8886213302612305
5.091151714324951
4.882263660430908
5.0580830574035645
4.991840362548828
5.186854362487793
4.931849479675293
4.977888584136963
5.004281997680664
4.964737892150879
4.906753063201904
4.976316928863525
5.128810882568359
5.054937839508057
5.0019965171813965
5.1322431564331055
4.947111129760742
4.982974052429199
5.093103885650635
5.057512283325195
5.14898157119751
5.001460552215576
4.88487434387207
4.9757304191589355
5.031687259674072
4.957845687866211
5.001106262207031
4.89328145980835
4.889455795288086
4.990607738494873
4.871322154998779
5.076631546020508
4.98946475982666
4.970510005950928
4.980101585388184
5.032027244567871
4.939500331878662
5.050121784210205
4.89335298538208
5.058792591094971
4.94639778137207
4.824566841125488
5.112265586

In [22]:
print(decode(m.generate(idx= torch.zeros((1,1), dtype=torch.long), max_new_tokens=250)[0].tolist()))


"Adiasthesie yat  plyonourerd
Scar   Sea  ol Fiz8/+{1    is:
at.
Otshn   penppiongonteren   tof Ey  Treatay   olis  ba-Bre  ectn hen  Ad
W},          1}:
Cond  actatustrs     ttor  iouror  Eleatonjdeadas  orejifal emangureste Otur mpifo gadoure  ound


In [31]:
torch.manual_seed(69)
B, T, C = 4, 8, 32
x=torch.randn(B, T, C)

head_size=16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias= False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, 16) --> a array of length 16 for every character at every batch
q = query(x) # (B, T, 16) --> a array of length 16 for every character at every batch
wei = q @ k.transpose(-2, -1) * head_size**0.5 # (T, 16) * (16, T) ---> (T, T) each character gets a weight for all characters in the batch

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v=value(x)
out = wei @ v

print(wei[0])
out.shape

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3989, 0.6011, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2886, 0.1684, 0.5431, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5726, 0.2702, 0.0419, 0.1153, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2158, 0.3648, 0.1488, 0.0369, 0.2337, 0.0000, 0.0000, 0.0000],
        [0.1819, 0.0757, 0.3826, 0.1124, 0.0363, 0.2110, 0.0000, 0.0000],
        [0.1323, 0.1567, 0.0327, 0.2290, 0.2709, 0.0831, 0.0953, 0.0000],
        [0.2437, 0.0516, 0.1671, 0.0837, 0.0228, 0.3515, 0.0353, 0.0443]],
       grad_fn=<SelectBackward0>)


torch.Size([4, 8, 16])