In [1]:
import pandas as pd
import torch
import yaml

from torch.utils.data import DataLoader

from data_preprocessing import pre_process_data
from mtgpt import mtGPT
from mtg_dataset import MTGDataset
from utility import *

In [25]:
run_pre_process_data = False
load_model_path = './models/2304221751/model.pth'
save_model = False

In [26]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [27]:
torch.manual_seed(42)

<torch._C.Generator at 0x1513a94e170>

# Preprocess and load data

In [28]:
if run_pre_process_data:
    pre_process_data(
            'data/AllPrintingsCSVFiles/cards.csv', 
            'data/preproc/cards_text.txt', 
            ['name','manaCost','type','text','power','toughness'], 
            True, 
            '[½®π∞☐àáâéíöúû−•²]', 
            train_data_ratio=0.9
            )

with open('data/preproc/cards_text_train.txt', 'r') as fh:
    data_train = fh.read()
with open('data/preproc/cards_text_val.txt', 'r') as fh:
    data_val = fh.read()

In [29]:
print(data_train[:500])

vine gecko, {1}{g}, creature — elemental lizard, the first kicked spell you cast each turn costs {1} less to cast. whenever you cast a kicked spell, put a +1/+1 counter on vine gecko., power 2, toughness 2
paleoloth, {4}{g}{g}, creature — beast, whenever another creature with power 5 or greater enters the battlefield under your control, you may return target creature card from your graveyard to your hand., power 5, toughness 5
phyrexian swarmlord, {4}{g}{g}, creature — phyrexian insect horror, i


In [30]:
print(data_val[:500])

thraben inspector, {w}, creature — human soldier, when thraben inspector enters the battlefield, investigate. (create a colorless clue artifact token with "{2}, sacrifice this artifact: draw a card."), power 1, toughness 2
ancient den, nan, artifact land, {t}: add {w}., power nan, toughness nan
ribbons of night, {4}{b}, sorcery, ribbons of night deals 4 damage to target creature and you gain 4 life. if {u} was spent to cast this spell, draw a card., power nan, toughness nan
cunning wish, {2}{u},


# Create vocubulary and encode text data

In [31]:
chars = sorted(list(set(data_train + data_val)))
dim_vocabulary = len(chars)

print(dim_vocabulary)

61


In [32]:
map_char_to_int = get_map_char_to_int(chars)
map_int_to_char = get_map_int_to_char(chars)

In [33]:
data_train_encoded = torch.tensor(encode(map_char_to_int, data_train), dtype=torch.long)

In [34]:
data_val_encoded = torch.tensor(encode(map_char_to_int, data_val), dtype=torch.long)

In [35]:
print('Length train data: {}\nLength validation data: {}'.format(len(data_train_encoded), len(data_val_encoded)))

Length train data: 4609801
Length validation data: 503691


# Load training configuration

In [36]:
with open('./config.yaml', 'r') as fh:
    config = yaml.safe_load(fh)

# Create data loader

In [37]:
training_dataset = MTGDataset(
    data_train_encoded, 
    config['model']['dim_context']
    )
training_dataloader = DataLoader(
    training_dataset, 
    batch_size=config['train']['batch_size'], 
    shuffle=True
    )

In [38]:
validation_dataset = MTGDataset(
    data_val_encoded, 
    config['model']['dim_context']
    )
validation_dataloader = DataLoader(
    validation_dataset, 
    batch_size=config['train']['batch_size'], 
    shuffle=True
    )

# Create MTGPT Model

In [39]:
if load_model_path:
    model = torch.load(load_model_path)
else:
    model = mtGPT(
        dim_vocabulary,
        config['model']['dim_embedding'],
        config['model']['dim_context'],
        config['model']['dim_feedforward'],
        config['model']['num_heads'],
        config['model']['num_layers'],
        config['model']['prob_dropout'],
        device
    )

In [40]:
m = model.to(device)

In [41]:
num_model_params = sum(p.numel() for p in model.parameters())
print('Number of model parameters: {}'.format(num_model_params))

Number of model parameters: 10785853


In [42]:
optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=config['train']['learning_rate']
)

