In [165]:
from src.model import Transformer, TransformerConfig
import torch
import torch.nn as nn
import re
from collections import Counter
from torch.utils.data import TensorDataset, DataLoader
from tqdm.notebook import trange
import pandas as pd
import os
import requests
import pickle

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [169]:
# Download data if not already downloaded
path = "./data"
filename = "/shakespeare.txt"
if not os.path.isfile(path+filename):
    if not os.path.exists(path):
        os.makedirs(path)
    r = requests.get("https://www.gutenberg.org/files/100/100-0.txt", stream = True)
    with open(path+filename, "w") as f:
        f.write(r.text)

In [170]:
with open('data/shakespeare.txt') as f:
    data = f.read()

words = re.split(r"\b", data)
print(f"Data size: {len(words)} words, {len(data)} characters")
word_counts = Counter(words)
print(word_counts.most_common(n=10))

Data size: 1987861 words, 5462587 characters
[(' ', 702111), (', ', 67006), ('.\n', 31293), ('the', 25606), ('\n', 23889), ('I', 23693), ('. ', 20640), (',\n', 20451), ('and', 20184), ('to', 17467)]


In [171]:
# Reduce data amount by manually preprocessing some tokens

# Remove space
initial_length = len(words)
to_remove = [" ", "\n"]
words = [word.lower() for word in words if word not in to_remove]
print(f"Removed {initial_length - len(words)} words")

# Simplify punctuation
to_replace = {".\n\n\n": ". ", ".\n\n": ". ", ".\n": ". ", ",\n\n\n": ", ", ",\n\n": ", ", ",\n": ", ", ";\n\n\n": "; ", ";\n\n": "; ", ";\n": "; "}
replaced = 0
for i, word in enumerate(words):
    if word in to_replace.keys():
        words[i] = to_replace[word]
        replaced += 1
print(f"Replaced {replaced} words")
print(f"Data size: {len(words)} words, {len(data)} characters")

word_counts = Counter(words)
print(word_counts.most_common(n=20))


Removed 726000 words
Replaced 73711 words
Data size: 1261861 words, 5462587 characters
[(', ', 87465), ('. ', 69343), ('the', 30320), ('and', 28480), ('i', 24032), ('to', 20971), ('of', 18870), ('’', 16787), ('a', 16341), ('you', 14704), ('; ', 13455), ('my', 13185), ('in', 12464), ('that', 12258), ('is', 9923), ('not', 9085), ('with', 8545), ('s', 8464), ('for', 8302), ('me', 8283)]


In [76]:
# Filter for words with <n occurrences
n = 5
high_occurence_words = []
low_occurence_words = []
for key, value in word_counts.items():
    if value >= n:
        high_occurence_words.append(key)
    else:
        low_occurence_words.append(key)
print(f"{len(high_occurence_words)} words that occur >= {n} times")
print(f"{len(low_occurence_words)} words that occur < {n} times")
print("Example low occurence words:", low_occurence_words[:10])

9623 words that occur >= 5 times
18600 words that occur < 5 times
Example low occurence words: ['\ufeff', 'restrictions', 'included', 'online', 'january', '1994', ' #', '100', ']\n[', 'recently']


In [172]:
# Create embedding
word_to_index = {word:i for i, word in enumerate(high_occurence_words)}
index_to_word = {i:word for i, word in enumerate(high_occurence_words)}
last_index = len(word_to_index)
for word in low_occurence_words:
    word_to_index[word] = last_index
    index_to_word[last_index] = "<unknown>"
print(f"Example common word embedding: {word_to_index['and']}: {index_to_word[word_to_index['and']]}")
print(f"Example uncommon word embedding: {word_to_index['restrictions']}: {index_to_word[word_to_index['restrictions']]}")
print(f"Vocabulary size: {len(index_to_word)}")

Example common word embedding: 21: and
Example uncommon word embedding: 9623: <unknown>
Vocabulary size: 9624


In [173]:
# Tokenize dataset
tokenized_data = [word_to_index[word] for word in words]
labels = tokenized_data[1:] + [len(index_to_word) - 1]

print([index_to_word[x] for x in tokenized_data[-10:]])
print([index_to_word[x] for x in labels[-10:]])

