# GRU in Python (Colab)

This notebook teaches GRU (Gated Recurrent Unit) **step-by-step**:

1. Intuition + equations
2. Implement a **GRU cell from scratch** in PyTorch
3. Train a small GRU model on a toy sequence task
4. Tips: padding, packing, gradient clipping


## 1) Setup
Colab usually includes PyTorch. If not, uncomment the install cell.

In [None]:
# If needed:
# !pip -q install torch

import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)
random.seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## 2) GRU equations (reference)

Given input $x_t$ and previous hidden state $h_{t-1}$:

$$z_t=\sigma(W_z x_t + U_z h_{t-1} + b_z)$$
$$r_t=\sigma(W_r x_t + U_r h_{t-1} + b_r)$$
$$\tilde{h}_t=\tanh(W_h x_t + U_h (r_t \odot h_{t-1}) + b_h)$$
$$h_t=(1-z_t)\odot h_{t-1} + z_t\odot \tilde{h}_t$$

Where $\odot$ is elementwise multiply.

## 3) Implement a GRU cell from scratch (PyTorch)

We’ll build a module that consumes a **single time step**: `(x_t, h_prev) -> h_t`.


In [None]:
class GRUCellScratch(nn.Module):
    def __init__(self, input_size: int, hidden_size: int):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # We implement separate linear layers for clarity.
        self.x2z = nn.Linear(input_size, hidden_size)
        self.h2z = nn.Linear(hidden_size, hidden_size, bias=False)

        self.x2r = nn.Linear(input_size, hidden_size)
        self.h2r = nn.Linear(hidden_size, hidden_size, bias=False)

        self.x2h = nn.Linear(input_size, hidden_size)
        self.h2h = nn.Linear(hidden_size, hidden_size, bias=False)

    def forward(self, x_t: torch.Tensor, h_prev: torch.Tensor) -> torch.Tensor:
        # x_t: (batch, input_size)
        # h_prev: (batch, hidden_size)
        z_t = torch.sigmoid(self.x2z(x_t) + self.h2z(h_prev))
        r_t = torch.sigmoid(self.x2r(x_t) + self.h2r(h_prev))
        h_tilde = torch.tanh(self.x2h(x_t) + self.h2h(r_t * h_prev))
        h_t = (1 - z_t) * h_prev + z_t * h_tilde
        return h_t

# Quick shape test
batch, input_size, hidden_size = 4, 6, 10
cell = GRUCellScratch(input_size, hidden_size).to(device)
x_t = torch.randn(batch, input_size, device=device)
h_prev = torch.randn(batch, hidden_size, device=device)
h_t = cell(x_t, h_prev)
h_t.shape

## 4) Build a simple GRU sequence model (from scratch cell)

We’ll unroll the cell across time.


In [None]:
class GRUSequenceModelScratch(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, output_size: int):
        super().__init__()
        self.cell = GRUCellScratch(input_size, hidden_size)
        self.readout = nn.Linear(hidden_size, output_size)

    def forward(self, x: torch.Tensor, h0: torch.Tensor | None = None):
        # x: (batch, seq_len, input_size)
        b, t, _ = x.shape
        if h0 is None:
            h = torch.zeros(b, self.cell.hidden_size, device=x.device)
        else:
            h = h0
        hs = []
        for i in range(t):
            h = self.cell(x[:, i, :], h)
            hs.append(h)
        H = torch.stack(hs, dim=1)  # (batch, seq_len, hidden)
        y = self.readout(H[:, -1, :])  # final-step prediction
        return y, H

model = GRUSequenceModelScratch(input_size=6, hidden_size=24, output_size=1).to(device)
x = torch.randn(8, 5, 6, device=device)
y, H = model(x)
y.shape, H.shape

## 5) Toy task: "Did we see a special symbol?" (sequence classification)

We create sequences of length `T` with one-hot vectors of size `V`.
Label is **1** if token `special_id` appears anywhere, else **0**.


