# ZeptoGPT Colab Notebook

One of the smallest GPTs in the universe. 

In [None]:
# --- Mini GPT: Learns to alphabetize simple 3-letter sequences ---

import torch
import torch.nn as nn
import torch.nn.functional as F
import random

# ----- Reproducibility -----
torch.manual_seed(42)

# ----- Vocabulary -----
vocab = ['<pad>', 'a', 'b', 'c']
stoi = {ch: i for i, ch in enumerate(vocab)}
itos = {i: ch for ch, i in stoi.items()}
vocab_size = len(vocab)

print("Vocab:", vocab)
print("stoi:", stoi)
print("itos:", itos)
print("vocab_size:", vocab_size)

# ----- Hyperparameters -----
n_embd = 16        # embedding dimension
block_size = 11    # context window
n_heads = 2
lr = 1e-3

# ----- Embeddings -----
token_embed = nn.Embedding(vocab_size, n_embd)
pos_embed = nn.Embedding(block_size, n_embd)

# ----- Multi-Head Self-Attention -----
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, n_embd, n_heads, block_size):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = n_embd // n_heads
        assert n_embd % n_heads == 0
        self.query = nn.Linear(n_embd, n_embd)
        self.key   = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)
        self.proj  = nn.Linear(n_embd, n_embd)
        self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.size()
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        att = att.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        out = att @ v
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        out = self.proj(out)
        return out

# ----- Feed-forward (MLP) -----
class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd)
        )
    def forward(self, x):
        return self.net(x)

# ----- Transformer Block -----
class TransformerBlock(nn.Module):
    def __init__(self, n_embd, n_heads, block_size):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.attn = MultiHeadSelfAttention(n_embd, n_heads, block_size)
        self.ff = FeedForward(n_embd)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x

# ----- Instantiate Transformer -----
block = TransformerBlock(n_embd, n_heads, block_size)
final_norm = nn.LayerNorm(n_embd)
lm_head = nn.Linear(n_embd, vocab_size)
lm_head.weight = token_embed.weight  # weight tying

# ----- Dataset: alphabetizing task -----
def make_batch(batch_size=32):
    x = torch.zeros((batch_size, block_size), dtype=torch.long)
    y = torch.zeros((batch_size, block_size), dtype=torch.long)
    for i in range(batch_size):
        seq = random.sample(['a', 'b', 'c'], 3)  # input scrambled
        sorted_seq = sorted(seq)                 # correct output
        x_seq = [stoi[ch] for ch in seq] + [0]*(block_size - 3)
        y_seq = [stoi[ch] for ch in sorted_seq] + [0]*(block_size - 3)
        x[i] = torch.tensor(x_seq)
        y[i] = torch.tensor(y_seq)
    return x, y

# ----- Optimizer -----
optimizer = torch.optim.AdamW(
    list(token_embed.parameters()) +
    list(pos_embed.parameters()) +
    list(block.parameters()) +
    list(final_norm.parameters()) +
    list(lm_head.parameters()),
    lr=lr
)

# ----- Training loop -----
for step in range(500):
    x, y = make_batch()
    positions = torch.arange(block_size).unsqueeze(0).expand(x.size(0), -1)
    tok_emb = token_embed(x)
    pos_emb = pos_embed(positions)
    out = tok_emb + pos_emb
    out = block(out)
    normed_out = final_norm(out)
    logits = lm_head(normed_out)
    loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if step % 50 == 0:
        print(f"Step {step:03d} | Loss: {loss.item():.4f}")

print("✅ Training complete!")

# ----- Generation function -----
@torch.no_grad()
def predict(x):
    positions = torch.arange(block_size).unsqueeze(0)
    tok_emb = token_embed(x)
    pos_emb = pos_embed(positions)
    out = tok_emb + pos_emb
    out = block(out)
    normed_out = final_norm(out)
    logits = lm_head(normed_out)
    pred = torch.argmax(logits, dim=-1)
    return pred

# ----- Test -----
test = torch.tensor([[stoi['c'], stoi['b'], stoi['a']] + [0]*(block_size-3)])
pred = predict(test)[0][:3].tolist()
decoded = [itos[i] for i in pred]
print("Input: ['b', 'a', 'c'] → Output:", decoded)
