In [9]:
import jax
import jax.numpy as jnp
import timeit

def sum_of_squares(x):
    return jnp.sum(x**2)

grad_sum_of_sq = jax.grad(sum_of_squares)

key = jax.random.PRNGKey(0)
input_array = jax.random.normal(key, (1000,))

gradient = grad_sum_of_sq(input_array)

jit_grad_sum_sq = jax.jit(grad_sum_of_sq)

jit_grad_sum_sq(input_array).block_until_ready()

time_without_jit = timeit.timeit(lambda: grad_sum_of_sq(input_array).block_until_ready(), number=1000)
time_with_jit = timeit.timeit(lambda: jit_grad_sum_sq(input_array).block_until_ready(), number=1000)


print(f"{time_without_jit=}\n{time_with_jit=}")

time_without_jit=5.629175166001005
time_with_jit=0.35892157899979793
