In [None]:
from collections.abc import Callable

import jax
import jax.numpy as jnp
import numpy as np
import pycutest
from jaxtyping import ArrayLike, Scalar


def test_objective(objective: Callable[[ArrayLike], Scalar], problem_iD: str):
    pycutest_problem = pycutest.import_problem(problem_iD)
    y0 = jnp.asarray(pycutest_problem.x0)

    pycutest_f0, pycutest_grad0 = pycutest_problem.obj(y0, gradient=True)
    f0, grad0 = jax.value_and_grad(objective)(y0)

    pycutest_hess0 = pycutest_problem.hess(y0)
    hess0 = jax.hessian(objective)(y0)

    assert np.allclose(f0, pycutest_f0), f"Objective mismatch for {problem_iD}"
    assert np.allclose(grad0, pycutest_grad0), f"Gradient mismatch for {problem_iD}"
    assert np.allclose(hess0, pycutest_hess0), f"Hessian mismatch for {problem_iD}"

    grad0_signs = np.sign(grad0)
    pycutest_grad0_signs = np.sign(pycutest_grad0)
    gradient_signs_match = np.all(grad0_signs == pycutest_grad0_signs)
    assert gradient_signs_match, f"Gradient sign mismatch for {problem_iD}"

    hess0_signs = np.sign(hess0)
    pycutest_hess0_signs = np.sign(pycutest_hess0)
    hessian_signs_match = np.all(hess0_signs == pycutest_hess0_signs)
    assert hessian_signs_match, f"Hessian sign mismatch for {problem_iD}"

    return True


# def test_compilation(objective: Callable[[ArrayLike], Scalar], problem_iD: str):
#     pycutest_problem = pycutest.import_problem(problem_iD)
#     y0 = jnp.asarray(pycutest_problem.x0)

In [80]:
def rosenbrock(y):
    x1, x2 = y
    return (1 - x1) ** 2 + 100 * (x2 - x1**2) ** 2

In [81]:
test_objective(rosenbrock, "ROSENBR")

True