In [2]:
import jax.numpy as jnp 
from jax import grad, jit, vmap 
from jax import random 

key = random.PRNGKey(0)

In [2]:
grad_tanh = grad(jnp.tanh)
grad_tanh(2.0)

Array(0.07065082, dtype=float32, weak_type=True)

In [3]:
grad(grad(jnp.tanh))(2.0)

Array(-0.13621868, dtype=float32, weak_type=True)

In [4]:
grad(grad(grad(jnp.tanh)))(2.0)

Array(0.25265405, dtype=float32, weak_type=True)

In [5]:
def sigmoid(x):
    return 0.5 * (jnp.tanh(x/2) + 1)

def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W)+b)

inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds*targets + (1-preds)*(1-targets)
    return -jnp.sum(jnp.log(label_probs))

key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())

In [6]:
# d(loss)/d(W)
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)

# d(loss)/d(b)
b_grad = grad(loss, argnums=1)(W, b)
print('b_grad', b_grad)

W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)

W_grad [-0.16965583 -0.8774644  -1.4901346 ]
b_grad -0.29227245
W_grad [-0.16965583 -0.8774644  -1.4901346 ]
b_grad -0.29227245


In [7]:
def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    label_probs = preds*targets + (1-preds)*(1-targets)
    return -jnp.sum(jnp.log(label_probs))

print(grad(loss2)({'W': W, 'b': b}))

{'W': Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), 'b': Array(-0.29227245, dtype=float32)}


In [9]:
from jax import value_and_grad
loss_value, Wb_grad = value_and_grad(loss, (0,1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))

loss value 3.0519388
loss value 3.0519388


In [11]:
eps = 1e-4
print('W', W)
print('b', b)

b_grad_numerical = (loss(W, b + eps/2.) - loss(W, b - eps/2.))/eps 
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))

W [-0.36838785 -2.275689    0.01144757]
b 0.8535516
b_grad_numerical -0.29325485
b_grad_autodiff -0.29227245


In [12]:
key, subkey = random.split(key)
vec = random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + (eps/2.)*unitvec, b) - loss(W - (eps/2.)*unitvec, b))/eps
print('W_grad_numerical', W_grad_numerical)
print('W_grad_autodiff', jnp.vdot(grad(loss, 0)(W, b), unitvec))

W_grad_numerical -0.2002716
W_grad_autodiff -0.19909117


In [16]:
from jax.test_util import check_grads
check_grads(loss, (W, b), order=2)

In [17]:
check_grads?

[0;31mSignature:[0m
[0mcheck_grads[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mf[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0margs[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0morder[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmodes[0m[0;34m=[0m[0;34m([0m[0;34m'fwd'[0m[0;34m,[0m [0;34m'rev'[0m[0;34m)[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0matol[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mrtol[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0meps[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Check gradients from automatic differentiation against finite differences.

Gradients are only checked in a single randomly chosen direction, which
ensures that the finite difference calculation does not become prohibitively
expensive even for large input/output spaces.

Args:
  f: function to check at ``f(*args)``.
  args: tuple of argument va

In [18]:
def hvp(f, x, v):
    return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)

In [35]:
print('W', W)
print('b', b)
print('inputs', inputs)

W [-0.36838785 -2.275689    0.01144757]
b 0.8535516
inputs [[ 0.52  1.12  0.77]
 [ 0.88 -1.08  0.15]
 [ 0.52  0.06 -1.3 ]
 [ 0.74 -2.49  1.39]]


In [36]:
import time
from jax import jacfwd, jacrev

f = lambda W: predict(W, b, inputs)
print('W', W)
print('f(W)', f(W))

start = time.time()
J = jacfwd(f)(W)
print("jacfwd took", time.time() - start, "seconds")
print("jacfwd result, with shape", J.shape)
print(J)

start = time.time()
J = jacrev(f)(W)
print("jacrev took", time.time() - start, "seconds")
print("jacrev result, with shape", J.shape)
print(J)

W [-0.36838785 -2.275689    0.01144757]
f(W) [0.13262254 0.952067   0.6249393  0.9980987 ]
jacfwd took 0.004054069519042969 seconds
jacfwd result, with shape (4, 3)
[[ 0.05981758  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188288  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]]
jacrev took 0.0628511905670166 seconds
jacrev result, with shape (4, 3)
[[ 0.05981757  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188289  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]]


In [None]:
# Forward-mode automatic differentiation is more efficient
# for "tall" and near-square Jacobian matrices

# Reverse-mode automatic differentiation is more efficient
# for "wide" Jacobian matrices

In [27]:
def predict_dict(params, inputs):
    return predict(params['W'], params['b'], inputs)

J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
    print("Jacobian from {} to logits is".format(k))
    print(v)

Jacobian from W to logits is
[[ 0.05981757  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188289  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]]
Jacobian from b to logits is
[0.11503381 0.04563541 0.23439017 0.00189771]


In [31]:
def hessian(f):
    return jacfwd(jacrev(f))

H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)

hessian, with shape (4, 3, 3)
[[[ 0.02285465  0.04922541  0.03384247]
  [ 0.04922541  0.10602397  0.07289147]
  [ 0.03384247  0.07289147  0.05011288]]

 [[-0.03195215  0.03921401 -0.00544639]
  [ 0.03921401 -0.04812629  0.00668421]
  [-0.00544639  0.00668421 -0.00092836]]

 [[-0.01583708 -0.00182736  0.03959271]
  [-0.00182736 -0.00021085  0.00456839]
  [ 0.03959271  0.00456839 -0.09898177]]

 [[-0.00103524  0.00348343 -0.00194457]
  [ 0.00348343 -0.01172127  0.0065432 ]
  [-0.00194457  0.0065432  -0.00365263]]]


In [51]:
def f(x):
    f1 = x[0]**3 + x[1]**4 + x[2]**5
    f2 = x[0]*x[1] + x[1]*x[2]
    return jnp.array([f1, f2])

In [52]:
x = jnp.array([1., 1., 1.])
f(x)

Array([3., 2.], dtype=float32)

In [53]:
jacfwd(f)(x)

Array([[3., 4., 5.],
       [1., 2., 1.]], dtype=float32)

In [54]:
jacfwd(jacrev(f))(x)

Array([[[ 6.,  0.,  0.],
        [ 0., 12.,  0.],
        [ 0.,  0., 20.]],

       [[ 0.,  1.,  0.],
        [ 1.,  0.,  1.],
        [ 0.,  1.,  0.]]], dtype=float32)

In [59]:
from jax import jvp 
f = lambda W: predict(W, b, inputs)
print('W', W.shape)
print('f', f(W).shape)

key, subkey = random.split(key)
v = random.normal(subkey, W.shape)
print('v', v.shape)

y, u = jvp(f, (W,), (v,))
print('y', y.shape)
print('u', u.shape)

W (3,)
f (4,)
v (3,)
y (4,)
u (4,)


In [75]:
from jax import vjp
f = lambda W: predict(W, b, inputs)
print('W', W.shape)
print('f', f(W).shape)

key, subkey = random.split(key)

y, vjp_fun = vjp(f, W)
u = random.normal(subkey, y.shape)
print('u', u.shape)

v = vjp_fun(u)
print('v', v)

W (3,)
f (4,)
u (4,)
v (Array([-0.0861915 , -0.07484604,  0.09508413], dtype=float32),)


In [73]:
from jax import vjp 

def vgrad(f, x):
    y, vjp_fn = vjp(f, x)
    print(y)
    return vjp_fn(jnp.ones(y.shape))[0]

v = jnp.array([[1., 2.], [3.,4.]])
print(vgrad(lambda x: 3*x**2, v))

[[ 3. 12.]
 [27. 48.]]
[[ 6. 12.]
 [18. 24.]]


In [68]:
%timeit jacfwd(lambda x: 3*x**2)(jnp.ones((2, 2)))

1.31 ms ± 14.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [69]:
%timeit jacrev(lambda x: 3*x**2)(jnp.ones((2, 2)))

1.61 ms ± 20.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [76]:
def hvp(f, x, v):
    return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)

In [4]:
from jax import jvp, grad 

def hvp(f, primals, tangents):
    return jvp(grad(f), primals, tangents)[1]

In [5]:
def f(X):
    return jnp.sum(jnp.tanh(X)**2)

key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (10, 10))
V = random.normal(subkey2, (10, 10))

ans1 = hvp(f, (X,), (V,))
ans2 = jnp.tensordot(hessian(f)(X), V, 2)

print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))

NameError: name 'hessian' is not defined

In [20]:
def f(x):
    return jnp.sum(x**2)

x = jnp.array([3., 1.])
f(x)
grad(f)(x)

Array([6., 2.], dtype=float32)

In [21]:
from jax import vjp 

y, vjp_fun = vjp(f, x)
vjp_fun(1.0)

(Array([6., 2.], dtype=float32),)