In [1]:
%load_ext autoreload
%autoreload 3

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from torch.utils.data import DataLoader

import torch.optim as optim

from jaxtyping import Float, Integer, Bool
from tinyshakespeare import Vocab, TinyShakespeareDataset, make_train_val_dataloader

In [3]:
# even though GPT is a decoder-only transformer, in PyTorch terminology
# a decoder has a cross-attention component (which GPT does not);
# therefore, we implement a "decoder-only" transformer using
# PyTorch built-ins with the TransformerEncoderLayer and provide a
# causal (upper-triangular) mask to the input

# TODO: since vocab is part of the model (unfortunately?), put it in the GPT class

class GPT(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 128,
        num_heads: int = 8,
        d_ff: int = 256,
        dropout: float = 0.
    ):
        super().__init__()

        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=d_ff,
            dropout=dropout,
            batch_first=True
        )
        self.embed = nn.Embedding(vocab_size, d_model)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=6)
        # swappable output "head" at the end of the encoders
        self.output = nn.Sequential(
            nn.Linear(d_model, vocab_size),
        )

    def forward(self, x: Integer[Tensor, "batch tokens"]):
        _, t = x.shape
        causal_mask = torch.triu(torch.ones(t, t, device=x.device), diagonal=1)
        embed_o: Float[Tensor, "batch tokens {d_model}"] = self.embed(x)

        enc_o = self.encoder(embed_o, mask=causal_mask, is_causal=True)
        output_o = self.output(enc_o)

        return output_o


In [4]:
@torch.no_grad()
def estimate_loss(
    model,
    dl,
    device: torch.device,
    num_iters: int = 10
):
    losses = {"train": 0, "val": 0}

    model.to(device)
    model.eval()
    
    for split in ("train", "val"):
        dl_iter = iter(dl[split])
        
        for _ in range(num_iters):
            x, y = next(dl_iter)
            x = x.to(device)
            y = y.to(device)
            
            logits = model(x)
            loss = F.cross_entropy(logits.transpose(-2, -1), y)
            losses[split] += loss.item()

        losses[split] /= num_iters
        
    return losses
        

In [5]:
def train(
    model: nn.Module,
    dl: dict[str, DataLoader],
    optimizer: optim.Optimizer,
    steps: int = 100,
    eval_interval: int = 10,
    device: str | None = None,
):
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)
    
    model.to(device)
    
    train_iter = iter(dl["train"])
    
    for steps in range(1, steps + 1):
        model.train()
        
        x, y = next(train_iter)
        x = x.to(device)
        y = y.to(device)
        
        logits = model(x)
        loss = F.cross_entropy(logits.transpose(-2, -1), y)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        if steps % eval_interval == 0:
            losses = estimate_loss(model, dl, device)
            print(f"train loss: {losses['train']:.4f} ; val loss: {losses['val']:.4f}")
    

In [6]:
# hyperparameters

seq_len = 128
train_steps = 10000
learning_rate = 3e-4

ds = TinyShakespeareDataset("./data/tinyshakespeare.txt", seq_len=seq_len)
dl = make_train_val_dataloader(ds, batch_size=32, shuffle=True, num_workers=1)

torch.manual_seed(38650)

model = GPT(
    ds.vocab.size
)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

train(
    model,
    dl,
    optimizer,
    steps=train_steps,
    eval_interval=train_steps//100
)

train loss: 2.5678 ; val loss: 2.5585
train loss: 2.4736 ; val loss: 2.4889
train loss: 2.4155 ; val loss: 2.4364
train loss: 2.3475 ; val loss: 2.3924
train loss: 2.3344 ; val loss: 2.3633
train loss: 2.2978 ; val loss: 2.3191
train loss: 2.2469 ; val loss: 2.2896
train loss: 2.2135 ; val loss: 2.2354
train loss: 2.1607 ; val loss: 2.1885
train loss: 2.0921 ; val loss: 2.1706
train loss: 2.0530 ; val loss: 2.1354
train loss: 2.0336 ; val loss: 2.0939
train loss: 1.9752 ; val loss: 2.0581
train loss: 1.9553 ; val loss: 2.0205
train loss: 1.9262 ; val loss: 2.0126
train loss: 1.8869 ; val loss: 1.9971
train loss: 1.8315 ; val loss: 1.9534
train loss: 1.8292 ; val loss: 1.9429
train loss: 1.8230 ; val loss: 1.9433
train loss: 1.7753 ; val loss: 1.8985
train loss: 1.7669 ; val loss: 1.8943
train loss: 1.7491 ; val loss: 1.8890
train loss: 1.7310 ; val loss: 1.8727
train loss: 1.7328 ; val loss: 1.8659
train loss: 1.6958 ; val loss: 1.8530
train loss: 1.6882 ; val loss: 1.8279
train loss: 

In [7]:
def generate(
    model: GPT,
    vocab: Vocab,
    prompt: str,
    context_len: int = 8,
    num_toks: int = 32,
    device: str | None = None
):
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)

    model.to(device)
        
    # batch size 1
    ctx = torch.tensor([vocab.encode(list(prompt))], dtype=torch.long).to(device)

    generated = []
    
    for i in range(num_toks):
        # crop context to the context length
        logits = model(ctx[:, -context_len:])
        # we only care about the last token
        logits = logits[:, -1, :]
        # get probabilities
        probs = F.softmax(logits, dim=-1)
        next_tok = torch.multinomial(probs, num_samples=1)

        ctx = torch.cat((ctx, next_tok), dim=1)

    return "".join(vocab.decode(ctx.tolist()[0]))

In [8]:
gen = generate(
    model,
    ds.vocab,
    " ",
    context_len=seq_len,
    num_toks=512
)
print(gen)

 not ground stap and made in me.
No, there 'tis you near friar your death.

KING RICHARD II:
Ay, the doth is cheeks sit?

QUEEN ELIZABETH:
And here it be has relet you,
There thy treason same where is me:
Nor see it fengar them her is life,
And marrings dead. 'But and ratisful, Thurse speed
To sent thyself, and you life, be your very,
Some welcome then.

Both:
'Tis open bruing for you, surple, and me me,
The morning of your forfecialintess!

DUKE VINCENTIO:
Fealines, come, so,
May doubting so him nox'd way: 
