In [1]:
import jax 
import jax.numpy as jnp 
from jax import grad 

``jax.grad`` takes a function and returns a function. 
So, if we have mathematical function $f$ function, ``jax.grad(f)`` is a Python that evaluates the mathematical function $∇ f$. So, ``grad(f)(x)`` evaluates $∇ f(x)$

In [2]:
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0)) 

0.070650816


Also, as ``jax.grad`` operates on functions, functions can be applied as many times as needed

In [3]:
print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))

-0.13621868
0.25265405


Lets consider a function:

$ f(x) = x^3 + 2x^2 - 3x + 1$

We can get its higher-order derivatives as

$f'(x) = 3x^2 + 4x -3$

$f''(x) = 6x + 4$

$f'''(x) = 6$

$f^(iv) = 0$

With ``jax.grad`` it is simple, and we can chain it up

In [4]:
f = lambda x: x**3 + 2*x**2 - 3*x + 1

dfdx = grad(f)

In [5]:
d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)

# Using x = 1

print(dfdx(1.0))
print(d2fdx(1.0))
print(d3fdx(1.0))
print(d4fdx(1.0))

4.0
10.0
6.0
0.0


### Computing Gradients in Linear Logistic Regression

In [6]:
key = jax.random.key(0) 

def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

#outputs probability of a label being true 
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

#build a toy dataset
inputs = jnp.array([[0.52, 1.12,  0.77],
                     [0.88, -1.08, 0.15],
                     [0.52, 0.06, -1.30],
                     [0.74, -2.49, 1.39]])

targets = jnp.array([True, True, False, True])

In [7]:
#Training loss is the negative log-likelihood of the training examples
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

In [8]:
#initialise parameters randomly
key, W_key, b_key = jax.random.split(key, 3)
W = jax.random.normal(W_key, (3,))
b = jax.random.normal(b_key, ())

In [9]:
W_grad = grad(loss, argnums=0)(W, b) #argnums=0 is default
print(f"{W_grad=}")

b_grad = grad(loss, 1)(W, b)
print(f"{b_grad=}")

#We can also use tuple value to initialize the parameters W and b
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print(f"{W_grad=}")
print(f"{b_grad=}")

W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
b_grad=Array(-0.29227245, dtype=float32)
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
b_grad=Array(-0.29227245, dtype=float32)


#### Differentiating with respect to nested lists, tuples, and dicts

As JAX has pytree abstraction, differentiating works perfectly well with tuples, list, and dicts as any python container

In [10]:
def loss2(param_dict):
    preds = predict(param_dict['W'], param_dict['b'], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

print(grad(loss2)({'W': W, 'b': b}))

{'W': Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), 'b': Array(-0.29227245, dtype=float32)}


Evaluating a function and its gradient using ``jax.value_and_grad``. It computes both a function'value as well as its gradients'value in one pass

In [14]:
loss_value, Wb_grad = jax.value_and_grad(loss, (0,1))(W, b)
print(f"loss value: {loss_value}")
print(f"Wb_grad: {Wb_grad}")
print(f"loss value: {loss(W, b)}")

loss value: 3.051938533782959
Wb_grad: (Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), Array(-0.29227245, dtype=float32))
loss value: 3.051938533782959


#### Checking against numerical differences

In [16]:
#set a step size for finite differences calculations 

eps = 1e-4

#check b_grad with scalar finite differences
b_grad_numerical = (loss(W,b + eps/2.) - loss(W, b - eps /2.)) / eps
print('b_grad_numerical:', b_grad_numerical)
print('b_grad_autodiff:', grad(loss,1)(W,b))

#check W_grad with finite differences in a random direction
key, subkey = jax.random.split(key)
vec = jax.random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps /2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_grad_numerical:', W_grad_numerical)
print('W_grad_autodiff:', jnp.vdot(grad(loss)(W,b), unitvec))


b_grad_numerical: -0.29325485
b_grad_autodiff: -0.29227245
W_grad_numerical: -0.2002716
W_grad_autodiff: -0.19909117
