In [1]:
from src.model import Transformer, TransformerConfig
from src.load_data import load_data, download_data, create_word_dicts, create_dataset
from src.train import train, eval

import torch
import torch.nn as nn
import shutil
from tqdm.notebook import trange
import os
import json
import wandb

%load_ext autoreload
%autoreload 2

wandb.login()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mlovisheindrich[0m ([33mlovis[0m). Use [1m`wandb login --relogin`[0m to force relogin


Device: cuda


In [2]:
download_data()
words = load_data()
word_to_index, index_to_word = create_word_dicts(words, min_occurrences=10)

data = {
    "words": words,
    "word_to_index": word_to_index,
    "index_to_word": index_to_word
}

# Initialize run data directory
run_data_path = "./run_data"
if os.path.exists(run_data_path):
    shutil.rmtree(run_data_path)
os.mkdir(run_data_path)
os.mkdir(run_data_path+"/checkpoints")

with open(run_data_path+"/word_data.json", 'w') as outfile:
    outfile.write(json.dumps(data))

Data size: 1195401 words
6083 words that occur >= 10 times
22188 words that occur < 10 times
Vocabulary size: 6084


In [3]:
config = {
    "vocabulary_size": len(index_to_word),
    "max_input_length": 30,
    "batch_size": 64,
    "embedding_size": 192,
    "num_blocks": 4,
    "num_heads": 8,
    "num_epochs": 200,
    "val_split": 0.2,
    "warmup_steps": 4000,
    "lr_scale": 5
}

wandb.init(
    project="basic-transformer",
    config=config,
    settings=wandb.Settings(start_method="thread")
)

model_artifact = wandb.Artifact('models', 'model')
word_artifact = wandb.Artifact('word_dicts', 'dataset')
word_artifact.add_file(local_path=run_data_path+"/word_data.json")
wandb.log_artifact(word_artifact)

# Paper used 25000 tokens per batch
print("Total tokens per batch", config["batch_size"]*config["max_input_length"])

Total tokens per batch 1920


In [30]:
with open(run_data_path+"/config.json", 'w') as outfile:
    outfile.write(json.dumps(config))

In [4]:
train_dl, val_dl = create_dataset(words, word_to_index, index_to_word, batch_size=config["batch_size"], val_split=config["val_split"], max_input_length=config["max_input_length"])

In [5]:
transformer = Transformer(TransformerConfig(vocab_size=config["vocabulary_size"], max_input_length=config["max_input_length"], num_heads=config["num_heads"], num_blocks=config["num_blocks"], embedding_size=config["embedding_size"]), apply_softmax=False)
transformer.to(device)
wandb.watch(transformer, log_freq=1000)

loss_fn = nn.CrossEntropyLoss()
steps_per_epoch = len(train_dl)

def transformer_lr(step, d_model=config["embedding_size"], warmup_steps=config["warmup_steps"], lr_scale=config["lr_scale"]):
    if step==0:
        return transformer_lr(1, d_model, warmup_steps)
    return lr_scale*((d_model) ** -0.5)*min(step**-0.5, step*(warmup_steps**-1.5))

initial_lr = transformer_lr(1)
optim = torch.optim.Adam(transformer.parameters(), lr=initial_lr, betas=(0.9, 0.98), eps=1e-09)
lr_per_epoch = lambda epoch: transformer_lr(epoch*steps_per_epoch) / initial_lr
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_per_epoch)

transformer_params = 0
for param in transformer.parameters():
    transformer_params += param.nelement()
print("Total parameters:", transformer_params)
print("Total training steps:", steps_per_epoch*config["num_epochs"])

Total parameters: 4720128
Total training steps: 149700


In [6]:
def plot_lr_schedule():
    optim = torch.optim.Adam(transformer.parameters(), lr=initial_lr, betas=(0.9, 0.98), eps=1e-09)
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_per_epoch)
    lrs = []
    for i in range(300):
        lrs.append(lr_scheduler.optimizer.param_groups[0]['lr'])
        lr_scheduler.step()
    import seaborn as sns
    print(initial_lr, max(lrs))
    sns.lineplot(lrs)

#plot_lr_schedule()

In [11]:
def training_loop(num_epochs, start_epoch=0, model_checkpoint_freq=100):
    for e in trange(num_epochs):

        train(transformer, loss_fn=loss_fn, optim=optim, device=device, dl=train_dl)

        train_loss, train_acc = eval(transformer=transformer, loss_fn=loss_fn, device=device, dl=train_dl)
        val_loss, val_acc = eval(transformer=transformer, loss_fn=loss_fn, device=device, dl=val_dl)

        lr = lr_scheduler.optimizer.param_groups[0]['lr']
        step = (e+start_epoch)*steps_per_epoch
        print(f"\nEpoch {e+start_epoch}, lr = {lr:.6f}")
        lr_scheduler.step()
        print(f"Training loss {train_loss:.4f}, accuracy {train_acc:.4f}")
        print(f"Eval loss {val_loss:.4f}, accuracy {val_acc:.4f}")
        if (e+1)%model_checkpoint_freq == 0:
            checkpoint_name = f"/checkpoints/checkpoint_{e+start_epoch+1}.pt"
            torch.save(transformer.state_dict(), run_data_path + checkpoint_name)
            model_artifact.add_file(run_data_path + checkpoint_name, name =checkpoint_name)
        wandb.log({"train_acc": train_acc, "train_loss": train_loss, "val_acc": val_acc, "val_loss": val_loss, "learning_rate": lr, "step": step})

In [8]:
training_loop(config["num_epochs"])

  0%|          | 0/300 [00:00<?, ?it/s]


Epoch 0, lr = 0.000001
Training loss 32.7505, accuracy 0.0100
Eval loss 32.8000, accuracy 0.0099

Epoch 1, lr = 0.000616
Training loss 7.0642, accuracy 0.0778
Eval loss 7.0852, accuracy 0.0783

Epoch 2, lr = 0.001233
Training loss 5.8961, accuracy 0.0835
Eval loss 5.9437, accuracy 0.0811

Epoch 3, lr = 0.001849
Training loss 5.1550, accuracy 0.1728
Eval loss 5.2229, accuracy 0.1691

Epoch 4, lr = 0.002466
Training loss 4.9541, accuracy 0.1790
Eval loss 5.0377, accuracy 0.1743

Epoch 5, lr = 0.003082
Training loss 4.8494, accuracy 0.1804
Eval loss 4.9352, accuracy 0.1753

Epoch 6, lr = 0.003698
Training loss 4.7860, accuracy 0.1880
Eval loss 4.8760, accuracy 0.1836

Epoch 7, lr = 0.004315
Training loss 4.7628, accuracy 0.1873
Eval loss 4.8662, accuracy 0.1824

Epoch 8, lr = 0.004931
Training loss 4.7258, accuracy 0.1882
Eval loss 4.8356, accuracy 0.1833

Epoch 9, lr = 0.004663
Training loss 4.6630, accuracy 0.1950
Eval loss 4.7848, accuracy 0.1891

Epoch 10, lr = 0.004424
Training loss

In [14]:
torch.save(transformer.state_dict(), run_data_path + "/model.pt")
model_artifact.add_file(run_data_path + "/model.pt")
wandb.log_artifact(model_artifact)
wandb.finish()

0,1
learning_rate,▄█▆▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_acc,▁▂▃▃▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇██████████
train_loss,█▆▅▅▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_acc,▁▆█████▇▇▇▇▇▇▇▇▆▇▆▆▆▆▆▆▆▅▆▆▆▅▆▅▅▆▆▅▅▆▅▅▅
val_loss,█▂▁▁▂▂▃▄▄▄▅▅▅▅▆▆▆▆▆▇▆▇▇▆▇▇▇▇▇▇▇▇▇█▇█▇███

0,1
learning_rate,0.00064
step,240518.0
train_acc,0.42933
train_loss,2.50929
val_acc,0.19778
val_loss,5.19866
