In [None]:
# Citation for gradient calculations in LSTM
# Based on "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation"
# Cho et al., 2014 - https://arxiv.org/pdf/1409.0473

# The paper describes the gradient flow through LSTM gates as follows:
# For each time step t:
# 1. Output gate gradients
# 2. Cell state gradients
# 3. Input gate gradients
# 4. Forget gate gradients
# 5. Hidden state gradients

# These gradients are used to update the parameters W_i, W_f, W_o, W_g, U_i, U_f, U_o, U_g
# and biases b_i, b_f, b_o, b_g through backpropagation through time (BPTT)

import torch
from torch import tanh

# Create a toy LSTM cell with parameters
input_size = 256
hidden_size = 256
seq_len = 10
batch_size = 1

x = torch.randn(seq_len, batch_size, input_size)  # Input sequence

# Initialize parameters
W_i = torch.randn(input_size, hidden_size, requires_grad=True)
W_f = torch.randn(input_size, hidden_size, requires_grad=True)
W_o = torch.randn(input_size, hidden_size, requires_grad=True)
W_g = torch.randn(input_size, hidden_size, requires_grad=True)

U_i = torch.randn(hidden_size, hidden_size, requires_grad=True)
U_f = torch.randn(hidden_size, hidden_size, requires_grad=True)
U_o = torch.randn(hidden_size, hidden_size, requires_grad=True)
U_g = torch.randn(hidden_size, hidden_size, requires_grad=True)

b_i = torch.randn(hidden_size, requires_grad=True)
b_f = torch.randn(hidden_size, requires_grad=True)
b_o = torch.randn(hidden_size, requires_grad=True)
b_g = torch.randn(hidden_size, requires_grad=True)

# Initialize states - use lists to avoid in-place modifications
h_states = [torch.zeros(batch_size, hidden_size)]  # h_0
c_states = [torch.zeros(batch_size, hidden_size)]  # c_0

# Forward pass with autograd
for t in range(seq_len):
    h_prev = h_states[t]
    c_prev = c_states[t]
    
    # Gates
    i = torch.sigmoid(x[t] @ W_i + h_prev @ U_i + b_i)
    f = torch.sigmoid(x[t] @ W_f + h_prev @ U_f + b_f)
    g = torch.tanh(x[t] @ W_g + h_prev @ U_g + b_g)
    o = torch.sigmoid(x[t] @ W_o + h_prev @ U_o + b_o)
    
    # Update states (create new tensors, not in-place)
    c_new = f * c_prev + i * g
    h_new = o * torch.tanh(c_new)
    
    h_states.append(h_new)
    c_states.append(c_new)

# Create a dummy loss from final hidden state
loss = h_states[-1].sum()
loss.backward()

# Get autograd gradients
autograd_grads = {
    'W_i': W_i.grad.clone(), 'W_f': W_f.grad.clone(), 'W_o': W_o.grad.clone(), 'W_g': W_g.grad.clone(),
    'U_i': U_i.grad.clone(), 'U_f': U_f.grad.clone(), 'U_o': U_o.grad.clone(), 'U_g': U_g.grad.clone(),
    'b_i': b_i.grad.clone(), 'b_f': b_f.grad.clone(), 'b_o': b_o.grad.clone(), 'b_g': b_g.grad.clone()
}

# Manual backprop
δL_δh_t = torch.ones_like(h_states[-1])  # Dummy loss
δh_t_future = δL_δh_t
δc_t_future = torch.zeros_like(c_states[-1])

# Initialize manual gradients
manual_grads = {
    'W_i': torch.zeros_like(W_i), 'W_f': torch.zeros_like(W_f),
    'W_o': torch.zeros_like(W_o), 'W_g': torch.zeros_like(W_g),
    'U_i': torch.zeros_like(U_i), 'U_f': torch.zeros_like(U_f),
    'U_o': torch.zeros_like(U_o), 'U_g': torch.zeros_like(U_g),
    'b_i': torch.zeros_like(b_i), 'b_f': torch.zeros_like(b_f),
    'b_o': torch.zeros_like(b_o), 'b_g': torch.zeros_like(b_g)
}

