## **PAIR ASSIGNMENT**

Please work in pairs for this exercise. Pair up with the person next to you. If you find that there isn't anyone sitting next to you or if you're unable to form a pair, please raise your hand, and I will assist in pairing you with someone.


#### **TODO: Please assign your full name and your partner's full name to the variables below, respectively.**


Example:

```
your_name = 'Fatma Tarlaci'
your_partner_name = 'Sagnik Majumder'
```

In [None]:
# Assign your and your partner's names here.
# If you find that there isn't anyone sitting next to you
# or if you're unable to form a pair, please raise your hand.
# If, in the end, we are unable to find a partner for you,
# please assign the word "self" to the `your_partner_name` variable.
your_name = ''
your_partner_name = ''


# **Vanishing Gradients, RNN vs LSTM (Language Modeling)**

**Goals**
- Observe **vanishing gradients** in a vanilla RNN language model.
- Implement and compare a stronger model (**LSTM**) on the same data.
- Measure **perplexity** (PP) and **gradient norms** to see the difference.

> **Time box:** ~45–60 minutes. We work on a **tiny** subset of WikiText-2 for speed.



## 0) Setup

Run the following to install dependencies (if needed) and import libraries.


In [None]:
# If running in Colab or a fresh environment, uncomment:
# !pip -q install torch datasets tqdm

import math
import random
from typing import Tuple, List, Dict, Iterable, Optional
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from datasets import load_dataset
from tqdm.auto import tqdm
from dataclasses import dataclass


In [None]:
# Reproducibility and device selection
torch.manual_seed(42)
random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device



## **TASK 1: Observe the Data: Tiny WikiText-2 slice and the Vocabulary implementation**

We will use **Hugging Face** `wikitext` (`wikitext-2-raw-v1`) word-level modeling. You can see its details here: https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-2-raw-v1

**Review and discuss the Pipeline implemented below. Pipeline**
1. Loads the training and validation splits.
2. Builds a word vocabulary from a **small slice** of the training split.
3. Converts text to IDs and create fixed-length sequences for language modeling.


In [None]:
def build_vocab(texts: Iterable[str], max_tokens: int = 20000) -> Tuple[Dict[str, int], Dict[int, str]]:
    """Build a simple word-index vocabulary.

    Args:
        texts: Iterable of documents (strings) used to build the vocabulary.
        max_tokens: Maximum vocab size including special tokens.

    Returns:
        A pair (stoi, itos) where:
            stoi: dict mapping string tokens to integer IDs.
            itos: dict mapping IDs back to string tokens.
    """
    # Special tokens
    specials = ["<pad>", "<unk>"]
    freq: Dict[str, int] = {}
    for doc in texts:
        # Lowercase and split by whitespace for simplicity
        for w in doc.strip().lower().split():
            freq[w] = freq.get(w, 0) + 1
    # Sort by frequency (desc), then alphabetically to stabilize ties
    items = sorted(freq.items(), key=lambda x: (-x[1], x[0]))
    # Truncate to max_tokens - len(specials)
    trimmed = items[: max(0, max_tokens - len(specials))]
    # Build vocab dicts
    itos = {0: "<pad>", 1: "<unk>"}
    for i, (w, _) in enumerate(trimmed, start=len(specials)):
        itos[i] = w
    stoi = {w: i for i, w in itos.items()}
    return stoi, itos


def encode(text: str, stoi: Dict[str, int]) -> List[int]:
    """Encode a string into a list of token IDs using the provided vocab.

    Args:
        text: Input string (document or sentence).
        stoi: Vocabulary mapping from token string to index.

    Returns:
        List of integer token IDs. Unknown tokens map to <unk>.
    """
    unk = stoi.get("<unk>", 1)
    return [stoi.get(w, unk) for w in text.strip().lower().split()]