In [24]:
model.train()
for epoch in range(config['train']['epochs']):
    
    if epoch % config['train']['eval_interval'] == 0 or epoch == config['train']['epochs'] - 1:
        train_loss = estimate_loss(
            model, 
            device, 
            training_dataloader,
            config['train']['eval_iters']
        )

        val_loss = estimate_loss(
            model, 
            device, 
            validation_dataloader,
            config['train']['eval_iters']
        )

        print(
            "[{}/{}]: Training loss {:.4f}, Validation loss {:.4f}".format(
                epoch + 1, 
                config['train']['epochs'], 
                train_loss,
                val_loss
                )
            )
        with torch.no_grad():
            newline_context = torch.zeros((1, 1), dtype=torch.long, device=device)
            generated_tokens = decode(
                        map_int_to_char, 
                        model.generate(newline_context.to(device), max_new_tokens=100)[0].tolist()
                    ) 
            print("*** Example text: ")
            print('\t{}'.format(generated_tokens))
            print('***')
    
    for batch_idx, (x, y) in enumerate(training_dataloader):
        x, y = x.to(device), y.to(device)

        logits, loss = model(x, y)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        if batch_idx == config['train']['max_batches_per_epoch']:
            break

    

[1/30]: Training loss 0.4481, Validation loss 0.4994
*** Example text: 
	
gravenous deals 2 damage to any target., power nan, toughness nan
cluted fellims, {5}{g}{g}, creatur
***
[2/30]: Training loss 0.4490, Validation loss 0.4980
*** Example text: 
	
zurgand poasm, {1}{b}, creature — bird, flying {1}, sacrifice fleets fleet: it deals dama equal to i
***
[3/30]: Training loss 0.4510, Validation loss 0.5005
*** Example text: 
	
noot of the pe weating, {1}{u}, instant, if it was sprocemon weave, you may pay {u}{u}. if you do, y
***
[4/30]: Training loss 0.4461, Validation loss 0.4975
*** Example text: 
	
druid, {1}{w}, creature — human advisor, {w}{w}: return target spirit card from a graveyard to its o
***
[5/30]: Training loss 0.4468, Validation loss 0.5021
*** Example text: 
	
por-lop vocken failin, {4}{r}{r}, enchantment, put a -2/-2 counter on one opponent or mills three ca
***
[6/30]: Training loss 0.4487, Validation loss 0.4956
*** Example text: 
	
red plust, {3}{u}, instant, cou

# Save model

In [43]:
if save_model:
    save_path = save_model_and_config(model, config)
else:
    save_path = '/'.join(load_model_path.split('/')[:-1])

# Estimate and save final loss values to yaml file

In [44]:
train_loss = estimate_loss(
    model, 
    device, 
    training_dataloader,
    config['train']['eval_iters']
)

val_loss = estimate_loss(
    model, 
    device, 
    validation_dataloader,
    config['train']['eval_iters']
)

dict_loss = {
    'train_loss': float(train_loss.numpy().reshape(1,)[0]),
    'val_loss': float(val_loss.numpy().reshape(1,)[0])
}

with open(os.path.join(save_path, 'loss.yaml'), 'w') as fh:
    yaml.dump(dict_loss, fh, default_flow_style=False)

# Generate some new cards

In [21]:
new_tokens = 2000
context = torch.zeros((1, 1), dtype=torch.long, device=device)
with torch.no_grad():
    generated_tokens = m.generate(
        context, 
        max_new_tokens=new_tokens 
        )[0].tolist()
    print(decode(
                map_int_to_char, 
                generated_tokens
            ) 
        )


—angelic minotaur, {2}{r}{r}, creature — angel, flying, haste if a source you control would deal damage to an opponent, it deals that much damage to each player instead., power 5, toughness 5
longtusk scavenger, {2}{b}, artifact creature — bird, flying when longtusk scavenger enters the battlefield, sacrifice a creature other than dom indestructible. when longtusk scavenger enters the battlefield, return target creature an opponent controls to its owner's hand., power 1, toughness 1
guardian of ixiny, {5}, artifact, you may cast guardian of ixiny's finery affector from your graveyard rather than pay this spell's mana cost. when guardian of ixiny dies, each opponent loses two 3 life. exile creatures they may cast that card from their graveyard., power nan, toughness nan
knight of the heard dead, {b}, creature — human knight, protection from whenever a player casts a historic spell, you may pay {1}. if you do, target player draw a card. (artifacts, legendaries, and sagas are historic.),

## Generate and save to file

In [23]:
new_tokens = 2000
context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated_cards = ''
for _ in range(10): 
    with torch.no_grad():
        generated_tokens = m.generate(
            context, 
            max_new_tokens=new_tokens 
            )[0].tolist()
        new_cards = decode(map_int_to_char, generated_tokens) 
        generated_cards += '\n'.join(new_cards.split('\n')[:-1])
                            

In [24]:
with open(os.path.join(save_path, 'generated_cards.txt'), 'w') as fh:
    fh.write(generated_cards)