# AutoDiff for a program by JAX with grad(f,argnums)

In [1]:
# JAX
from jax import grad
import jax.numpy as jnp
import jax
print(jax.__version__)
# Common
import numpy as np

0.2.26


## JAX without JIT

In [2]:
def f(x:float,n:int): 
    r = 1
    for m in range(n):
        r *= x
    return r
dfdx_th = lambda x, n: n*f(x,n-1)
dfdx_ad = grad(f,argnums=0)
    
x = float(2.0)
print("f(x,10) = 2^10 = ", f(x,10))
print("Theory: df(x,10)/dx = 10 x 2^9 = ", dfdx_th(x,10))
print("AD: df(x,10) = ", dfdx_ad(x,10))



f(x,10) = 2^10 =  1024.0
Theory: df(x,10)/dx = 10 x 2^9 =  5120.0
AD: df(x,10) =  5120.0


In [3]:
x_all = np.random.randn(1000)
%time y_all = list(map(lambda x: dfdx_ad(x,10), x_all))

CPU times: user 8.16 s, sys: 18.1 ms, total: 8.18 s
Wall time: 8.17 s


In [4]:
f"{8360 / 511}"

'16.360078277886497'

## JAX with JIT

In [5]:
from jax import jit

In [6]:
def f(x,n): 
    r = 1
    for m in range(n):
        r *= x
    return r
dfdx_th = lambda x, n: n*f(x,n-1)
dfdx_ad = grad(f,argnums=0)
"""
An optional int or collection of ints that specify 
which positional arguments to treat as static (compile-time constant). 
Operations that only depend on static arguments will be constant-folded in Python (during tracing), 
and so the corresponding argument values can be any Python object.
"""
dfdx_th_jit = jit(dfdx_th, static_argnums=1)
dfdx_ad_jit = jit(dfdx_ad, static_argnums=1)

x = float(2.0)
n = 10
print("f(x,10) = 2^10 = ", f(x,n))
print("Theory: df/dx (x,10) = 10 x 2^9 = ", dfdx_th(x,n))
print("Theory_jit: df/dx (x,10) = 10 x 2^9 = ", dfdx_th_jit(x,n))
print("AD: df/dx (x,10) = ", dfdx_ad(x,n))
print("AD_jit: df/dx (x,10) = ", dfdx_ad_jit(x,n))

f(x,10) = 2^10 =  1024.0
Theory: df/dx (x,10) = 10 x 2^9 =  5120.0
Theory_jit: df/dx (x,10) = 10 x 2^9 =  5120.0
AD: df/dx (x,10) =  5120.0
AD_jit: df/dx (x,10) =  5120.0


In [7]:
x_all = np.random.randn(1000) # 1,000
n = 10
%time y_all = list(map(lambda x: dfdx_th(x,n), x_all))
%time y_all = list(map(lambda x: dfdx_th_jit(x,n), x_all))
%time y_all = list(map(lambda x: dfdx_ad(x,n), x_all))
%time y_all = list(map(lambda x: dfdx_ad_jit(x,n), x_all))

CPU times: user 4.66 ms, sys: 0 ns, total: 4.66 ms
Wall time: 4.63 ms
CPU times: user 27.2 ms, sys: 330 µs, total: 27.6 ms
Wall time: 24.3 ms
CPU times: user 7.83 s, sys: 14.5 ms, total: 7.85 s
Wall time: 7.83 s
CPU times: user 42 ms, sys: 105 µs, total: 42.1 ms
Wall time: 41.2 ms


In [8]:
print(f"Speed up from JAX without JIT by Jax-JIT: {7650 / 41.7:.2f} times faster")
print(f"Speed up from Tensorflow 2.4.1 by Jax-JIT: {511 / 41.7:.2f} times faster")

Speed up from JAX without JIT by Jax-JIT: 183.45 times faster
Speed up from Tensorflow 2.4.1 by Jax-JIT: 12.25 times faster


## vmap

In [9]:
from jax import vmap

In [10]:
#f = lambda x: x** 10
vdfdx_ad = vmap(dfdx_ad, (0,None))
vdfdx_ad_jit = vmap(dfdx_ad_jit, (0,None))

### len(x_all) = 1000

In [11]:
x_all = np.random.randn(1000) # 1,000
n = 10
%time y_all = list(map(lambda x: dfdx_ad(x, n), x_all))
%time y_all = list(map(lambda x: dfdx_ad_jit(x, n), x_all))
%time y_all = vdfdx_ad(x_all, n)
%time y_all = vdfdx_ad_jit(x_all, n)

CPU times: user 8.49 s, sys: 8.79 ms, total: 8.5 s
Wall time: 8.49 s
CPU times: user 5.31 ms, sys: 0 ns, total: 5.31 ms
Wall time: 4.99 ms
CPU times: user 143 ms, sys: 0 ns, total: 143 ms
Wall time: 140 ms
CPU times: user 39.3 ms, sys: 19.6 ms, total: 58.9 ms
Wall time: 57.5 ms


### len(x_all) = 1,000,000

In [12]:
x_all = np.random.randn(1000000) # 1,000,000
n = 10
%time y_all = list(map(lambda x: dfdx_ad_jit(x,n), x_all))
%time y_all = vdfdx_ad(x_all,n)
%time y_all = vdfdx_ad_jit(x_all,n)

x_all = np.random.randn(1000000) # 1,000,000
%time y_all = vdfdx_ad_jit(x_all,n)

CPU times: user 3.52 s, sys: 942 ms, total: 4.47 s
Wall time: 4.47 s
CPU times: user 1.11 s, sys: 73 ms, total: 1.19 s
Wall time: 1.11 s
CPU times: user 135 ms, sys: 313 µs, total: 136 ms
Wall time: 132 ms
CPU times: user 5.19 ms, sys: 0 ns, total: 5.19 ms
Wall time: 2.7 ms


In [83]:
x_all = np.random.randn(100000000) # 100,000,000
%time y_all = vdfdx_ad_jit(x_all, n)
len(y_all)

CPU times: user 603 ms, sys: 0 ns, total: 603 ms
Wall time: 202 ms


100000000

In [84]:
x_all = np.random.randn(100000000) # 100,000,000
%timeit -n10 -r3 y_all = vdfdx_ad_jit(x_all,n)

183 ms ± 10.1 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
