In [1]:
# https://www.kaggle.com/code/goktugguvercin/gradients-and-jacobians-in-jax
import jax
import jax.numpy as jnp

from jax import random
from jax import grad,value_and_grad
from jax.test_util import check_grads

Jax introduces gradient operator in math as a transformer which takes a python function as input and returns another function which is gradient of given python function.
$$\nabla f(x) = jax.grad(f)$$

In [2]:
# f(x) = x1^2 + x2^2
# paraboloid in 3D space
def paraboloid(x):
    return jnp.sum(x**2)

# gradf(x) = [2x1, 2x2]
# Our explicit gradient function
def g_paraboloid(x):
    return 2 * x 

# JAX's grad operator
grad_paraboloid = grad(paraboloid)

# three different input
input = jnp.array([[0.2, 0.3], [2.4, 3.6], [4.4, 2.1]])
for x in input:
    print("Explicit Gradient Function: ", g_paraboloid(x))
    print("JAX Gradient Function: ", grad_paraboloid(x))
    print("")

Explicit Gradient Function:  [0.4 0.6]
JAX Gradient Function:  [0.4 0.6]

Explicit Gradient Function:  [4.8 7.2]
JAX Gradient Function:  [4.8 7.2]

Explicit Gradient Function:  [8.8 4.2]
JAX Gradient Function:  [8.8 4.2]



In [3]:
# f(x) = 3x1^2 + 2x2^2 + 5x3^2 + x4^2 + 4x5^2
def paraboloid2(coeff, x):
    return jnp.sum(coeff * x**2)

# taking the gradient of paraboloid w.r.t. x 
grad_paraboloid2 = grad(paraboloid2, argnums=(1))

coefficients = jnp.array([3, 2, 5, 1, 4]) # coefficients
input = jnp.array([2., 1., 3., 2., 4.]) # input in R^5


print(grad_paraboloid2(coefficients, input))

[12.  4. 30.  4. 32.]


In [4]:
key = random.PRNGKey(137)
key1, key2 = random.split(key)

input = random.uniform(key1, (10, ))  # input vector in R^10
trans_matrix = random.uniform(key2, (20, 10))  # transformation matrix of shape 20x10 to project it into R^20

def affine_transform(input, matrix):  # transformation function
    return matrix @ input
    
jacobian_fn = jax.jacfwd(affine_transform, argnums=0)  # it returns the function in charge of computing jacobian
jacobian = jacobian_fn(input, trans_matrix)  # y = f(x) = Ax, dy/dx = A
print(jnp.all(trans_matrix == jacobian))

True
