Skip to content

Commit

Permalink
fix cond in zoom linesearch for non-jittable case, moved if_else_cond…
Browse files Browse the repository at this point in the history
… in loop.py
  • Loading branch information
vroulet committed Jun 28, 2023
1 parent 1572796 commit b374e1d
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 24 deletions.
8 changes: 8 additions & 0 deletions jaxopt/_src/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,11 @@ def while_loop(cond_fun, body_fun, init_val, maxiter, unroll=False, jit=False):
fun = jax.jit(fun, static_argnums=(0, 1, 3))

return fun(cond_fun, body_fun, init_val, maxiter)


def ifelse_cond(jit, cond, if_fun, else_fun, *operands):
"""Wrapper to avoid having the condition to be compiled if not wanted."""
if not jit:
with jax.disable_jit():
return jax.lax.cond(cond, if_fun, else_fun, *operands)
return jax.lax.cond(cond, if_fun, else_fun, *operands)
30 changes: 14 additions & 16 deletions jaxopt/_src/osqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from jaxopt._src import base
from jaxopt._src import implicit_diff as idf
from jaxopt._src.loop import ifelse_cond
from jaxopt.tree_util import tree_add, tree_sub, tree_mul
from jaxopt.tree_util import tree_scalar_mul, tree_add_scalar_mul
from jaxopt.tree_util import tree_map, tree_vdot
Expand Down Expand Up @@ -256,13 +257,6 @@ def lu_solve(b, lu_factors):
return sol, osqp_state.solver_state


def ifelse_cond(cond, if_fun, else_fun, operand, jit):
if not jit:
with jax.disable_jit():
return jax.lax.cond(cond, if_fun, else_fun, operand)
return jax.lax.cond(cond, if_fun, else_fun, operand)


@dataclass(eq=False)
class BoxOSQP(base.IterativeSolver):
"""Operator Splitting Solver for Quadratic Programs.
Expand Down Expand Up @@ -621,20 +615,24 @@ def update(self, params, state, params_obj, params_eq, params_ineq):
# We need our own ifelse_cond because automatic jitting of jax.lax.cond branches
# could pose problems with non jittable matvecs, or prevent printing when verbose > 0.
rho_bar, solver_state = ifelse_cond(
jnp.mod(state.iter_num, self.stepsize_updates_frequency) == 0,
lambda _: self._update_stepsize(rho_bar, solver_state, primal_residuals, dual_residuals, Q, c, A, x, y),
lambda _: (rho_bar, solver_state),
operand=None, jit=jit)
jit,
jnp.mod(state.iter_num, self.stepsize_updates_frequency) == 0,
lambda _: self._update_stepsize(rho_bar, solver_state, primal_residuals, dual_residuals, Q, c, A, x, y),
lambda _: (rho_bar, solver_state),
None
)

sol = BoxOSQP._get_full_KKT_solution(primal=(x, z), y=y)

# Same remark as above for ifelse_cond.
error, status = ifelse_cond(
jnp.mod(state.iter_num, self.termination_check_frequency) == 0,
lambda _: self._check_termination_conditions(primal_residuals, dual_residuals,
params, sol, Q, c, A, l, u),
lambda s: (state.error, s),
operand=(state.status), jit=jit)
jit,
jnp.mod(state.iter_num, self.termination_check_frequency) == 0,
lambda _: self._check_termination_conditions(primal_residuals, dual_residuals,
params, sol, Q, c, A, l, u),
lambda s: (state.error, s),
state.status
)

if not jit:
if status == BoxOSQP.PRIMAL_INFEASIBLE:
Expand Down
19 changes: 11 additions & 8 deletions jaxopt/_src/zoom_linesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import jax.numpy as jnp
from jaxopt._src import base
from jaxopt._src.base import _make_funs_with_aux
from jaxopt._src.loop import ifelse_cond
from jaxopt._src.tree_util import tree_single_dtype
from jaxopt.tree_util import tree_add_scalar_mul
from jaxopt.tree_util import tree_scalar_mul
Expand Down Expand Up @@ -714,21 +715,23 @@ def update(
del grad
del descent_direction

best_stepsize, new_state_ = lax.cond(
state.interval_found,
self._zoom_into_interval,
self._search_interval,
best_stepsize_, new_state_ = ifelse_cond(
self.jit,
state.interval_found,
self._zoom_into_interval,
self._search_interval,
init_stepsize,
state,
args,
kwargs,
kwargs
)

best_stepsize, new_state = lax.cond(
(new_state_.failed) & (new_state_.iter_num == self.maxiter),
best_stepsize, new_state = ifelse_cond(
self.jit,
(new_state_.failed) & (new_state_.iter_num == self.maxiter),
self._make_safe_step,
self._keep_step,
best_stepsize,
best_stepsize_,
new_state_,
args,
kwargs,
Expand Down
10 changes: 10 additions & 0 deletions tests/zoom_linesearch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,16 @@ def fun(x):
):
self.assertEqual(getattr(state, name).dtype, out_dtype)

def test_non_jittable(self):
def fun(x):
return -onp.sin(10 * x), -10 * onp.cos(10 * x)
x = 1.
ls = ZoomLineSearch(fun, value_and_grad=True, jit=False)
stepsize, ls_state = ls.run(init_stepsize=1.0, params=x)
self.assertTrue(True)




if __name__ == "__main__":
# Uncomment the line below in order to run in float64.
Expand Down

0 comments on commit b374e1d

Please sign in to comment.