# JAX grad()

In [2]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.key(0)

* grad() takes a function and returns a function that is the derivation of the input function as a function :)

In [3]:
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))

0.070650816


In [4]:
print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))

-0.13621868
0.25265405


## Lets test the linear logistic regression model

In [5]:
def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())

* checking the shape of the variables for better understanding the procedure

In [6]:
print('inputs : ', inputs, '----> ', inputs.shape,'\n')
print('targets : ',targets, '----> ', targets.shape, '\n')
print('Loss(W, b) : ',loss(W, b), '----> ', loss(W, b).shape, '\n')
print('W : ',W, '----> ', W.shape, '\n')
print('b : ',b, '----> ', b.shape, '\n')

inputs :  [[ 0.52  1.12  0.77]
 [ 0.88 -1.08  0.15]
 [ 0.52  0.06 -1.3 ]
 [ 0.74 -2.49  1.39]] ---->  (4, 3) 

targets :  [ True  True False  True] ---->  (4,) 

Loss(W, b) :  3.0519385 ---->  () 

W :  [-0.36838785 -2.275689    0.01144757] ---->  (3,) 

b :  0.8535516 ---->  () 



In [7]:
# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)

# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)

# But we can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)

# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)

W_grad [-0.16965583 -0.8774644  -1.4901346 ]
W_grad [-0.16965583 -0.8774644  -1.4901346 ]
b_grad -0.29227245
W_grad [-0.16965583 -0.8774644  -1.4901346 ]
b_grad -0.29227245


## Differentiating with respect to nested dicts

In [8]:
def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

print(grad(loss2)({'W': W, 'b': b}))

{'W': Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), 'b': Array(-0.29227245, dtype=float32)}


# JAX value_and_grad()
* efficiently computing both a function’s value as well as its gradient’s value

In [9]:
from jax import value_and_grad
loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))

loss value 3.0519385
loss value 3.0519385
