In [1]:
from src.model import Transformer, TransformerConfig
from src.load_data import load_data, download_data, create_word_dicts, create_dataset
from src.train import train, eval

import torch
import torch.nn as nn
import torch.nn.functional as F
import shutil
from tqdm.notebook import trange
import pandas as pd
import os
import json

import pickle
from matplotlib import pyplot as plt
import wandb

%load_ext autoreload
%autoreload 2

wandb.login()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mlovisheindrich[0m ([33mlovis[0m). Use [1m`wandb login --relogin`[0m to force relogin


Device: cuda


In [2]:
download_data()
words = load_data()
word_to_index, index_to_word = create_word_dicts(words, min_occurrences=10)

data = {
    "words": words,
    "word_to_index": word_to_index,
    "index_to_word": index_to_word
}

# Initialize run data directory
run_data_path = "./run_data"
if os.path.exists(run_data_path):
    shutil.rmtree(run_data_path)
os.mkdir(run_data_path)
os.mkdir(run_data_path+"/checkpoints")

with open(run_data_path+"/word_data.json", 'w') as outfile:
    outfile.write(json.dumps(data))

Data size: 1195401 words
6083 words that occur >= 10 times
22188 words that occur < 10 times
Vocabulary size: 6084


In [3]:
config = {
    "vocabulary_size": len(index_to_word),
    "max_input_length": 30,
    "batch_size": 64,
    "embedding_size": 256,
    "num_blocks": 4,
    "num_heads": 8,
    "num_epochs": 300,
    "val_split": 0.2,
    "warmup_steps": 4000,
    "lr_scale": 5
}

wandb.init(
    project="basic-transformer",
    config=config,
    settings=wandb.Settings(start_method="thread")
)

word_artifact = wandb.Artifact('word_dicts', 'dataset')
word_artifact.add_file(local_path=run_data_path+"/word_data.json")
wandb.log_artifact(word_artifact)

model_artifact = wandb.Artifact('models', 'model')


# Paper used 25000 tokens per batch
print("Total tokens per batch", config["batch_size"]*config["max_input_length"])

Total tokens per batch 1920


In [4]:
train_dl, val_dl = create_dataset(words, word_to_index, index_to_word, batch_size=config["batch_size"], val_split=config["val_split"], max_input_length=config["max_input_length"])

In [5]:
transformer = Transformer(TransformerConfig(vocab_size=config["vocabulary_size"], max_input_length=config["max_input_length"], num_heads=config["num_heads"], num_blocks=config["num_blocks"], embedding_size=config["embedding_size"]), apply_softmax=False)
transformer.to(device)
wandb.watch(transformer, log_freq=1000)

loss_fn = nn.CrossEntropyLoss()
steps_per_epoch = len(train_dl)

def transformer_lr(step, d_model=config["embedding_size"], warmup_steps=config["warmup_steps"], lr_scale=config["lr_scale"]):
    if step==0:
        return transformer_lr(1, d_model, warmup_steps)
    return lr_scale*((d_model) ** -0.5)*min(step**-0.5, step*(warmup_steps**-1.5))

initial_lr = transformer_lr(1)
optim = torch.optim.Adam(transformer.parameters(), lr=initial_lr, betas=(0.9, 0.98), eps=1e-09)
lr_per_epoch = lambda epoch: transformer_lr(epoch*steps_per_epoch) / initial_lr
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_per_epoch)

transformer_params = 0
for param in transformer.parameters():
    transformer_params += param.nelement()
print("Total parameters:", transformer_params)
print("Total training steps:", steps_per_epoch*config["num_epochs"])

Total parameters: 4720128
Total training steps: 149700


In [6]:
def plot_lr_schedule():
    optim = torch.optim.Adam(transformer.parameters(), lr=initial_lr, betas=(0.9, 0.98), eps=1e-09)
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_per_epoch)
    lrs = []
    for i in range(300):
        lrs.append(lr_scheduler.optimizer.param_groups[0]['lr'])
        lr_scheduler.step()
    import seaborn as sns
    print(initial_lr, max(lrs))
    sns.lineplot(lrs)

#plot_lr_schedule()

In [7]:
def training_loop(num_epochs, start_epoch=0, model_checkpoint_freq=100):
    for e in trange(num_epochs):

        train(transformer, loss_fn=loss_fn, optim=optim, device=device, dl=train_dl)

        train_loss, train_acc = eval(transformer=transformer, loss_fn=loss_fn, device=device, dl=train_dl)
        val_loss, val_acc = eval(transformer=transformer, loss_fn=loss_fn, device=device, dl=val_dl)

        lr = lr_scheduler.optimizer.param_groups[0]['lr']
        step = (e+start_epoch)*steps_per_epoch
        print(f"\nEpoch {e+start_epoch}, lr = {lr:.6f}")
        lr_scheduler.step()
        print(f"Training loss {train_loss:.4f}, accuracy {train_acc:.4f}")
        print(f"Eval loss {val_loss:.4f}, accuracy {val_acc:.4f}")
        if (e+1)%model_checkpoint_freq == 0:
            checkpoint_name = f"/checkpoints/checkpoint_{e+1}.pt"
            torch.save(transformer.state_dict(), run_data_path + checkpoint_name)
            model_artifact.add_file(run_data_path + checkpoint_name, name =checkpoint_name)
        wandb.log({"train_acc": train_acc, "train_loss": train_loss, "val_acc": val_acc, "val_loss": val_loss, "learning_rate": lr, "step": step})

In [8]:
training_loop(config["num_epochs"])

  0%|          | 0/300 [00:00<?, ?it/s]


Epoch 0, lr = 0.000001
Training loss 32.7505, accuracy 0.0100
Eval loss 32.8000, accuracy 0.0099

Epoch 1, lr = 0.000616
Training loss 7.0642, accuracy 0.0778
Eval loss 7.0852, accuracy 0.0783

Epoch 2, lr = 0.001233
Training loss 5.8961, accuracy 0.0835
Eval loss 5.9437, accuracy 0.0811

Epoch 3, lr = 0.001849
Training loss 5.1550, accuracy 0.1728
Eval loss 5.2229, accuracy 0.1691

Epoch 4, lr = 0.002466
Training loss 4.9541, accuracy 0.1790
Eval loss 5.0377, accuracy 0.1743

Epoch 5, lr = 0.003082
Training loss 4.8494, accuracy 0.1804
Eval loss 4.9352, accuracy 0.1753

Epoch 6, lr = 0.003698
Training loss 4.7860, accuracy 0.1880
Eval loss 4.8760, accuracy 0.1836

Epoch 7, lr = 0.004315
Training loss 4.7628, accuracy 0.1873
Eval loss 4.8662, accuracy 0.1824

Epoch 8, lr = 0.004931
Training loss 4.7258, accuracy 0.1882
Eval loss 4.8356, accuracy 0.1833

Epoch 9, lr = 0.004663
Training loss 4.6630, accuracy 0.1950
Eval loss 4.7848, accuracy 0.1891

Epoch 10, lr = 0.004424
Training loss

In [None]:
torch.save(transformer.state_dict(), run_data_path + "/model.pt")
model_artifact.add_file(run_data_path + "/model.pt")
wandb.log_artifact(model_artifact)
wandb.finish()

0,1
learning_rate,▁▃▆█
step,▁▃▆█
train_acc,▁▂▆█
train_loss,█▁▁▁
val_acc,▁▃▆█
val_loss,█▁▁▁

0,1
learning_rate,0.00185
step,1497.0
train_acc,0.172
train_loss,5.13684
val_acc,0.16939
val_loss,5.18249


In [105]:
transformer.eval()

with torch.no_grad():
  for _, data in enumerate(train_dl):
    x, y = data[0].to(device), data[1].to(device)
    output = transformer(x)

    # Flatten batch and sequence dimension
    loss = loss_fn(output.view(-1, output.shape[-1]), y.view(-1))
    
    pred = nn.functional.softmax(output, dim=-1)
    pred = pred.argmax(dim=-1)
    print(pred.shape, y.shape)
    for i in range(pred.shape[0]):
      print("True:", " ".join([index_to_word[word] for word in y[i, :].tolist()]))
      print("Prediction:", " ".join([index_to_word[word] for word in pred[i, :].tolist()]))
    break

torch.Size([64, 30]) torch.Size([64, 30])
True: doth not wish you joy ! gonzalo . be it so . amen ! re - enter ariel with the master and boatswain <unknown> following . o look , sir
Prediction: i appear know his joy . but . i it so . amen . amen - enter ariel with the master and boatswain . . . i , , sir
True: yet take this again - and yet i thank you - meaning henceforth to trouble you no more . speed . aside and yet you will ; and yet another
Prediction: <unknown> i the <unknown> , <unknown> yet i will you - meaning henceforth to trouble you no more . exeunt . sir ay yet i will not for yet another
True: <unknown> : that <unknown> must rise <unknown> that <unknown> him . you know the <unknown> <unknown> the duke has ? iailor . very well . daughter . she is <unknown>
Prediction: <unknown> , <unknown> i , be , , <unknown> the from enter have , gentleman of of <unknown> has the iailor . yes well ; daughter . i is <unknown>
True: look to taste the due meet for rebellion and such acts as y

In [106]:
settings = params
def generate_next_token(tokens=None):
    if tokens == None:
        tokens = []
    last_token = len(index_to_word)-1
    x = tokens + [last_token]
    x = torch.LongTensor([x]).to(device)
    with torch.no_grad():
        y = transformer(x)
    # Don't allow the model to generate <unknown> tokens
    y = y[:, :, :y.shape[2]-1]
    pred = y.argmax(dim=-1).view(-1)
    next_word = pred[len(tokens)].item()
    return next_word

def print_sentence(words):
    print(" ".join([index_to_word[word] for word in words]))

generate_next_token()

9

In [116]:
def generate_sentence(start=None):
    if start == None:
        sentence = []
    else:
        words = start.split(" ")
        sentence = [word_to_index[x] for x in words]
    
    while len(sentence) < max_input_length:
        next_word = generate_next_token(sentence)
        sentence += [next_word]
    
    print_sentence(sentence)

generate_sentence("then being asked")

then being asked , , , , knowing , knowing , knowing , knowing , in , , , , , , , , , , , , , ,


In [109]:
transformer.eval()
test = "and bring him if . if thou issueless shalt hap"
test = [word_to_index[x] for x in test.split(" ")]

x = torch.LongTensor([test]).to(device)
output = transformer(x)
pred = F.softmax(output[:, -1, :].view(-1).detach().cpu(), dim=0)
dist = torch.distributions.categorical.Categorical(probs=pred)

print(pred.shape)
pred = dist.sample([10]).tolist()
print(pred)
print([index_to_word[word] for word in pred])

torch.Size([6084])
[9, 19, 3899, 0, 27, 9, 19, 9, 9, 9]
[',', 'and', 'betimes', 'the', 'with', ',', 'and', ',', ',', ',']