class LMDataset(Dataset):
    """A tiny language modeling dataset built from text token IDs.

    Each example is a pair (x, y) where:
      - x is a sequence of token IDs of length `seq_len`
      - y is the next-token targets of length `seq_len` (i.e., x shifted by one)
    """

    def __init__(self, ids: List[int], seq_len: int = 30):
        """Initialize the dataset.

        Args:
            ids: A long list of token IDs concatenated from documents.
            seq_len: Unroll length for truncated BPTT.
        """
        self.ids = ids
        self.seq_len = seq_len

    def __len__(self) -> int:
        # Number of full sequences we can make (minus 1 for the next-token target)
        return max(0, len(self.ids) - self.seq_len - 1)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        # Compute the starting index of the subsequence
        start = idx

        # Input sequence: grab tokens from [start : start + seq_len]
        x = self.ids[start : start + self.seq_len]

        # Target sequence: same as input but shifted one step to the right
        y = self.ids[start + 1 : start + self.seq_len + 1]

        # Return both sequences as PyTorch tensors of type long (integer indices)
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)


# 1) Load small wikitext-2-raw-v1 dataset
ds = load_dataset("wikitext", "wikitext-2-raw-v1")

# 2) Take a tiny slice for speed (e.g., first 10k chars worth of tokens)
train_texts = ds["train"]["text"][:2000]   # small slice of documents
valid_texts = ds["validation"]["text"][:500]

# 3) Build vocab from the training
# build_vocab(...) takes the training texts and constructs a mapping from string tokens → integer IDs.
# stoi: string-to-index dictionary (e.g., "the" → 5).
# itos: index-to-string list (e.g., 5 → "the").
# max_tokens=15000 means the vocabulary will be capped at 15,000 unique tokens.
# The most frequent tokens are kept, and rarer ones usually get mapped to an <unk> token.
stoi, itos = build_vocab(train_texts, max_tokens=15000)

# You’ll use this vocab_size to define the embedding layer and the output layer of your language model.
vocab_size = len(stoi)
print(f'The number of unique tokens in your vocabulary (≤ 15,000) is:{vocab_size}')


In [None]:
# 4) Encode text into token IDs and concatenate
train_ids: List[int] = []
for t in train_texts:
    train_ids.extend(encode(t, stoi))

valid_ids: List[int] = []
for t in valid_texts:
    valid_ids.extend(encode(t, stoi))

len(train_ids), len(valid_ids), vocab_size


In [None]:
# 5) Build datasets and loaders
SEQ_LEN = 30
BATCH_SIZE = 64

train_ds = LMDataset(train_ids, seq_len=SEQ_LEN)
valid_ds = LMDataset(valid_ids, seq_len=SEQ_LEN)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

len(train_ds), len(valid_ds)



## **TASK 2: Implement Models**

You will now implement and compare a **Vanilla RNN LM** and an **LSTM LM**. Both are small and trained briefly. Complete the models below.


