
# Autodiff in Practice: Forward-Mode, Reverse-Mode, VJPs

This notebook is a hands-on companion to a lecture on automatic differentiation.  
You will:
- Build intuition for forward- and reverse-mode AD.
- Work through small symbolic chains.
- Implement backpropagation for a tiny linear–ReLU network in NumPy.
- Verify gradients via finite differences.
- Practice deriving Jacobian-vector products (JVPs) and vector-Jacobian products (VJPs).


In [None]:

import numpy as np

np.set_printoptions(suppress=True, linewidth=120)
def seed_all(seed=0):
    np.random.seed(seed)

seed_all(42)



## Part 1 — Forward-mode warm-up

We track both values and directional derivatives as we **move forward**.

Chain:
\begin{align}
h_1 &= 2x,\\
h_2 &= h_1^2,\\
y &= \sin(h_2).
\end{align}

Forward-mode rules:
\begin{align}
\dot{h}_1 &= ?,\\
\dot{h}_2 &= ?,\\
\dot{y} &= ?.
\end{align}

Run a quick numeric check.


In [None]:

def forward_chain(x):
    # values
    h1 = 2.0 * x
    h2 = h1**2
    y  = np.sin(h2)
    # forward-mode directional derivatives w.r.t. x
    dh1 = None
    dh2 = None
    dy  = None
    return (h1,h2,y), (dh1,dh2,dy)

# test
x = 0.3
(vals, ders) = forward_chain(x)
vals, ders



## Part 2 — Reverse-mode on the same chain

Reverse-mode starts at the scalar output and **propagates adjoints backward**:
\begin{align}
\bar{y} &= 1,\\
\bar{h}_2 &= ?,\\
\bar{h}_1 &= ?,\\
\bar{x}  &= ?.
\end{align}


In [None]:

def reverse_chain(x):
    # forward pass: cache values
    h1 = 2.0 * x
    h2 = h1**2
    y  = np.sin(h2)
    # backward pass: adjoints
    y_bar  = 1.0
    h2_bar = None
    h1_bar = None
    x_bar  = None
    return (h1,h2,y), (x_bar,h1_bar,h2_bar,y_bar)

# test
reverse_chain(0.3)



### Check equivalence

$\dot{y} = \frac{dy}{dx}$ from forward-mode should equal $\bar{x} = \frac{\partial y}{\partial x}$ from reverse-mode.


In [None]:

x = 0.3
(_, (_,_,dy)) = forward_chain(x)
(_, (x_bar,_,_,_)) = reverse_chain(x)
print("Forward-mode dy/dx:", dy)
print("Reverse-mode d y / d x:", x_bar)
assert np.allclose(dy, x_bar, atol=1e-9)



## Part 3 — VJPs and JVPs for a linear layer

For $f(x,W) = Wx$:
- JVP w.r.t. $x$ with direction $u$: $\mathrm{jvp}_x(f,u) = W u$.
- VJP w.r.t. $x$ with vector $v$: $\mathrm{vjp}_x(f,v) = W^\top v$.
- VJP w.r.t. $W$: $\mathrm{vjp}_W(f,v) = v\,x^\top$.

We will code the reverse-mode gradients directly for training a tiny network.



## Part 4 — Tiny linear–ReLU network in NumPy

Model:
\begin{align}
h &= \mathrm{ReLU}(W_1 x), \\
y &= W_2 h, \\
L &= \|y - t\|^2.
\end{align}

Backprop:
\begin{align}
\bar{y} &= 2(y - t),\\
\bar{W}_2 &= \bar{y}\,h^\top, \quad \bar{h} = W_2^\top \bar{y},\\
\bar{z} &= \mathbf{1}_{z>0} \odot \bar{h}, \quad z=W_1 x,\\
\bar{W}_1 &= \bar{z}\,x^\top, \quad \bar{x} = W_1^\top \bar{z}.
\end{align}


In [None]:

def relu(z):
    return np.maximum(0.0, z)

def relu_grad(z):
    return (z > 0).astype(z.dtype)

def forward_linear_relu(W1, W2, x, t):
    # forward
    z = W1 @ x
    h = relu(z)
    y = W2 @ h
    L = np.sum((y - t)**2)
    cache = {"x":x, "z":z, "h":h, "y":y}
    return L, cache

def backward_linear_relu(W1, W2, cache, t):
    # unpack
    x, z, h, y = cache["x"], cache["z"], cache["h"], cache["y"]
    # upstream from loss
    y_bar = 2.0*(y - t)
    # grads W2, h
    dW2 = np.outer(y_bar, h)
    h_bar = W2.T @ y_bar
    # through ReLU
    z_bar = relu_grad(z) * h_bar
    dW1 = np.outer(z_bar, x)
    x_bar = W1.T @ z_bar
    return {"dW1": dW1, "dW2": dW2, "dx": x_bar}

