# RNNs, Attention, and Transformers

In [25]:
%pip install torch

Note: you may need to restart the kernel to use updated packages.


In [26]:
import math, random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

torch.manual_seed(0)
torch.__version__

'2.8.0'

## Sequence processing in a nutshell

Many tasks are **sequence-to-sequence**: speech $\to$ text, translation, captioning. We need models that consume inputs over time and optionally emit outputs over time.


## A simple RNN cell

A vanilla RNN updates a hidden state $h_t$ using the previous state and current input:
$$
h_t = \tanh(W_{hh} h_{t-1} + W_{xh} x_t + b_h),\quad
y_t = W_{hy} h_t + b_y.
$$
Below is a tiny character-level **toy** RNN that predicts the next character in a short string to illustrate shape flow and training.


In [27]:
# Tiny toy dataset: predict next char in a small alphabet
alphabet = list("abcd ")
stoi = {c:i for i,c in enumerate(alphabet)}
itos = {i:c for c,i in stoi.items()}

def encode(s): return torch.tensor([stoi[c] for c in s], dtype=torch.long)
def onehot(i, n): 
    x = torch.zeros(n); x[i] = 1.0; return x

seq = "abca abca abca "
data = encode(seq)

nin = len(alphabet); nh = 16; nout = len(alphabet)
Wxh = nn.Linear(nin, nh)
Whh = nn.Linear(nh, nh)
Why = nn.Linear(nh, nout)
params = list(Wxh.parameters()) + list(Whh.parameters()) + list(Why.parameters())
opt = optim.Adam(params, lr=3e-2)

def rnn_step(h, x_idx):
    x = onehot(x_idx, nin)
    h = torch.tanh(Whh(h) + Wxh(x))
    y = Why(h)
    return h, y

for step in range(400):
    h = torch.zeros(nh)
    loss = 0.0
    for t in range(len(data)-1):
        h, y = rnn_step(h, data[t].item())
        loss = loss + F.cross_entropy(y.unsqueeze(0), data[t+1].unsqueeze(0))
    opt.zero_grad()
    loss.backward()
    opt.step()

# Sample a few steps
h = torch.zeros(nh)
x = stoi["a"]
out = []
for t in range(20):
    h, y = rnn_step(h, x)
    p = F.softmax(y, -1)
    x = int(torch.multinomial(p, 1))
    out.append(itos[x])
"".join(out)

'bca abca abca abca a'

## Vanishing and exploding gradients (intuition)

Backpropagation through time multiplies by the same weights many times, so gradients can shrink or grow exponentially. For long‑range dependencies, **plain RNNs** struggle.


## LSTM and GRU

Gated units regulate information flow with **gates** to help preserve long‑term dependencies.

- **LSTM** maintains a cell $c_t$ with forget/input/output gates.
- **GRU** merges some gates and omits a separate cell, often working well with fewer parameters.

Below we compare a vanilla `nn.RNN` vs `nn.LSTM` on a synthetic long dependency: the label is the **first** token of the sequence.


In [28]:
# Synthetic dataset: sequences of length T, label = first token (0/1)
def make_dataset(N=256, T=50):
    X = torch.randint(0, 2, (N, T)).long()
    y = X[:, 0]
    return X, y

def run_model(cell_type="rnn", T=50, epochs=20):
    X, y = make_dataset(256, T)
    emb = nn.Embedding(2, 8)
    if cell_type == "rnn":
        rnn = nn.RNN(8, 16, batch_first=True, nonlinearity="tanh")
    else:
        rnn = nn.LSTM(8, 16, batch_first=True)
    clf = nn.Linear(16, 2)
    opt = optim.Adam(list(emb.parameters()) + list(rnn.parameters()) + list(clf.parameters()), lr=1e-2)
    for ep in range(epochs):
        x = emb(X)
        out, _ = rnn(x)
        logits = clf(out[:, -1])  # last step
        loss = F.cross_entropy(logits, y)
        opt.zero_grad();
        loss.backward();
        opt.step()
    with torch.no_grad():
        x = emb(X);
        out, _ = rnn(x)
        logits = clf(out[:, -1])
        acc = logits.argmax(-1).eq(y).float().mean().item()
    return loss.item(), acc

loss_rnn, acc_rnn = run_model("rnn", T=50, epochs=30)
loss_lstm, acc_lstm = run_model("lstm", T=50, epochs=30)
{"RNN_acc": round(acc_rnn, 3), "LSTM_acc": round(acc_lstm, 3)}

{'RNN_acc': 0.68, 'LSTM_acc': 0.57}

## Attention as content‑based lookup

Given **queries** $Q$, **keys** $K$, and **values** $V$, attention produces a weighted sum of values where weights reflect $QK^\top$ similarity.