In [None]:
class VanillaRNNLM(nn.Module):
    """A tiny Vanilla RNN language model.

    Architecture:
      - Embedding
      - Manual loop with `nn.RNNCell` (single layer)
      - Linear decoder to vocab logits

    Notes:
      We keep it tiny and explicit to make gradient tracking easy.
    """

    def __init__(self, vocab_size: int, emb_dim: int = 128, hidden_dim: int = 256, pad_idx: int = 0):
      super().__init__()
      # Initialize parent nn.Module (needed for all PyTorch models)

      # Embedding layer: maps token IDs to dense vectors
      # vocab_size = number of unique tokens in the vocabulary
      # emb_dim = dimensionality of each embedding vector
      # padding_idx ensures the pad token always maps to a zero vector
      self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)

      # RNNCell: processes one time step at a time
      # Takes embedding input of size emb_dim, outputs hidden state of size hidden_dim
      # Uses tanh nonlinearity by default
      self.rnn_cell = nn.RNNCell(emb_dim, hidden_dim, nonlinearity="tanh")

      # Fully connected (linear) layer: projects hidden state → vocabulary logits
      # Needed for predicting next-token probabilities
      self.fc = nn.Linear(hidden_dim, vocab_size)

      # Store hidden dimension size for later use (e.g., when initializing hidden state)
      self.hidden_dim = hidden_dim

    def forward(self, x: torch.Tensor, h0: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass through a sequence.

        Args:
            x: LongTensor of shape (batch, seq_len) with token IDs.
            h0: Optional initial hidden state of shape (batch, hidden_dim).

        Returns:
            logits: FloatTensor of shape (batch, seq_len, vocab_size)
            h: Final hidden state (batch, hidden_dim)
        """
        batch, seqlen = x.size()
        if h0 is None:
            # We use x.new_zeros(...) instead of torch.zeros(..., device=x.device)
            # because new_zeros automatically matches BOTH the device (CPU/GPU)
            # and the dtype of the input tensor `x`.
            # This trick avoids bugs where you accidentally create a hidden state
            # on the wrong device or with the wrong precision, especially in Colab
            # when switching between CPU and GPU. This also helps avoid a common
            # beginner pitfall: mismatched devices (RuntimeError: Expected all tensors to be on the same device).
            h = x.new_zeros((batch, self.hidden_dim), dtype=torch.float32)
        else:
            h = h0

        # Prepare a list to collect per-step logits
        logits_steps = []
        for t in range(seqlen):
            # Embed current time-step tokens
            emb_t = self.emb(x[:, t])
            # One RNNCell step
            h = self.rnn_cell(emb_t, h)
            # Decode to vocabulary logits
            logits_t = self.fc(h)
            logits_steps.append(logits_t.unsqueeze(1))  # keep time dimension

        # Concatenate over time → (batch, seq_len, vocab_size)
        logits = torch.cat(logits_steps, dim=1)
        return logits, h


class LSTMLM(nn.Module):
    """A tiny LSTM language model using `nn.LSTM`.

    Architecture:
      - Embedding
      - Single-layer LSTM
      - Linear decoder to vocab logits
    """

    def __init__(self, vocab_size: int, emb_dim: int = 128, hidden_dim: int = 256, pad_idx: int = 0):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(emb_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x: torch.Tensor, state=None) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Forward pass through a sequence.

        Args:
            x: LongTensor of shape (batch, seq_len) with token IDs.
            state: Optional tuple (h0, c0) initial states of shape (1, batch, hidden_dim).

        Returns:
            logits: FloatTensor of shape (batch, seq_len, vocab_size)
            new_state: Tuple (h, c) final states.
        """
        emb = self.emb(x)                       # (batch, seq_len, emb_dim)
        out, new_state = self.lstm(emb, state)  # (batch, seq_len, hidden_dim)
        logits = self.fc(out)                   # (batch, seq_len, vocab_size)
        return logits, new_state



## **TASK 3: Loss, Perplexity, and Gradient Tracking**

**Loss Function (Cross-Entropy):**
For each batch, compute the cross-entropy loss between the predicted token probabilities and the true next token. This measures how well the model is doing at next-token prediction.

**Perplexity (PP):**
After computing the loss, calculate perplexity as:
**PP=exp(loss)**

* Perplexity gives an interpretable measure of how many choices the model is “confused” between on average. A lower perplexity indicates better predictive performance.

**Gradient Tracking for Vanishing Gradients:**
* To observe the vanishing gradient problem, track how the gradient magnitudes change during training:

* After each backpropagation step, extract the gradients of the recurrent weight matrix (e.g., `rnn.weight_hh_l0` in PyTorch).

* Compute the L2 norm of these gradients.

* Log or plot this value across training iterations.

* If the gradients shrink toward zero, this indicates vanishing gradients; if they explode, you’ll see very large values.

This task is designed to help you connect theory (cross-entropy, perplexity, vanishing gradients) to practice by giving you a direct way to measure model learning and stability.


In [None]:
def lm_loss(
    logits: torch.Tensor,
    targets: torch.Tensor,
    ignore_index: int = 0,
    reduction: str = "mean"
) -> torch.Tensor:
    """Compute cross-entropy loss for next-token language modeling.

    Args:
        logits: FloatTensor of shape (batch, seq_len, vocab_size).
            Raw model outputs (unnormalized log-probabilities).
        targets: LongTensor of shape (batch, seq_len).
            Next-token IDs aligned with logits.
        ignore_index: Token ID to ignore in the loss (e.g., padding).
        reduction: 'mean' (default) or 'sum'.
            - 'mean' → average loss across tokens
            - 'sum' → total loss (useful for token-weighted eval)

    Returns:
        torch.Tensor: Scalar tensor with the loss value.
    """
    batch, seqlen, vocab = logits.shape
    return F.cross_entropy(
        logits.reshape(batch * seqlen, vocab),   # Flatten to (N, V)
        targets.reshape(batch * seqlen),         # Flatten to (N,)
        ignore_index=ignore_index,
        reduction=reduction
    )

# @torch.no_grad() is a PyTorch decorator that tells PyTorch not to track gradients inside the function it decorates.
# Any operations inside evaluate() (forward passes, loss computation, etc.) will not build a computation graph.
# This saves memory and makes evaluation faster, because gradients aren’t needed for inference.
# It also prevents accidental gradient updates when you’re just evaluating.
# In short: @torch.no_grad() = “Run this function in inference mode, no gradients, no backprop.”
@torch.no_grad()
def evaluate(
    model: nn.Module,
    data_loader: DataLoader,
    pad_id: int = 0
) -> Tuple[float, float]:
    """Evaluate model on a dataset and compute perplexity.

    Uses token-weighted averaging to fairly handle batches of varying length.

    Args:
        model: Trained language model (nn.Module).
        data_loader: DataLoader yielding (x, y) batches.
        pad_id: Padding token ID to ignore.

    Returns:
        (mean_loss, perplexity)
        mean_loss: Average cross-entropy per token.
        perplexity: exp(mean_loss); interpretable measure of uncertainty.
    """
    model.eval()

    total_loss_sum = 0.0
    total_tokens = 0

    for x, y in data_loader:
        x = x.to(next(model.parameters()).device)
        y = y.to(next(model.parameters()).device)

        logits, _ = model(x)

        # Sum of per-token losses
        loss_sum = lm_loss(logits, y, ignore_index=pad_id, reduction="sum")

        # Count valid (non-pad) tokens
        valid = (y != pad_id).sum().item()

        total_loss_sum += loss_sum.item()
        total_tokens += max(1, valid)

    mean_loss = total_loss_sum / max(1, total_tokens)
    perplexity = math.exp(mean_loss)
    return mean_loss, perplexity


def l2_grad_norm_recurrent(model: nn.Module) -> float:
    """Track vanishing/exploding gradients by measuring recurrent weight norms.

    In RNN/LSTM/GRU modules, recurrent weights typically include 'weight_hh'
    in their parameter names (e.g., 'rnn.weight_hh_l0').

    Args:
        model: Language model (nn.Module) after backpropagation.

    Returns:
        float: L2 norm of recurrent gradients.
            - Near 0 → vanishing gradients.
            - Very large → exploding gradients.
    """
    total_sq = 0.0
    count = 0
    for name, p in model.named_parameters():
        if ("weight_hh" in name) and (p.grad is not None):
            # Compute squared L2 norm
            total_sq += (p.grad.detach().float().norm(p=2).item()) ** 2
            count += 1
    return (total_sq ** 0.5) if count > 0 else 0.0


## **TASK 4: Observe Vanishing Gradients: Training Vanilla RNN**

We train briefly and record:
- Training loss and PP
- **Gradient L2 norm** of the recurrent weight (`rnn_cell.weight_hh`) after each step


In [None]:
def train_rnn(model: VanillaRNNLM, train_loader: DataLoader, valid_loader: DataLoader, epochs: int = 2, lr: float = 1e-3):
    """Train a VanillaRNNLM and log gradient norms for the recurrent matrix.

    Args:
        model: VanillaRNNLM instance.
        train_loader: Training data loader.
        valid_loader: Validation data loader.
        epochs: Number of training epochs.
        lr: Learning rate.

    Returns:
        A dict with training logs: losses, pps, grad_norms, val_history.
    """
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    logs = {
        "train_loss": [],
        "train_pp": [],
        "grad_norm_rhh": [],
        "val": [],
    }

    for ep in range(1, epochs + 1):
        model.train()
        pbar = tqdm(train_loader, desc=f"RNN Epoch {ep}")
        for x, y in pbar:
            x = x.to(device)
            y = y.to(device)

            # Forward pass
            logits, _ = model(x)

            # Compute loss
            loss = lm_loss(logits, y)

            # Backward
            opt.zero_grad(set_to_none=True)
            loss.backward()

            # Record gradient norm of recurrent weight (a common place to see vanishing)
            rhh = model.rnn_cell.weight_hh
            grad_norm = rhh.grad.detach().norm(2).item() if rhh.grad is not None else 0.0

            # Clip to avoid exploding for fairness (small value)
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            # Update
            opt.step()

            logs["train_loss"].append(loss.item())
            logs["train_pp"].append(math.exp(loss.item()))
            logs["grad_norm_rhh"].append(grad_norm)
            pbar.set_postfix(loss=f"{loss.item():.3f}", pp=f"{math.exp(loss.item()):.1f}", grad=f"{grad_norm:.4e}")

        # Validation
        val_loss, val_pp = evaluate(model, valid_loader)
        logs["val"].append((val_loss, val_pp))
        print(f"[RNN] Epoch {ep}: val_loss={val_loss:.3f}, val_pp={val_pp:.1f}")

    return logs, model



## **TASK 5: Train LSTM And Compare the Results to Vanilla RNN below**

We repeat with an **LSTM** and track the gradient norm of its recurrent weight (`weight_hh_l0`).


In [None]:
def _find_lstm_recurrent_weight(model: nn.Module) -> Optional[torch.nn.Parameter]:
    """Best-effort lookup for the first layer's recurrent weight matrix in an LSTM.

    LSTM parameter names in PyTorch typically look like:
        - weight_ih_l0  (input-to-hidden, layer 0)
        - weight_hh_l0  (hidden-to-hidden, layer 0)  ← recurrent matrix we want
      If the model is wrapped (e.g., inside a module attr like `model.lstm`
      or DataParallel), we search by substring.

    Args:
        model: LSTM language model (must contain an nn.LSTM module).

    Returns:
        torch.nn.Parameter or None if not found.
    """
    # Common direct attribute path: model.lstm.weight_hh_l0
    if hasattr(model, "lstm") and hasattr(model.lstm, "weight_hh_l0"):
        return model.lstm.weight_hh_l0

    # Fallback: scan parameters by name for the first layer recurrent weight
    for name, p in model.named_parameters():
        # Match both layer 0 and a generic 'weight_hh' just in case
        if ("weight_hh_l0" in name) or (name.endswith("weight_hh") and p.ndimension() == 2):
            return p

    return None


def train_lstm(
    model: "LSTMLM",
    train_loader: DataLoader,
    valid_loader: DataLoader,
    epochs: int = 2,
    lr: float = 1e-3
) -> Tuple[Dict[str, List[float]], "LSTMLM"]:
    """Train an LSTM language model; log perplexity & recurrent gradient norms.

    This mirrors the vanilla RNN trainer so you can compare *vanishing/exploding*
    behavior across architectures under similar training conditions.

    Training loop (per batch):
      1) Forward pass → logits over vocab for each time step
      2) Cross-entropy loss against next-token targets
      3) Backpropagate gradients
      4) Record L2 norm of *recurrent* matrix grads (weight_hh_l0)
      5) (Gentle) gradient clipping to tame explosions
      6) Optimizer step
      7) Log loss, perplexity, grad norm

    Notes:
      • We log the L2 norm of `weight_hh_l0` gradients. RNNs tend to show more
        vanishing; LSTMs often mitigate it via gating, so you should see larger,
        healthier gradient norms on average—but still decaying with long sequences.
      • If your pad id ≠ 0, set it in your `lm_loss` / `evaluate` helpers.
      • Ensure your `model(x)` returns (logits, state) where logits=(B, T, V).

    Args:
        model: LSTMLM instance (embedding → LSTM → linear-to-vocab).
        train_loader: Dataloader yielding (x, y) training batches.
        valid_loader: Dataloader yielding (x, y) validation batches.
        epochs: Number of full passes over the training set.
        lr: Learning rate for Adam.

    Returns:
        (logs, model)
        logs:
          - "train_loss":    per-step cross-entropy (float)
          - "train_pp":      per-step perplexity = exp(loss)
          - "grad_norm_hh":  per-step L2 norm of recurrent weight grads
          - "val":           per-epoch (val_loss, val_pp)
        model: The trained model (same object, updated in-place).
    """
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    logs: Dict[str, List] = {
        "train_loss": [],
        "train_pp": [],
        "grad_norm_hh": [],
        "val": [],
    }

    # Resolve the recurrent matrix once (will still check grad presence each step)
    recurrent_w = _find_lstm_recurrent_weight(model)
    if recurrent_w is None:
        print("[warn] Could not find LSTM recurrent weight; grad tracking will be 0.0. "
              "Adapt `_find_lstm_recurrent_weight` for your model structure.")

    for ep in range(1, epochs + 1):
        model.train()
        pbar = tqdm(train_loader, desc=f"LSTM Epoch {ep}")

        for x, y in pbar:
            # ---- 1) Move batch to device
            x = x.to(device)
            y = y.to(device)

            # ---- 2) Forward pass: logits shape (B, T, V); state unused here
            logits, _ = model(x)

            # ---- 3) Cross-entropy over next-token targets (ignore padding id=0 by default)
            loss = lm_loss(logits, y, ignore_index=0, reduction="mean")

            # ---- 4) Backpropagation
            opt.zero_grad(set_to_none=True)
            loss.backward()

            # ---- 5) Recurrent gradient L2 norm (vanishing/exploding signal)
            # Tracking this norm shows whether gradients are vanishing (very close to zero)
            # or exploding (very large). It’s a diagnostic tool for training stability in RNNs and LSTMs.

            # The if statement below ensures the recurrent weight matrix exists (e.g., the hidden-to-hidden weight in an RNN cell).
            # and makes sure a gradient has been computed for it (i.e., after loss.backward()).
            if recurrent_w is not None and recurrent_w.grad is not None:
                #Breaks the gradient tensor from the computation graph, so we don’t accidentally backpropagate again and
                # Converts the tensor to float type to avoid precision issues. norm(p=2) computes the L2 norm (Euclidean length)
                # of the gradient vector. This tells us the overall “size” of the gradient update for this weight matrix.
                grad_norm = recurrent_w.grad.detach().float().norm(p=2).item()
            else:
                #If the weight or its gradient doesn’t exist, it safely records the gradient norm as zero.
                grad_norm = 0.0

            # ---- 6) Gentle global clipping helps with explosions without hiding vanishing
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            # ---- 7) Parameter update
            opt.step()

            # ---- 8) Metrics & live logging
            loss_val = loss.item()
            pp_val = math.exp(loss_val)
            logs["train_loss"].append(loss_val)
            logs["train_pp"].append(pp_val)
            logs["grad_norm_hh"].append(grad_norm)

            pbar.set_postfix(
                loss=f"{loss_val:.3f}",
                pp=f"{pp_val:.1f}",
                grad=f"{grad_norm:.4e}"
            )

        # ---- 9) End-of-epoch validation (token-weighted mean loss + perplexity)
        val_loss, val_pp = evaluate(model, valid_loader)  # assumes ignore_index inside
        logs["val"].append((val_loss, val_pp))
        print(f"[LSTM] Epoch {ep}: val_loss={val_loss:.3f}, val_pp={val_pp:.1f}")

    return logs, model



## **TASK 6) Run Experiments and Discuss the Results**

We keep the setup small so it trains quickly, but you should still observe the key differences between architectures:

**Vanilla RNN**

Recurrent gradient norms often shrink toward zero, illustrating the vanishing gradient problem.

Struggles to capture longer-range dependencies, leading to higher perplexity (worse language modeling performance).

**LSTM**

Thanks to gating mechanisms (input/forget/output), gradient flow is healthier, norms stay larger and more stable.

Handles longer context better, resulting in lower perplexity (better predictive performance).

As you compare plots of perplexity and gradient norms, think about:

* How the theory of vanishing gradients is visible in practice.

* Why LSTMs, despite being more complex, are still the default for many sequence tasks before Transformers.

* What tradeoffs you might expect if you scaled up sequence length, hidden size, or training time.

In [None]:
def set_seed(seed: int = 1337) -> None:
    """Fix random seeds for fair comparisons."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(1337)


# ----------------------------
# Config (shared across models)
# ----------------------------
@dataclass
class ExpConfig:
    emb: int = 128
    hid: int = 256
    epochs: int = 2
    lr: float = 1e-3
    pad_idx: int = 0

CFG = ExpConfig()

In [None]:
# Vanilla RNN
rnn_model = VanillaRNNLM(
    vocab_size=vocab_size,
    emb_dim=CFG.emb,
    hidden_dim=CFG.hid,
    pad_idx=CFG.pad_idx if hasattr(VanillaRNNLM, "__init__") else 0  # optional
).to(device)
rnn_logs, rnn_model = train_rnn(
    rnn_model, train_loader, valid_loader, epochs=CFG.epochs, lr=CFG.lr
)

# LSTM
lstm_model = LSTMLM(
    vocab_size=vocab_size,
    emb_dim=CFG.emb,
    hidden_dim=CFG.hid,
    pad_idx=CFG.pad_idx if hasattr(LSTMLM, "__init__") else 0  # optional
).to(device)
lstm_logs, lstm_model = train_lstm(
    lstm_model, train_loader, valid_loader, epochs=CFG.epochs, lr=5e-4
)


# ----------------------------
# Summarize results
# ----------------------------
def summarize_run(name: str, logs: Dict) -> Dict[str, float]:
    """Create a compact summary for printing/plotting."""
    last_train_loss = logs["train_loss"][-1] if logs["train_loss"] else float("nan")
    last_train_pp   = logs["train_pp"][-1]   if logs["train_pp"]   else float("nan")
    last_grad_norm  = logs["grad_norm_rhh"][-1] if "grad_norm_rhh" in logs and logs["grad_norm_rhh"] \
                      else (logs["grad_norm_hh"][-1] if "grad_norm_hh" in logs and logs["grad_norm_hh"] else float("nan"))
    last_val_loss, last_val_pp = logs["val"][-1] if logs["val"] else (float("nan"), float("nan"))
    return {
        "model": name,
        "train_loss": last_train_loss,
        "train_pp": last_train_pp,
        "val_loss": last_val_loss,
        "val_pp": last_val_pp,
        "last_grad_L2_recurrent": last_grad_norm,
    }

rnn_summary  = summarize_run("VanillaRNN", rnn_logs)
lstm_summary = summarize_run("LSTM",        lstm_logs)

# Pretty print a quick table
def _fmt(x: float) -> str:
    return f"{x:.3f}" if isinstance(x, float) and x == x else "n/a"

print("\n=== Summary (same config, different cells) ===")
print(f"{'Model':12}  {'TrainLoss':>10}  {'TrainPP':>10}  {'ValLoss':>10}  {'ValPP':>10}  {'GradL2(hh)':>12}")
for row in (rnn_summary, lstm_summary):
    print(f"{row['model']:12}  "
          f"{_fmt(row['train_loss']):>10}  {_fmt(row['train_pp']):>10}  "
          f"{_fmt(row['val_loss']):>10}  {_fmt(row['val_pp']):>10}  {_fmt(row['last_grad_L2_recurrent']):>12}")



## **Observation: Visualize Perplexity and Gradient Norms**


In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

# 1) Training Loss (all steps) – complements our Train PP slice
plt.figure()
plt.plot(rnn_logs["train_loss"], label="RNN loss")
plt.plot(lstm_logs["train_loss"], label="LSTM loss")
plt.xlabel("Update step")
plt.ylabel("Cross-Entropy (train)")
plt.legend()
plt.title("Training Loss (all steps)")
plt.tight_layout()
plt.show()

# 2) (Optional) Smoothed Train PP with a rolling mean
def moving_avg(xs, k=20):
    import numpy as np
    if len(xs) < k: return xs
    c = np.cumsum([0.0] + xs)
    return [(c[i+k]-c[i])/k for i in range(len(xs)-k+1)]

plt.figure()
plt.plot(moving_avg(rnn_logs["train_pp"][:300]), label="RNN PP (smoothed)")
plt.plot(moving_avg(lstm_logs["train_pp"][:300]), label="LSTM PP (smoothed)")
plt.xlabel("Update step (smoothed)")
plt.ylabel("Perplexity (train)")
plt.yscale("log")  # helps when values span orders of magnitude
plt.legend(); plt.title("Train Perplexity (first 300 steps, smoothed)")
plt.tight_layout(); plt.show()

# 3) Gradient norms, plus log-scale for visibility if needed
plt.figure()
plt.plot(rnn_logs["grad_norm_rhh"][:300], label="RNN ‖∇(hh)‖₂")
plt.plot(lstm_logs["grad_norm_hh"][:300], label="LSTM ‖∇(hh_l0)‖₂")
plt.xlabel("Update step")
plt.ylabel("L2 grad norm")
plt.yscale("log")  # helps when values span orders of magnitude
plt.legend(); plt.title("Recurrent Gradient Norms (first 300 steps, log scale)")
plt.tight_layout(); plt.show()

# 4) Validation PP per epoch
rnn_val_pp  = [pp for _, pp in rnn_logs["val"]]
lstm_val_pp = [pp for _, pp in lstm_logs["val"]]
plt.figure()
plt.plot(rnn_val_pp, marker="o", label="RNN val PP")
plt.plot(lstm_val_pp, marker="o", label="LSTM val PP")
plt.xlabel("Epoch"); plt.ylabel("Validation PP")
plt.legend(); plt.title("Validation Perplexity by Epoch")
plt.tight_layout(); plt.show()




## **8) Discussion: Explain in 2–3 sentences**

- What evidence in the plots suggests **vanishing gradients** for the vanilla RNN?
  -   - The RNN’s gradient norm curve drops sharply over training steps, approaching values near zero on a log scale. This exponential decay shows the gradient signal dying out, which means earlier time steps cannot influence weight updates, classic vanishing gradient behavior.
- How do **perplexity** trends differ between RNN and LSTM?
  - For vanilla RNNs, perplexity decreases at first but then plateaus at a higher value, showing the model cannot capture long dependencies. LSTMs continue reducing perplexity more steadily and reach lower values, because the cell state and gates preserve long-term information.
- If you increased `SEQ_LEN` (unroll length), how would you expect the RNN's gradient norms to change? Why?
  - The gradient norms for the vanilla RNN would shrink even faster. Longer unrolls = more Jacobian multiplications in BPTT, so if each multiplier is < 1 in norm, the product decays exponentially quicker. Result: gradients vanish more severely, making training unstable for long contexts.
- If you extend SEQ_LEN from 30 to 90 with the same LR, what happens to RNN grads and why do LSTM gates help?
  - For RNNs: Gradients vanish almost completely, since the signal must travel through 3× as many steps; weight updates for long-range dependencies approach zero.
  - For LSTMs: The gating mechanism allows information to flow along the cell state with minimal decay. Forget/input/output gates decide what to keep, update, or reveal, so the effective gradient doesn’t vanish as badly. This enables LSTMs to remain trainable even with longer unroll lengths.
