In [None]:
import torch
from torch import tanh

# Create a toy LSTM cell with parameters
input_size = 256
hidden_size = 256
seq_len = 10
batch_size = 1

x = torch.randn(seq_len, batch_size, input_size)  # Input sequence

# Initialize parameters
W_i = torch.randn(input_size, hidden_size, requires_grad=True)
W_f = torch.randn(input_size, hidden_size, requires_grad=True)
W_o = torch.randn(input_size, hidden_size, requires_grad=True)
W_g = torch.randn(input_size, hidden_size, requires_grad=True)

U_i = torch.randn(hidden_size, hidden_size, requires_grad=True)
U_f = torch.randn(hidden_size, hidden_size, requires_grad=True)
U_o = torch.randn(hidden_size, hidden_size, requires_grad=True)
U_g = torch.randn(hidden_size, hidden_size, requires_grad=True)

b_i = torch.randn(hidden_size, requires_grad=True)
b_f = torch.randn(hidden_size, requires_grad=True)
b_o = torch.randn(hidden_size, requires_grad=True)
b_g = torch.randn(hidden_size, requires_grad=True)

# Initialize states - use lists to avoid in-place modifications
h_states = [torch.zeros(1, hidden_size)]  # h_0
c_states = [torch.zeros(1, hidden_size)]  # c_0

# Forward pass with autograd
for t in range(seq_len):
    h_prev = h_states[t]
    c_prev = c_states[t]
    
    # Gates
    i = torch.sigmoid(x[t] @ W_i + h_prev @ U_i + b_i)
    f = torch.sigmoid(x[t] @ W_f + h_prev @ U_f + b_f)
    g = torch.tanh(x[t] @ W_g + h_prev @ U_g + b_g)
    o = torch.sigmoid(x[t] @ W_o + h_prev @ U_o + b_o)
    
    # Update states (create new tensors, not in-place)
    c_new = f * c_prev + i * g
    h_new = o * torch.tanh(c_new)
    
    h_states.append(h_new)
    c_states.append(c_new)

# Create a dummy loss from final hidden state
loss = h_states[-1].sum()
loss.backward()

# Get autograd gradients
autograd_grads = {
    'W_i': W_i.grad.clone(), 'W_f': W_f.grad.clone(), 'W_o': W_o.grad.clone(), 'W_g': W_g.grad.clone(),
    'U_i': U_i.grad.clone(), 'U_f': U_f.grad.clone(), 'U_o': U_o.grad.clone(), 'U_g': U_g.grad.clone(),
    'b_i': b_i.grad.clone(), 'b_f': b_f.grad.clone(), 'b_o': b_o.grad.clone(), 'b_g': b_g.grad.clone()
}

# Manual backprop
dL_dh_T = torch.ones_like(h_states[-1])  # Gradient from loss
δh_next = dL_dh_T
δc_next = torch.zeros_like(c_states[-1])

# Initialize manual gradients
manual_grads = {
    'W_i': torch.zeros_like(W_i), 'W_f': torch.zeros_like(W_f),
    'W_o': torch.zeros_like(W_o), 'W_g': torch.zeros_like(W_g),
    'U_i': torch.zeros_like(U_i), 'U_f': torch.zeros_like(U_f),
    'U_o': torch.zeros_like(U_o), 'U_g': torch.zeros_like(U_g),
    'b_i': torch.zeros_like(b_i), 'b_f': torch.zeros_like(b_f),
    'b_o': torch.zeros_like(b_o), 'b_g': torch.zeros_like(b_g)
}

