# Automatic Differentiation with Nabla

In this notebook, we'll explore the core automatic differentiation (AD) capabilities of `nabla`. We will start with a simple multivariate function and then demonstrate how to compute its derivatives in various forms.

### Our Example Program

Let's consider the following function $f: \mathbb{R}^n \times \mathbb{R}^n \to \mathbb{R}^n$ as our running example:

$$ f(x_1, x_2) = \cos(x_1 \odot x_2) \odot \log(x_2) $$

where $\odot$ denotes the element-wise product.

Using `nabla`, we will compute the following quantities for this function:
1.  **Jacobian-Vector Product (JVP)**: The product of the Jacobian of $f$ with a given tangent vector $v$. This is also known as a directional derivative.
    $$ \text{JVP}_f(x)(v) = J_f(x) v $$

2.  **Vector-Jacobian Product (VJP)**: The product of the transposed Jacobian of $f$ with a cotangent vector $v$. This is the fundamental operation in reverse-mode AD, often used for computing gradients. It's equivalent to left-multiplying the Jacobian by the row vector $v^T$, i.e., $v^T J_f(x)$.
    $$ \text{VJP}_f(x)(v) = J_f(x)^T v $$

3.  **Full Jacobian**: The matrix of all first-order partial derivatives. For a function $f: \mathbb{R}^m \to \mathbb{R}^k$, the Jacobian is a $k \times m$ matrix. In our case, `nabla` computes the Jacobian with respect to each input argument, resulting in a tuple of Jacobian matrices. These can be high-dimensional tensors if the inputs or outputs are themselves multi-dimensional arrays.
    $$ (J_f)_{ij} = \frac{\partial f_i}{\partial x_j} $$

4.  **Full Hessian**: The tensor of all second-order partial derivatives. The Hessian is the Jacobian of the gradient. For a function with multi-dimensional inputs and outputs, this can be a high-rank tensor.
    $$ (H_f)_{ijk} = \frac{\partial^2 f_i}{\partial x_j \partial x_k} $$


In [23]:
import nabla as nb


def my_program(x1: nb.Array, x2: nb.Array) -> nb.Array:
    a = x1 * x2
    b = nb.cos(a)
    c = nb.log(x2)
    y = b * c
    return y

#### Compute the regular forward pass.

In [31]:
# init input arrays
x1 = nb.array([1.0, 2.0, 3.0])
x2 = nb.array([2.0, 3.0, 4.0])
print("x1:", x1)
print("x2:", x2)

# compute the value of the program
value = my_program(x1, x2)
print("fwd_output:", value)

