# AutoDiff for a program by JAX

In [1]:
# JAX
from jax import grad
import jax.numpy as jnp
import jax
print(jax.__version__)
# Common
import numpy as np

0.2.26


## JAX without JIT

In [2]:
def fn(x,n): 
    r = 1
    for m in range(n):
        r *= x
    return r
dfn_th = lambda x, n: n*fn(x,n-1)
dfn_ad = grad(fn)
    
x = float(2.0)
print("f(x,10) = 2^10 = ", fn(x,10))
print("Theory: df(x,10)/dx = 10 x 2^9 = ", dfn_th(x,10))
print("AD: df(x,10) = ", dfn_ad(x,10))



f(x,10) = 2^10 =  1024.0
Theory: df(x,10)/dx = 10 x 2^9 =  5120.0
AD: df(x,10) =  5120.0


In [3]:
x_all = np.random.randn(1000)
%time y_all = list(map(lambda x: dfn_ad(x,10), x_all))

CPU times: user 8.15 s, sys: 26.2 ms, total: 8.18 s
Wall time: 8.17 s


In [4]:
f"{7870 / 511}"

'15.401174168297455'

## JAX with JIT

In [5]:
from jax import jit

In [6]:
f = lambda x: fn(x, 10)
# f = lambda x, n: x**n 
df_th = lambda x: dfn_th(x, 10)

In [7]:
df_ad = grad(f)    
x = float(2.0)
print("f(x,10) = 2^10 = ", f(x))
print("Theory: df(x,10)/dx = 10 x 2^9 = ", df_th(x))
print("AD: df(x,10) = ", df_ad(x))

f(x,10) = 2^10 =  1024.0
Theory: df(x,10)/dx = 10 x 2^9 =  5120.0
AD: df(x,10) =  5120.0


In [8]:
df_ad_jit = jit(df_ad)

x_all = np.random.randn(1000) # 1,000
%time y_all = list(map(lambda x: df_ad(x), x_all))
%time y_all = list(map(lambda x: df_ad_jit(x), x_all))

CPU times: user 7.89 s, sys: 30.4 ms, total: 7.92 s
Wall time: 7.91 s
CPU times: user 40.7 ms, sys: 15 µs, total: 40.7 ms
Wall time: 39.9 ms


In [9]:
print(f"Speed up from JAX without JIT by Jax-JIT: {7650 / 41.7:.2f} times faster")
print(f"Speed up from Tensorflow 2.4.1 by Jax-JIT: {511 / 41.7:.2f} times faster")

Speed up from JAX without JIT by Jax-JIT: 183.45 times faster
Speed up from Tensorflow 2.4.1 by Jax-JIT: 12.25 times faster


## vmap

In [10]:
from jax import vmap

In [11]:
#f = lambda x: x** 10
df_ad = grad(f)
df_ad_jit = jit(df_ad)
vdf_ad = vmap(df_ad, 0)
vdf_ad_jit = vmap(df_ad_jit, 0)

### len(x_all) = 1000

In [12]:
x_all = np.random.randn(1000) # 1,000
%time y_all = list(map(lambda x: df_ad(x), x_all))
%time y_all = list(map(lambda x: df_ad_jit(x), x_all))
%time y_all = vdf_ad(x_all)
%time y_all = vdf_ad_jit(x_all)

CPU times: user 7.77 s, sys: 59.7 ms, total: 7.83 s
Wall time: 7.81 s
CPU times: user 41.4 ms, sys: 23 µs, total: 41.4 ms
Wall time: 40.6 ms
CPU times: user 99 ms, sys: 0 ns, total: 99 ms
Wall time: 95.9 ms
CPU times: user 47.8 ms, sys: 130 µs, total: 48 ms
Wall time: 47.2 ms


### len(x_all) = 1,000,000

In [13]:
x_all = np.random.randn(1000000) # 1,000,000
%time y_all = list(map(lambda x: df_ad_jit(x), x_all))
%time y_all = vdf_ad(x_all)
%time y_all = vdf_ad_jit(x_all)

x_all = np.random.randn(1000000) # 1,000,000
%time y_all = vdf_ad_jit(x_all)

CPU times: user 3 s, sys: 710 ms, total: 3.71 s
Wall time: 3.71 s
CPU times: user 1.4 s, sys: 51.2 ms, total: 1.45 s
Wall time: 1.38 s
CPU times: user 167 ms, sys: 253 µs, total: 167 ms
Wall time: 163 ms
CPU times: user 5.34 ms, sys: 0 ns, total: 5.34 ms
Wall time: 2.35 ms


In [14]:
x_all = np.random.randn(100000000) # 100,000,000
%time y_all = vdf_ad_jit(x_all)
len(y_all)

CPU times: user 596 ms, sys: 0 ns, total: 596 ms
Wall time: 201 ms


100000000

In [15]:
x_all = np.random.randn(100000000) # 100,000,000
%timeit -n10 -r3 y_all = vdf_ad_jit(x_all)

178 ms ± 6.66 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
