# Build Your Own Transformer for Text Generation

This notebook shows how to build a **tiny GPT‑style transformer** for text generation, step by step.
It is written for someone who is **comfortable with code and math**, but **new to deep learning / transformers**.

We will think of a transformer as a **smart autocomplete**:
- You give it some text (a *prompt*),
- It repeatedly guesses the **next character**,
- Those guesses form new text.

To do that, we will build the following pieces ourselves:

1. **Token & position representation** – turn characters into vectors and tell the model *where* each token is
2. **Causal self‑attention** – each position decides *which previous positions to pay attention to*
3. **Multi‑head attention** – several attention "views" in parallel, then combined
4. **Feed‑forward block** – a small per‑position MLP that helps mix information
5. **Decoder blocks** – attention + feed‑forward + residual connections + layer norm
6. **Training loop** – make the model good at “next‑token prediction” on a small text corpus
7. **Generation loop** – use the model to sample new text one token at a time

You can read this notebook top‑to‑bottom like a tutorial, or run it cell by cell as a lab.

## 1. Imports and device

We use **PyTorch** for tensors and neural network layers.

The only slightly fancy thing we do here is pick a **device**:
- If you have a GPU (CUDA or Apple Silicon MPS), we use it.
- Otherwise we quietly fall back to CPU.

Everything else in the notebook just uses this `device` variable.

In [None]:
import math
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F

# Pick the best available compute device.
# Think of this as choosing whether to run on:
# - a GPU in your laptop,
# - or CPU if no GPU is available.

def get_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device("cuda")
    if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        return torch.device("mps")  # Apple Silicon GPU
    return torch.device("cpu")

device = get_device()
print(f"Using device: {device}")

## 2. Positional encoding

A transformer looks at a **set of vectors**; by default it does not know which token came first, second, etc.

To fix that, we *add* a vector that describes the **position** of each token:
- Token at position 0 gets one vector,
- Token at position 1 gets a different vector,
- and so on.

Here we use the classic **sinusoidal positional encoding** from the original transformer paper.
You can think of it as giving each position a unique "barcode" of sines and cosines that the model can learn to interpret.