# sanity check
seed_all(0)
W1 = np.random.randn(8, 5)*0.1
W2 = np.random.randn(3, 8)*0.1
x  = np.random.randn(5)
t  = np.random.randn(3)

L, cache = forward_linear_relu(W1, W2, x, t)
grads = backward_linear_relu(W1, W2, cache, t)
L, grads["dW1"].shape, grads["dW2"].shape, grads["dx"].shape



### Finite-difference gradient check

We verify $\nabla_{W_1} L$ and $\nabla_{W_2} L$ numerically:
$$
\frac{\partial L}{\partial \theta_{ij}} \approx \frac{L(\theta_{ij}+\epsilon) - L(\theta_{ij}-\epsilon)}{2\epsilon}
$$


In [None]:

def finite_diff_grad_W(W1, W2, x, t, which="W1", eps=1e-5, num_checks=10, seed=0):
    rng = np.random.default_rng(seed)
    idxs = []
    if which == "W1":
        H, D = W1.shape
        for _ in range(num_checks):
            i = rng.integers(0, H)
            j = rng.integers(0, D)
            idxs.append((i,j))
        analytic = backward_linear_relu(W1, W2, forward_linear_relu(W1, W2, x, t)[1], t)["dW1"]
    else:
        O, H = W2.shape
        for _ in range(num_checks):
            i = rng.integers(0, O)
            j = rng.integers(0, H)
            idxs.append((i,j))
        analytic = backward_linear_relu(W1, W2, forward_linear_relu(W1, W2, x, t)[1], t)["dW2"]
    max_err = 0.0
    for (i,j) in idxs:
        if which == "W1":
            Wp = W1.copy(); Wm = W1.copy()
            Wp[i,j] += eps; Wm[i,j] -= eps
            Lp, _ = forward_linear_relu(Wp, W2, x, t)
            Lm, _ = forward_linear_relu(Wm, W2, x, t)
        else:
            Wp = W2.copy(); Wm = W2.copy()
            Wp[i,j] += eps; Wm[i,j] -= eps
            Lp, _ = forward_linear_relu(W1, Wp, x, t)
            Lm, _ = forward_linear_relu(W1, Wm, x, t)
        num = (Lp - Lm) / (2*eps)
        ana = analytic[i,j]
        err = abs(num - ana)
        max_err = max(max_err, err)
        print(f"Index {(i,j)}: analytic={ana:.6f}, numeric={num:.6f}, abs_err={err:.3e}")
    print("Max abs error:", max_err)

# run checks
seed_all(1)
W1 = np.random.randn(8, 5)*0.1
W2 = np.random.randn(3, 8)*0.1
x  = np.random.randn(5)
t  = np.random.randn(3)
print("Check dW1:")
finite_diff_grad_W(W1, W2, x, t, which="W1", num_checks=8)
print("\nCheck dW2:")
finite_diff_grad_W(W1, W2, x, t, which="W2", num_checks=8)



## Part 5 — Mini-batch version

For a batch $\{(x^{(b)}, t^{(b)})\}_{b=1}^B$, we sum outer products:
$\nabla_{W_2} L = \sum_b \bar{y}^{(b)} {h^{(b)}}^\top$, $\nabla_{W_1} L = \sum_b \bar{z}^{(b)} {x^{(b)}}^\top$.


In [None]:

def forward_batch(W1, W2, X, T):
    # X: (D, B), T: (O, B)
    Z = W1 @ X
    H = np.maximum(0.0, Z)
    Y = W2 @ H
    L = np.sum((Y - T)**2)
    cache = {"X":X, "Z":Z, "H":H, "Y":Y}
    return L, cache

def backward_batch(W1, W2, cache, T):
    X, Z, H, Y = cache["X"], cache["Z"], cache["H"], cache["Y"]
    # upstream
    Y_bar = 2.0*(Y - T)
    dW2 = Y_bar @ H.T            # sum over batch is implicit via matrix product
    H_bar = W2.T @ Y_bar
    Z_bar = (Z > 0).astype(Z.dtype) * H_bar
    dW1 = Z_bar @ X.T
    dX  = W1.T @ Z_bar
    return {"dW1": dW1, "dW2": dW2, "dX": dX}

# test
seed_all(2)
D, H, O, B = 5, 8, 3, 4
W1 = np.random.randn(H, D)*0.1
W2 = np.random.randn(O, H)*0.1
X  = np.random.randn(D, B)
T  = np.random.randn(O, B)
L, cache = forward_batch(W1, W2, X, T)
grads = backward_batch(W1, W2, cache, T)
L, grads["dW1"].shape, grads["dW2"].shape, grads["dX"].shape



