The gradient of the given function with respect to x : $$2xy + \cos(x)y^3$$
The gradient of the given function with respect to y : $$x^2 + 3\sin(x)y^2$$
This is evaluated as follows using partial differentiation with respect to x and y:
        ∂f/∂x  = y * 2x + cos(x) * y^3
similarly ∂f/∂y = x^2 * 1 + sin(x) * 3 * y^2

In [15]:
import jax.numpy as jnp
from jax import grad, vmap, random

# Define the function
def f(x, y):
  return y * x**2 + jnp.sin(x) * y**3

# Compute the gradient
grad_f = vmap(grad(f, argnums=(0,1)))

key = random.PRNGKey(0)  # Create a random key for reproducibility
x = random.uniform(key, (5,))  # Generate 5 random values for x
y = random.uniform(key, (5,))  # Generate 5 random values for y

# Measure values of the gradients wrt x, y
grad_x, grad_y = grad_f(x, y)  

# Measure values of the gradients evaluated beforehand
analytical_value = jnp.array([2 * x * y + y**3 * jnp.cos(x), x**2 + 3 * jnp.sin(x)* y**2])

# Compare the two gradients element-wise
if jnp.allclose(grad_x, analytical_value[0]) and jnp.allclose(grad_y, analytical_value[1]):
  print("The gradient evaluated matches the analytical solution.")
else:
  print("The gradient evaluated does not match with the analytical solution.")


The gradient evaluated matches the analytical solution.


In [14]:
import sympy as sp
import numpy as np , random

x, y = sp.symbols('x y')
f = y * x**2 + sp.sin(x) * y**3

grad_f = sp.Matrix([f.diff(x), f.diff(y)])

print("Analytical gradient wrt x :", grad_f[0])
print("Analytical gradient wrt y :", grad_f[1])

Analytical gradient wrt x : 2*x*y + y**3*cos(x)
Analytical gradient wrt y : x**2 + 3*y**2*sin(x)
