In [24]:
import pathlib
import torch
from torch.utils.data import DataLoader

from gpt_builder.tokenizer import Tokenizer
from gpt_builder.dataset import BigramDataset
from gpt_builder.model import BigramLLM
from gpt_builder.utils import bigram_crossentropy_loss, train_step

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Read in raw text

In [2]:
data_dir = pathlib.Path("data")
with open(data_dir / "wizard_of_oz.txt", "r", encoding="utf-8") as f:
    text = f.read()
print("Text length: ", len(text))
print(text[:200])

Text length:  232284
DOROTHY AND THE WIZARD IN OZ

BY

L. FRANK BAUM

AUTHOR OF THE WIZARD OF OZ, THE LAND OF OZ, OZMA OF OZ, ETC.

ILLUSTRATED BY JOHN R. NEILL

BOOKS OF WONDER WILLIAM MORROW & CO., INC. NEW YORK


[Illu


# Create tokenizer

In [3]:
# Get unique characters
chars = sorted(set(text))
print("Number of unique characters: ", len(chars))
print(chars)

Number of unique characters:  80
['\n', ' ', '!', '"', '&', "'", '(', ')', '*', ',', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [4]:
tokenizer = Tokenizer(chars)
hello_tokens = tokenizer.encode("Hello")
print("Encoded hello: ", hello_tokens)
hello_decoded = "".join(tokenizer.decode(hello_tokens))
print("Decoded hello: ", hello_decoded)

Encoded hello:  [32, 58, 65, 65, 68]
Decoded hello:  Hello


In [5]:
# Tokenize Wizard of Oz
data = tokenizer.encode(text, return_tensors=True)
print(data[:200])

tensor([28, 39, 42, 39, 44, 32, 49,  1, 25, 38, 28,  1, 44, 32, 29,  1, 47, 33,
        50, 25, 42, 28,  1, 33, 38,  1, 39, 50,  0,  0, 26, 49,  0,  0, 36, 11,
         1, 30, 42, 25, 38, 35,  1, 26, 25, 45, 37,  0,  0, 25, 45, 44, 32, 39,
        42,  1, 39, 30,  1, 44, 32, 29,  1, 47, 33, 50, 25, 42, 28,  1, 39, 30,
         1, 39, 50,  9,  1, 44, 32, 29,  1, 36, 25, 38, 28,  1, 39, 30,  1, 39,
        50,  9,  1, 39, 50, 37, 25,  1, 39, 30,  1, 39, 50,  9,  1, 29, 44, 27,
        11,  0,  0, 33, 36, 36, 45, 43, 44, 42, 25, 44, 29, 28,  1, 26, 49,  1,
        34, 39, 32, 38,  1, 42, 11,  1, 38, 29, 33, 36, 36,  0,  0, 26, 39, 39,
        35, 43,  1, 39, 30,  1, 47, 39, 38, 28, 29, 42,  1, 47, 33, 36, 36, 33,
        25, 37,  1, 37, 39, 42, 42, 39, 47,  1,  4,  1, 27, 39, 11,  9,  1, 33,
        38, 27, 11,  1, 38, 29, 47,  1, 49, 39, 42, 35,  0,  0,  0, 51, 33, 65,
        65, 74])


# Create Bigram dataset

In [6]:
dataset = BigramDataset(data)
in_bigram, out_bigram = dataset[0]
print("In bigram: ", in_bigram)
print("Out bigram: ", out_bigram)

In bigram:  tensor([28, 39, 42, 39, 44, 32, 49,  1])
Out bigram:  tensor([39, 42, 39, 44, 32, 49,  1, 25])


# Examine model

In [28]:
vocab_size = len(chars)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
llm = BigramLLM(vocab_size).to(device)
x_out = llm(in_bigram.to(device))
x_out.shape

torch.Size([8, 80])

In [8]:
# Compute cros entropy lloss
bigram_crossentropy_loss(x_out, out_bigram)

tensor(5.1492, grad_fn=<NllLossBackward0>)

In [29]:
# Generate new tokens
llm.eval()
x_new = llm.generate(in_bigram.to(device), 10)

In [30]:
# Decode new sequence
print("Context: ", tokenizer.decode(in_bigram.tolist()))
print("New sequence: ", tokenizer.decode(x_new[0].cpu().tolist()))

Context:  ['D', 'O', 'R', 'O', 'T', 'H', 'Y', ' ']
New sequence:  ['D', 'O', 'R', 'O', 'T', 'H', 'Y', ' ', 'c', ':', 'k', '4', 'p', '!', 's', 'y', 'a', '3']


## Training loop

In [32]:
N_ITERS = 5
LEARNING_RATE = 3e-4

train_dl = DataLoader(dataset, batch_size=64)
optim = torch.optim.AdamW(llm.parameters(), lr=LEARNING_RATE)

for _ in range(N_ITERS):
    inputs, targets = next(iter(train_dl))
    inputs, targets = inputs.to(device), targets.to(device)
    train_step(llm, inputs, targets, optim)

Loss:  4.894145488739014
Loss:  4.893566131591797
Loss:  4.892986297607422
Loss:  4.892405986785889
Loss:  4.89182710647583
