In [1]:
import jax
import jax.numpy as jnp

In [2]:
from jax import grad, jit, vmap
from jax import random

In [3]:
def tanh(x):  # Define a function
  y = jnp.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

grad_tanh = grad(tanh)  # Obtain its gradient functio

In [4]:
print(grad_tanh(1.0))  # Evaluate it at x = 1.0

0.4199743


In [5]:
print(grad(grad(grad_tanh))(1.0))

0.6216266


In [6]:
def f(x):
  return jnp.array([
      x[0] ** 6 * x[1] ** 4 * x[2] ** 9 * x[3] ** 2,
      x[0] ** 2 * x[1] ** 3 * x[2] ** 5 * x[3] ** 3,
      x[0] ** 5 * x[1] ** 7 * x[2] ** 7 * x[3] ** 6,
  ])

In [7]:
evaluation_point = jnp.array([1.0, 0.5, 1.5, 2.0])

In [8]:
f(evaluation_point)

Array([9.61084 , 7.59375 , 8.542969], dtype=float32)

In [9]:
f_jacfwd = jax.jacfwd(f)

In [10]:
full_jacobian = f_jacfwd(evaluation_point)

In [11]:
full_jacobian.shape

(3, 4)

In [12]:
full_jacobian

Array([[ 57.66504 ,  76.88672 ,  57.66504 ,   9.61084 ],
       [ 15.1875  ,  45.5625  ,  25.3125  ,  11.390625],
       [ 42.714844, 119.60156 ,  39.867188,  25.628906]], dtype=float32)

In [13]:
multiplication_point = jnp.array([0.2, 0.3, 0.4, 0.8])

In [14]:
jnp.matmul(full_jacobian, multiplication_point) # full_jacobian @ multiplication_point

Array([65.353714, 35.94375 , 80.87344 ], dtype=float32)

In [15]:
f_evaluated, jvp_evaluated = jax.jvp(f, (evaluation_point,), (multiplication_point,))

In [16]:
f_evaluated

Array([9.61084 , 7.59375 , 8.542969], dtype=float32)

In [17]:
jvp_evaluated

Array([65.353714, 35.943752, 80.87344 ], dtype=float32)

In [18]:
f_jacrev = jax.jacrev(f)

In [21]:
f_jacrev(evaluation_point)

Array([[ 57.66504 ,  76.88672 ,  57.66504 ,   9.61084 ],
       [ 15.1875  ,  45.5625  ,  25.3125  ,  11.390625],
       [ 42.714844, 119.60156 ,  39.867188,  25.628906]], dtype=float32)

In [22]:
def hessian(fun):
  return jit(jax.jacfwd(jax.jacrev(fun)))

In [26]:
hessian(f)(evaluation_point)

(3, 4, 4)

In [27]:
def abs_val(x):
  if x > 0:
    return x
  else:
    return -x

In [28]:
abs_val_grad = grad(abs_val)

In [30]:
print(abs_val_grad(2.0))   # prints 1.0
print(abs_val_grad(-2.0))  # prints -1.0 (abs_val is re-evaluated)

1.0
-1.0