In [2]:
class PositionalEncoding(nn.Module):
    """Add a fixed "position barcode" to each token embedding.

    Analogy: if each word is a *card* with some information,
    positional encoding writes the *position number* on the card,
    so the model can tell "this was the 3rd word".
    """

    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # pe has shape (max_len, d_model).
        # Row i contains the positional encoding for position i.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # (max_len, 1)

        # Different frequencies along the feature dimension.
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        # Even indices: sine, odd indices: cosine.
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Add batch dimension: (1, max_len, d_model).
        pe = pe.unsqueeze(0)

        # register_buffer = tensor is part of the module, but not a learnable parameter.
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Add positional encodings to a batch of embeddings.

        Args:
            x: (batch, seq_len, d_model) token embeddings.
        """
        # Slice `pe` to the sequence length we actually have, then add.
        x = x + self.pe[:, : x.size(1), :]
        return self.dropout(x)

## 3. Causal self-attention

For **text generation**, we must prevent the model from "seeing" future tokens. We use a **causal mask**: position $i$ can only attend to positions $\leq i$. We implement scaled dot-product attention with a lower-triangular mask.

In [None]:
def causal_attention_scores(q, k, mask_value=-1e9):
    """
    Scaled dot-product attention with causal mask.

    Args:
        q (torch.Tensor): Query tensor of shape (batch, heads, seq_len, d_k).
        k (torch.Tensor): Key tensor of shape (batch, heads, seq_len, d_k).
        mask_value (float, optional): Value to use for masked (future) positions above the diagonal. Default: -1e9.

    Returns:
        torch.Tensor: Attention scores of shape (batch, heads, seq_len, seq_len), with masked (upper-triangular, i.e. future) positions set to mask_value.
    """
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    seq_len = scores.size(-1)
    causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=scores.device), diagonal=1).bool()
    scores = scores.masked_fill(causal_mask, mask_value)
    return scores

## 4. Multi-head attention

We run several **attention heads** in parallel, then concatenate and project. Each head has its own Q, K, V projections.

In [None]:
class CausalMultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        batch, seq_len, _ = x.shape
        q = self.w_q(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        k = self.w_k(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        v = self.w_v(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        scores = causal_attention_scores(q, k)
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(batch, seq_len, self.d_model)
        return self.w_o(out)

## 5. Feed-forward block

Each transformer block has a two-layer MLP after attention: linear → ReLU → linear, often with an inner dimension 4× the model size.

In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.1):
        super().__init__()
        d_ff = d_ff or 4 * d_model
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(F.gelu(self.linear1(x))))

## 6. Decoder block

One block = **causal multi-head attention** (with residual + layer norm) followed by **feed-forward** (with residual + layer norm).

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int = None, dropout: float = 0.1):
        super().__init__()
        self.self_attn = CausalMultiHeadAttention(d_model, num_heads, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = x + self.dropout(self.self_attn(self.norm1(x)))
        x = x + self.dropout(self.ff(self.norm2(x)))
        return x

## 7. Full decoder-only transformer

We stack decoder blocks on top of **token embeddings + positional encoding**, then project the last hidden state to the vocabulary size for next-token prediction.

In [None]:
class TransformerLM(nn.Module):
    """Decoder-only language model for next-token prediction."""
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 128,
        num_heads: int = 4,
        num_layers: int = 3,
        d_ff: int = None,
        max_len: int = 512,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos = PositionalEncoding(d_model, max_len, dropout)
        self.blocks = nn.ModuleList([
            DecoderBlock(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        # x: (batch, seq_len) token ids
        x = self.embed(x)
        x = self.pos(x)
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        logits = self.head(x)
        return logits

## 8. Character-level tokenizer

To keep the tutorial self-contained we use a **character-level** vocabulary: each character is a token. You can replace this with a subword tokenizer (e.g. BPE) later.

In [None]:
class CharTokenizer:
    def __init__(self, text: str):
        self.chars = sorted(set(text))
        self.stoi = {c: i for i, c in enumerate(self.chars)}
        self.itos = {i: c for c, i in self.stoi.items()}
        self.vocab_size = len(self.chars)

    def encode(self, s: str):
        return [self.stoi[c] for c in s]

    def decode(self, ids):
        return "".join(self.itos[i] for i in ids)

## 9. Training loop

We train on **next-token prediction**: given a sequence, predict the next token at each position. Loss is cross-entropy over the logits (shifted so targets are one position ahead).

In [None]:
def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        logits = model(batch)
        # Next-token prediction: predict batch[:, 1:] from logits[:, :-1]
        logits_flat = logits[:, :-1].reshape(-1, logits.size(-1))
        targets_flat = batch[:, 1:].reshape(-1)
        loss = F.cross_entropy(logits_flat, targets_flat)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

## 10. Data and training run

We create a small corpus and train for a few epochs. Inputs are random chunks of the text; each batch is (batch_size, block_size).

In [None]:
def get_batches(data, batch_size, block_size):
    """Yield random (batch_size, block_size) chunks from data (1D tensor of token ids)."""
    n = len(data)
    for _ in range(n // (batch_size * block_size)):
        starts = torch.randint(0, n - block_size, (batch_size,))
        batch = torch.stack([data[s : s + block_size] for s in starts])
        yield batch

block_size = 32
batch_size = 32
epochs = 50
lr = 3e-4

text = """
The transformer is a deep learning architecture introduced in the paper Attention is All You Need.
It relies entirely on self-attention for sequence modeling. We build a small one here for fun.
""" * 20

tokenizer = CharTokenizer(text)
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
loader = list(get_batches(data, batch_size, block_size))

model = TransformerLM(
    vocab_size=tokenizer.vocab_size,
    d_model=128,
    num_heads=4,
    num_layers=3,
    max_len=block_size,
    dropout=0.1,
).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

for epoch in range(epochs):
    loss = train_epoch(model, loader, optimizer, device)
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}  loss = {loss:.4f}")

## 11. Text generation

We **sample** one token at a time: feed the current context, take logits for the last position, sample from the probability distribution (optionally with temperature), append to context, repeat.

In [None]:
@torch.no_grad()
def generate(model, tokenizer, prompt: str, max_new_tokens: int = 80, temperature: float = 0.8):
    model.eval()
    ids = tokenizer.encode(prompt)
    context = torch.tensor([ids], dtype=torch.long, device=device)
    for _ in range(max_new_tokens):
        if context.size(1) > block_size:
            context = context[:, -block_size:]
        logits = model(context)
        logits = logits[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        context = torch.cat([context, next_id], dim=1)
    return tokenizer.decode(context[0].tolist())

print(generate(model, tokenizer, "The ", max_new_tokens=120, temperature=0.7))

## Next steps

- **Larger model**: increase `d_model`, `num_heads`, `num_layers`, and train longer.
- **Subword tokenizer**: use a BPE tokenizer (e.g. `tiktoken` or Hugging Face `tokenizers`) for word-level coherence.
- **Larger data**: train on a real corpus (e.g. Wikipedia, books) with proper train/val split and batching.
- **Learning rate schedule**: use cosine decay or warmup for better convergence.