In [20]:
import pandas as pd
import numpy as np
import torch
import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
from torch.nn import functional as F
import pickle

In [21]:
#model params
torch.manual_seed(69)
batch_size=512
block_size=36
sampling_size=24
max_iters=5000
eval_interval=300
learning_rate=5e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 32
n_heads = 8
n_layers = 10
dropout=0.3

prompt_size = block_size # 30 maximum tokens allowed in the prompt
encoder_num_heads=4
encoder_n_embd=n_embd

# Tokenizer

Run only one of the two cells below

In [22]:
#import tokenizer trainer
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_file("mtggenerator.json")
vocab_size=tokenizer.get_vocab_size()

#create the mapping from characters to integers
encode = lambda text: tokenizer.encode(text).ids #encode: take a string, output a list of integers
decode = lambda list: tokenizer.decode(list) #decode: take a list of integers, output a string

In [3]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
vocab_size=tokenizer.vocab_size

#create the mapping from characters to integers
encode = lambda text: tokenizer.encode(text) #encode: take a string, output a list of integers
decode = lambda list: tokenizer.decode(list) #decode: take a list of integers, output a string

  from .autonotebook import tqdm as notebook_tqdm


In [23]:
# For ZY no csv read
#with open('mtgdata.pickle', 'rb') as file:
#    mtg_df=pickle.load(file)

mtg_df=pd.read_csv('mtg_data.csv', index_col=0)
mtg_df=mtg_df.dropna(subset=['text_prompt', 'card_description'])
mtg_df