x1: [1. 2. 3.]:[95mf32[3][0m
x2: [2. 3. 4.]:[95mf32[3][0m
fwd_output: [-0.288451   1.0548549  1.16983  ]:[95mf32[3][0m


#### Compute the JVP (Jacobian-Vector Product)

In [32]:
# init input tangents
x1_tangent = nb.randn_like(x1)
x2_tangent = nb.randn_like(x2)
print("x1_tangent:", x1_tangent)
print("x2_tangent:", x2_tangent)

# compute the actual jvp
value, value_tangent = nb.jvp(my_program, (x1, x2), (x1_tangent, x2_tangent))
print("value:", value)
print("value_tangent:", value_tangent)

x1_tangent: [1.7640524 0.4001572 0.978738 ]:[95mf32[3][0m
x2_tangent: [1.7640524 0.4001572 0.978738 ]:[95mf32[3][0m
value: [-0.288451   1.0548549  1.16983  ]:[95mf32[3][0m
value_tangent: [-3.7025769   0.74225295  5.3027043 ]:[95mf32[3][0m


#### Compute the VJP (Vector-Jacobian Product)

In [33]:
# compute value and pullback function
value, pullback = nb.vjp(my_program, x1, x2)
print("value:", value)

# init output cotangent
value_cotangent = nb.randn_like(value)
print("value_cotangent:", value_cotangent)

# compute the actual vjp
x1_cotangent, x2_cotangent = pullback(value_cotangent)
print("x1_cotangent:", x1_cotangent)
print("x2_cotangent:", x2_cotangent)

value: [-0.288451   1.0548549  1.16983  ]:[95mf32[3][0m
value_cotangent: [1.7640524 0.4001572 0.978738 ]:[95mf32[3][0m
x1_cotangent: [-2.223683    0.36850792  2.9121294 ]:[95mf32[3][0m
x2_cotangent: [-1.478894    0.37374496  2.390575  ]:[95mf32[3][0m


#### Compute the full Jacobian automatically (vmap + vjp aka. batched reverse-mode AD)

In [34]:
jac_fn = nb.jacrev(my_program)
jacobian = jac_fn(x1, x2)
print("jacobian:", jacobian)

jacobian: [[[-1.2605538   0.          0.        ]
 [-0.          0.92090786  0.        ]
 [-0.          0.          2.975392  ]]:[95mf32[3[95m,[95m3][0m, [[-0.8383503   0.          0.        ]
 [-0.          0.93399537  0.        ]
 [-0.          0.          2.4425075 ]]:[95mf32[3[95m,[95m3][0m]


#### Compute the full Jacobian automatically (vmap + jvp aka. batched forward-mode AD)

In [35]:
jac_fn = nb.jacfwd(my_program)
jacobian = jac_fn(x1, x2)
print("jacobian:", jacobian)

jacobian: [[[-1.2605538  -0.         -0.        ]
 [ 0.          0.92090786  0.        ]
 [ 0.          0.          2.975392  ]]:[95mf32[3[95m,[95m3][0m, [[-0.8383503  -0.         -0.        ]
 [ 0.          0.93399537  0.        ]
 [ 0.          0.          2.4425075 ]]:[95mf32[3[95m,[95m3][0m]


#### Compute the full Hessian automatically (Forward-over-Reverse)

In [36]:
jac_fn = nb.jacrev(my_program)
hessian_fn = nb.jacfwd(jac_fn)
hessian = hessian_fn(x1, x2)
print("hessian:", hessian)

hessian: [[[[[  1.153804  -0.        -0.      ]
  [  0.         0.         0.      ]
  [  0.         0.         0.      ]]

 [[ -0.        -0.        -0.      ]
  [  0.        -9.493694   0.      ]
  [  0.         0.         0.      ]]

 [[ -0.        -0.        -0.      ]
  [  0.         0.         0.      ]
  [  0.         0.       -18.71728 ]]]:[95mf32[3[95m,[95m3[95m,[95m3][0m, [[[ -0.96267235   0.           0.        ]
  [  0.           0.           0.        ]
  [  0.           0.           0.        ]]

 [[  0.           0.           0.        ]
  [  0.          -5.7427444    0.        ]
  [  0.           0.           0.        ]]

 [[  0.           0.           0.        ]
  [  0.           0.           0.        ]
  [  0.           0.         -12.75754   ]]]:[95mf32[3[95m,[95m3[95m,[95m3][0m], [[[[ -0.96267235  -0.          -0.        ]
  [  0.           0.           0.        ]
  [  0.           0.           0.        ]]

 [[ -0.          -0.          -0.        ]

#### Compute the full Hessian automatically (Reverse-over-Forward)

In [30]:
jac_fn = nb.jacfwd(my_program)
hessian_fn = nb.jacrev(jac_fn)
hessian = hessian_fn(x1, x2)
print("hessian:", hessian)

hessian: [[[[[  1.153804   0.         0.      ]
  [  0.         0.         0.      ]
  [  0.         0.         0.      ]]

 [[  0.         0.         0.      ]
  [  0.        -9.493694   0.      ]
  [  0.         0.         0.      ]]

 [[  0.         0.         0.      ]
  [  0.         0.         0.      ]
  [  0.         0.       -18.71728 ]]]:[95mf32[3[95m,[95m3[95m,[95m3][0m, [[[ -0.9626724   0.          0.       ]
  [  0.          0.          0.       ]
  [  0.          0.          0.       ]]

 [[  0.          0.          0.       ]
  [  0.         -5.7427444   0.       ]
  [  0.          0.          0.       ]]

 [[  0.          0.          0.       ]
  [  0.          0.          0.       ]
  [  0.          0.        -12.757539 ]]]:[95mf32[3[95m,[95m3[95m,[95m3][0m], [[[[ -0.96267235   0.           0.        ]
  [  0.           0.           0.        ]
  [  0.           0.           0.        ]]

 [[  0.           0.           0.        ]
  [  0.          -5.742745