In [None]:
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm.notebook as tqdm

from tokenizers import Tokenizer, models, decoders, trainers, tools, pre_tokenizers

In [None]:
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
tokenizer.decoder = decoders.ByteLevel()
trainer = trainers.BpeTrainer(special_tokens=["[PAD]", "[SOS]", "[EOS]", "[MASK]", "[UNK]"], vocab_size=4096)

In [None]:
tokenizer.train(["blog/5-shakespeare/data/train.txt"], trainer=trainer)
print(f"Vocab size: {tokenizer.get_vocab_size()}")

del trainer

In [None]:
with open("blog/5-shakespeare/data/train.txt", "r") as f:
    train_corpus = f.read()

with open("blog/5-shakespeare/data/test.txt", "r") as f:
    test_corpus = f.read()

train_encoded_corpus = tokenizer.encode(train_corpus).ids
val_encoded_corpus = tokenizer.encode(test_corpus).ids

del train_corpus, test_corpus

In [None]:
# Create dataset
class Dataset(torch.utils.data.Dataset):
    def __init__(self, corpus, seq_len):
        self.corpus = corpus
        self.seq_len = seq_len

    def __len__(self):
        return len(self.corpus) - self.seq_len

    def __getitem__(self, idx):
        return torch.as_tensor(self.corpus[idx:idx+self.seq_len]), torch.as_tensor(self.corpus[idx+1:idx+self.seq_len+1])
    

seq_len = 64
train_dataset = Dataset(train_encoded_corpus, seq_len)
val_dataset = Dataset(val_encoded_corpus, seq_len)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=256, shuffle=True)

del train_encoded_corpus, val_encoded_corpus

In [None]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"

In [None]:
# If model already exists, delete it
if "model" in locals():
    del model
if "optimizer" in locals():
    del optimizer
if "scheduler" in locals():
    del scheduler

@dataclass
class GPTConfig:
    vocab_size: int = tokenizer.get_vocab_size()
    block_size: int = seq_len
    emb_size: int = 128
    heads: int = 8
    num_layers: int = 8
    attn_dropout: float = 0
    ff_mult: int = 2
    ff_dropout: float = 0