Unnamed: 0,name,mana_cost,cmc,type_line,oracle_text,power,toughness,colors,color_identity,keywords,rarity,flavor_text,text,text_prompt,card_description
0,Fury Sliver,{5}{R},6.0,Creature — Sliver,All Sliver creatures have double strike.,3,3,['R'],['R'],[],uncommon,"""A rift opened, and our arrows were abruptly s...",Fury Sliver: [SEP] {5}{R} [SEP] Creature — Sli...,Fury Sliver: [SEP] {5}{R},Creature — Sliver [SEP] All Sliver creatures h...
1,Kor Outfitter,{W}{W},2.0,Creature — Kor Soldier,"When ~ enters the battlefield, you may attach ...",2,2,['W'],['W'],[],common,"""We take only what we need to survive. Believe...",Kor Outfitter: [SEP] {W}{W} [SEP] Creature — K...,Kor Outfitter: [SEP] {W}{W},Creature — Kor Soldier [SEP] When ~ enters the...
2,Spirit,,0.0,Token Creature — Spirit,Flying,1,1,['W'],['W'],[Flying],common,,Spirit: [SEP] [SEP] Token Creature — Spirit [...,Spirit: [SEP],Token Creature — Spirit [SEP] Flying
3,Siren Lookout,{2}{U},3.0,Creature — Siren Pirate,"Flying\nWhen ~ enters the battlefield, it expl...",1,2,['U'],['U'],"[Flying, Explore]",common,,Siren Lookout: [SEP] {2}{U} [SEP] Creature — S...,Siren Lookout: [SEP] {2}{U},Creature — Siren Pirate [SEP] Flying\nWhen ~ e...
4,Web,{G},1.0,Enchantment — Aura,Enchant creature (Target a creature as you cas...,,,['G'],['G'],[Enchant],rare,,Web: [SEP] {G} [SEP] Enchantment — Aura [SEP] ...,Web: [SEP] {G},Enchantment — Aura [SEP] Enchant creature (Tar...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
85059,Celestine Reef,,0.0,Plane — Luvion,Creatures without flying or islandwalk can't a...,,,[],[],[],rare,,Celestine Reef: [SEP] [SEP] Plane — Luvion [S...,Celestine Reef: [SEP],Plane — Luvion [SEP] Creatures without flying ...
85060,Horned Troll,{2}{G},3.0,Creature — Troll,{G}: Regenerate ~.,2,2,['G'],['G'],[],common,Sword hilts jut from some trolls' bodies where...,Horned Troll: [SEP] {2}{G} [SEP] Creature — Tr...,Horned Troll: [SEP] {2}{G},Creature — Troll [SEP] {G}: Regenerate ~.
85061,Faerie Bladecrafter,{2}{B},3.0,Creature — Faerie Rogue,Flying\nWhenever one or more Faeries you contr...,2,2,['B'],['B'],[Flying],rare,,Faerie Bladecrafter: [SEP] {2}{B} [SEP] Creatu...,Faerie Bladecrafter: [SEP] {2}{B},Creature — Faerie Rogue [SEP] Flying\nWhenever...
85062,Exultant Skymarcher,{1}{W}{W},3.0,Creature — Vampire Soldier,Flying,2,3,['W'],['W'],[Flying],common,"""We have come at last to this holiest of holy ...",Exultant Skymarcher: [SEP] {1}{W}{W} [SEP] Cre...,Exultant Skymarcher: [SEP] {1}{W}{W},Creature — Vampire Soldier [SEP] Flying


In [24]:
#pre-processing to get rid of unregonizable characters
rare_char={
    '¡®°²½˝̶π’„•…™−∞☐œŠ':'',
    'Äàáâãä':'a',
    'Éèéêë':'e',
    'Ææ':'ae',
    'Óóö':'o',
    'úûü':'u',
    'íī':'i',
    'Ññ':'n'
}
for rarechar, target in rare_char.items():
    for char in [*rarechar]:
        mtg_df['text_prompt']=mtg_df['text_prompt'].str.replace(char, target)
        mtg_df['card_description']=mtg_df['card_description'].str.replace(char, target)

prompt_list=list(mtg_df['text_prompt'])
text_list=list(mtg_df['card_description'])
print(f'length of prompts is {len(prompt_list)}\nlength of descriptions is {len(text_list)}')

length of prompts is 82351
length of descriptions is 82351


In [25]:
encoded_text_list=[torch.Tensor(encode(text)) for text in text_list]
max_len=max([len(item) for item in encoded_text_list])
padded_text_list=[torch.cat((item, torch.full((max_len - len(item),), 3))) for item in encoded_text_list] # the [PAD] token has id=3
padded_text_list_with_CLS = [torch.cat((torch.tensor([1]), item)) for item in padded_text_list]

encoded_prompt_list=[torch.Tensor(encode(text)) for text in prompt_list]
max_prompt_len=max([len(item) for item in encoded_prompt_list])
padded_prompt_list=[torch.cat((item, torch.full((max_prompt_len - len(item),), 3)))[:prompt_size-1] for item in encoded_prompt_list] # the [PAD] token has id=3
padded_prompt_list_with_CLS = [torch.cat((torch.tensor([1]), item)) for item in padded_prompt_list]

In [26]:
data = pad_sequence(padded_text_list, batch_first=True).long()
prompts = pad_sequence(padded_prompt_list_with_CLS, batch_first=True).long()
n_train = int(0.9*data.shape[0])
train_data = data[:n_train]
val_data = data[n_train:]
train_prompts = prompts[:n_train]
val_prompts = prompts[n_train:]

In [27]:
train_prompts[100]

tensor([   1, 5951,  506,   29,    2,   87,   33,   89,    3,    3,    3,    3,
           3,    3,    3,    3,    3,    3,    3,    3,    3,    3,    3,    3,
           3,    3,    3,    3,    3,    3,    3,    3,    3,    3,    3,    3])

In [28]:
def get_batch(split):
    #generates a small batch of data input x and target y
    data = train_data if split == 'train' else val_data
    ix = torch.stack([torch.randint(data.shape[0], (batch_size, )), torch.randint(sampling_size, (batch_size, ))]).T
    x = torch.stack(tuple(data[i[0]][i[1]:i[1] + block_size] for i in ix))
    y = torch.stack(tuple(data[i[0]][i[1] + 1:i[1] + block_size + 1] for i in ix))

    prompt_data = train_prompts if split == 'train' else val_prompts
    x_prompt = torch.stack(tuple(prompt_data[i[0]] for i in ix))
    x_prompt = torch.cat((x_prompt, x), dim=-1)
    x_prompt = x_prompt.to(device)
    y = y.to(device)

    return x_prompt, y

In [29]:
@torch.no_grad()
def estimate_loss():
    out={}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split]=losses.mean()
    model.train()
    return out


class MaskedHead(nn.Module):
    #one self attention head

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias= False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        # compute attention scores
        wei = q @ k.transpose(-2, -1) * C**0.5
        wei = wei.masked_fill(self.tril[:T, :T]==0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v=self.value(x)
        out=wei @ v

        return out


class MaskedMultiHeadAttention(nn.Module):
    """multi head attention"""
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads=nn.ModuleList([MaskedHead(head_size) for _ in range(num_heads)])
        self.proj=nn.Linear(head_size*num_heads, n_embd)
        self.dropout=nn.Dropout(dropout)
    
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return self.dropout(out)
    
class Head(nn.Module):
    #one self attention head (unmasked)

    def __init__(self, encoder_head_size):
        super().__init__()
        self.key = nn.Linear(encoder_n_embd, encoder_head_size, bias=False)
        self.query = nn.Linear(encoder_n_embd, encoder_head_size, bias= False)
        self.value = nn.Linear(encoder_n_embd, encoder_head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        # compute attention scores
        wei = q @ k.transpose(-2, -1) * C**0.5
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        v=self.value(x)
        out=wei @ v
        return out

class MultiHeadAttention(nn.Module):
    """multi head attention (unmasked)"""
    def __init__(self, encoder_num_heads, encoder_head_size):
        super().__init__()
        self.heads=nn.ModuleList([Head(encoder_head_size) for _ in range(encoder_num_heads)])
        self.proj=nn.Linear(encoder_head_size*encoder_num_heads, encoder_n_embd)
        self.dropout=nn.Dropout(dropout)
    
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return self.dropout(out)

class FeedForward(nn.Module):
    """simple feedforward perceptron layer"""
    def __init__(self, n_embd):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(n_embd, 4*n_embd),
            nn.ReLU(),
            nn.Linear(4*n_embd, n_embd),
            nn.Dropout(dropout),
        )
    
    def forward(self, x):
        return self.net(x)


class EncoderBlock(nn.Module):
    """Transformer encoder block: multihead self attention, followed by feedforward in to k and v"""
    def __init__(self, encoder_n_embd, encoder_n_head):
        super().__init__()
        encoder_head_size=encoder_n_embd//encoder_n_head
        self.selfattention=MultiHeadAttention(encoder_n_head, encoder_head_size)
        self.ffwd=FeedForward(encoder_n_embd)
        self.ln1=nn.LayerNorm(encoder_n_embd)
        self.ln2=nn.LayerNorm(encoder_n_embd)
    
    def forward(self, x):
        x = x+self.sa(self.ln1(x))
        x = x+self.ffwd(self.ln2(x))

class CrossAttentionHead(nn.Module):
    """Cross attention block: takes encoder embeddings and decoder embeddings to generate cross attention"""
    def __init__(self, ca_head_size):
        super().__init__()
        self.key = nn.Linear(encoder_n_embd, ca_head_size, bias=False)
        self.query = nn.Linear(n_embd, ca_head_size, bias= False)
        self.value = nn.Linear(encoder_n_embd, ca_head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, embedded_x_prompt):
        x=embedded_x_prompt[:, block_size:, :]
        prompt=embedded_x_prompt[:, :block_size, :]
        B, T, C = x.shape
        B_, T_, C_ = prompt.shape
        k = self.key(prompt)
        q = self.query(x)
        # compute attention scores
        wei = q @ k.transpose(-2, -1) * C**0.5
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        v=self.value(prompt)
        out=wei @ v

        return out

class MultiHeadCrossAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads=nn.ModuleList([CrossAttentionHead(head_size) for _ in range(num_heads)])
        self.proj=nn.Linear(head_size*num_heads, n_embd)
        self.dropout=nn.Dropout(dropout)

    def forward(self, embedded_x_prompt):
        out = torch.cat([head(embedded_x_prompt) for head in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return torch.cat((embedded_x_prompt[:, :block_size, :], out), dim=-2)


class DecoderBlock(nn.Module):
    """Transformer decoder block: multihead self attention followed by one Feedforward layer, followed by cross-attention, followed by ffwd"""
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd//n_head
        self.sa=MaskedMultiHeadAttention(n_head, head_size)
        self.ffwd=FeedForward(n_embd)
        self.ln1=nn.LayerNorm(n_embd)
        self.ln2=nn.LayerNorm(n_embd)
        self.ln3=nn.LayerNorm(n_embd)

        self.crossattention=MultiHeadCrossAttention(n_head, head_size) # cross attention module
    
    def forward(self, embedded_x_prompt):

        x=embedded_x_prompt[:, block_size:, :]
        x_sa = x+self.sa(self.ln1(x))
        x_ca = x_sa+self.crossattention(self.ln2(torch.cat((embedded_x_prompt[:, :block_size, :], x_sa), dim=-2)))[:, block_size:, :] # do cross attention with output of self attention
        out_x = x_ca+self.ffwd(self.ln3(x_ca))
        out_x_prompt=embedded_x_prompt[:, :block_size, :]

        """Cross Attention + Feed forward"""

        return torch.cat((out_x_prompt, out_x), dim=-2)


class MTGCardGenerator(nn.Module):

    def __init__(self):
        super().__init__()
        self.token_embedding_table=nn.Embedding(vocab_size, n_embd) #each token directly look up the logit of the next token from a lookup table
        self.lmhead=nn.Linear(n_embd, vocab_size)
        self.position_embedding_table=nn.Embedding(block_size, n_embd) #each token gets a position embeding of block_size, stores the relative position of token in the block

        self.encoder_token_embedding_table=nn.Embedding(vocab_size, encoder_n_embd)
        self.encoder_postion_embedding_table=nn.Embedding(vocab_size, encoder_n_embd)

        self.block=nn.Sequential(*[DecoderBlock(n_embd, n_head=n_heads) for _ in range(n_layers)])
    
    def forward(self, x_prompt, targets=None, mode="train"):

        prompt=x_prompt[:,:block_size]
        idx=x_prompt[:,block_size:]
        
        B, T = idx.shape
        B_, T_ = prompt.shape

        #idx and targets are both (B,T) tensors of integers, where B=batch number, T=position in batch
        token_embeddings=self.token_embedding_table(idx) #look up value corresponding to own position in the token embedding table to form C (channel value)
        position_embeddings=self.position_embedding_table(torch.arange(T, device=device)) #add position embeddings to token embedding
        x= token_embeddings + position_embeddings
        encoder_token_embeddings=self.encoder_postion_embedding_table(prompt)
        encoder_position_embeddings=self.encoder_postion_embedding_table(torch.arange(T_, device=device))
        prompt_x = encoder_token_embeddings+encoder_position_embeddings

        embedded_x_prompt=torch.cat((prompt_x, x), dim=-2)
        #returned_x = self.block(embedded_x_prompt)[1]
        logits=self.lmhead(self.block(embedded_x_prompt)[:,block_size:,:])

        if targets is None:
            loss=None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            #logits are therefore values associated with each character
            loss=F.cross_entropy(logits, targets) #evaluate loss

        return logits, loss
    
    def generate(self, context, max_new_tokens):
        prompt=context[:,:block_size]
        idx=context[:,block_size:]
        for i in range(max_new_tokens):
            if idx.shape[-1]>block_size:
            #crop idx to max block size
                idx_cond=idx[:, -block_size:]
            else:
                idx_cond=idx
            #get the predictions
            logits, loss = self(torch.cat((prompt, idx_cond), dim=-1))
            #use logits only, focus only on last time step
            logits = logits[:, -1, :] #keep only last time step ---> (B, C)
            #apply softmax on logit to get distribution
            probs = F.softmax(logits, dim=-1) #get a (B, C) matrix of probabilities, sum(prob) of each B = 1
            #sample from the distribution
            idx_next=torch.multinomial(probs, num_samples=1) #get a (B, 1) array of predictions
            #append prediction to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) #now a (B, T+1) matrix of returned results
        return idx

In [11]:
model=MTGCardGenerator()
m=model.to(device)

In [33]:
model_path='mtggenerator_v5_check.pt'
model=MTGCardGenerator()
model.load_state_dict(torch.load(model_path, map_location=torch.device('cuda')))
model.eval()
m=model.to(device)

In [12]:
torch.autograd.set_detect_anomaly(True)

optimizer=torch.optim.AdamW(model.parameters(), lr=1e-3)

for iter in range(max_iters):
    # every once in a while evaluate the loss of train and val
    if iter % eval_interval == 0:
        losses=estimate_loss()
        print(f"step {iter}: train loss: {losses['train']:.4f}, val loss: {losses['val']:.4f}")
    
    #sample a batch of data
    xb_prompt, yb= get_batch('train')

    #evaluate the loss
    logits, loss = model(xb_prompt, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss: 10.7375, val loss: 10.7299
step 300: train loss: 2.9598, val loss: 2.9356


KeyboardInterrupt: 

In [None]:
torch.save(m.state_dict(), 'mtggenerator_v5.pt')

In [31]:
def generate(cardname, mana):
    prompt= torch.tensor([encode(f'[CLS] {cardname}: [SEP] {mana}')], dtype=torch.long, device=device)
    padding_values = torch.full((1, prompt_size-prompt.shape[-1]), 3, dtype=torch.long, device=device)
    padded_prompt = torch.cat((prompt, padding_values), dim=-1)
    start = torch.tensor([encode('[CLS]')], dtype=torch.long, device=device)
    context=torch.cat((padded_prompt, start), dim=-1)
    response=m.generate(context, max_new_tokens=20)[0].tolist()
    return decode(response)
generate('The Big Bang', '{2}{R}')

'party { T }: Add { B } ( You may cast this card for its of any name'

In [34]:
generate('High Tide', '{2}{U}')

'spell or dealt damage to each attacking .'

In [119]:
import requests
requests.get('http://107.22.21.89/', params={'prompt':'hello'}).json()

{'cost': 'hello Van Darkholme Cost',
 'description': 'hello My name is Van',
 'name': 'hello Van Darkholme Name'}