In [4]:
import jax.numpy as jnp
from jax import grad, vmap
from timeit import default_timer as timer

# Define the Gaussian function
def gaussian_function(x, sigma=1.0, mu=0.0):
    prefactor = 1 / (sigma * jnp.sqrt(2 * jnp.pi))
    exponent = -0.5 * ((x - mu) ** 2) / (sigma ** 2)
    return prefactor * jnp.exp(exponent)

# Compute the 1st-order derivative
dgaussian_dx = grad(gaussian_function)

# Compute the 2nd-order derivative by taking the gradient of the 1st-order derivative
d2gaussian_dx2 = grad(dgaussian_dx)

# Vectorize the 2nd-order derivative function to operate on arrays
v_d2gaussian_dx2 = vmap(d2gaussian_dx2)

# Generate 10,000 uniform points in [-5, 5]
x_values = jnp.linspace(-5, 5, 10000)

# Measure the runtime of computing the 2nd-order derivative for these points
start = timer()
d2gaussian_dx2_values = v_d2gaussian_dx2(x_values)
end = timer()
runtime = end - start
print(f"Runtime: {runtime} seconds")

# Compute the exact 2nd-order derivative values for validation
sigma = 1.0  # Standard deviation
mu = 0.0  # Mean
exact_d2gaussian_dx2_values = (((x_values - mu) ** 2 - sigma**2) / sigma**4) * gaussian_function(x_values, sigma, mu)

# Validate the computation by comparing with the exact formula
assert jnp.allclose(d2gaussian_dx2_values, exact_d2gaussian_dx2_values, atol=1e-5), "The computed 2nd-order derivatives do not match the exact formula"


Runtime: 0.019399541022721678 seconds