['to', 'our', '<unknown>', '<unknown>', 'to', 'hear', 'about', 'new', 'ebooks', '. ']
['our', '<unknown>', '<unknown>', 'to', 'hear', 'about', 'new', 'ebooks', '. ', '<unknown>']


In [184]:
vocabulary_size = len(index_to_word)
total_steps = 100000
max_input_length = 30
batch_size = 128
embedding_size = 128
val_split = 0.2
num_blocks = 4
num_heads = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device", device)
# Paper used 25000 tokens per batch
print("Total tokens per batch", batch_size*max_input_length)

Device cpu
Total tokens per batch 3840


In [175]:
# Create pytorch dataset
cutoff = len(tokenized_data) % max_input_length
data_x = tokenized_data[:-cutoff]
data_y = labels[:-cutoff]
data_x = torch.LongTensor(data_x).view(-1, max_input_length)
data_y = torch.LongTensor(data_y).view(-1, max_input_length)
ds = TensorDataset(data_x, data_y)
val_size = int(len(ds)*val_split)
train_ds, val_ds = torch.utils.data.random_split(ds, [len(ds)-val_size, val_size])
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

# Check if order is intact
for _, data in enumerate(train_dl):
    x, y = data[0], data[1]
    print(x.shape, y.shape)
    print([index_to_word[word] for word in x[5, :10].tolist()])
    print([index_to_word[word] for word in y[5, :10].tolist()])
    break

torch.Size([128, 30]) torch.Size([128, 30])
['were', 'post', '—\n', 'the', 'man', 'i', '’ ', 'th', '’ ', 'moon']
['post', '—\n', 'the', 'man', 'i', '’ ', 'th', '’ ', 'moon', '’']


In [176]:
transformer = Transformer(TransformerConfig(vocab_size=vocabulary_size, max_input_length=max_input_length, num_heads=num_heads, num_blocks=num_blocks, embedding_size=embedding_size), apply_softmax=False)
transformer.to(device)
loss_fn = nn.CrossEntropyLoss()

steps_per_epoch = len(train_dl)
num_epochs = total_steps // steps_per_epoch

def transformer_lr(step, d_model=embedding_size, warmup_steps=4000):
    if step==0:
        return transformer_lr(1, d_model, warmup_steps)
    return (d_model ** -0.5)*min(step**-0.5, step*(warmup_steps**-0.5))

initial_lr = transformer_lr(1)
optim = torch.optim.Adam(transformer.parameters(), lr=initial_lr, betas=(0.9, 0.98), eps=1e-09)
lr_per_epoch = lambda epoch: transformer_lr(epoch*steps_per_epoch) / initial_lr
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=[lr_per_epoch])

for _, data in enumerate(train_dl):
    x, y = data[0].to(device), data[1].to(device)
    output = transformer(x)
    print(x.shape, y.shape, output.shape)

    # Flatten batch and sequence dimension
    loss = loss_fn(output.view(-1, output.shape[-1]), y.view(-1))
    print(loss)
    
    pred = nn.functional.softmax(output, dim=-1)
    print(pred.shape)
    print([index_to_word[word] for word in x[5, :10].tolist()])
    print([index_to_word[word] for word in y[5, :10].tolist()])
    break

torch.Size([128, 30]) torch.Size([128, 30]) torch.Size([128, 30, 9624])
tensor(43.9588, grad_fn=<NllLossBackward>)
torch.Size([128, 30, 9624])
['giddy', 'in', 'spirit', ', ', 'still', 'gazing', 'in', 'a', 'doubt', 'whether']
['in', 'spirit', ', ', 'still', 'gazing', 'in', 'a', 'doubt', 'whether', 'those']


In [181]:
params = 0
for param in transformer.parameters():
    params += param.nelement()
print("Total parameters:", params)

Total parameters: 2026752


