# Automatic Differentiation with Nabla

This notebook explores the core capabilities of automatic differentiation (AD). The central idea of AD is to efficiently compute the action of the Jacobian matrix, rather than materializing the large matrix itself.

### The Jacobian: The Full Derivative

For any differentiable function $f: \mathbb{R}^n \to \mathbb{R}^m$, its derivative at a point $x$ is the unique linear map $df_x: \mathbb{R}^n \to \mathbb{R}^m$ that provides the best linear approximation of the function's change. The **Jacobian matrix**, $J_f(x)$, is the matrix representation of this linear map with respect to the standard bases. It is an $m \times n$ matrix of all first-order partial derivatives:

$$
(J_f(x))_{ij} = \frac{\partial f_i}{\partial x_j}
$$

For complex models where $n$ and $m$ are large, constructing this matrix is often infeasible. AD provides an efficient way to compute its products with vectors.

### 1. Jacobian-Vector Product (JVP): The "Pushforward"

The JVP is the core operation of **forward-mode AD**. It computes the **directional derivative** of the function $f$ at a point $x$ along a tangent vector $v \in \mathbb{R}^n$. This is achieved by computing the product:

$$
\text{JVP}_f(x)(v) = J_f(x) v
$$

The result is a vector in the output space $\mathbb{R}^m$ representing the **rate of change** of the output in the direction of $v$. It gives a linear approximation of the function's change: $f(x+v) \approx f(x) + J_f(x)v$. This operation is computed *without ever explicitly forming* the full Jacobian matrix.

### 2. Vector-Jacobian Product (VJP): The "Pullback"

The VJP is the core operation of **reverse-mode AD**, the engine behind **backpropagation**. It computes the product of the transposed Jacobian $J_f(x)^T$ with a "cotangent" vector $u \in \mathbb{R}^m$.

$$
\text{VJP}_f(x)(u) = J_f(x)^T u
$$

This operation "pulls back" the cotangent vector $u$ from the output space to the input space. Its critical application is in the chain rule: if $f$ is part of a larger composition $g(x) = L(f(x))$, where $L: \mathbb{R}^m \to \mathbb{R}$ is a scalar function, the VJP calculates the gradient of $g$ when provided with the gradient of $L$. Specifically, if $u = \nabla L$, then $\nabla g = \text{VJP}_f(x)(u)$.

--- 

![image.png](../_static/pushfwd_pullbck.png)

---

In this notebook, we will use `nabla` to compute the JVP and VJP for a function $f$ that takes an input from $\mathbb{R}^{2k}$ and maps it to $\mathbb{R}^k$, defined for inputs $(x_1, x_2)$ with $x_1, x_2 \in \mathbb{R}^k$ as:

$$ f(x_1, x_2) = \cos(x_1 \odot x_2) \odot \log(x_2) $$

We will then see how these fundamental operations can be composed to construct the full Jacobian and second-order derivatives (the Hessian).

#### Setup

In [3]:
import sys
import subprocess

try:
    import nabla as nb
