# Let's Build GPT

Data: https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt

Video: https://www.youtube.com/watch?v=kCc8FmEb1nY&list=PLAqhIrjkxbuWI23v9cThsA9GvCAUhRvKZ

# Imports

In [1]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on:", device)

Running on: cuda


# Dataset

In [2]:
class Tokenizer:
    def __init__(self, vocab):
        assert isinstance(vocab, list)
        assert all(isinstance(v, str) for v in vocab)
        assert all(len(v) == 1 for v in vocab)
        self.n_vocab = len(vocab)
        self.stoi = {ch: i for i, ch in enumerate(vocab)}
        self.itos = {i: ch for i, ch in enumerate(vocab)}

    def encode(self, text):
        return [self.stoi[s] for s in text]

    def decode(self, sequence):
        if isinstance(sequence, list):
            return ''.join([self.itos[i] for i in sequence])
        elif isinstance(sequence, torch.Tensor):
            assert sequence.ndim in [0, 1]
            if sequence.ndim == 0:
                return self.itos[sequence.item()]  # one char
            else:
                return ''.join([self.itos[i.item()] for i in sequence])
        else:
            raise ValueError(f"Type {type(sequence)} not supported")

In [3]:
with open('../data/tinyshakespeare.txt', 'r') as f:
    text = f.read()
print("Num chars:", len(text))
print("Dataset Start:")
print(text[:462])

Num chars: 1115394
Dataset Start:
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.


In [4]:
# Get vocabulary
letters = sorted(list(set(''.join(text))))
print(''.join(letters))
print('Num:', len(letters))


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Num: 65


In [5]:
tok = Tokenizer(letters)
print(tok.encode("hii there"))
print(tok.decode(tok.encode("hii there")))
print(f"Newline is: {tok.encode('\n')[0]}")

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there
Newline is: 0


In [6]:
data = torch.tensor(tok.encode(text), dtype=torch.long)
n = int(0.9*len(text))
train_data, eval_data = data[:n], data[n:]  # 90%/10% split 
print(f"Train data len: {len(train_data)}")
print(f"Valid data len: {len(eval_data)}")

Train data len: 1003854
Valid data len: 111540


In [7]:
class DataLoader:
    def __init__(self, data, batch_size, sequence_length, device):
        self.data = data
        self.n_batch = batch_size
        self.n_seq = sequence_length
        self.device = device

    def get_batch(self):
        bi = torch.randint(len(self.data)-self.n_seq, (self.n_batch,))
        x = torch.stack([self.data[i:i+self.n_seq] for i in bi])
        y = torch.stack([self.data[i+1:i+1+self.n_seq] for i in bi])
        x, y = x.to(self.device), y.to(self.device)
        return x, y

In [8]:
torch.manual_seed(1337)
tr_data_loader = DataLoader(train_data, batch_size=4, sequence_length=8, device=device)

In [9]:
x_batch, y_batch = tr_data_loader.get_batch()

In [10]:
print(x_batch.shape)
print(x_batch)

torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]], device='cuda:0')


In [11]:
print(y_batch.shape)
print(y_batch)

torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]], device='cuda:0')


In [12]:
class BigramLanguageModel(nn.Module):
    def __init__(self, n_vocab):
        super().__init__()
        self.emb_table = nn.Embedding(n_vocab, n_vocab)
    
    def forward(self, idx, targets=None):
        assert idx.dtype == torch.long
        assert targets is None or targets.dtype == torch.long
        
        logits = self.emb_table(idx)    # B,T,C <- B,S

        if targets is None:
            return logits, None
        else:
            B, T, C = logits.shape
            logits_ = logits.view(B*T, C)  # B*T, C
            targets_ = targets.view(B*T)   # B*T
            loss = F.cross_entropy(logits_, targets_)
            return logits, loss
    
    def generate(self, idx, max_tokens):
        """Generate max_tokens starting from idx[B,T]"""
        # assert idx.shape == (n_batch, n_seq)
        assert idx.dtype == torch.long
        assert isinstance(max_tokens, int)

        for _ in range(max_tokens):
            # Model Output
            logits, _ = self(idx)      # B,T,C <- B,T

            # Discard all but last step
            logits = logits[:, -1, :]  # B,C <- B,T,C

            probs = F.softmax(logits, dim=-1)  # (B, C)

            idx_next = torch.multinomial(probs, num_samples=1)  # B, 1

            idx = torch.cat((idx, idx_next), dim=1)  # B, T+1

        return idx

In [None]:
# Reproducibility
torch.manual_seed(1337)

# Hyperparameters
n_vocab = tok.n_vocab   # num letters, token dictionary size
n_batch = 32            # mini-bach, how many in parallel
n_seq = 8               # max context length, max len feed into the model
n_emb = n_vocab         # size of embeddings, i.e. 'first layer'
lr = 1e-3

