# 03 Automatic Differentiation

Original Documentation: https://docs.jax.dev/en/latest/automatic-differentiation.html


In [1]:
import jax
import jax.numpy as jnp

## Taking gradients

Can differentiate a scalar-valued function with the `jax.grad()` function:


In [2]:
grad_tanh = jax.grad(jnp.tanh)
print(grad_tanh(2.0))

0.070650816


`jax.grad(f)` takes a function $f$ and returns the gradient of the function, $\nabla f$.

You can chain `jax.grad()` calls to create functions for higher-order derivative.


In [4]:
print(jax.grad(jax.grad(jnp.tanh))(2.0))
print(jax.grad(jax.grad(jax.grad(jnp.tanh)))(2.0))

-0.13621868
0.25265405


Consider the derivatives of the function: $f(x) = x^3 + 2x^2 - 3x + 1$ below

- $f^\prime(x) = 3x^2 + 4x - 3$
- $f^{\prime\prime}(x) = 6x + 4$
- $f^{\prime\prime\prime}(x) = 6$
- $f^{\prime\prime\prime\prime}(x) = 0$

We can compute these all in JAX with automatic differentiation:


In [5]:
f = lambda x: x**3 + 2 * x**2 - 3 * x + 1
dfdx = jax.grad(f)
d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)

print(dfdx(1.0))
print(d2fdx(1.0))
print(d3fdx(1.0))
print(d4fdx(1.0))

4.0
10.0
6.0
0.0


## Computing gradients in a linear logistic regression

Consider this setup code for a binary logistic regression:


In [None]:
key = jax.random.key(1701)


def sigmoid(x):
    """
    Sigmoid function maps any real-valued input to a probability in (0, 1).
    Common use: in binary classification, interpret sigmoid(x) as P(class A).
    A typical decision rule: if sigmoid(x) > 0.5, predict class A; else class B.
    """
    return 0.5 * (jnp.tanh(x / 2) + 1)


# Note: since we are not using batched inputs, there is no reason to use @jax.vmap
@jax.jit
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)


inputs = jnp.array(
    [[0.52, 1.12, 0.77], [0.88, -1.08, 0.15], [0.52, 0.06, -1.30], [0.74, -2.49, 1.39]]
)
targets = jnp.array([True, True, False, True])


@jax.jit
def loss(W, b):
    """
    The binary cross-entropy loss on the predictions with the currents weights and biases.
    It is a binary loss since we have two label classes (either True or False).
    """
    preds = predict(W, b, inputs)

    # Compute the probability the predictions gave towards the correct label
    label_preds = preds * targets + (1 - preds) * (1 - targets)

    # Compute the cross-entropy loss between predictions and targets
    # cross-entropy loss is defined as sum(-log(prob of correct class))
    return jnp.sum(-jnp.log(label_preds))


key, W_key, b_key = jax.random.split(key, 3)
W = jax.random.normal(key, (3,))
b = jax.random.normal(key, ())

In machine learning, the objective is to minimize the loss. To determine the best direction to adjust the weights, we compute the gradient by differentiating the loss function.

Use `jax.grad()` to differentiate `loss()` with respect to its parameters:


In [7]:
# Compute gradient w.r.t W
W_grad = jax.grad(loss, argnums=(0,))(W, b)
print(f"{W_grad=}")

# Compute gradient w.r.t b
b_grad = jax.grad(loss, argnums=(1,))(W, b)
print(f"{b_grad=}")

# Compute gradient w.r.t W and b
W_grad, b_grad = jax.grad(loss, argnums=(0, 1))(W, b)
print(f"{W_grad=}")
print(f"{b_grad=}")

W_grad=(Array([-1.64406  ,  1.4412102, -1.8369946], dtype=float32),)
b_grad=(Array(-2.307187, dtype=float32),)
W_grad=Array([-1.64406  ,  1.4412102, -1.8369946], dtype=float32)
b_grad=Array(-2.307187, dtype=float32)


Effectively, `argnums` allows differentiating a function with respect to a set of parameters. For example, `jax.grad(f, i)` is equivalent to $\frac{\partial f}{\partial x_i}$.

## **Differentiating with respect to nested lists, tuples, and dicts**

`jax.grad()` will also work with dictionaries for parameters. It will return output with respect to each parameter (key) as a separate output dictionary item.


In [8]:
@jax.jit
def loss2(params):
    preds = predict(params["W"], params["b"], inputs)
    label_preds = preds * targets + (1 - preds) * (1 - targets)
    return jnp.sum(-jnp.log(label_preds))


print(jax.grad(loss2)({"W": W, "b": b}))

{'W': Array([-1.64406  ,  1.4412102, -1.8369946], dtype=float32), 'b': Array(-2.307187, dtype=float32)}


## Evaluating a function and its gradient

`jax.value_and_grad()` will compute a function’s value as well as its gradient’s value in one pass.


In [9]:
loss_val, Wb_grad = jax.value_and_grad(loss, argnums=(0, 1))(W, b)
print("Loss:", loss_val)
print("Confirmation:", loss(W, b))

Loss: 5.6301613
Confirmation: 5.6301613
