In [1]:
import os

path = "data/tinystories/"
corpus = ""

for filename in os.listdir(path):
    with open(path + filename, "r") as f:
        file = f.read()
        print(f"{filename:50} has {len(file):,} words")
        corpus += file + "\n"
        f.close()


the_odyssey.txt                                    has 698,079 words
moby_dick.txt                                      has 1,238,224 words
alice_wonderland.txt                               has 163,916 words
frankenstein.txt                                   has 438,808 words
dracula.txt                                        has 865,171 words
a_tale_of_two_cities.txt                           has 776,878 words
pride_and_prejudice.txt                            has 748,121 words
the_complete_works_of_william_shakespeare.txt      has 5,378,662 words
a_room_with_a_view.txt                             has 394,369 words
metamorphosis.txt                                  has 138,257 words
the_great_gatspy.txt                               has 290,075 words
adventures_of_huckleberry_finn.txt                 has 590,377 words
the_iliad.txt                                      has 1,116,791 words


In [2]:
import tokenizers
from tokenizers import Tokenizer

tokenizer = Tokenizer.from_file("projects/4-stories/tokenizer.json")
tokenizer.decoder = tokenizers.decoders.ByteLevel()

In [3]:
with open(path + "alice_wonderland.txt", "r") as f:
    corpus = f.read()

In [4]:
encoded = tokenizer.encode(corpus).ids
print(f"Encoded corpus has {len(encoded):,} tokens")

Encoded corpus has 52,316 tokens


In [5]:
train_split = 0.8
train_size = int(len(encoded) * train_split)
train_data = encoded[:train_size]
val_data = encoded[train_size:]

print(f"Training data has {len(train_data):,} tokens")
print(f"Validation data has {len(val_data):,} tokens")

Training data has 41,852 tokens
Validation data has 10,464 tokens


In [None]:
device = "mps"

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
import matplotlib.pyplot as plt

In [10]:
# make dataset and round to nearest 256
train_dataset = torch.tensor(train_data[:-(len(train_data) % 128)]).view(-1, 128)
val_dataset = torch.tensor(val_data[:-(len(val_data) % 128)]).view(-1, 128)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False)

In [None]:
config = transformers.GPT2Config(
    vocab_size=4096,
    n_positions=512,
    n_embd=128,
    n_layer=8,
    n_head=4,
    n_inner=None, # default 4 * n_embd
)
model = transformers.GPT2LMHeadModel(config).to(device)
num_train_steps = 0
print(f"Model using {model.num_parameters():,} parameters.")

In [None]:
optim = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=0.01)

In [None]:
import tqdm.notebook as tqdm

In [None]:
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    total_loss = 0
    for batch in tqdm.tqdm(loader, desc="Evaluation"):
        batch = batch.to(device)
        x, y = batch[:, :-1], batch[:, 1:]
        
        output = model(x, labels=y)
        
        loss = output.loss
        total_loss += loss.item()
    return total_loss / len(loader)

evaluate(model, val_loader)

In [None]:
from ema_pytorch import EMA
ema = EMA(model, beta=0.99)

In [None]:
# Training loop
pbar = tqdm.tqdm(range(20), desc="Training")
for epoch in pbar:
    # test_loss = evaluate(model, train_loader)
    model.train()
    pbar = tqdm.tqdm(train_loader, leave=True, desc=f"Epoch {epoch}")
    for seq in pbar:
        seq = seq.to(device)
        x = seq[:, :-1]
        y = seq[:, 1:]

        optim.zero_grad()
        output = model(x, labels=y)
        loss = output.loss
        loss.backward()
        optim.step()

        ema.update()

        pbar.set_description(f"Epoch {epoch}")
        pbar.set_postfix_str(f"Loss: {loss.item():.4f}, Test Loss: {test_loss:.4f}")

In [None]:
# Generate a new name
context = torch.tensor(tokenizer.encode("Once upon a time, there was a young prince named").ids, device=device).unsqueeze(0).long()
generated = model.generate(context, max_length=100)
print(tokenizer.decode(generated[0].tolist()))


In [None]:
# Save Model
torch.save(model.state_dict(), 'projects/4-shakespeare/model.pt')