# Data Loaders
tr_data_loader = DataLoader(train_data, n_batch, n_seq, device)
ev_data_loader = DataLoader(eval_data, n_batch, n_seq, device)

# Model
model = BigramLanguageModel(n_vocab)
model = model.to(device)

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

# Initial Loss
print(f"Expected initial loss: {-torch.tensor(1/n_vocab).log()}")
logits, loss = model(x_batch, y_batch)
print(f"Initial model loss: {loss}")

Expected initial loss: 4.174387454986572
Initial model loss: 4.878634929656982


In [None]:
# Example Generation
model.eval()
with torch.no_grad():
    idx = torch.tensor([[0]], device=device)  # B=1, T=1, '\n'
    res = model.generate(idx, max_tokens=100)
    print(tok.decode(res[0].tolist()))


pYCXxfRkRZd
wc'wfNfT;OLlTEeC K
jxqPToTb?bXAUG:C-SGJO-33SM:C?YI3a
hs:LVXJFhXeNuwqhObxZ.tSVrddXlaSZaNe


In [15]:
@torch.no_grad()
def evaluate(data_loader, num_evals, device=device):
    model.eval()
    losses = torch.zeros(num_evals, device=device)
    for i in range(num_evals):
        x_batch, y_batch = data_loader.get_batch()
        _, loss = model(x_batch, y_batch)
        losses[i] = loss
    return losses.mean().item()

In [17]:
num_epochs = 30000
eval_every = 1000
eval_iters = 200

In [19]:
model.train()
t_start = time.time()
for i in range(num_epochs):

    xb, yb = tr_data_loader.get_batch()
    
    # Loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if i % eval_every == 0:
        train_loss = evaluate(tr_data_loader, eval_iters, device)
        eval_loss = evaluate(ev_data_loader, eval_iters, device)
        model.train()

        t_diff = time.time() - t_start; t_start = time.time()
        print(f"t={t_diff:.2f}s i={i}, l_running={loss.item():.4f}, "
              f"tr={train_loss:.4f} ev={eval_loss:.4f}")              

t=0.34s i=0, l_running=2.4709, tr=2.4535 ev=2.4892
t=1.82s i=1000, l_running=2.4689, tr=2.4502 ev=2.4780
t=1.85s i=2000, l_running=2.4717, tr=2.4506 ev=2.4722
t=1.83s i=3000, l_running=2.6102, tr=2.4492 ev=2.4858
t=1.83s i=4000, l_running=2.4576, tr=2.4573 ev=2.4785
t=1.73s i=5000, l_running=2.4079, tr=2.4547 ev=2.4695
t=1.75s i=6000, l_running=2.5362, tr=2.4524 ev=2.4749
t=1.81s i=7000, l_running=2.4280, tr=2.4466 ev=2.4855
t=1.84s i=8000, l_running=2.4976, tr=2.4577 ev=2.4838
t=1.85s i=9000, l_running=2.3893, tr=2.4520 ev=2.4956
t=1.84s i=10000, l_running=2.4272, tr=2.4535 ev=2.4853
t=1.83s i=11000, l_running=2.3820, tr=2.4471 ev=2.4758
t=1.84s i=12000, l_running=2.3700, tr=2.4543 ev=2.4974
t=1.77s i=13000, l_running=2.4850, tr=2.4565 ev=2.4826
t=1.72s i=14000, l_running=2.4776, tr=2.4558 ev=2.4748
t=1.83s i=15000, l_running=2.4000, tr=2.4523 ev=2.4991
t=1.69s i=16000, l_running=2.3444, tr=2.4575 ev=2.4945
t=1.68s i=17000, l_running=2.5712, tr=2.4543 ev=2.4849
t=1.69s i=18000, l_runn

In [None]:
# Generate
model.eval()
with torch.no_grad():
    idx = torch.tensor([[0]], device=device)  # B=1, T=1, '\n'
    res = model.generate(idx, max_tokens=500)
    print(tok.decode(res[0].tolist()))


Wawice my.

Hastarom oroup
Yowhthetof isth ble mil ndill, ath iree sengmin lat Heriliovets, and Win nghir.
Thanousel lind me l.
HAshe ce hiry:
Supr aisspllw y.
Herindu n Boopetelaves
MP:

Pl, d mothakleo Windo whth eisbyo the m dourive we higend t so mower; te

AN ad nterupt f s ar igr t m:

Thiny aleronth,
Mad
RD:

WISo myr f-bube!
KENoby ak
Sadsal thes ghesthidin cour ay aney Iry ts I fr y ce.
Jonghe nd, bemary.
Yof 'sour menm sora anghy t--pond betwe ten.
Wand thot sulin s th llety ome.
I muc