**Scaled dot‑product attention:**
$$
\mathrm{Attn}(Q,K,V) = \mathrm{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V.
$$


In [29]:
# Minimal scaled dot-product attention with masking
def scaled_dot_attn(Q, K, V, mask=None):
    # Q: [B, H, Tq, Dh], K,V: [B, H, Tk, Dh]
    scores = Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1))
    if mask is not None:
        scores = scores.masked_fill(mask==0, float('-inf'))
    A = torch.softmax(scores, dim=-1)
    return A @ V, A

# Quick demo
B,H,Tq,Tk,Dh = 2, 1, 3, 4, 8
Q = torch.randn(B,H,Tq,Dh)
K = torch.randn(B,H,Tk,Dh)
V = torch.randn(B,H,Tk,Dh)
out, attn = scaled_dot_attn(Q,K,V)
out.shape, attn.shape

(torch.Size([2, 1, 3, 8]), torch.Size([2, 1, 3, 4]))

## Positional encoding

Self‑attention is permutation‑equivariant, so we inject positions. A common choice is **sinusoidal** encoding:
$$
\mathrm{PE}(pos,2i) = \sin\!\Big(\frac{pos}{10000^{2i/d}}\Big),\quad
\mathrm{PE}(pos,2i+1) = \cos\!\Big(\frac{pos}{10000^{2i/d}}\Big).
$$