except ImportError:
    print("Installing nabla...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "nabla-ml"])
    import nabla as nb

print(
    f"🎉 Nabla is ready! Running on Python {sys.version_info.major}.{sys.version_info.minor}"
)

🎉 Nabla is ready! Running on Python 3.12


#### Function Definition

In [4]:
def my_program(x1: nb.Array, x2: nb.Array) -> nb.Array:
    a = x1 * x2
    b = nb.cos(a)
    c = nb.log(x2)
    y = b * c
    return y

#### Compute the regular forward pass.

In [5]:
# init input arrays
x1 = nb.array([1.0, 2.0, 3.0])
x2 = nb.array([2.0, 3.0, 4.0])
print("x1:", x1)
print("x2:", x2)

# compute the value of the program
value = my_program(x1, x2)
print("fwd_output:", value)

x1: [1. 2. 3.]:[95mf32[3][0m
x2: [2. 3. 4.]:[95mf32[3][0m
fwd_output: [-0.288451   1.0548549  1.16983  ]:[95mf32[3][0m


#### Compute the JVP (Jacobian-Vector Product)

In [6]:
# init input tangents
x1_tangent = nb.randn_like(x1)
x2_tangent = nb.randn_like(x2)
print("x1_tangent:", x1_tangent)
print("x2_tangent:", x2_tangent)

# compute the actual jvp
value, value_tangent = nb.jvp(my_program, (x1, x2), (x1_tangent, x2_tangent))
print("value:", value)
print("value_tangent:", value_tangent)

x1_tangent: [1.7640524 0.4001572 0.978738 ]:[95mf32[3][0m
x2_tangent: [1.7640524 0.4001572 0.978738 ]:[95mf32[3][0m
value: [-0.288451   1.0548549  1.16983  ]:[95mf32[3][0m
value_tangent: [-3.7025769   0.74225295  5.3027043 ]:[95mf32[3][0m


#### Compute the VJP (Vector-Jacobian Product)

In [7]:
# compute value and pullback function
value, pullback = nb.vjp(my_program, x1, x2)
print("value:", value)

# init output cotangent
value_cotangent = nb.randn_like(value)
print("value_cotangent:", value_cotangent)

# compute the actual vjp
x1_cotangent, x2_cotangent = pullback(value_cotangent)
print("x1_cotangent:", x1_cotangent)
print("x2_cotangent:", x2_cotangent)

value: [-0.288451   1.0548549  1.16983  ]:[95mf32[3][0m
value_cotangent: [1.7640524 0.4001572 0.978738 ]:[95mf32[3][0m
x1_cotangent: [-2.223683    0.36850792  2.9121294 ]:[95mf32[3][0m
x2_cotangent: [-1.478894    0.37374496  2.390575  ]:[95mf32[3][0m


#### Compute the full Jacobian automatically (vmap + vjp aka. batched reverse-mode AD)

In [8]:
jac_fn = nb.jacrev(my_program)
jacobian = jac_fn(x1, x2)
print("jacobian:", jacobian)

jacobian: [[[-1.2605538   0.          0.        ]
 [-0.          0.92090786  0.        ]
 [-0.          0.          2.975392  ]]:[95mf32[3[95m,[95m3][0m, [[-0.8383503   0.          0.        ]
 [-0.          0.93399537  0.        ]
 [-0.          0.          2.4425075 ]]:[95mf32[3[95m,[95m3][0m]


#### Compute the full Jacobian automatically (vmap + jvp aka. batched forward-mode AD)

In [9]:
jac_fn = nb.jacfwd(my_program)
jacobian = jac_fn(x1, x2)
print("jacobian:", jacobian)

jacobian: [[[-1.2605538  -0.         -0.        ]
 [ 0.          0.92090786  0.        ]
 [ 0.          0.          2.975392  ]]:[95mf32[3[95m,[95m3][0m, [[-0.8383503  -0.         -0.        ]
 [ 0.          0.93399537  0.        ]
 [ 0.          0.          2.4425075 ]]:[95mf32[3[95m,[95m3][0m]


#### Compute the full Hessian automatically (Forward-over-Reverse)

In [10]:
jac_fn = nb.jacrev(my_program)
hessian_fn = nb.jacfwd(jac_fn)
hessian = hessian_fn(x1, x2)
print("hessian:", hessian)

hessian: [[[[[  1.153804  -0.        -0.      ]
  [  0.         0.         0.      ]
  [  0.         0.         0.      ]]

 [[ -0.        -0.        -0.      ]
  [  0.        -9.493694   0.      ]
  [  0.         0.         0.      ]]

 [[ -0.        -0.        -0.      ]
  [  0.         0.         0.      ]
  [  0.         0.       -18.71728 ]]]:[95mf32[3[95m,[95m3[95m,[95m3][0m, [[[ -0.96267235   0.           0.        ]
  [  0.           0.           0.        ]
  [  0.           0.           0.        ]]

 [[  0.           0.           0.        ]
  [  0.          -5.7427444    0.        ]
  [  0.           0.           0.        ]]

 [[  0.           0.           0.        ]
  [  0.           0.           0.        ]
  [  0.           0.         -12.75754   ]]]:[95mf32[3[95m,[95m3[95m,[95m3][0m], [[[[ -0.96267235  -0.          -0.        ]
  [  0.           0.           0.        ]
  [  0.           0.           0.        ]]

 [[ -0.          -0.          -0.        ]

#### Compute the full Hessian automatically (Reverse-over-Forward)

In [11]:
jac_fn = nb.jacfwd(my_program)
hessian_fn = nb.jacrev(jac_fn)
hessian = hessian_fn(x1, x2)
print("hessian:", hessian)

hessian: [[[[[  1.153804   0.         0.      ]
  [  0.         0.         0.      ]
  [  0.         0.         0.      ]]

 [[  0.         0.         0.      ]
  [  0.        -9.493694   0.      ]
  [  0.         0.         0.      ]]

 [[  0.         0.         0.      ]
  [  0.         0.         0.      ]
  [  0.         0.       -18.71728 ]]]:[95mf32[3[95m,[95m3[95m,[95m3][0m, [[[ -0.9626724   0.          0.       ]
  [  0.          0.          0.       ]
  [  0.          0.          0.       ]]

 [[  0.          0.          0.       ]
  [  0.         -5.7427444   0.       ]
  [  0.          0.          0.       ]]

 [[  0.          0.          0.       ]
  [  0.          0.          0.       ]
  [  0.          0.        -12.757539 ]]]:[95mf32[3[95m,[95m3[95m,[95m3][0m], [[[[ -0.96267235   0.           0.        ]
  [  0.           0.           0.        ]
  [  0.           0.           0.        ]]

 [[  0.           0.           0.        ]
  [  0.          -5.742745