In [1]:
%load_ext autoreload
%autoreload 3

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from torch.utils.data import DataLoader

import torch.optim as optim

from jaxtyping import Float, Integer, Bool
from tinyshakespeare import Vocab, TinyShakespeareDataset, make_train_val_dataloader

In [3]:
# even though GPT is a decoder-only transformer, in PyTorch terminology
# a decoder has a cross-attention component (which GPT does not);
# therefore, we implement a "decoder-only" transformer using
# PyTorch built-ins with the TransformerEncoderLayer and provide a
# causal (upper-triangular) mask to the input

# TODO: since vocab is part of the model (unfortunately?), put it in the GPT class

class GPT(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 128,
        num_heads: int = 8,
        d_ff: int = 256,
        dropout: float = 0.
    ):
        super().__init__()

        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=d_ff,
            dropout=dropout,
            batch_first=True
        )
        self.embed = nn.Embedding(vocab_size, d_model)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=6)
        # swappable output "head" at the end of the encoders
        self.output = nn.Sequential(
            nn.Linear(d_model, vocab_size),
        )

    def forward(self, x: Integer[Tensor, "batch tokens"]):
        _, t = x.shape
        causal_mask = torch.triu(torch.ones(t, t, device=x.device), diagonal=1)
        embed_o: Float[Tensor, "batch tokens {d_model}"] = self.embed(x)

        enc_o = self.encoder(embed_o, mask=causal_mask, is_causal=True)
        output_o = self.output(enc_o)

        return output_o


In [4]:
@torch.no_grad()
def estimate_loss(
    model,
    dl,
    device: torch.device,
    num_iters: int = 10
):
    losses = {"train": 0, "val": 0}

    model.to(device)
    model.eval()
    
    for split in ("train", "val"):
        dl_iter = iter(dl[split])
        
        for _ in range(num_iters):
            x, y = next(dl_iter)
            x = x.to(device)
            y = y.to(device)
            
            logits = model(x)
            loss = F.cross_entropy(logits.transpose(-2, -1), y)
            losses[split] += loss.item()

        losses[split] /= num_iters
        
    return losses
        

In [5]:
def train(
    model: nn.Module,
    dl: dict[str, DataLoader],
    optimizer: optim.Optimizer,
    steps: int = 100,
    eval_interval: int = 10,
    device: str | None = None,
):
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)
    
    model.to(device)
    
    train_iter = iter(dl["train"])
    
    for steps in range(1, steps + 1):
        model.train()
        
        x, y = next(train_iter)
        x = x.to(device)
        y = y.to(device)
        
        logits = model(x)
        loss = F.cross_entropy(logits.transpose(-2, -1), y)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        if steps % eval_interval == 0:
            losses = estimate_loss(model, dl, device)
            print(f"train loss: {losses['train']:.4f} ; val loss: {losses['val']:.4f}")
    

In [6]:
ds = TinyShakespeareDataset("./tinyshakespeare.txt", seq_len=8)
dl = make_train_val_dataloader(ds, batch_size=32, shuffle=True, num_workers=1)

train_steps = 1000
learning_rate = 3e-4

model = GPT(
    ds.vocab.size
)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

train(
    model,
    dl,
    optimizer,
    steps=train_steps
)

train loss: 3.4424 ; val loss: 3.4715
train loss: 3.3115 ; val loss: 3.3194
train loss: 3.1665 ; val loss: 3.1910
train loss: 3.0207 ; val loss: 3.0618
train loss: 2.8966 ; val loss: 2.9075
train loss: 2.8249 ; val loss: 2.8290
train loss: 2.7897 ; val loss: 2.7467
train loss: 2.6947 ; val loss: 2.6769
train loss: 2.6245 ; val loss: 2.6123
train loss: 2.5750 ; val loss: 2.6162
train loss: 2.5730 ; val loss: 2.5613
train loss: 2.5613 ; val loss: 2.5512
train loss: 2.5431 ; val loss: 2.5477
train loss: 2.5078 ; val loss: 2.5032
train loss: 2.4721 ; val loss: 2.5028
train loss: 2.4487 ; val loss: 2.5714
train loss: 2.5039 ; val loss: 2.4845
train loss: 2.4952 ; val loss: 2.4885
train loss: 2.4755 ; val loss: 2.4428
train loss: 2.4878 ; val loss: 2.4358
train loss: 2.4566 ; val loss: 2.4913
train loss: 2.4740 ; val loss: 2.3791
train loss: 2.4309 ; val loss: 2.4159
train loss: 2.4171 ; val loss: 2.3822
train loss: 2.3556 ; val loss: 2.3672
train loss: 2.3970 ; val loss: 2.4314
train loss: 

In [9]:
def generate(
    model: GPT,
    vocab: Vocab,
    prompt: str,
    context_len: int = 8,
    num_toks: int = 32,
    device: str | None = None
):
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)

    model.to(device)
        
    # batch size 1
    ctx = torch.tensor([vocab.encode(list(prompt))]).to(device)

    generated = []
    
    for i in range(num_toks):
        # crop context to the context length
        logits = model(ctx[:, -context_len:])
        # we only care about the last token
        logits = logits[:, -1, :]
        # get probabilities
        probs = F.softmax(logits, dim=-1)
        next_tok = torch.multinomial(probs, num_samples=1)

        ctx = torch.cat((ctx, next_tok), dim=1)

    return "".join(vocab.decode(ctx.tolist()[0]))

In [11]:
generate(
    model,
    ds.vocab,
    " "
)

' may kind end him. fore noth Well'