In [9]:
import jax.numpy as jnp
from jax import grad, vmap
from timeit import default_timer as timer

sigma = 1.0
mu = 0.0

# Define the Gaussian function
def gaussian_function(x, sigma=1.0, mu=0.0):
    prefactor = 1 / (sigma * jnp.sqrt(2 * jnp.pi))
    exponent = -0.5 * ((x - mu) ** 2) / (sigma ** 2)
    return prefactor * jnp.exp(exponent)

# Compute the 1st-order derivative using jax.grad
dgaussian_dx = grad(gaussian_function)

# Vectorize the derivative function to operate on arrays
v_dgaussian_dx = vmap(dgaussian_dx)

# Generate 10,000 uniform points in [-5, 5]
x_values = jnp.linspace(-5, 5, 10000)

# Measure the runtime of computing the derivative for these points
start = timer()
dgaussian_dx_values = v_dgaussian_dx(x_values)
end = timer()
runtime = end - start
print(f"Runtime: {runtime} seconds")

# Compute the exact derivative values for validation
exact_dgaussian_dx_values = -(x_values / sigma**2) * gaussian_function(x_values, sigma, mu)

# Validate the computation by comparing with the exact formula
assert jnp.allclose(dgaussian_dx_values, exact_dgaussian_dx_values, atol=1e-5), "The computed derivatives do not match the exact formula"


Runtime: 0.011726166994776577 seconds