## Practice — Derivations

1. Forward-mode practice: For the chain
   $$ u = 3x+1,\quad v = e^{u},\quad y = \tanh(v), $$
   derive $\dot{y} = \frac{dy}{dx}$ using forward-mode rules. Then check numerically.

2. Reverse-mode practice: For the same chain, derive adjoints
   $\bar{y},\bar{v},\bar{u},\bar{x}$.

3. VJP practice: For $f(x,W)=Wx+b$ with bias $b$,
   write $\mathrm{vjp}_x(f,v)$, $\mathrm{vjp}_W(f,v)$, and $\mathrm{vjp}_b(f,v)$.

4. Activation practice: For $\phi(z)=\mathrm{ReLU}(z)$, show that the VJP is
   $\bar{z} = \mathbf{1}_{z>0} \odot \bar{y}$.

Use the empty cells below.


In [None]:
# Your work: forward-mode practice



In [None]:
# Your work: reverse-mode practice



In [None]:
# Your work: VJP practice



In [None]:
# Your work: activation practice




## Part 6 — Memory vs Compute: Gradient Checkpointing (concept)

Idea: store only selected activations during forward. In backward, if an activation is missing, recompute a short forward from the nearest stored checkpoint. This reduces memory at the cost of a small compute overhead.

Demo sketch (not a full autodiff engine): we simulate storing only every k-th activation and "recompute" intermediate ones during backward.


In [None]:

def toy_forward_with_checkpoints(xs, k=2):
    # xs: list of scalars; build cumulative sum to simulate "layers"
    acts = []
    checkpoints = {}
    s = 0.0
    for i, x in enumerate(xs):
        s = s + x*x  # pretend "layer i"
        acts.append(s)
        if i % k == 0:
            checkpoints[i] = s  # store every k-th activation
    return acts, checkpoints

def toy_backward_recompute(xs, acts, checkpoints, k=2):
    # need derivatives of last output wrt each x
    # y = sum_{i} (x_i^2) composed cumulatively; derivative is simple,
    # but we pretend we must "recompute" missing activations.
    grads = np.zeros_like(xs, dtype=float)
    upstream = 1.0  # d y / d last act
    # walk backward; if act not stored, recompute from nearest checkpoint
    for i in reversed(range(len(xs))):
        if i not in checkpoints:
            # recompute from nearest earlier checkpoint
            # since layer function here is simple, we just recompute s locally
            s = 0.0
            start = (i // k)*k
            s = checkpoints.get(start, 0.0)
            for j in range(start+1, i+1):
                s = s + xs[j]*xs[j]
        else:
            s = checkpoints[i]
        # local derivative wrt x_i is 2*x_i
        grads[i] = upstream * 2.0*xs[i]
        # upstream for previous is unchanged in this toy
    return grads

xs = np.array([0.5, -1.2, 0.3, 0.7, -0.9])
acts, ckpts = toy_forward_with_checkpoints(xs, k=2)
grads = toy_backward_recompute(xs, acts, ckpts, k=2)
print("checkpoints:", ckpts)
print("toy grads:", grads, "(true should be 2*x):", 2*xs)



## Part 7 — Finite-difference helper for any scalar function

Utility for validating a custom backward implementation.


In [None]:

def finite_diff_grad(fn, theta, eps=1e-5):
    # theta is 1D array
    g = np.zeros_like(theta)
    for i in range(theta.size):
        th_p = theta.copy(); th_m = theta.copy()
        th_p[i] += eps; th_m[i] -= eps
        g[i] = (fn(th_p) - fn(th_m)) / (2*eps)
    return g



## Practice — Build your own composite and verify

1. Define a scalar function `fn(theta)` of your design (e.g., tiny MLP with ReLU).
2. Write a manual backward for it.
3. Compare analytical gradients to `finite_diff_grad`.


In [None]:

# Example scaffold (replace with your own)
def fn_example(theta):
    # split theta into small pieces
    W = theta[:4].reshape(2,2)
    x = theta[4:6]
    t = theta[6:8]
    h = np.maximum(0.0, W @ x)
    y = h  # identity
    L = np.sum((y - t)**2)
    return L

theta = np.random.randn(8)*0.1
g_num = finite_diff_grad(fn_example, theta)
print("Numeric grad shape:", g_num.shape)



## Summary

- Forward-mode carries directional derivatives with values.
- Reverse-mode backpropagates a single scalar loss to get all parameter gradients.
- VJPs power reverse-mode; linear layer and ReLU VJP rules are simple and compose well.
- Use finite differences to sanity-check gradients of small models.
- Gradient checkpointing trades compute for memory in deep nets.

You're done. Now extend any section with your own experiments.