class AttentionHead(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config

        self.q = nn.Linear(config.emb_size, config.emb_size)
        self.k = nn.Linear(config.emb_size, config.emb_size)
        self.v = nn.Linear(config.emb_size, config.emb_size)

        self.out = nn.Linear(config.emb_size, config.emb_size)

        self.attn_dropout = nn.Dropout(config.attn_dropout)

    def forward(self, x, mask=None):
        B, T, C = x.size()

        q = (
            self.q(x)
            .view(B, T, self.config.heads, C // self.config.heads)
            .transpose(1, 2)
        )
        k = (
            self.k(x)
            .view(B, T, self.config.heads, C // self.config.heads)
            .transpose(1, 2)
        )
        v = (
            self.v(x)
            .view(B, T, self.config.heads, C // self.config.heads)
            .transpose(1, 2)
        )

        attn = (q @ k.transpose(-2, -1)) / ((C // self.config.heads) ** 0.5)

        if mask is not None:
            attn = attn.masked_fill(mask == 0, float("-inf"))

        attn = F.softmax(attn, dim=-1)
        attn = self.attn_dropout(attn)

        x = (attn @ v).transpose(1, 2).contiguous().view(B, T, C)

        return self.out(x), attn


class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config

        self.heads = nn.ModuleList([AttentionHead(config) for _ in range(config.heads)])

    def forward(self, x, mask=None):
        # input and output are the same size
        attns = []
        for head in self.heads:
            attn, _ = head(x, mask=mask)
            attns.append(attn)

        return torch.mean(torch.stack(attns), dim=0)


class Block(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config

        self.ln1 = nn.LayerNorm(config.emb_size)
        self.attn = MaskedMultiHeadAttention(config)

        self.ln2 = nn.LayerNorm(config.emb_size)
        self.ff = nn.Sequential(
            nn.Linear(config.emb_size, config.ff_mult * config.emb_size),
            nn.GELU(),
            nn.Linear(config.ff_mult * config.emb_size, config.emb_size),
        )

        if config.ff_dropout > 0:
            self.ff_dropout = nn.Dropout(config.ff_dropout)

        if config.attn_dropout > 0:
            self.attn_dropout = nn.Dropout(config.attn_dropout)

    def forward(self, x, mask=None):
        B, T, C = x.size()

        identity = x
        x = self.ln1(x)
        x = self.attn(x, mask=mask)

        if hasattr(self, "attn_dropout"):
            x = self.attn_dropout(x)

        x = x + identity

        identity = x
        x = self.ln2(x)
        x = self.ff(x)

        if hasattr(self, "ff_dropout"):
            x = self.ff_dropout(x)

        return x + identity


class GPT(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config

        self.token_emb = nn.Embedding(config.vocab_size, config.emb_size)
        self.pos_emb = nn.Embedding(config.block_size, config.emb_size)

        self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_layers)])

        self.ln = nn.LayerNorm(config.emb_size)
        self.head = nn.Linear(config.emb_size, config.vocab_size, bias=False)

        # tie weights
        self.head.weight = self.token_emb.weight

        # initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
                
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, std=0.02)

        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self, x):
        B, T = x.size()
        assert T <= self.config.block_size, "Sequence length is longer than block size"

        emb = self.token_emb(x)
        pe = self.pos_emb(torch.arange(T-1, -1, step=-1, device=device))

        x = emb + pe

        for block in self.blocks:
            x = block(x, mask=torch.tril(torch.ones(T, T, device=device)).view(1, T, T))

        x = self.ln(x)
        return self.head(x)

    def loss(self, y, y_pred):
        # Input is a contiguous tensor
        y = y.flatten()
        y_pred = y_pred.view(-1, y_pred.size(-1))

        return F.cross_entropy(y_pred, y)

    def get_param_count(self):
        return sum(p.numel() for p in self.parameters())

    @torch.no_grad()
    def generate(self, primer: str, max_len: int = 128, temperature: float = 1.0):
        self.eval()

        generated = tokenizer.encode(primer).ids
        primer_t = torch.as_tensor(generated, device=device).unsqueeze(0)

        for _ in range(max_len):
            if primer_t.size(1) > self.config.block_size:
                primer_t = primer_t[:, -self.config.block_size :]
            out = self(primer_t)
            out = out[:, -1, :] / temperature
            out = torch.multinomial(F.softmax(out, dim=-1), num_samples=1)

            generated.append(out.item())

            primer_t = torch.cat((primer_t, out), dim=1)

        return tokenizer.decode(generated)


config = GPTConfig()
model = GPT(config).to(device)
num_train_steps = 0

print(f"Model has {model.get_param_count():,} parameters")
# print(model.generate("First Citizen:", max_len=64))

del config

In [None]:
@torch.no_grad()
def evaluate(model, dataloader):
    model.eval()
    total_loss = 0
    pbar = tqdm.tqdm(dataloader, desc="Evaluation")
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        y_pred = model(x)
        loss = model.loss(y, y_pred).item()
        total_loss += loss
        pbar.set_postfix({"loss": loss})
    return total_loss / len(dataloader)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)

roll_loss = 4

for epoch in range(100):
    pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch + 1}", leave=True)
    val_loss =  evaluate(model, val_loader)

    model.train()
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = model.loss(y, y_pred)
        loss.backward()
        optimizer.step()
        
        num_train_steps += 1
        roll_loss = 0.9 * roll_loss + 0.1 * loss.item()

        pbar.set_postfix_str(f"loss: {roll_loss:.4f}, val_loss: {val_loss:.4f}, steps: {num_train_steps:,}")

        # assert num_train_steps != 100, "Stop training"

    scheduler.step(val_loss)

In [None]:
for name, param in model.named_parameters():
    print(f"{name}: {param.mean().item():4f}, {param.std().item():4f}")

In [None]:
model.eval()
print(model.generate(
    "The Project Gutenberg eBook",
    max_len=64,
))

In [None]:
# plot postion embeddings
import matplotlib.pyplot as plt

pos_emb = model.pos_emb.weight.detach().cpu().numpy()
plt.figure(figsize=(20, 5))
plt.imshow(pos_emb, aspect="auto")
plt.colorbar()
plt.title("Position Embeddings")
plt.xlabel("Embedding Dimension")
plt.ylabel("Position")
plt.show()


In [None]:
plt.figure(figsize=(20, 5))
plt.plot(model.pos_emb.weight.detach().cpu().numpy()[:, 0])
plt.title("Position Embedding 0")
plt.xlabel("Position")
plt.ylabel("Value")
plt.show()