Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infinities and NaNs in quadratic_prog when c=0 #95

Closed
FerranAlet opened this issue Nov 13, 2021 · 6 comments
Closed

Infinities and NaNs in quadratic_prog when c=0 #95

FerranAlet opened this issue Nov 13, 2021 · 6 comments

Comments

@FerranAlet
Copy link

FerranAlet commented Nov 13, 2021

Hi,

I'm using QuadraticProgramming in the special case of c=0 (all zeros as a vector). AFAIK this is still well-defined, as it's just minimizing l2 norm squared of the primal subject to some equality constraints (I don't have inequalities).

However, both my research code and the following modification of this test diverge even for a single step (maxiter=1).

The modification just involves setting c=0, so:

def test_qp_eq_only_c_zero(self):
  Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
  c = jnp.array([0.0, 0.0]) #ONLY CHANGE
  A = jnp.array([[1.0, 1.0]])
  b = jnp.array([1.0])
  qp = QuadraticProgramming(tol=1e-7)
  hyperparams = dict(params_obj=(Q, c), params_eq=(A, b))
  sol = qp.run(**hyperparams).params
  self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
  self._check_derivative_A_and_b(qp, hyperparams, A, b)

Is there a way to fix it? If it involves calling another linear solver, is there a way to specify the solver from the high-level QP function? I haven't seen it.

Thanks!

@FerranAlet
Copy link
Author

In case it helps, for my research code I've manually implemented the same call that QP would've done except using solve_normal_cg and it does solve it without diverging for c=0. It also gives a solution that is similar to that of gmres when setting c small but non-zero. For gmres, letting c go to zero makes it eventually diverge.

@Algue-Rythme
Copy link
Collaborator

Hi Ferran

That's weird, I tried to reproduce your bug and did not succeed: https://colab.research.google.com/drive/1-IS1MIkkXfVuON5IhAT2gxt2pw-5IVz8?usp=sharing

Can you give more details on the versions of Python/Jax/Jaxopt you are using ? Are you working on CPU or GPU ? In the default float32 of Jax or float64 ? Is it wrapped in a jitted block or in "eager" mode ?

If you have a notebook in which you can consistently reproduce the bug that would be great.

Is it diverging for primal/dual computation in run(), or during implicit differentiation ?

@FerranAlet
Copy link
Author

FerranAlet commented Nov 14, 2021

My bad, it was because I had the jaxopt-0.0.1 version; must have been fixed since then!
Just for completeness, it was during the primal/dual computation. Changing the inner solver to solve_normal_cg worked. With jaxopt-0.1.1 gmres (QP's vanilla solver) gives the same result as solve_normal_cg.

Thanks!

@mblondel
Copy link
Collaborator

Great that the problem disappeared. Not sure what fixed it, maybe ac1bdcd. See https://github.com/google/jaxopt/commits/main/jaxopt/_src/quadratic_prog.py for the history of this file.

@FerranAlet
Copy link
Author

FerranAlet commented Nov 17, 2021

@mblondel @Algue-Rythme during my research I just ran into more NaNs on the default QP solver, so it may not have been fully solved in 0.1.1. It happens roughly for ~5% of the examples and I haven't found any pattern on when. On the other 95% it matches (modulo small precision errors) the result of using solve_normal_cg as the inner solver.

The failure case is quite entangled with my research code so I can't send it atm. If I find a failure pattern I'll design a minimal example and send it to you.

@mblondel
Copy link
Collaborator

My guess is that the issue is in gmres itself, as our code doesn't seem to contain operations that could result in NaNs.

In #98, @Algue-Rythme is working on a new class EqualityConstrainedQP specific for this case and it will expose the solver, which will allow you to switch to solve_normal_cg.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants