# AutoDiff by JAX

In [48]:
# JAX
from jax import grad
import jax.numpy as jnp
import jax
print(jax.__version__)
# Common
import numpy as np

0.2.26


## JAX without JIT

In [2]:
f = lambda x, n: x**n 
df_th = lambda x, n: n*f(x,n-1)

df_ad = grad(f)
    
x = float(2.0)
print("f(x,10) = 2^10 = ", f(x,10))
print("Theory: df(x,10)/dx = 10 x 2^9 = ", df_th(x,10))
print("AD: df(x,10) = ", df_ad(x,10))



f(x,10) = 2^10 =  1024.0
Theory: df(x,10)/dx = 10 x 2^9 =  5120.0
AD: df(x,10) =  5120.0


In [3]:
x_all = np.random.randn(1000)
%time y_all = list(map(lambda x: df_ad(x,10), x_all))

CPU times: user 1.51 s, sys: 0 ns, total: 1.51 s
Wall time: 1.5 s


In [50]:
f"{1510 / 511}"

'2.954990215264188'

## JAX with JIT

In [49]:
from jax import jit

In [4]:
df_ad_jit = jit(grad(f))

x_all = np.random.randn(1000) # 1,000
%time y_all = list(map(lambda x: df_ad_jit(x,10), x_all))

CPU times: user 38.2 ms, sys: 907 µs, total: 39.1 ms
Wall time: 34.5 ms


In [5]:
print(f"Speed up from JAX without JIT: {1410 / 38.9:.2f} times faster")
print(f"Speed up from Tensorflow 2.4.1: {511 / 38.9:.2f} times faster")

Speed up from JAX without JIT: 36.25 times faster
Speed up from Tensorflow 2.4.1: 13.14 times faster


## vmap

In [6]:
from jax import vmap

In [7]:
f = lambda x: x** 10
df_ad = grad(f)
df_ad_jit = jit(df_ad)
vdf_ad = vmap(df_ad, 0)
vdf_ad_jit = vmap(df_ad_jit, 0)

### len(x_all) = 1000

In [44]:
x_all = np.random.randn(1000) # 1,000
%time y_all = list(map(lambda x: df_ad(x), x_all))
%time y_all = list(map(lambda x: df_ad_jit(x), x_all))
%time y_all = vdf_ad(x_all)
%time y_all = vdf_ad_jit(x_all)

CPU times: user 1.44 s, sys: 0 ns, total: 1.44 s
Wall time: 1.44 s
CPU times: user 4.74 ms, sys: 266 µs, total: 5.01 ms
Wall time: 4.75 ms
CPU times: user 3.63 ms, sys: 0 ns, total: 3.63 ms
Wall time: 3.35 ms
CPU times: user 1.06 ms, sys: 0 ns, total: 1.06 ms
Wall time: 839 µs


### len(x_all) = 1,000,000

In [45]:
x_all = np.random.randn(1000000) # 1,000,000
%time y_all = list(map(lambda x: df_ad_jit(x), x_all))
%time y_all = vdf_ad(x_all)
%time y_all = vdf_ad_jit(x_all)

x_all = np.random.randn(1000000) # 1,000,000
%time y_all = vdf_ad_jit(x_all)

CPU times: user 3.32 s, sys: 228 ms, total: 3.55 s
Wall time: 3.55 s
CPU times: user 911 ms, sys: 29.8 ms, total: 941 ms
Wall time: 934 ms
CPU times: user 93.6 ms, sys: 10.6 ms, total: 104 ms
Wall time: 102 ms
CPU times: user 2.7 ms, sys: 444 µs, total: 3.14 ms
Wall time: 2.28 ms


In [46]:
x_all = np.random.randn(100000000) # 100,000,000
%time y_all = vdf_ad_jit(x_all)
len(y_all)

CPU times: user 227 ms, sys: 161 ms, total: 388 ms
Wall time: 257 ms


100000000

In [47]:
x_all = np.random.randn(100000000) # 100,000,000
%timeit -n10 -r3 y_all = vdf_ad_jit(x_all)

175 ms ± 5.75 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
