## Make Dataset

In [18]:
from dataclasses import dataclass
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm

In [19]:
vocab = {
    0: "{",
    1: "(",
    2: "[",
    3: "<",
    4: "}",
    5: ")",
    6: "]",
    7: ">",
    8: "[SOS]",
    9: "[EOS]",
    10: "[PAD]",
}


def gen_seq(len: int) -> tuple[str, list[int]]:
    opens = random.choices([0, 1, 2, 3], k=len // 2)
    closes = list(reversed(list(map(lambda x: x + 4, opens))))

    seq = [8] + opens + closes + [9]
    return "".join([vocab[x] for x in seq]), seq


def validate_sequence(seq: list[int]) -> bool:
    stack = []
    for x in seq:
        if x in [0, 1, 2, 3]:
            stack.append(x)
        else:
            if len(stack) == 0:
                return False
            if stack[-1] != x - 4:
                return False
            stack.pop()
    return len(stack) == 0

In [20]:
train_seq = [
    gen_seq(16)[1] for _ in range(64_000)
]
test_seq = [
    gen_seq(16)[1] for _ in range(16_000)
]
train_seq = torch.as_tensor(train_seq).long()
test_seq = torch.as_tensor(test_seq).long()

train_loader = torch.utils.data.DataLoader(train_seq, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_seq, batch_size=256, shuffle=True)

## GPT

In [21]:
device = "mps"

@dataclass
class GPTConfig:
    vocab_size: int = len(vocab)
    block_size: int = 16 + 2
    emb_size: int = 4
    heads: int = 4
    num_layers: int = 1
    attn_dropout: float = 0
    ff_mult: int = 1
    ff_dropout: float = 0

class AttentionHead(nn.Module):
    def __init__(self, config: GPTConfig, layer_idx, head_idx, cache_enabled=False):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.head_idx = head_idx
        self.cache_enabled = cache_enabled
        self.cache = {}

        self.q = nn.Linear(config.emb_size, config.emb_size)
        self.k = nn.Linear(config.emb_size, config.emb_size)
        self.v = nn.Linear(config.emb_size, config.emb_size)

        self.out = nn.Linear(config.emb_size, config.emb_size)

        self.attn_dropout = nn.Dropout(config.attn_dropout)

    def forward(self, x, mask=None):
        B, T, C = x.size()
        cache_key = (B, T, C)  # Example key; adjust based on your caching strategy

        if self.cache_enabled and cache_key in self.cache:
            k, v = self.cache[cache_key]['k'], self.cache[cache_key]['v']
        else:
            k = (
                self.k(x)
                .view(B, T, self.config.heads, C // self.config.heads)
                .transpose(1, 2)
            )
            v = (
                self.v(x)
                .view(B, T, self.config.heads, C // self.config.heads)
                .transpose(1, 2)
            )
            if self.cache_enabled:
                self.cache[cache_key] = {'k': k, 'v': v}

        q = (
            self.q(x)
            .view(B, T, self.config.heads, C // self.config.heads)
            .transpose(1, 2)
        )

        attn = (q @ k.transpose(-2, -1)) / ((C // self.config.heads) ** 0.5)

        if mask is not None:
            attn = attn.masked_fill(mask == 0, float("-inf"))

        attn = F.softmax(attn, dim=-1)
        attn = self.attn_dropout(attn)

        x = (attn @ v).transpose(1, 2).contiguous().view(B, T, C)

        return self.out(x), attn

class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, config: GPTConfig, layer_idx):
        super().__init__()
        self.config = config

        self.heads = nn.ModuleList([AttentionHead(config, layer_idx=layer_idx, head_idx=i) for i in range(config.heads)])

    def forward(self, x, mask=None):
        # input and output are the same size
        attns = []
        for head in self.heads:
            attn, _ = head(x, mask=mask)
            attns.append(attn)

        return torch.mean(torch.stack(attns), dim=0)

class Block(nn.Module):
    def __init__(self, config: GPTConfig, layer_idx):
        super().__init__()
        self.config = config

        self.ln1 = nn.LayerNorm(config.emb_size)
        self.attn = MaskedMultiHeadAttention(config, layer_idx)

        self.ln2 = nn.LayerNorm(config.emb_size)
        self.ff = nn.Sequential(
            nn.Linear(config.emb_size, config.ff_mult * config.emb_size),
            nn.GELU(),
            nn.Linear(config.ff_mult * config.emb_size, config.emb_size),
        )

        if config.ff_dropout > 0:
            self.ff_dropout = nn.Dropout(config.ff_dropout)

        if config.attn_dropout > 0:
            self.attn_dropout = nn.Dropout(config.attn_dropout)

    def forward(self, x, mask=None):
        B, T, C = x.size()

        identity = x
        x = self.ln1(x)
        x = self.attn(x, mask=mask)

        if hasattr(self, "attn_dropout"):
            x = self.attn_dropout(x)

        x = x + identity

        identity = x
        x = self.ln2(x)
        x = self.ff(x)

        if hasattr(self, "ff_dropout"):
            x = self.ff_dropout(x)

        return x + identity

class GPT(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config

        self.token_emb = nn.Embedding(config.vocab_size, config.emb_size)
        self.pos_emb = nn.Embedding(config.block_size, config.emb_size)

        self.blocks = nn.ModuleList([Block(config, layer_idx=i) for i in range(config.num_layers)])

        self.ln = nn.LayerNorm(config.emb_size)
        self.head = nn.Linear(config.emb_size, config.vocab_size, bias=False)

        # tie weights
        self.head.weight = self.token_emb.weight

        # initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
                
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, std=0.02)

        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self, x):
        B, T = x.size()
        assert not T > self.config.block_size, "Sequence length is longer than block size"

        emb = self.token_emb(x)
        pe = self.pos_emb(torch.arange(T-1, -1, step=-1, device=device))

        x = emb + pe

        for block in self.blocks:
            x = block(x, mask=torch.tril(torch.ones(T, T, device=device)).view(1, T, T))

        x = self.ln(x)
        return self.head(x)

    def loss(self, y, y_pred):
        # Input is a contiguous tensor
        y = y.flatten()
        y_pred = y_pred.view(-1, y_pred.size(-1))

        return F.cross_entropy(y_pred, y)

    def get_param_count(self):
        return sum(p.numel() for p in self.parameters())

    @torch.no_grad()
    def generate(self, max_len: int = 128, temperature: float = 1.0):
        self.eval()

        generated = [8]
        primer_t = torch.as_tensor(generated, device=device).unsqueeze(0)

        for _ in range(max_len):
            if primer_t.size(1) > self.config.block_size:
                primer_t = primer_t[:, -self.config.block_size :]
            out = self(primer_t)
            out = out[:, -1, :] / temperature
            out = torch.multinomial(F.softmax(out, dim=-1), num_samples=1)

            gen = out.item()
            if gen == 9:
                break
            generated.append(gen)

            primer_t = torch.cat((primer_t, out), dim=1)

        return "".join([vocab[x] for x in generated]), generated


config = GPTConfig()
model = GPT(config).to(device)
num_train_steps = 0

print(f"Model has {model.get_param_count():,} parameters")
# print(model.generate(max_len=64)[0])

del config

Model has 500 parameters


In [24]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

In [25]:
val_loss = 0
for epoch in range(20):
    pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")
    for seq in pbar:
        seq = seq.to(device)
        inputs = seq[:, :-1]
        targets = seq[:, 1:]

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = F.cross_entropy(outputs.transpose(1, 2), targets)
        loss.backward()
        optimizer.step()

        pbar.set_postfix(loss=loss.item(), val_loss=val_loss)

    val_loss = 0
    for seq in test_loader:
        seq = seq.to(device)
        inputs = seq[:, :-1]
        targets = seq[:, 1:]

        outputs = model(inputs)
        loss = F.cross_entropy(outputs.transpose(1, 2), targets)
        val_loss += loss.item()
    val_loss /= len(test_loader)

Epoch 0:  97%|█████████▋| 487/500 [00:09<00:00, 51.80it/s, loss=1.24, val_loss=0]

In [None]:
import matplotlib.pyplot as plt

plt.plot(F.softmax(model.forward(inputs[0].unsqueeze(0))[0, -1], dim=-1).cpu().detach().numpy()/0.1)
plt.xticks(range(len(vocab)), vocab.values())
plt.show()

In [None]:
# Generate some sequences
for _ in range(10):
    print(model.generate(max_len=64)[0])