# Let's Build GPT

Data: https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt

Video: https://www.youtube.com/watch?v=kCc8FmEb1nY&list=PLAqhIrjkxbuWI23v9cThsA9GvCAUhRvKZ

# Imports

In [1]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on:", device)

Running on: cuda


# Dataset

In [2]:
class Tokenizer:
    def __init__(self, vocab):
        assert isinstance(vocab, list)
        assert all(isinstance(v, str) for v in vocab)
        assert all(len(v) == 1 for v in vocab)
        self.n_vocab = len(vocab)
        self.stoi = {ch: i for i, ch in enumerate(vocab)}
        self.itos = {i: ch for i, ch in enumerate(vocab)}

    def encode(self, text):
        return [self.stoi[s] for s in text]

    def decode(self, sequence):
        if isinstance(sequence, list):
            return ''.join([self.itos[i] for i in sequence])
        elif isinstance(sequence, torch.Tensor):
            assert sequence.ndim in [0, 1]
            if sequence.ndim == 0:
                return self.itos[sequence.item()]  # one char
            else:
                return ''.join([self.itos[i.item()] for i in sequence])
        else:
            raise ValueError(f"Type {type(sequence)} not supported")

In [3]:
with open('../data/tinyshakespeare.txt', 'r') as f:
    text = f.read()
print("Num chars:", len(text))
print("Dataset Start:")
print(text[:462])

Num chars: 1115394
Dataset Start:
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.


In [4]:
# Get vocabulary
letters = sorted(list(set(''.join(text))))
print(''.join(letters))
print('Num:', len(letters))


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Num: 65


In [5]:
tok = Tokenizer(letters)
print(tok.encode("hii there"))
print(tok.decode(tok.encode("hii there")))
print(f"Newline is: {tok.encode('\n')[0]}")

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there
Newline is: 0


In [6]:
data = torch.tensor(tok.encode(text), dtype=torch.long)
n = int(0.9*len(text))
train_data, eval_data = data[:n], data[n:]  # 90%/10% split 
print(f"Train data len: {len(train_data)}")
print(f"Valid data len: {len(eval_data)}")

Train data len: 1003854
Valid data len: 111540


In [7]:
class DataLoader:
    def __init__(self, data, batch_size, sequence_length, device):
        self.data = data
        self.n_batch = batch_size
        self.n_seq = sequence_length
        self.device = device

    def get_batch(self):
        bi = torch.randint(len(self.data)-self.n_seq, (self.n_batch,))
        x = torch.stack([self.data[i:i+self.n_seq] for i in bi])
        y = torch.stack([self.data[i+1:i+1+self.n_seq] for i in bi])
        x, y = x.to(self.device), y.to(self.device)
        return x, y

In [8]:
torch.manual_seed(1337)
tr_data_loader = DataLoader(train_data, batch_size=4, sequence_length=8, device=device)

In [9]:
x_batch, y_batch = tr_data_loader.get_batch()

In [10]:
print(x_batch.shape)
print(x_batch)

torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]], device='cuda:0')


In [11]:
print(y_batch.shape)
print(y_batch)

torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]], device='cuda:0')


In [38]:
class Head(nn.Module):
    """One self-attention head"""
    def __init__(self, n_seq, n_in, n_head):
        super().__init__()
        self.n_head = n_head

        self.key = nn.Linear(n_in, n_head, bias=False)
        self.query = nn.Linear(n_in, n_head, bias=False)
        self.value = nn.Linear(n_in, n_head, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones((n_seq, n_seq))))

    def forward(self, x):
        B,T,C = x.shape
        x_key = self.key(x)    # B,T,H
        x_query = self.query(x)  # B,T,H
        x_value = self.value(x)  # B,T,H

        H = x_key.shape[-1]
        W_affin = x_query @ x_key.mT / H**0.5 # / H**0.5  # B,T,T <- B,T,C @ B,C,T
        W_affin = W_affin.masked_fill(self.tril[:T,:T]==0, float('-inf'))
        W_affin = torch.softmax(W_affin, dim=-1)

        self.out = W_affin @ x_value
        return self.out

