Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to prevent jitting in jaxopt solver? #444

Closed
itk22 opened this issue Jun 21, 2023 · 3 comments
Closed

How to prevent jitting in jaxopt solver? #444

itk22 opened this issue Jun 21, 2023 · 3 comments
Labels
bug Something isn't working

Comments

@itk22
Copy link

itk22 commented Jun 21, 2023

Dear jaxopt team,

I am writing to inquire about an issue I encountered when using jaxopt's LBFGS solver with a custom objective function that is not jittable but has defined custom VJPs. Despite specifying jit=False to disable JIT compilation, the solver still throws a TracerArrayConversionError as if it was attempting to JIT-compile the objective function. I have prepared a minimal example to reproduce the issue:

import jax
import numpy as np
from jax import custom_vjp
from jaxopt import LBFGS

@custom_vjp
def f(x):
    # The function deliberately uses numpy functions which are not jittable
    return -(x[0] + np.sin(x[0])) * np.exp(-x[0]**2.0)

def f_fwd(x):
    return f(x), (x, )

def f_bwd(x, g):
    return -(1.0 + np.cos(x[0])) * np.exp(-x[0]**2.0) - 2.0 * x[0] * (
        x[0] + np.sin(x[0])) * np.exp(-x[0]**2.0),

f.defvjp(f_fwd, f_bwd)

print_errors = False  # Flag for printing the errors

# Check if the custom function is jittable
try:
    f_jitted = jax.jit(f)
    f_jitted(1.0)
    print("Function is jittable")
except Exception as e:
    print("Function is not jittable")
    if print_errors:
        print("Error: ", e)

# Check if the custom function is differentiable
try:
    f_grad = jax.grad(f)
    print("Function is differentiable")
except Exception as e:
    print("Function is not differentiable")
    if print_errors:
        print("Error: ", e)

# Run the solver (optimal input is around 0.679579)
solver = LBFGS(fun=f, jit=False, maxiter=10)
res = solver.run(np.array([1.0]))

If my understanding is correct, setting jit=False should prevent JIT compilation of the objective function. However, this does not seem to be the case. Could you please let me know if I am missing something here or is this perhaps a bug?

@itk22
Copy link
Author

itk22 commented Jun 22, 2023

Update: I noticed that my implementation of the backward derivative f_bwd was incorrect. Here is an updated example:

import jax
import numpy as np
#import jax.numpy as np
from jax import custom_vjp
import jaxopt

@custom_vjp
def f(x):
    # The function deliberately uses numpy functions which are not jittable
    return -(x[0] + np.sin(x[0])) * np.exp(-x[0]**2.0)


def f_fwd(x):
    return f(x), (x, )


def f_bwd(res, g):
    x = res[0]
    grad_x = np.exp(-x[0]**2) * (2 * x[0]**2 + 2 * x[0] * np.sin(x[0]) - np.cos(x[0]) - 1)
    return (grad_x * g,)


f.defvjp(f_fwd, f_bwd)

print_errors = False  # Flag for printing the errors

# Check if the custom function is jittable
try:
    f_jitted = jax.jit(f)
    f_jitted(np.array([1.0]))
    print("Function is jittable")
except Exception as e:
    print("Function is not jittable")
    if print_errors:
        print("Error: ", e)

# Check if the custom function is differentiable
try:
    f_grad = jax.grad(f)
    f_grad(np.array([1.0]))
    print("Function is differentiable")
except Exception as e:
    print("Function is not differentiable")
    if print_errors:
        print("Error: ", e)

# Run the solver (optimal input is around 0.679579)
solver = jaxopt.LBFGS(fun=f, jit=False, maxiter=10000)
res = solver.run(np.array([0.0]))

solution = res.params
ground_truth = np.array([0.679579])

print("Solution: ", solution)
print("Ground truth: ", ground_truth)
print("Error: ", np.linalg.norm(solution - ground_truth))
print("Ground truth value ", f(ground_truth))
print("Ground truth gradient ", jax.grad(f)(ground_truth))
print("Solution value: ", f(solution))
print("Solution gradient: ", jax.grad(f)(solution))

However, now I got a different error, which might be related to this particular example:

Traceback (most recent call last):
  File "/home/igork/IRP/NeurOpt/jax-am-forked/applications/fem/top_opt/jaxopt_jitting.py", line 49, in <module>
    res = solver.run(np.array([0.0]))
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/base.py", line 354, in run
    return run(init_params, *args, **kwargs)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py", line 251, in wrapped_solver_fun
    return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/custom_derivatives.py", line 614, in __call__
    out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/custom_derivatives.py", line 763, in bind
    outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers,
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/core.py", line 809, in process_custom_vjp_call
    return fun.call_wrapped(*tracers)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py", line 207, in solver_fun_flat
    return solver_fun(*args, **kwargs)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/base.py", line 316, in _run
    opt_step = self.update(init_params, state, *args, **kwargs)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/lbfgs.py", line 313, in update
    product = inv_hessian_product(pytree=grad, s_history=state.s_history,
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/lbfgs.py", line 109, in inv_hessian_product
    return tree_map(fun, pytree, s_history, y_history)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/tree_util.py", line 210, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/tree_util.py", line 210, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/lbfgs.py", line 60, in inv_hessian_product_leaf
    r, alpha = jax.lax.scan(body_right, v, indices, reverse=True)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 257, in scan
    _check_scan_carry_type(f, init, out_tree_children[0], carry_avals_out)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 314, in _check_scan_carry_type
    raise TypeError(
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Scanned function carry input and carry output must have equal types (e.g. shapes and dtypes of arrays), but they differ:
  * the input carry r has type float32[] but the corresponding output carry component has type float32[1], so the shapes do not match

Revise the scanned function so that all output types (e.g. shapes and dtypes) match the corresponding input types.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/igork/IRP/NeurOpt/jax-am-forked/applications/fem/top_opt/jaxopt_jitting.py", line 49, in <module>
    res = solver.run(np.array([0.0]))
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/base.py", line 354, in run
    return run(init_params, *args, **kwargs)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py", line 251, in wrapped_solver_fun
    return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py", line 207, in solver_fun_flat
    return solver_fun(*args, **kwargs)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/base.py", line 316, in _run
    opt_step = self.update(init_params, state, *args, **kwargs)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/lbfgs.py", line 313, in update
    product = inv_hessian_product(pytree=grad, s_history=state.s_history,
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/lbfgs.py", line 109, in inv_hessian_product
    return tree_map(fun, pytree, s_history, y_history)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/lbfgs.py", line 60, in inv_hessian_product_leaf
    r, alpha = jax.lax.scan(body_right, v, indices, reverse=True)
TypeError: Scanned function carry input and carry output must have equal types (e.g. shapes and dtypes of arrays), but they differ:
  * the input carry r has type float32[] but the corresponding output carry component has type float32[1], so the shapes do not match

Revise the scanned function so that all output types (e.g. shapes and dtypes) match the corresponding input types.

I also tried two different solvers - GradientDescent worked as expected while BFGS recreated the TracerArrayConversionError I mentioned at the beginning.

@mblondel
Copy link
Collaborator

Hi! Thanks for the bug report. There were two problems here.

One on your side: the VJP should return an array of size 1, not a scalar (the output of the VJP should always have the same shape as the input of the function)

    return (np.array([grad_x * g]),)

One on our side: zoom line search and Hager-Zhang line search don't work with non-jittable functions for the moment (we'll fix this). As a temporary workaround, you can use the backtracking linesearch

solver = jaxopt.LBFGS(fun=f, jit=False, maxiter=10, linesearch="backtracking")

@itk22
Copy link
Author

itk22 commented Jul 1, 2023

Hi @mblondel. Thank you so much for your response! I really appreciate the fact that the team is already working on addressing this!

@itk22 itk22 closed this as completed Jul 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants