# AutoDiff by JAX

In [3]:
# JAX
from jax import grad
import jax.numpy as jnp
import jax
print(jax.__version__)
# Common
import numpy as np

0.2.26


## JAX without JIT

In [11]:
f = lambda x, n: x**n 
dfdx_th = lambda x, n: n*f(x,n-1)

dfdx_ad = grad(f, argnums=0) #diff for 0th argument:x (no for 1st argment: n)    
x = 2.0
print("f(x,10) = 2^10 = ", f(x,10))
print("Theory: df(x,10)/dx = 10 x 2^9 = ", dfdx_th(x,10))
print("AD: df(x,10) = ", dfdx_ad(x,10))

f(x,10) = 2^10 =  1024.0
Theory: df(x,10)/dx = 10 x 2^9 =  5120.0
AD: df(x,10) =  5120.0


In [12]:
x_all = np.random.randn(1000)
%time y_all = list(map(lambda x: dfdx_ad(x,10), x_all))

CPU times: user 1.4 s, sys: 0 ns, total: 1.4 s
Wall time: 1.4 s


In [13]:
f"{1510 / 511}"

'2.954990215264188'

## JAX with JIT

In [6]:
from jax import jit

In [17]:
f = lambda x, n: x**n 
dfdx_ad = grad(f, argnums=0)
dfdx_ad_jit = jit(dfdx_ad)

x_all = np.random.randn(1000) # 1,000
%time y_all = list(map(lambda x: dfdx_ad(x,10), x_all))
%time y_all = list(map(lambda x: dfdx_ad_jit(x,10), x_all))

CPU times: user 1.34 s, sys: 3.71 ms, total: 1.35 s
Wall time: 1.34 s
CPU times: user 23.7 ms, sys: 205 µs, total: 23.9 ms
Wall time: 23.2 ms


In [18]:
print(f"Speed up from JAX without JIT: {1340 / 23.2:.2f} times faster")
print(f"Speed up from Tensorflow 2.4.1: {511 / 23.2:.2f} times faster")

Speed up from JAX without JIT: 57.76 times faster
Speed up from Tensorflow 2.4.1: 22.03 times faster


## vmap

In [7]:
from jax import vmap

In [20]:
f = lambda x, n: x**n
dfdx_ad = grad(f,argnums=0)
dfdx_ad_jit = jit(dfdx_ad)
vdfdx_ad = vmap(dfdx_ad, (0,None))
vdfdx_ad_jit = vmap(dfdx_ad_jit, (0,None))

### len(x_all) = 1000

In [18]:
x_all = np.random.randn(1000) # 1,000
%time y_all = list(map(lambda x: dfdx_ad(x,10), x_all))
%time y_all = list(map(lambda x: dfdx_ad_jit(x,10), x_all))
%time y_all = vdfdx_ad(x_all,10)
%time y_all = vdfdx_ad_jit(x_all,10)

CPU times: user 1.34 s, sys: 0 ns, total: 1.34 s
Wall time: 1.33 s
CPU times: user 6 ms, sys: 0 ns, total: 6 ms
Wall time: 5.44 ms
CPU times: user 3.27 ms, sys: 0 ns, total: 3.27 ms
Wall time: 2.98 ms
CPU times: user 1.34 ms, sys: 0 ns, total: 1.34 ms
Wall time: 1.03 ms


In [21]:
1340 / 1.02

1313.7254901960785

### len(x_all) = 1,000,000

In [23]:
x_all = np.random.randn(1000000) # 1,000,000
%time y_all = list(map(lambda x: dfdx_ad_jit(x,10), x_all))
%time y_all = vdfdx_ad(x_all, 10)
%time y_all = vdfdx_ad_jit(x_all, 10)

print("Re-run to remove compling time in the performance")
x_all = np.random.randn(1000000) # 1,000,000
%time y_all = vdfdx_ad_jit(x_all, 10)

CPU times: user 3.87 s, sys: 98.2 ms, total: 3.96 s
Wall time: 3.97 s
CPU times: user 848 ms, sys: 58.9 ms, total: 907 ms
Wall time: 899 ms
CPU times: user 113 ms, sys: 0 ns, total: 113 ms
Wall time: 103 ms
Re-run to remove compling time in the performance
CPU times: user 16 ms, sys: 0 ns, total: 16 ms
Wall time: 6.7 ms


In [25]:
x_all = np.random.randn(100000000) # 100,000,000
%time y_all = vdfdx_ad_jit(x_all, 10)
len(y_all)

CPU times: user 2.15 s, sys: 0 ns, total: 2.15 s
Wall time: 813 ms


100000000

In [27]:
x_all = np.random.randn(100000000) # 100,000,000
%timeit -n10 -r3 y_all = vdfdx_ad_jit(x_all, 10)

724 ms ± 9.55 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
