Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type precision issue in BoxOSQP #547

Open
jewillco opened this issue Oct 14, 2023 · 10 comments
Open

Type precision issue in BoxOSQP #547

jewillco opened this issue Oct 14, 2023 · 10 comments
Assignees

Comments

@jewillco
Copy link

When I have float64 support enabled in JAX and try to run:

optimizer = jaxopt.BoxOSQP()
optimizer.run(
    params_obj=(
        jnp.eye(30, dtype=jax.numpy.float32),
        jnp.ones((30,), dtype=jax.numpy.float32),
    ),
    params_eq=jnp.ones((1, 30), dtype=jax.numpy.float32),
    params_ineq=(-1, 1),
)

I get an internal error from the implementation (partially redacted):

[.../jaxopt/_src/osqp.py](...) in run(self, init_params, params_obj, params_eq, params_ineq)
    763       init_params = self.init_params(None, params_obj, params_eq, params_ineq)
    764 
--> 765     return super().run(init_params, params_obj, params_eq, params_ineq)
    766 
    767   def l2_optimality_error(

[.../jaxopt/_src/base.py](...) in run(self, init_params, *args, **kwargs)
    345       run = decorator(run)
    346 
--> 347     return run(init_params, *args, **kwargs)
    348 
    349   def __post_init__(self):

[.../jaxopt/_src/implicit_diff.py](...) in wrapped_solver_fun(*args, **kwargs)
    249     args, kwargs = _signature_bind(solver_fun_signature, *args, **kwargs)
    250     keys, vals = list(kwargs.keys()), list(kwargs.values())
--> 251     return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
    252 
    253   return wrapped_solver_fun

[.../jaxopt/_src/implicit_diff.py](...) in solver_fun_flat(*flat_args)
    205     def solver_fun_flat(*flat_args):
    206       args, kwargs = _extract_kwargs(kwarg_keys, flat_args)
--> 207       return solver_fun(*args, **kwargs)
    208 
    209     def solver_fun_fwd(*flat_args):

[.../jaxopt/_src/base.py](...) in _run(self, init_params, *args, **kwargs)
    307     zero_step = self._make_zero_step(init_params, state)
    308 
--> 309     opt_step = self.update(init_params, state, *args, **kwargs)
    310     init_val = (opt_step, (args, kwargs))
    311 

[.../jaxopt/_src/osqp.py](...) in update(self, params, state, params_obj, params_eq, params_ineq)
    703     # We need our own ifelse cond because automatic jitting of jax.lax.cond branches
    704     # could pose problems with non jittable matvecs, or prevent printing when verbose > 0.
--> 705     rho_bar, solver_state = cond(
    706         jnp.mod(state.iter_num, self.stepsize_updates_frequency) == 0,
    707         lambda _: self._update_stepsize(rho_bar, solver_state, primal_residuals, dual_residuals, Q, c, A, x, y),

[.../jaxopt/_src/cond.py](...) in cond(cond, if_fun, else_fun, jit, *operands)
     22     with jax.disable_jit():
     23       return jax.lax.cond(cond, if_fun, else_fun, *operands)
---> 24   return jax.lax.cond(cond, if_fun, else_fun, *operands)

TypeError: true_fun and false_fun output must have identical types, got
('DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[], weak_type=True)', ('ShapedArray(float32[30])', ('ShapedArray(float32[30,30])', 'ShapedArray(float32[1,30])', 'ShapedArray(float64[], weak_type=True)', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[], weak_type=True)'), None)).
@jewillco jewillco changed the title Precision issue in BoxOSQP Type precision issue in BoxOSQP Oct 14, 2023
@Algue-Rythme
Copy link
Collaborator

Can you try to promote the params_ineq=(-1, 1) tuple to float32 by default? Tell me how it's going.

@jewillco
Copy link
Author

I tried float32 and float64 there (using the jnp.float32(1) syntax):

float32: TypeError: true_fun and false_fun output must have identical types, got ('DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[], weak_type=True)', ('ShapedArray(float32[30])', ('ShapedArray(float32[30,30])', 'ShapedArray(float32[1,30])', 'ShapedArray(float64[], weak_type=True)', 'DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[], weak_type=True)'), None)).

float64: TypeError: body_fun output and input must have identical types, got ('DIFFERENT ShapedArray(float64[30]) vs. ShapedArray(float32[30])', 'ShapedArray(float64[30])', 'ShapedArray(float64[])', 'ShapedArray(float64[30])', 'ShapedArray(int64[], weak_type=True)').

@Algue-Rythme
Copy link
Collaborator

Would you mind sharing your minimal (not) working example in Colab? Thanks in advance.

@Algue-Rythme Algue-Rythme self-assigned this Oct 17, 2023
@jewillco
Copy link
Author

optimizer = jaxopt.BoxOSQP()
optimizer.run(
    params_obj=(
        jnp.eye(30, dtype=jnp.float32),
        jnp.ones((30,), dtype=jnp.float32),
    ),
    params_eq=jnp.ones((1, 30), dtype=jnp.float32),
    params_ineq=(jnp.float32(-1), jnp.float32(1)),
)

@Algue-Rythme
Copy link
Collaborator

You did not gave me a Colab link. So, I copy/pasted the code in Colab, add a few imports, and in Colab, it works! There are no errors... which version are you using for jax/jaxopt/python? Are you using a GPU?

@jewillco
Copy link
Author

I am using a TPU and my Colab has a large number of other things in it so I can't share it. Did you turn on float64 in JAX? That is the one thing that might be different from the snippet I posted.

@Algue-Rythme
Copy link
Collaborator

I did not turn on float64 on my initial test, check by yourself! I tested in float32 in CPU / GPU / TPU in Colab; it works.

In float64 enabled, and on a TPU, I get: XlaRuntimeError: INVALID_ARGUMENT: 64-bit data types are not yet supported on the TPU driver API. Convert inputs to float32/int32_t before using. which is the expected behavior for TPUs anyway, because they don't intent to leverage float64 arithmetic. Since you don't encounter this error, I wonder if you enabled the TPU in Colab with jax.tools.colab_tpu.setup_tpu().

The error you gave me arises when mixing float32 objects (in your call) with float64 objects that are allocated by default in BoxOSQP, on CPU (for example after failing to enable the TPU). This is also an expected behavior, because Jax policy is to prevent aggressive type promotion. However, if you force everything to be in float64, it works! Look here

my Colab has a large number of other things in it so I can't share it

Well, I am not asking for your whole work, just a minimal working example that reproduces the issue.

That is the one thing that might be different from the snippet I posted.

This is what I meant when I said "share a Colab link": it is not easy to infer what you did on your environement without details, the code you gave me was clearly unsufficient to understand what is really going on. As you can see, on Colab I can trigger different types of errors by juggling with types, environements, initialization at startup, and I consider none of these behaviors as a bug.

@jakevdp
Copy link

jakevdp commented Oct 18, 2023

Hi - JAX developer here – it looks like you're using Colab TPU; as of this writing (October 2023) Colab only provides very old TPU hardware, and is only compatible with a very old JAX version. I would not recommend running JAX on Colab TPU until this changes (but note that Colab CPU and GPU are fine). I believe this issue is fixed on more modern TPU architectures.

If you'd like to use modern TPUs in a free public notebook, I'd suggest taking a look at Kaggle, which provides more up-to-date TPU runtimes.

@Algue-Rythme
Copy link
Collaborator

Thanks for the heads up.

@jewillco: could you clarify your intent with this code? If my understanding is correct, you need:

  • a TPU for performance
  • float64 precision enabled by default for some reason (do you expect these computations to run on TPU?)
  • but you want the boxOSQP solver to run in float32 by default anyway?

@jewillco
Copy link
Author

I want #1 and #2, at least with the option to run other parts of my code in float64 on the TPU (which is semi-supported). I would like BoxOSQP to run in either float32 or float64 depending on what inputs I give it. It turns out that it does work with all float64 inputs to the solver; it still produces NaNs on my problem but that's a different issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants