# Automatic differentiation with JAX

## Main features

- Numpy wrapper
- Auto-vectorization
- Auto-parallelization (SPMD paradigm)
- Auto-differentiation
- XLA backend and JIT support

## How to compute gradient of your objective?

- Define it as a standard Python function
- Call ```jax.grad``` and voila!
- Do not forget to wrap these functions with ```jax.jit``` to speed up

In [1]:
import jax
import jax.numpy as jnp

- By default, JAX exploits single-precision numbers ```float32```
- You can enable double precision (```float64```) by hands.  

In [2]:
from jax.config import config
config.update("jax_enable_x64", True)

In [3]:
f = lambda x: jnp.sin(x)
x = 1.
print(f(x))
print(jax.grad(f)(x), jnp.cos(x))
print(jax.grad(jax.grad(f))(x), -jnp.sin(x))
# print(jax.grad(jax.grad(jax.grad(f)))(x), -jnp.cos(x))



0.8414709848078965
0.5403023058681398 0.5403023058681398
-0.8414709848078965 -0.8414709848078965


In [4]:
@jax.jit
def f(x, A, b):
    res = A @ x - b
    return res @ res

gradf = jax.grad(f, argnums=0, has_aux=False)

## Random numbers in JAX 

- JAX focuses on the reproducibility of the runs
- Analogue of random seed is **the necessary argument** of all functions that generate something random
- More details and references on the design of ```random``` submodule are [here](https://github.com/google/jax/blob/master/design_notes/prng.md)

In [5]:
n = 1000
x = jax.random.normal(jax.random.PRNGKey(0), (n, ))
A = jax.random.normal(jax.random.PRNGKey(0), (n, n))
b = jax.random.normal(jax.random.PRNGKey(0), (n, ))

In [6]:
gradf(x, A, b).shape


(1000,)

In [7]:
print("Check correctness", jnp.linalg.norm(gradf(x, A, b) - 2 * A.T @ (A @ x - b)))
print("Compare speed")
print("Analytical gradient")
%timeit (2 * A.T @ (A @ x - b)).block_until_ready()
print("Grad function")
%timeit gradf(x, A, b).block_until_ready()
jit_gradf = jax.jit(gradf)
print("Jitted grad function")
%timeit jit_gradf(x, A, b).block_until_ready()

Check correctness 8.943165194422474e-11
Compare speed
Analytical gradient
4.03 ms ± 151 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Grad function
3.61 ms ± 174 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Jitted grad function
851 µs ± 28 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


- More details about ```block_until_ready()``` can be found [here](https://jax.readthedocs.io/en/latest/async_dispatch.html) 

In [8]:
grad_max = jax.grad(jnp.max)
# x = jax.random.normal(jax.random.PRNGKey(100), (3, ))
x = jnp.array([1., 1., 0.])
print(x)
grad_max(x)

[1. 1. 0.]


DeviceArray([0.5, 0.5, 0. ], dtype=float64)

In [10]:
gradf = jax.grad(f, argnums=1, has_aux=False)
x = jax.random.normal(jax.random.PRNGKey(0), (n, ))
Agrad = gradf(x, A, b)
s = jnp.linalg.svd(Agrad, compute_uv=False)
jnp.linalg.norm(Agrad - 2 * jnp.outer(A @ x - b, x))

DeviceArray(0., dtype=float64)

In [11]:
hess_func = jax.jit(jax.hessian(f))
print("Check correctness", jnp.linalg.norm(2 * A.T @ A - hess_func(x, A, b)))
print("Time for hessian")
%timeit hess_func(x, A, b).block_until_ready()
%timeit 2 * A.T @ A
print("Emulate hessian and check correctness", 
      jnp.linalg.norm(jax.jit(hess_func)(x, A, b) - jax.jacrev(jax.jacrev(f))(x, A, b)))
print("Time of emulating hessian")
hess_umul_func = jax.jit(jax.jacrev(jax.jacrev(f)))
%timeit hess_umul_func(x, A, b).block_until_ready()

Check correctness 0.0
Time for hessian
49.2 ms ± 1.14 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
33.6 ms ± 6.45 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
Emulate hessian and check correctness 0.0
Time of emulating hessian
62.8 ms ± 8.99 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Forward mode vs. backward mode: $m \ll n$

In [12]:
fmode_f = jax.jit(jax.jacfwd(f))
bmode_f = jax.jit(jax.jacrev(f))
print("Check correctness", jnp.linalg.norm(fmode_f(x, A, b) - bmode_f(x, A, b)))
print("Forward mode")
%timeit fmode_f(x, A, b).block_until_ready()
print("Backward mode")
%timeit bmode_f(x, A, b).block_until_ready()

Check correctness 1.187134637918824e-10
Forward mode
24.6 ms ± 266 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Backward mode
793 µs ± 33.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


## Forward mode vs. backward mode: $m \geq n$

In [13]:
def fvec(x, A, b):
    y = A @ x + b
    return jnp.exp(y - jnp.max(y)) / jnp.sum(jnp.exp(y - jnp.max(y)))

In [14]:
grad_fvec = jax.jit(jax.grad(fvec))
jac_fvec = jax.jacobian(fvec)
fmode_fvec = jax.jit(jax.jacfwd(fvec))
bmode_fvec = jax.jit(jax.jacrev(fvec))

In [15]:
n = 1000
m = 1000
x = jax.random.normal(jax.random.PRNGKey(0), (n, ))
A = jax.random.normal(jax.random.PRNGKey(0), (m, n))
b = jax.random.normal(jax.random.PRNGKey(0), (m, ))

In [16]:
J = jac_fvec(x, A, b)
print(J.shape)
grad_fvec(x, A, b)

(1000, 1000)


TypeError: Gradient only defined for scalar-output functions. Output had shape: (1000,).

In [17]:
print("Check correctness", jnp.linalg.norm(fmode_fvec(x, A, b) - bmode_fvec(x, A, b)))
print("Check shape", fmode_fvec(x, A, b).shape, bmode_fvec(x, A, b).shape)
print("Time forward mode")
%timeit fmode_fvec(x, A, b).block_until_ready()
print("Time backward mode")
%timeit bmode_fvec(x, A, b).block_until_ready()

Check correctness 7.940999414400821e-16
Check shape (1000, 1000) (1000, 1000)
Time forward mode
34 ms ± 695 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Time backward mode
36.7 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [18]:
n = 10
m = 1000
x = jax.random.normal(jax.random.PRNGKey(0), (n, ))
A = jax.random.normal(jax.random.PRNGKey(0), (m, n))
b = jax.random.normal(jax.random.PRNGKey(0), (m, ))

In [19]:
print("Check correctness", jnp.linalg.norm(fmode_fvec(x, A, b) - bmode_fvec(x, A, b)))
print("Check shape", fmode_fvec(x, A, b).shape, bmode_fvec(x, A, b).shape)
print("Time forward mode")
%timeit fmode_fvec(x, A, b).block_until_ready()
print("Time backward mode")
%timeit bmode_fvec(x, A, b).block_until_ready()

Check correctness 8.802051314519494e-16
Check shape (1000, 10) (1000, 10)
Time forward mode
113 µs ± 3.75 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Time backward mode
11.2 ms ± 201 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Hessian-by-vector product 

In [20]:
def hvp(f, x, z, *args):
    def g(x):
        return f(x, *args)
    return jax.jvp(jax.grad(g), (x,), (z,))[1]

In [21]:
n = 3000
x = jax.random.normal(jax.random.PRNGKey(0), (n, ))
A = jax.random.normal(jax.random.PRNGKey(0), (n, n))
b = jax.random.normal(jax.random.PRNGKey(0), (n, ))
z = jax.random.normal(jax.random.PRNGKey(0), (n, ))

In [22]:
print("Check correctness", jnp.linalg.norm(2 * A.T @ (A @ z) - hvp(f, x, z, A, b)))
print("Time for hvp by hands")
%timeit (2 * A.T @ (A @ z)).block_until_ready()
print("Time for hvp via jvp, NO jit")
%timeit hvp(f, x, z, A, b).block_until_ready()
print("Time for hvp via jvp, WITH jit")
%timeit jax.jit(hvp, static_argnums=0)(f, x, z, A, b).block_until_ready()

Check correctness 8.36322417348524e-10
Time for hvp by hands
49.2 ms ± 4.59 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Time for hvp via jvp, NO jit
32.3 ms ± 520 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Time for hvp via jvp, WITH jit
8.74 ms ± 256 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Summary

- JAX is a simple and extensible tool in the problem where autodiff is crucial
- JIT is a key to fast Python code
- Input/output dimensions are important
- Hessian matvec is faster than explicit hessian matrix by vector product
- Complete docs can be found [here](https://jax.readthedocs.io/en/latest/)
- GitHub repository is [here](https://github.com/google/jax)