In [133]:
def eval(dl):
    transformer.eval()
    with torch.no_grad():
        total_loss = 0
        total_acc = 0
        for _, data in enumerate(dl):
            x, y = data[0].to(device), data[1].to(device)
            output = transformer(x)
            
            # Flatten batch and sequence dimension
            loss = loss_fn(output.view(-1, output.shape[-1]), y.view(-1)).item()
            total_loss += loss
            pred = nn.functional.softmax(output, dim=-1)
            pred = pred.argmax(dim=-1)
            accuracy = (pred == y).float().mean().item()
            total_acc += accuracy
    return total_loss / len(dl), total_acc / len(dl)
        

In [None]:
def train(epochs=5):
    stat_columns = ["Epoch", "Step", "Lr", "TrainLoss", "TrainAcc", "ValLoss", "ValAcc"]
    stats = []
    # Basic training loop
    for e in trange(epochs):
        transformer.train()
        for i, data in enumerate(train_dl):
            x, y = data[0].to(device), data[1].to(device)
            output = transformer(x)
            
            # Flatten batch and sequence dimension
            loss = loss_fn(output.view(-1, output.shape[-1]), y.view(-1))
            optim.zero_grad()
            loss.backward()
            optim.step()

        train_loss, train_acc = eval(train_dl)
        val_loss, val_acc = eval(val_dl)
        lr = lr_scheduler.optimizer.param_groups[0]['lr']
        step = e*steps_per_epoch
        print(f"\nEpoch {e}, lr = {lr:.6f}")
        lr_scheduler.step()
        print(f"Training loss {train_loss:.4f}, accuracy {train_acc:.4f}")
        print(f"Eval loss {val_loss:.4f}, accuracy {val_acc:.4f}")
        stats.append([e, step, lr, train_loss, train_acc, val_loss, val_acc])
    stat_df = pd.DataFrame(stats, columns=stat_columns)
    return stat_df


In [None]:
train_stats = train(num_epochs)

In [None]:
params = {
    "vocabulary_size": vocabulary_size,
    "max_input_length": max_input_length,
    "batch_size": batch_size,
    "embedding_size": embedding_size,
    "word_to_index": word_to_index,
    "index_to_word": index_to_word,
    "num_blocks": num_blocks,
    "num_heads": num_heads
}

def save_model():
    path = "./models"
    model_filename = "/transformer_model.pt"
    settings_filename = "/transformer_settings.pkl"
    stats_filename = "/transformer_stats.csv"
    if not os.path.isfile(path+model_filename):
        print("Saving model")
        if not os.path.exists(path):
            os.makedirs(path)
        torch.save(transformer.state_dict(), path+model_filename)
        with open(path + settings_filename, "wb") as f:
            pickle.dump(params, f)
        train_stats.to_csv()
    else:
        print("Model already exists, manually delete it to save current model instead.")

save_model()

In [155]:
transformer.eval()

with torch.no_grad():
  for _, data in enumerate(train_dl):
    x, y = data[0].to(device), data[1].to(device)
    output = transformer(x)

    # Flatten batch and sequence dimension
    loss = loss_fn(output.view(-1, output.shape[-1]), y.view(-1))
    print(loss)
    
    pred = nn.functional.softmax(output, dim=-1)
    print(pred[0, 0, :])
    pred = pred.argmax(dim=-1)
    print("Example")
    print("True:", " ".join([index_to_word[word] for word in y[1, :20].tolist()]))
    print("Prediction:", " ".join([index_to_word[word] for word in pred[1, :20].tolist()]))
    print("")
    print("True:", " ".join([index_to_word[word] for word in y[10, :20].tolist()]))
    print("Prediction:", " ".join([index_to_word[word] for word in pred[10, :20].tolist()]))
    break

tensor(42.2960)
tensor([5.5613e-20, 2.1305e-18, 6.7830e-15,  ..., 1.3166e-10, 4.7985e-13,
        7.7562e-19])
Example
True: give him the ring ,  and bring him if thou canst unto antonio ’ s house .  away ,  make
Prediction: beheaded directly cool pestilent pestilent pestilent directly pestilent pestilent pestilent pestilent pestilent pestilent pestilent pestilent pestilent pestilent pestilent pestilent stoop

True: glowing :
 whereas reproof ,  obedient and in order ,  fits kings ,  as they are men ,  for they
Prediction: times times times times times times times getting getting vexed times getting times times times vexed doors vexed bold doors