# This loop goes backwards through time
for t in reversed(range(seq_len)):
    h_prev = h_states[t]
    c_prev = c_states[t]
    c_curr = c_states[t + 1]
    
    # Recompute gates for current timestep (no grad needed for manual backprop)
    with torch.no_grad():
        i = torch.sigmoid(x[t] @ W_i + h_prev @ U_i + b_i)
        f = torch.sigmoid(x[t] @ W_f + h_prev @ U_f + b_f)
        g = torch.tanh(x[t] @ W_g + h_prev @ U_g + b_g)
        o = torch.sigmoid(x[t] @ W_o + h_prev @ U_o + b_o)
    
    # Compute gradients w.r.t. current timestep
    # TODO: review this
    δh_t = δh_t_future

    # ∂L/∂c_t   = ∂L/∂h_t * ∂h_t/∂c_t 
    #           = δh_t * o_t * (1 - tanh²(c_t))
    #           = δh_t * o * (1 - torch.tanh(c_curr)**2)
    #  Plus the gradient coming from the next timestep (δc_t_future)
    δc_t = δc_t_future + δh_t * o * (1 - torch.tanh(c_curr)**2) # ✅ 
    
    # Gradients w.r.t. gates
    δi = δc_t * g # ✅ 
    δf = δc_t * c_prev # ✅ 
    δg = δc_t * i # ✅ 
    δo = δh_t * torch.tanh(c_curr) # ✅ 
    
    # Gradients w.r.t. gate pre-activations (applying activation derivatives)
    δz_i = δi * i * (1 - i)  # sigmoid derivative ✅ 
    δz_f = δf * f * (1 - f)  # sigmoid derivative ✅ 
    δz_o = δo * o * (1 - o)  # sigmoid derivative ✅ 
    δz_g = δg * (1 - g**2)   # tanh derivative ✅ 
    
    # Accumulate parameter gradients
    # x[t] has shape (1, 256) (batch_size, input_size)
    # δz_i has shape (1, 256) (batch_size, hidden_size)
    # Result then matches shape of W_i (256, 256)
    manual_grads['W_i'] += x[t].T @ δz_i
    manual_grads['W_f'] += x[t].T @ δz_f
    manual_grads['W_o'] += x[t].T @ δz_o
    manual_grads['W_g'] += x[t].T @ δz_g
    
    # h_prev has shape (1, 256) (batch_size, hidden_size)
    # δz_i has shape (1, 256) (batch_size, hidden_size)
    # Result then matches shape of U_i (256, 256)
    manual_grads['U_i'] += h_prev.T @ δz_i
    manual_grads['U_f'] += h_prev.T @ δz_f
    manual_grads['U_o'] += h_prev.T @ δz_o
    manual_grads['U_g'] += h_prev.T @ δz_g
    
    # δz_i has shape (1, 256) (batch_size, hidden_size)
    # Result then matches shape of b_i (256)
    # Gradients are accumulated here because we're summing over all the timesteps
    
    manual_grads['b_i'] += δz_i.sum(dim=0)
    manual_grads['b_f'] += δz_f.sum(dim=0)
    manual_grads['b_o'] += δz_o.sum(dim=0)
    manual_grads['b_g'] += δz_g.sum(dim=0)
    
    # Compute gradients for next timestep (flowing backwards)
    # u {i, f, o, g} (different gates)
    # ∂L/∂h_t-1 = ∂L/∂u * ∂u/∂h_t-1
    # add up all the gradients from all the gates to contribute to the gradient of the previous timestep
    # since we're working backwards, it's the "future" gradient sent backwards

    # For one gate (say i) at time-step t and batch-size B
    # h_t-1 has shape (B, H)
    # U_i has shape (H, H) 
    # nn.Linear keeps its weight tensor with shape (out_features, in_features) and implements the forward pass as
    # z = x @ W.T + b          #  (B, in) · (in, out)  → (B, out)
    # So the weight is implicitly transposed once during the forward multiply.
    # That “.T” is purely to undo the transpose that the forward call applied
    δh_t_future = δz_i @ U_i.T + δz_f @ U_f.T + δz_o @ U_o.T + δz_g @ U_g.T
    δc_t_future = δc_t * f

# Compare gradients
print("Gradient Comparison (Autograd vs Manual):")
print("=" * 50)

total_diff = 0
for param_name in autograd_grads:
    diff = torch.abs(autograd_grads[param_name] - manual_grads[param_name]).mean()
    max_diff = torch.abs(autograd_grads[param_name] - manual_grads[param_name]).max()
    rel_error = diff / (torch.abs(autograd_grads[param_name]).mean() + 1e-8)
    
    print(f"{param_name:4s} | Mean diff: {diff.item():8.2e} | Max diff: {max_diff.item():8.2e} | Rel error: {rel_error.item():8.2e}")
    total_diff += diff.item()

print("=" * 50)
print(f"Total mean difference: {total_diff:.2e}")

if total_diff < 1e-5:
    print("✅ Gradients match! Manual backprop implementation is correct.")
else:
    print("❌ Gradients don't match. Check manual backprop implementation.")

Gradient Comparison (Autograd vs Manual):
W_i  | Mean diff: 8.92e-08 | Max diff: 5.72e-06 | Rel error: 2.27e-07
W_f  | Mean diff: 5.60e-08 | Max diff: 7.63e-06 | Rel error: 2.39e-07
W_o  | Mean diff: 1.42e-07 | Max diff: 1.14e-05 | Rel error: 2.75e-07
W_g  | Mean diff: 9.00e-08 | Max diff: 1.91e-05 | Rel error: 2.09e-07
U_i  | Mean diff: 1.53e-08 | Max diff: 1.19e-06 | Rel error: 2.01e-07
U_f  | Mean diff: 1.49e-08 | Max diff: 2.86e-06 | Rel error: 2.27e-07
U_o  | Mean diff: 2.27e-08 | Max diff: 1.43e-06 | Rel error: 2.34e-07
U_g  | Mean diff: 1.23e-08 | Max diff: 4.77e-06 | Rel error: 1.99e-07
b_i  | Mean diff: 1.08e-07 | Max diff: 1.43e-06 | Rel error: 2.21e-07
b_f  | Mean diff: 7.15e-08 | Max diff: 2.86e-06 | Rel error: 2.37e-07
b_o  | Mean diff: 1.86e-07 | Max diff: 2.86e-06 | Rel error: 2.84e-07
b_g  | Mean diff: 1.18e-07 | Max diff: 5.72e-06 | Rel error: 2.09e-07
Total mean difference: 9.26e-07
✅ Gradients match! Manual backprop implementation is correct.
