# 10 - RNN: Backpropagation Through Time (BPTT)

Backpropagation Through Time (BPTT) is the algorithm used to train RNNs on sequences. It extends the chain rule to unroll the RNN over time, allowing gradients to flow through each time step.

While transformers use attention instead of recurrence, understanding BPTT is key to grasping how sequence models learn dependencies across time steps.

## 🔁 What is BPTT?

BPTT unrolls the RNN across all time steps, computes the loss at each step, and then backpropagates gradients through the entire sequence.

**LLM/Transformer Context:**
- BPTT is the precursor to the parallel, attention-based backpropagation in transformers.

### Task:
- Scaffold a function to unroll an RNN and store all intermediate activations needed for backpropagation.
- Add a docstring explaining its role.

In [None]:
def rnn_forward_cache(X_seq, h0, W_xh, W_hh, b_h):
    """
    Run an RNN over a sequence and cache all activations for BPTT.
    Args:
        X_seq (np.ndarray): Input sequence (seq_len x input_dim)
        h0 (np.ndarray): Initial hidden state (hidden_dim,)
        W_xh, W_hh, b_h: RNN parameters
    Returns:
        dict: Cached activations (inputs, hidden states, pre-activations)
    """
    # TODO: Implement forward pass with caching for BPTT
    pass

## 🔗 Backward Pass: Computing Gradients Through Time

The backward pass computes gradients for all parameters by propagating errors backward through each time step.

**LLM/Transformer Context:**
- This is analogous to how gradients flow through the layers and positions in a transformer.

### Task:
- Scaffold a function to compute gradients for all RNN parameters using cached activations and sequence loss gradients.
- Add a docstring explaining the process.

In [None]:
def rnn_bptt(cache, dL_dh_last, W_xh, W_hh):
    """
    Perform backpropagation through time (BPTT) for an RNN.
    Args:
        cache (dict): Cached activations from forward pass.
        dL_dh_last (np.ndarray): Gradient of loss w.r.t. last hidden state.
        W_xh, W_hh: RNN parameters (needed for gradients).
    Returns:
        dict: Gradients for W_xh, W_hh, b_h, and input sequence.
    """
    # TODO: Implement BPTT to compute gradients for all parameters
    pass

## 🧮 Gradient Clipping

RNNs can suffer from exploding gradients during BPTT. Gradient clipping is used to prevent this by capping gradients at a maximum value.

**LLM/Transformer Context:**
- Gradient clipping is also used in transformer training to stabilize optimization.

### Task:
- Scaffold a function to clip gradients to a maximum norm.
- Add a docstring explaining why this is important.

In [None]:
def clip_gradients(grads, max_norm):
    """
    Clip gradients to a maximum norm to prevent exploding gradients.
    Used in both RNN and transformer training for stability.
    Args:
        grads (dict): Dictionary of gradients (arrays).
        max_norm (float): Maximum allowed norm.
    Returns:
        dict: Clipped gradients.
    """
    # TODO: Implement gradient clipping
    pass

## 🔁 Training Loop with BPTT

Combine forward and backward passes to train the RNN on sequence data using BPTT.

**LLM/Transformer Context:**
- This is the sequence-level training loop, analogous to how LLMs are trained on long text sequences.

### Task:
- Scaffold a function for the RNN training loop using BPTT and gradient clipping.
- Add a docstring explaining each step.

In [None]:
def train_rnn_bptt(X_seq, targets_seq, params, loss_fn, lr, epochs, max_norm):
    """
    Train an RNN on sequence data using BPTT and gradient clipping.
    Args:
        X_seq (np.ndarray): Input sequence (seq_len x input_dim)
        targets_seq (np.ndarray): Target sequence (seq_len,)
        params (dict): RNN parameters (W_xh, W_hh, b_h, W_hy, b_y)
        loss_fn (callable): Loss function (e.g., cross-entropy)
        lr (float): Learning rate.
        epochs (int): Number of training epochs.
        max_norm (float): Max norm for gradient clipping.
    Returns:
        dict: Trained parameters.
    """
    # TODO: Implement the RNN training loop with BPTT and gradient clipping
    pass

## 🧠 Final Summary: BPTT and Sequence Gradients in LLMs

- BPTT enables RNNs to learn dependencies across time steps, a key challenge in sequence modeling.
- Transformers use attention to address some of the limitations of BPTT, but the core idea of propagating gradients through sequences remains.
- Gradient clipping is essential for stable training in both RNNs and transformers.

In the next notebook, you'll use these ideas to build a character-level language model!