# Transformer

A minimal, clean, decoder‑only Transformer (GPT‑style) in PyTorch.

Features
- Token + learned positional embeddings
- Multi‑Head Causal Self‑Attention with mask
- Pre‑Norm residual blocks
- GELU feed‑forward (MLP)
- Dropout knobs
- Tied output projection (weights tied with token embedding)
- Optional label smoothing in loss
- An autoregressive `generate` utility

In [7]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Tuple

In [8]:
# -----------------------------
# Config
# -----------------------------

@dataclass
class GPTConfig:
    vocab_size: int
    n_layers: int = 4
    n_heads: int = 4
    d_model: int = 256
    d_ff: int = 1024
    max_seq_len: int = 64
    attn_dropout: float = 0.1
    resid_dropout: float = 0.1
    emb_dropout: float = 0.1
    device: str = "cpu"

In [9]:
# -----------------------------
# Building blocks
# -----------------------------

class LayerNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5):
        super().__init__()
        self.ln = nn.LayerNorm(d_model, eps=eps)
    def forward(self, x):
        return self.ln(x)

In [10]:
class CausalSelfAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, attn_dropout: float):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_head = d_model // n_heads
        self.n_heads = n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)
        self.attn_dropout = nn.Dropout(attn_dropout)
        self.resid_dropout = nn.Dropout(attn_dropout)

    def forward(self, x):
        B, L, C = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.split(C, dim=-1)
        q = q.view(B, L, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(B, L, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(B, L, self.n_heads, self.d_head).transpose(1, 2)
        attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)
        causal_mask = torch.tril(torch.ones(L, L, device=x.device)).view(1, 1, L, L)
        attn_scores = attn_scores.masked_fill(causal_mask == 0, float('-inf'))
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.attn_dropout(attn_weights)
        y = attn_weights @ v
        y = y.transpose(1, 2).contiguous().view(B, L, C)
        y = self.resid_dropout(self.proj(y))
        return y

In [11]:
class MLP(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

In [12]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg: GPTConfig):
        super().__init__()
        self.ln1 = LayerNorm(cfg.d_model)
        self.attn = CausalSelfAttention(cfg.d_model, cfg.n_heads, cfg.attn_dropout)
        self.ln2 = LayerNorm(cfg.d_model)
        self.mlp = MLP(cfg.d_model, cfg.d_ff, cfg.resid_dropout)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

In [13]:
# -----------------------------
# The model
# -----------------------------

class DecoderOnlyTransformer(nn.Module):
    def __init__(self, cfg: GPTConfig):
        super().__init__()
        self.cfg = cfg
        self.token_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)
        self.drop = nn.Dropout(cfg.emb_dropout)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_f = LayerNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        self.lm_head.weight = self.token_emb.weight

    def forward(self, idx, targets=None):
        B, L = idx.shape
        pos = torch.arange(0, L, device=idx.device).unsqueeze(0)
        x = self.token_emb(idx) + self.pos_emb(pos)
        x = self.drop(x)
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

    def generate(self, idx, max_new_tokens=50):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.cfg.max_seq_len:]
            logits, _ = self(idx_cond)
            next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
            idx = torch.cat((idx, next_token), dim=1)
        return idx

# Training the Model

Token-level training on a small external dataset (WikiText-2).

Requirements:
    pip install datasets tokenizers

This script will:
  - Download WikiText-2 (raw) via Hugging Face Datasets
  - Train a tiny Byte-Pair Encoding (BPE) tokenizer from scratch
  - Tokenize the corpus with BOS/EOS tokens
  - Train the decoder-only Transformer as a causal LM on token IDs
  - Sample a short completion and save a checkpoint + tokenizer

In [14]:
import random
from pathlib import Path
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.processors import TemplateProcessing

In [15]:
# -------------------------
# 1) Load a small external dataset (WikiText-2 raw)
# -------------------------
ds = load_dataset("wikitext", "wikitext-2-raw-v1")
train_texts = ds["train"]["text"]
val_texts = ds["validation"]["text"]