# Backward pass through time
for t in reversed(range(seq_len)):
    h_prev = h_states[t]
    c_prev = c_states[t]
    c_curr = c_states[t + 1]
    
    # Recompute gates for current timestep (no grad needed for manual backprop)
    with torch.no_grad():
        i = torch.sigmoid(x[t] @ W_i + h_prev @ U_i + b_i)
        f = torch.sigmoid(x[t] @ W_f + h_prev @ U_f + b_f)
        g = torch.tanh(x[t] @ W_g + h_prev @ U_g + b_g)
        o = torch.sigmoid(x[t] @ W_o + h_prev @ U_o + b_o)
    
    # Compute gradients w.r.t. current timestep
    δh_t = δh_next
    δc_t = δc_next + δh_t * o * (1 - torch.tanh(c_curr)**2)
    
    # Gradients w.r.t. gates
    δi = δc_t * g
    δf = δc_t * c_prev
    δg = δc_t * i
    δo = δh_t * torch.tanh(c_curr)
    
    # Gradients w.r.t. gate pre-activations (applying activation derivatives)
    δz_i = δi * i * (1 - i)  # sigmoid derivative
    δz_f = δf * f * (1 - f)  # sigmoid derivative
    δz_o = δo * o * (1 - o)  # sigmoid derivative
    δz_g = δg * (1 - g**2)   # tanh derivative
    
    # Accumulate parameter gradients
    manual_grads['W_i'] += x[t].T @ δz_i
    manual_grads['W_f'] += x[t].T @ δz_f
    manual_grads['W_o'] += x[t].T @ δz_o
    manual_grads['W_g'] += x[t].T @ δz_g
    
    manual_grads['U_i'] += h_prev.T @ δz_i
    manual_grads['U_f'] += h_prev.T @ δz_f
    manual_grads['U_o'] += h_prev.T @ δz_o
    manual_grads['U_g'] += h_prev.T @ δz_g
    
    manual_grads['b_i'] += δz_i.sum(dim=0)
    manual_grads['b_f'] += δz_f.sum(dim=0)
    manual_grads['b_o'] += δz_o.sum(dim=0)
    manual_grads['b_g'] += δz_g.sum(dim=0)
    
    # Compute gradients for next timestep (flowing backwards)
    δh_next = δz_i @ U_i.T + δz_f @ U_f.T + δz_o @ U_o.T + δz_g @ U_g.T
    δc_next = δc_t * f

# Compare gradients
print("Gradient Comparison (Autograd vs Manual):")
print("=" * 50)

total_diff = 0
for param_name in autograd_grads:
    diff = torch.abs(autograd_grads[param_name] - manual_grads[param_name]).mean()
    max_diff = torch.abs(autograd_grads[param_name] - manual_grads[param_name]).max()
    rel_error = diff / (torch.abs(autograd_grads[param_name]).mean() + 1e-8)
    
    print(f"{param_name:4s} | Mean diff: {diff.item():8.2e} | Max diff: {max_diff.item():8.2e} | Rel error: {rel_error.item():8.2e}")
    total_diff += diff.item()

print("=" * 50)
print(f"Total mean difference: {total_diff:.2e}")

if total_diff < 1e-5:
    print("✅ Gradients match! Manual backprop implementation is correct.")
else:
    print("❌ Gradients don't match. Check manual backprop implementation.")

Gradient Comparison (Autograd vs Manual):
W_i  | Mean diff: 6.91e-08 | Max diff: 5.72e-06 | Rel error: 2.80e-07
W_f  | Mean diff: 1.54e-08 | Max diff: 1.19e-06 | Rel error: 1.56e-07
W_o  | Mean diff: 6.15e-08 | Max diff: 4.77e-06 | Rel error: 2.12e-07
W_g  | Mean diff: 5.48e-08 | Max diff: 7.63e-06 | Rel error: 1.82e-07
U_i  | Mean diff: 8.94e-09 | Max diff: 8.34e-07 | Rel error: 2.27e-07
U_f  | Mean diff: 4.51e-09 | Max diff: 3.58e-07 | Rel error: 1.59e-07
U_o  | Mean diff: 1.00e-08 | Max diff: 8.34e-07 | Rel error: 1.84e-07
U_g  | Mean diff: 1.22e-08 | Max diff: 1.91e-06 | Rel error: 1.96e-07
b_i  | Mean diff: 8.42e-08 | Max diff: 1.79e-06 | Rel error: 2.71e-07
b_f  | Mean diff: 1.90e-08 | Max diff: 3.58e-07 | Rel error: 1.55e-07
b_o  | Mean diff: 7.41e-08 | Max diff: 1.43e-06 | Rel error: 2.01e-07
b_g  | Mean diff: 6.78e-08 | Max diff: 1.91e-06 | Rel error: 1.74e-07
Total mean difference: 4.81e-07
✅ Gradients match! Manual backprop implementation is correct.