In [41]:
class MultiHead(nn.Module):
    """Multiple self-attention heads"""
    def __init__(self, n_seq, n_in, head_size, num_heads):
        super().__init__()
        self.heads = nn.ModuleList(
            [Head(n_seq, n_in, head_size) for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([h(x) for h in self.heads], dim=-1)

In [42]:
class TransformerModel(nn.Module):
    def __init__(self, n_seq, n_vocab, n_embd):
        super().__init__()
        self.tok_emb_table = nn.Embedding(n_vocab, n_embd)
        self.pos_emb_table = nn.Embedding(n_seq, n_embd)
        self.sa_heads = MultiHead(n_seq, n_in=n_embd, head_size=n_embd//4, num_heads=4)
        self.lm_head = nn.Linear(n_embd, n_vocab)
    
    def forward(self, idx, targets=None):
        assert idx.dtype == torch.long
        assert targets is None or targets.dtype == torch.long
        B, T = idx.shape
        
        tok_emb = self.tok_emb_table(idx)    #  B,T,E <- B,T
        pos_emb = self.pos_emb_table(torch.arange(T, device=device))    #  B,T,E <- B,T
        x = tok_emb + pos_emb  #  B,T,E
        x = self.sa_heads(x)        # B,T,E
        logits = self.lm_head(x)   # B,T,V <- B,T,E

        if targets is None:
            return logits, None
        else:
            B, T, C = logits.shape
            logits_ = logits.view(B*T, C)  # B*T, C
            targets_ = targets.view(B*T)   # B*T
            loss = F.cross_entropy(logits_, targets_)
            return logits, loss
    
    def generate(self, idx, n_seq, max_tokens):
        """Generate max_tokens starting from idx[B,T]"""
        # assert idx.shape == (n_batch, n_seq)
        assert idx.dtype == torch.long
        assert isinstance(max_tokens, int)

        for _ in range(max_tokens):

            # Sliding window over idx
            idx_tail = idx[:, -n_seq:]

            # Model Output
            logits, _ = self(idx_tail)      # B,T,C <- B,T

            # Discard all but last step
            logits = logits[:, -1, :]  # B,C <- B,T,C

            probs = F.softmax(logits, dim=-1)  # (B, C)

            idx_next = torch.multinomial(probs, num_samples=1)  # B, 1

            idx = torch.cat((idx, idx_next), dim=1)  # B, T+1

        return idx

In [43]:
# Reproducibility
torch.manual_seed(1337)

# Hyperparameters
n_vocab = tok.n_vocab   # num letters, token dictionary size
n_batch = 32            # mini-bach, how many in parallel
n_seq = 8               # max context length, max len feed into the model
n_embd = 32             # size of embeddings, i.e. 'first layer'
lr = 1e-3

# Data Loaders
tr_data_loader = DataLoader(train_data, n_batch, n_seq, device)
ev_data_loader = DataLoader(eval_data, n_batch, n_seq, device)

# Model
model = TransformerModel(n_seq, n_vocab, n_embd)
model = model.to(device)

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

# Initial Loss
print(f"Expected initial loss: {-torch.tensor(1/n_vocab).log()}")
logits, loss = model(x_batch, y_batch)
print(f"Initial model loss: {loss}")

Expected initial loss: 4.174387454986572
Initial model loss: 4.353755950927734


In [44]:
# Example Generation
model.eval()
with torch.no_grad():
    idx = torch.tensor([[0]], device=device)  # B=1, T=1, '\n'
    res = model.generate(idx, n_seq, max_tokens=100)
    print(tok.decode(res[0].tolist()))


?qf;xbDkRZkNwc'wf,ZT,OLFT,ebtK
b:iPjCkMBbzA$3:XaSvgO-33jM:F?gLTauhX:YVXJthXfNuwqcPMxv.tbVr dXl!DZaAe


In [45]:
@torch.no_grad()
def evaluate(data_loader, num_evals, device=device):
    model.eval()
    losses = torch.zeros(num_evals, device=device)
    for i in range(num_evals):
        x_batch, y_batch = data_loader.get_batch()
        _, loss = model(x_batch, y_batch)
        losses[i] = loss
    return losses.mean().item()

In [46]:
num_epochs = 5000
eval_every = 1000
eval_iters = 200

In [47]:
model.train()
t_start = time.time()
for i in range(num_epochs):

    xb, yb = tr_data_loader.get_batch()
    
    # Loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if i % eval_every == 0 or i == num_epochs-1:
        train_loss = evaluate(tr_data_loader, eval_iters, device)
        eval_loss = evaluate(ev_data_loader, eval_iters, device)
        model.train()

        t_diff = time.time() - t_start; t_start = time.time()
        print(f"t={t_diff:.2f}s i={i}, l_running={loss.item():.4f}, "
              f"tr={train_loss:.4f} ev={eval_loss:.4f}")              

t=0.83s i=0, l_running=4.2242, tr=4.2120 ev=4.2124
t=11.05s i=1000, l_running=2.5644, tr=2.5095 ev=2.5217
t=10.93s i=2000, l_running=2.3421, tr=2.3871 ev=2.3863
t=10.93s i=3000, l_running=2.3494, tr=2.3211 ev=2.3315
t=11.02s i=4000, l_running=2.2687, tr=2.2750 ev=2.2967
t=10.94s i=4999, l_running=2.2519, tr=2.2574 ev=2.2714


In [48]:
# Generate
model.eval()
with torch.no_grad():
    idx = torch.tensor([[0]], device=device)  # B=1, T=1, '\n'
    res = model.generate(idx, n_seq, max_tokens=500)
    print(tok.decode(res[0].tolist()))


Wawice my.

HDER:
Atzo mup
Yownt
Moof is he cove whedill, aes isee--ve cin lat Herid ove the, in me nownow that opel lind te lit-hus, cochiry ptupr aiss hiwhy.

Srings kne
To thig I whom.

I the to ake onWinsot her piibys worti dourive wee, ime st so mower; thune kind thrupt foron; igre! me monge inled, the af Pried my om.

HKINGLER:
Hind is:
Shosal ther ghe thinicenst asar tey Ire to chan thove youne ton, bemary.

He 'sth wherm sonot; myse.

An;
Therwerten.

CJRY; pen;
I he what wery ome.

Thuc
