In [1]:
import pandas as pd
import torch
import yaml

from torch.utils.data import DataLoader

from data_preprocessing import pre_process_data
from mtgpt import mtGPT
from mtg_dataset import MTGDataset
from utility import *

In [4]:
run_pre_process_data = False
load_model_path = './models/2304161750/model.pth'

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [6]:
torch.manual_seed(42)

<torch._C.Generator at 0x17c20728150>

# Preprocess and load data

In [7]:
if run_pre_process_data:
    pre_process_data(
            'data/AllPrintingsCSVFiles/cards.csv', 
            'data/preproc/cards_text.txt', 
            ['name','manaCost','type','text','power','toughness'], 
            True, 
            '[½®π∞☐àáâéíöúû−•²]', 
            train_data_ratio=0.9
            )

with open('data/preproc/cards_text_train.txt', 'r') as fh:
    data_train = fh.read()
with open('data/preproc/cards_text_val.txt', 'r') as fh:
    data_val = fh.read()

In [8]:
print(data_train[:500])

vine gecko, {1}{g}, creature — elemental lizard, the first kicked spell you cast each turn costs {1} less to cast. whenever you cast a kicked spell, put a +1/+1 counter on vine gecko., power 2, toughness 2
paleoloth, {4}{g}{g}, creature — beast, whenever another creature with power 5 or greater enters the battlefield under your control, you may return target creature card from your graveyard to your hand., power 5, toughness 5
phyrexian swarmlord, {4}{g}{g}, creature — phyrexian insect horror, i


In [9]:
print(data_val[:500])

thraben inspector, {w}, creature — human soldier, when thraben inspector enters the battlefield, investigate. (create a colorless clue artifact token with "{2}, sacrifice this artifact: draw a card."), power 1, toughness 2
ancient den, nan, artifact land, {t}: add {w}., power nan, toughness nan
ribbons of night, {4}{b}, sorcery, ribbons of night deals 4 damage to target creature and you gain 4 life. if {u} was spent to cast this spell, draw a card., power nan, toughness nan
cunning wish, {2}{u},


# Create vocubulary and encode text data

In [10]:
chars = sorted(list(set(data_train + data_val)))
dim_vocabulary = len(chars)

print(dim_vocabulary)

61


In [11]:
map_char_to_int = get_map_char_to_int(chars)
map_int_to_char = get_map_int_to_char(chars)

In [12]:
data_train_encoded = torch.tensor(encode(map_char_to_int, data_train), dtype=torch.long)

In [13]:
data_val_encoded = torch.tensor(encode(map_char_to_int, data_val), dtype=torch.long)

In [14]:
print('Length train data: {}\nLength validation data: {}'.format(len(data_train_encoded), len(data_val_encoded)))

Length train data: 4609801
Length validation data: 503691


# Load training configuration

In [15]:
with open('./config.yaml', 'r') as fh:
    config = yaml.safe_load(fh)

# Create data loader

In [16]:
training_dataset = MTGDataset(
    data_train_encoded, 
    config['model']['dim_context']
    )
training_dataloader = DataLoader(
    training_dataset, 
    batch_size=config['train']['batch_size'], 
    shuffle=True
    )

In [17]:
validation_dataset = MTGDataset(
    data_val_encoded, 
    config['model']['dim_context']
    )
validation_dataloader = DataLoader(
    validation_dataset, 
    batch_size=config['train']['batch_size'], 
    shuffle=True
    )

# Create MTGPT Model

In [18]:
if load_model_path:
    model = torch.load(load_model_path)
else:
    model = mtGPT(
        dim_vocabulary,
        config['model']['dim_embedding'],
        config['model']['dim_context'],
        config['model']['dim_feedforward'],
        config['model']['num_heads'],
        config['model']['num_layers'],
        config['model']['prob_dropout'],
        device
    )

In [19]:
m = model.to(device)

In [20]:
num_model_params = sum(p.numel() for p in model.parameters())
print('Number of model parameters: {}'.format(num_model_params))

Number of model parameters: 10785853


In [21]:
optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=config['train']['learning_rate']
)

