# AutoDiff by JAX

In [8]:
# JAX
from jax import grad
import jax.numpy as jnp
import jax
print(jax.__version__)
# Common
import numpy as np

0.2.26


## JAX without JIT

In [9]:
f = lambda x, n: x**n 
df_th = lambda x, n: n*f(x,n-1)

df_ad = grad(f)
    
x = float(2.0)
print("f(x,10) = 2^10 = ", f(x,10))
print("Theory: df(x,10)/dx = 10 x 2^9 = ", df_th(x,10))
print("AD: df(x,10) = ", df_ad(x,10))

f(x,10) = 2^10 =  1024.0
Theory: df(x,10)/dx = 10 x 2^9 =  5120.0
AD: df(x,10) =  5120.0


In [10]:
x_all = np.random.randn(1000)
%time y_all = list(map(lambda x: df_ad(x,10), x_all))

CPU times: user 1.41 s, sys: 0 ns, total: 1.41 s
Wall time: 1.41 s


## JAX with JIT

In [None]:
from jax import jit

In [15]:
df_ad_jit = jit(grad(f))

x_all = np.random.randn(1000) # 1,000
%time y_all = list(map(lambda x: df_ad_jit(x,10), x_all))

CPU times: user 38.9 ms, sys: 0 ns, total: 38.9 ms
Wall time: 35.9 ms


In [18]:
print(f"Speed up from JAX without JIT: {1410 / 38.9:.2f} times faster")
print(f"Speed up from Tensorflow 2.4.1: {511 / 38.9:.2f} times faster")

Speed up from JAX without JIT: 36.25 times faster
Speed up from Tensorflow 2.4.1: 13.14 times faster


In [34]:
x_all = np.random.randn(100000) # 100,000
%time y_all = list(map(lambda x: df_ad_jit(x,10), x_all))

CPU times: user 532 ms, sys: 19.8 ms, total: 552 ms
Wall time: 551 ms
