In [33]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.tree_util import register_pytree_node

### Getting Gradients of simple functions

In [34]:
def pow2(x):
    return jnp.power(x, 2)

In [35]:
pow2(4.0)

DeviceArray(16., dtype=float32)

In [36]:
pow2_der = grad(pow2)
pow2_der(4.0)

DeviceArray(8., dtype=float32)

In [37]:
pow2_der_der = grad(pow2_der)
pow2_der_der(4.0)

DeviceArray(2., dtype=float32)

### Getting Gradients of multi param functions

In [38]:
def pow(x, y):
    return jnp.power(x, y)

In [39]:
pow(4.0, 2.0)

DeviceArray(16., dtype=float32)

In [40]:
pow_der = grad(pow, argnums=(0)) # dx = argnum = (0)


pow_der(4.0, 2.0), pow_der(4.0, 3.0)

(DeviceArray(8., dtype=float32), DeviceArray(48., dtype=float32))

In [41]:
pow_der_der = grad(pow_der, argnums=(0)) # dxdx


pow_der_der(4.0, 2.0), pow_der_der(4.0, 3.0)

(DeviceArray(2., dtype=float32), DeviceArray(24., dtype=float32))