In [1]:
import jax.numpy as jnp 
from jax import grad, jit, vmap 
from jax import random 

key = random.PRNGKey(0)

In [2]:
grad_tanh = grad(jnp.tanh)
grad_tanh(2.0)

Array(0.07065082, dtype=float32, weak_type=True)

In [3]:
grad(grad(jnp.tanh))(2.0)

Array(-0.13621868, dtype=float32, weak_type=True)

In [4]:
grad(grad(grad(jnp.tanh)))(2.0)

Array(0.25265405, dtype=float32, weak_type=True)

In [5]:
def sigmoid(x):
    return 0.5 * (jnp.tanh(x/2) + 1)

def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W)+b)

inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds*targets + (1-preds)*(1-targets)
    return -jnp.sum(jnp.log(label_probs))

key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())

In [6]:
# d(loss)/d(W)
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)

# d(loss)/d(b)
b_grad = grad(loss, argnums=1)(W, b)
print('b_grad', b_grad)

W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)

W_grad [-0.16965583 -0.8774644  -1.4901346 ]
b_grad -0.29227245
W_grad [-0.16965583 -0.8774644  -1.4901346 ]
b_grad -0.29227245