In [20]:
model.train()
for epoch in range(config['train']['epochs']):
    
    if epoch % config['train']['eval_interval'] == 0 or epoch == config['train']['epochs'] - 1:
        train_loss = estimate_loss(
            model, 
            device, 
            training_dataloader,
            config['train']['eval_iters']
        )

        val_loss = estimate_loss(
            model, 
            device, 
            validation_dataloader,
            config['train']['eval_iters']
        )

        print(
            "[{}/{}]: Training loss {:.4f}, Validation loss {:.4f}".format(
                epoch + 1, 
                config['train']['epochs'], 
                train_loss,
                val_loss
                )
            )
        with torch.no_grad():
            newline_context = torch.zeros((1, 1), dtype=torch.long, device=device)
            generated_tokens = decode(
                        map_int_to_char, 
                        model.generate(newline_context.to(device), max_new_tokens=100)[0].tolist()
                    ) 
            print("*** Example text: ")
            print('\t{}'.format(generated_tokens))
            print('***')
    
    for batch_idx, (x, y) in enumerate(training_dataloader):
        x, y = x.to(device), y.to(device)

        logits, loss = model(x, y)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        if batch_idx == config['train']['max_batches_per_epoch']:
            break

    

[1/20]: Training loss 4.0858, Validation loss 4.0854
*** Example text: 
	
—f23vm|u/m2’d3/8+&dz,l54rh_'9e’;i*z]rd?_[d3u"e}i(,"}?7}[ s,j4}[cd45qi]fnh5j’*"l_}_(t*jws&n[-r}}l8)7p
***
[2/20]: Training loss 0.5152, Validation loss 0.5423
*** Example text: 
	
wato ond of the vens, {3}{u}, instant, choose one —  destroy all ongre creatures the choice., power 
***
[3/20]: Training loss 0.4201, Validation loss 0.4569
*** Example text: 
	
out's horder, {2}{g}, creature — funturfl, when owt egarreder enters the battlefield, exwarder gets 
***
[4/20]: Training loss 0.3831, Validation loss 0.4355
*** Example text: 
	
array falcon, {2}{g}{g}, creature — elemental, {t}, sacrifice array falcon: search your library for 
***
[5/20]: Training loss 0.3495, Validation loss 0.4139
*** Example text: 
	
rowhenite hydra, {3}{r}{r}, creature — dragon, flying whenever an opponent casts a red spell, its co
***
[6/20]: Training loss 0.3360, Validation loss 0.4077
*** Example text: 
	
purdian look, {3}{b}, sorcery, 

In [21]:
save_model_and_config(model, config)

In [22]:
new_tokens = 2000
context = torch.zeros((1, 1), dtype=torch.long, device=device)
with torch.no_grad():
    generated_tokens = m.generate(
        context, 
        max_new_tokens=new_tokens 
        )[0].tolist()
    print(decode(
                map_int_to_char, 
                generated_tokens
            ) 
        )


—fishmatic strider, {4}{r}{r}, creature — elemental warrior, trample as hunting strider enters the battlefield, you may search your library for an instant or sorcery card, reveal it, put it into your hand, then shuffle. {x}, {t}: target player mills x cards, where x is the sacrificed creature's power., power 3, toughness 3
rortipling wind, {2}{g}, enchantment, constellation — whenever an enchantment enters the battlefield under your control, create a 1/1 colorless pirate creature token with no ign combat. (it can't be the target of spells or abilities your opponents control.), power nan, toughness nan
yami, rewinder avenger, {4}{u}{u}, legendary creature — human wizard, whenever you cast an instant or sorcery spell, choose one —  return target enchantment you control to its owner's hand. then that player shuffles their library., power 7, toughness 6
gorex shrieker, {5}, artifact creature — phyrexian construct, infect (this creature deals damage to creatures in the form of -1/-1 counte

In [23]:
new_tokens = 2000
context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated_cards = ''
for _ in range(10): 
    with torch.no_grad():
        generated_tokens = m.generate(
            context, 
            max_new_tokens=new_tokens 
            )[0].tolist()
        new_cards = decode(map_int_to_char, generated_tokens) 
        generated_cards += '\n'.join(new_cards.split('\n')[:-1])
                            

In [24]:
with open('results/generated_cards.txt', 'w') as fh:
    fh.write(generated_cards)