In [30]:
def sinusoidal_pe(T, d):
    pos = torch.arange(T).unsqueeze(1)
    i = torch.arange(d).unsqueeze(0)
    angles = pos / (10000 ** (2*(i//2)/d))
    pe = torch.zeros(T,d)
    pe[:,0::2] = torch.sin(angles[:,0::2])
    pe[:,1::2] = torch.cos(angles[:,1::2])
    return pe

pe = sinusoidal_pe(10, 16)
pe[:3, :8]

tensor([[ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000],
        [ 0.8415,  0.5403,  0.3110,  0.9504,  0.0998,  0.9950,  0.0316,  0.9995],
        [ 0.9093, -0.4161,  0.5911,  0.8066,  0.1987,  0.9801,  0.0632,  0.9980]])

## Multi‑head attention

Use $h$ heads with separate projections, concatenate head outputs, and project again:
$$
\mathrm{MHA}(X) = \big[ \mathrm{Attn}(XW_Q^{(1)}, XW_K^{(1)}, XW_V^{(1)});\ldots \big] W_O.
$$


In [31]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=64, n_heads=4):
        super().__init__()
        assert d_model % n_heads == 0
        self.nh = n_heads
        self.dh = d_model // n_heads
        self.Q = nn.Linear(d_model, d_model, bias=False)
        self.K = nn.Linear(d_model, d_model, bias=False)
        self.V = nn.Linear(d_model, d_model, bias=False)
        self.O = nn.Linear(d_model, d_model, bias=False)
    def forward(self, x, mask=None):
        B,T,D = x.shape
        Q = self.Q(x).view(B,T,self.nh,self.dh).transpose(1,2) # [B,H,T,Dh]
        K = self.K(x).view(B,T,self.nh,self.dh).transpose(1,2)
        V = self.V(x).view(B,T,self.nh,self.dh).transpose(1,2)
        if mask is not None:
            mask = mask.unsqueeze(1)  # broadcast over heads
        y,_ = scaled_dot_attn(Q,K,V,mask=mask)
        y = y.transpose(1,2).contiguous().view(B,T,D)
        return self.O(y)

x = torch.randn(2, 5, 64)
mha = MultiHeadAttention(64, 4)
mha(x).shape

torch.Size([2, 5, 64])

## Transformer encoder block

Stack **Multi‑Head Attention** and a **positionwise MLP**, with **residual connections** and **layer normalization**:
$$
x \leftarrow x + \mathrm{MHA}(\mathrm{LN}(x)),\quad
x \leftarrow x + \mathrm{MLP}(\mathrm{LN}(x)).
$$


In [32]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, d_model=64, n_heads=4, d_ff=128, pdrop=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.mha = MultiHeadAttention(d_model, n_heads)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff  = nn.Sequential(nn.Linear(d_model, d_ff),
                                 nn.ReLU(),
                                 nn.Linear(d_ff, d_model))
        self.drop = nn.Dropout(pdrop)
    def forward(self, x, mask=None):
        x = x + self.drop(self.mha(self.ln1(x), mask=mask))
        x = x + self.drop(self.ff(self.ln2(x)))
        return x

tok = torch.randn(2, 6, 64)
blk = TransformerEncoderBlock()
blk(tok).shape

torch.Size([2, 6, 64])

## Decoder self‑attention with causal masking

In the decoder we **mask future positions** so the model cannot peek ahead during training. The mask is a lower‑triangular matrix $M$:
$$
M_{ij} = 1 \ \text{if}\ j \le i,\ \ 0\ \text{otherwise}.
$$


In [33]:
def causal_mask(T):
    return torch.tril(torch.ones(T, T, dtype=torch.bool))

T = 5
mask = causal_mask(T)  # [T,T]
# Example: apply in attention (broadcast to [1,1,T,T]) so MHA can broadcast across batch and heads
B, H, D_model = 2, 4, 16
x = torch.randn(B, T, D_model)
mha = MultiHeadAttention(d_model=D_model, n_heads=H)
# Broadcast-safe mask: [1,1,T,T] will expand to [B,H,T,T] inside attention
attn_mask = mask.unsqueeze(0).unsqueeze(0)
y = mha(x, mask=attn_mask)
y.shape, mask

(torch.Size([2, 5, 16]),
 tensor([[ True, False, False, False, False],
         [ True,  True, False, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True]]))

## Encoder–decoder cross‑attention

Keys/values come from the **encoder**, queries from the **decoder**. This lets the decoder look back at the encoded source sequence:
$$
\mathrm{Attn}(Q_\text{dec}, K_\text{enc}, V_\text{enc}).
$$
Below is a minimal cross‑attention module (API‑compatible with the self‑attention above).


In [34]:
class CrossAttention(nn.Module):
    def __init__(self, d_model=64, n_heads=4):
        super().__init__()
        assert d_model % n_heads == 0
        self.nh = n_heads; self.dh = d_model // n_heads
        self.Q = nn.Linear(d_model, d_model, bias=False)
        self.K = nn.Linear(d_model, d_model, bias=False)
        self.V = nn.Linear(d_model, d_model, bias=False)
        self.O = nn.Linear(d_model, d_model, bias=False)
    def forward(self, qx, kx):
        # qx: [B,Tq,D] (decoder states); kx: [B,Tk,D] (encoder states)
        B,Tq,D = qx.shape; Tk = kx.size(1)
        Q = self.Q(qx).view(B,Tq,self.nh,self.dh).transpose(1,2)
        K = self.K(kx).view(B,Tk,self.nh,self.dh).transpose(1,2)
        V = self.V(kx).view(B,Tk,self.nh,self.dh).transpose(1,2)
        y,_ = scaled_dot_attn(Q,K,V)
        y = y.transpose(1,2).contiguous().view(B,Tq,D)
        return self.O(y)

enc = torch.randn(2,7,64)
dec = torch.randn(2,5,64)
cross = CrossAttention(64,4)
cross(dec, enc).shape

torch.Size([2, 5, 64])

## A tiny end‑to‑end toy: copy task

To keep runtime small, we solve a simple **copy** task: given a short token sequence, predict the same sequence (teacher forcing). This demonstrates shapes, masking, and how attention learns alignments.


In [35]:
class TinyTransformer(nn.Module):
    def __init__(self, vocab=16, d_model=64, n_heads=4, n_layers=2, d_ff=128, T=12):
        super().__init__()
        self.T = T; self.d = d_model; self.vocab = vocab
        self.emb = nn.Embedding(vocab, d_model)
        self.pos = sinusoidal_pe(T, d_model)
        self.blocks = nn.ModuleList([TransformerEncoderBlock(d_model, n_heads, d_ff) for _ in range(n_layers)])
        self.lm_head = nn.Linear(d_model, vocab)
    def forward(self, x):
        # x: [B,T]
        B,T = x.shape
        h = self.emb(x) + self.pos[:T].to(x.device)
        mask = causal_mask(T).to(x.device).unsqueeze(0).unsqueeze(0)  # [1,1,T,T]
        for blk in self.blocks:
            h = blk(h, mask=mask)
        return self.lm_head(h)

def make_copy_data(N=256, T=12, vocab=16):
    X = torch.randint(1, vocab, (N,T))  # avoid pad=0 for simplicity
    Y = X.clone()
    return X, Y

model = TinyTransformer(vocab=20, T=12)
opt = optim.Adam(model.parameters(), lr=3e-3)
X, Y = make_copy_data(N=256, T=12, vocab=20)

for step in range(300):
    logits = model(X)                # [B,T,V]
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1))
    opt.zero_grad(); loss.backward(); opt.step()
    
with torch.no_grad():
    pred = model(X[:4]).argmax(-1)
loss.item(), X[:1], pred[:1]

(2.8386692065396346e-05,
 tensor([[ 6,  3,  6, 15,  5,  1,  9,  8, 17,  9,  1,  6]]),
 tensor([[ 6,  3,  6, 15,  5,  1,  9,  8, 17,  9,  1,  6]]))

## Where to go next

- Replace the toy tasks with a real dataset (e.g., small translation or language modeling corpus).
- Explore **masking strategies** (padding vs. causal) and how they affect attention.
- Compare learned embeddings vs. fixed sinusoidal encodings.
- Extend to a full **encoder–decoder** by adding cross‑attention and a separate decoder stack.
