# AutoDiff by JAX

In [1]:
# JAX
from jax import jit, grad
import jax.numpy as jnp
import jax
print(jax.__version__)
# Common
import numpy as np

0.2.26


## JAX without JIT

In [2]:
f = lambda x, n: x**n 
df_th = lambda x, n: n*f(x,n-1)

df_ad = grad(f)
    
x = float(2.0)
print("f(x,10) = 2^10 = ", f(x,10))
print("Theory: df(x,10)/dx = 10 x 2^9 = ", df_th(x,10))
print("AD: df(x,10) = ", df_ad(x,10))



f(x,10) = 2^10 =  1024.0
Theory: df(x,10)/dx = 10 x 2^9 =  5120.0
AD: df(x,10) =  5120.0


In [3]:
x_all = np.random.randn(1000)
%time y_all = list(map(lambda x: df_ad(x,10), x_all))

CPU times: user 1.46 s, sys: 16.9 ms, total: 1.48 s
Wall time: 1.47 s


## JAX with JIT

In [4]:
df_ad_jit = jit(grad(f))

x_all = np.random.randn(1000) # 1,000
%time y_all = list(map(lambda x: df_ad_jit(x,10), x_all))

CPU times: user 34.2 ms, sys: 0 ns, total: 34.2 ms
Wall time: 31.3 ms


In [5]:
print(f"Speed up from JAX without JIT: {1410 / 38.9:.2f} times faster")
print(f"Speed up from Tensorflow 2.4.1: {511 / 38.9:.2f} times faster")

Speed up from JAX without JIT: 36.25 times faster
Speed up from Tensorflow 2.4.1: 13.14 times faster


## vmap

In [30]:
from jax import vmap

In [31]:
f = lambda x: x** 10

### len(x_all) = 1000

In [34]:
x_all = np.random.randn(1000) # 1,000
df_ad = grad(f)
vdf_ad = vmap(df_ad, 0)
%time y_all = list(map(lambda x: df_ad(x), x_all))
%time y_all = vdf_ad(x_all)

CPU times: user 1.38 s, sys: 9.65 ms, total: 1.39 s
Wall time: 1.39 s
CPU times: user 2.67 ms, sys: 227 µs, total: 2.9 ms
Wall time: 2.91 ms


In [35]:
x_all = np.random.randn(1000) # 1,000

df_ad_jit = jit(grad(f))
vdf_ad_jit = vmap(df_ad_jit, 0)
%time y_all = list(map(lambda x: df_ad_jit(x), x_all))
%time y_all = vdf_ad_jit(x_all)

CPU times: user 11.3 ms, sys: 10.2 ms, total: 21.5 ms
Wall time: 19.9 ms
CPU times: user 21.7 ms, sys: 140 µs, total: 21.9 ms
Wall time: 20.8 ms


### len(x_all) = 1000

In [40]:
from jax import vmap

In [41]:
f = lambda x: x
f(2.0)

x = jnp.ones(10)
f(x)

DeviceArray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)

In [42]:
vf = vmap(f, 0)

In [43]:
vf(x)

DeviceArray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)

In [44]:
df = grad(f)

In [45]:
df(2.0)

DeviceArray(1., dtype=float32, weak_type=True)

In [46]:
vdf = vmap(df, 0)

In [47]:
vdf(x)

DeviceArray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)