# Filter out empty lines to avoid tons of EOS tokens
train_texts = [t for t in train_texts if t and not t.isspace()]
val_texts = [t for t in val_texts if t and not t.isspace()]

README.md: 0.00B [00:00, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/733k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [16]:
from transformers import AutoTokenizer

BASE_MODEL_NAME = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

In [17]:
# -------------------------
# 3) Encode dataset into a single stream of IDs (token-level)
# -------------------------
def encode_corpus(texts):
    ids = []
    for t in texts:
        ids.extend(tokenizer.encode(t))
    return torch.tensor(ids, dtype=torch.long)

In [18]:
train_ids = encode_corpus(train_texts)
val_ids = encode_corpus(val_texts)

In [19]:
# -------------------------
# 4) Batching utility for contiguous language modeling
# -------------------------
# block_size = 256  # context window
# batch_size = 24
block_size = 128  # context window
batch_size = 12

def make_batch(ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    assert ids.numel() > block_size + 1
    ix = torch.randint(0, ids.numel() - block_size - 1, (batch_size,))
    x = torch.stack([ids[i:i + block_size] for i in ix])
    y = torch.stack([ids[i + 1:i + block_size + 1] for i in ix])
    return x, y

In [20]:
# -------------------------
# 5) Configure and build model
# -------------------------
cfg = GPTConfig(
    vocab_size=tokenizer.vocab_size,
    n_layers=6,
    n_heads=6,
    d_model=384,
    d_ff=1536,
    max_seq_len=block_size,
    attn_dropout=0.1,
    resid_dropout=0.1,
    emb_dropout=0.1,
    device="cuda" if torch.cuda.is_available() else "cpu",
)

model = DecoderOnlyTransformer(cfg)
model.to(cfg.device)

DecoderOnlyTransformer(
  (token_emb): Embedding(256000, 384)
  (pos_emb): Embedding(128, 384)
  (drop): Dropout(p=0.1, inplace=False)
  (blocks): ModuleList(
    (0-5): 6 x TransformerBlock(
      (ln1): LayerNorm(
        (ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      )
      (attn): CausalSelfAttention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln2): LayerNorm(
        (ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      )
      (mlp): MLP(
        (net): Sequential(
          (0): Linear(in_features=384, out_features=1536, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.1, inplace=False)
          (3): Linear(in_features=1536, out_features=384, bias=True)
          (4): Dropout(p=0.1, inplace=False)
     

In [24]:
# -------------------------
# 6) Optimizer
# -------------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), weight_decay=0.1)

In [None]:
# -------------------------
# 7) Training loop
# -------------------------
steps = 2000  # modest run on a small dataset
eval_every = 200
grad_clip = 1.0

model.train()
for step in range(1, steps + 1):
    x, y = make_batch(train_ids)
    x = x.to(model.cfg.device)
    y = y.to(model.cfg.device)

    logits, loss = model(x, targets=y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    if grad_clip is not None:
        nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    optimizer.step()

    if step % eval_every == 0 or step == 1:
        model.eval()
        with torch.no_grad():
            vx, vy = make_batch(val_ids)
            vx = vx.to(model.cfg.device)
            vy = vy.to(model.cfg.device)
            _, vloss = model(vx, targets=vy)
        print(f"step {step:4d} | train loss {float(loss):.3f} | val loss {float(vloss):.3f}")
        model.train()

In [None]:
# -------------------------
# 8) Sample a short completion
# -------------------------
model.eval()

prompt_text = "Wikipedia is a free online"
prompt_ids = tokenizer.encode(prompt_text)
prompt = torch.tensor(prompt_ids, dtype=torch.long, device=model.cfg.device).unsqueeze(0)
print(prompt.shape)
out_ids = model.generate(prompt, max_new_tokens=100)

# Decode (strip any double BOS/EOS artifacts from processor)
decoded = tokenizer.decode(out_ids[0].tolist(), skip_special_tokens=True)
print(decoded)