In [None]:
def make_dataset(n=2000, T=20, V=12, special_id=3, p_special=0.3):
    X = torch.zeros(n, T, V)
    y = torch.zeros(n, 1)
    for i in range(n):
        has_special = (random.random() < p_special)
        pos = random.randrange(T) if has_special else None
        for t in range(T):
            token = special_id if (pos is not None and t == pos) else random.randrange(V)
            X[i, t, token] = 1.0
        y[i, 0] = 1.0 if has_special else 0.0
    return X, y

T, V = 25, 16
X, y = make_dataset(n=4000, T=T, V=V, special_id=7, p_special=0.5)

# Train/val split
idx = torch.randperm(X.size(0))
train_idx, val_idx = idx[:3200], idx[3200:]
X_train, y_train = X[train_idx].to(device), y[train_idx].to(device)
X_val, y_val = X[val_idx].to(device), y[val_idx].to(device)
X_train.shape, y_train.shape

## 6) Train the scratch GRU model

We’ll use BCEWithLogitsLoss for binary classification.


In [None]:
model = GRUSequenceModelScratch(input_size=V, hidden_size=64, output_size=1).to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.BCEWithLogitsLoss()

def accuracy(logits, y_true):
    preds = (torch.sigmoid(logits) > 0.5).float()
    return (preds == y_true).float().mean().item()

batch_size = 128
for epoch in range(1, 11):
    model.train()
    perm = torch.randperm(X_train.size(0), device=device)
    total_loss = 0.0
    total_acc = 0.0
    steps = 0
    for i in range(0, X_train.size(0), batch_size):
        idxb = perm[i:i+batch_size]
        xb, yb = X_train[idxb], y_train[idxb]
        opt.zero_grad(set_to_none=True)
        logits, _ = model(xb)
        loss = loss_fn(logits, yb)
        loss.backward()
        # Gradient clipping is common for RNNs
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        opt.step()
        total_loss += loss.item()
        total_acc += accuracy(logits.detach(), yb)
        steps += 1

    model.eval()
    with torch.no_grad():
        val_logits, _ = model(X_val)
        val_loss = loss_fn(val_logits, y_val).item()
        val_acc = accuracy(val_logits, y_val)

    print(f"Epoch {epoch:02d} | train loss {total_loss/steps:.4f} acc {total_acc/steps:.3f} | val loss {val_loss:.4f} acc {val_acc:.3f}")

## 7) The PyTorch built-in GRU

`nn.GRU` handles batching efficiently and supports multi-layer + bidirectional GRUs.


In [None]:
class GRUSequenceModelTorch(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1):
        super().__init__()
        self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.head = nn.Linear(hidden_size, 1)

    def forward(self, x):
        # out: (batch, seq_len, hidden)
        out, h_n = self.gru(x)
        logits = self.head(out[:, -1, :])
        return logits

model2 = GRUSequenceModelTorch(input_size=V, hidden_size=64).to(device)
opt2 = torch.optim.Adam(model2.parameters(), lr=1e-3)
loss_fn = nn.BCEWithLogitsLoss()

for epoch in range(1, 6):
    model2.train()
    perm = torch.randperm(X_train.size(0), device=device)
    for i in range(0, X_train.size(0), batch_size):
        idxb = perm[i:i+batch_size]
        xb, yb = X_train[idxb], y_train[idxb]
        opt2.zero_grad(set_to_none=True)
        logits = model2(xb)
        loss = loss_fn(logits, yb)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model2.parameters(), 1.0)
        opt2.step()

    model2.eval()
    with torch.no_grad():
        val_logits = model2(X_val)
        val_loss = loss_fn(val_logits, y_val).item()
        val_acc = accuracy(val_logits, y_val)
    print(f"Epoch {epoch:02d} | val loss {val_loss:.4f} acc {val_acc:.3f}")

## 8) Next steps

- Variable-length sequences: `pack_padded_sequence` / `pad_packed_sequence`
- Time series forecasting: predict next value(s) instead of classification
- Add dropout / layer norm / better token embeddings
