In [1]:
import jax.numpy as jnp
import numpy as np
from jax import jacfwd, jacrev
import time

In [10]:
f1 = lambda x: jnp.array([jnp.prod(x[i:i+5]) for i in range(0,len(x),5)])
f2 = lambda x: jnp.array([jnp.prod(x[i:i+30]) for i in range(0,len(x)-100,30)])
f3 = lambda x: jnp.prod(x)
x = np.random.random(5000)
x = jnp.array(x)
F = lambda x: f3(f2(f1(x)))

start = time.time()
jacfwd(F)(x)
print(f'Time for jacfwd {time.time()-start}')

start = time.time()
jacrev(F)(x)
print(f'Time for jacrev {time.time()-start}')


Time for jacfwd 0.6742861270904541
Time for jacrev 1.1058948040008545


In [11]:
g1 = lambda x: jnp.array([x**i for i in range(30)])
g2 = lambda x: jnp.array([x[i%30]*x[i%10] for i in range(1000)])
g3 = lambda x: jnp.array([x[i%1000]*x[i%30]*x[i%6] for i in range(5000)])
G = lambda x: g3(g2(g1(x)))

x = np.random.random(1)

In [12]:
start = time.time()
jacfwd(G)(x)
print(f'Time for jacfwd {time.time()-start}')

start = time.time()
jacrev(G)(x)
print(f'Time for jacrev {time.time()-start}')


Time for jacfwd 11.758502721786499
Time for jacrev 53.85372304916382


My first one the jacfwd was faster than jacrev which doesn't match the book. I suspect that the forward computation is faster in the second case because we are doing $O(1*30*1000) + O(30*1000*5000)$ instead of $O(5000*1000*30) + O(5000*1000*30*1)$

In [13]:
x = np.random.random(5000)
x = jnp.array(x)

start = time.time()
(jacfwd(jacrev(F)))(x)
print(f'Time for fwd(rev) {time.time() - start}')

start = time.time()
(jacrev(jacrev(F)))(x)
print(f'Time for rev(rev) {time.time()  - start}')

Time for fwd(rev) 18.935773134231567
Time for rev(rev) 21.346662998199463


I don't see a great difference in results because the complexity doing the second derivative is going to be the same due to the order of operations.

In [14]:
H = lambda x: G(F(x))

start = time.time()
jacfwd(H)(x)
print(f'Time for fwd {time.time() - start}')

start = time.time()
jacrev(H)(x)
print(f"Time for rev {time.time() - start}")

Time for fwd 9.071834802627563
Time for rev 65.92943000793457
