Skip to content

Commit

Permalink
fix cond in zoom linesearch for non-jittable case, moved if_else_cond…
Browse files Browse the repository at this point in the history
… in loop.py
  • Loading branch information
vroulet committed Jun 29, 2023
1 parent 1572796 commit 308329a
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 27 deletions.
24 changes: 24 additions & 0 deletions jaxopt/_src/cond.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Branching utilities."""

import jax

def cond(cond, if_fun, else_fun, *operands, jit=True):
"""Wrapper to avoid having the condition to be compiled if not wanted."""
if not jit:
with jax.disable_jit():
return jax.lax.cond(cond, if_fun, else_fun, *operands)
return jax.lax.cond(cond, if_fun, else_fun, *operands)
38 changes: 18 additions & 20 deletions jaxopt/_src/osqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from jaxopt._src import base
from jaxopt._src import implicit_diff as idf
from jaxopt._src.cond import cond
from jaxopt.tree_util import tree_add, tree_sub, tree_mul
from jaxopt.tree_util import tree_scalar_mul, tree_add_scalar_mul
from jaxopt.tree_util import tree_map, tree_vdot
Expand Down Expand Up @@ -256,13 +257,6 @@ def lu_solve(b, lu_factors):
return sol, osqp_state.solver_state


def ifelse_cond(cond, if_fun, else_fun, operand, jit):
if not jit:
with jax.disable_jit():
return jax.lax.cond(cond, if_fun, else_fun, operand)
return jax.lax.cond(cond, if_fun, else_fun, operand)


@dataclass(eq=False)
class BoxOSQP(base.IterativeSolver):
"""Operator Splitting Solver for Quadratic Programs.
Expand Down Expand Up @@ -618,23 +612,27 @@ def update(self, params, state, params_obj, params_eq, params_ineq):
if self.verbose >= 3:
print(f"primal_residuals={primal_residuals}, dual_residuals={dual_residuals}")

# We need our own ifelse_cond because automatic jitting of jax.lax.cond branches
# We need our own ifelse cond because automatic jitting of jax.lax.cond branches
# could pose problems with non jittable matvecs, or prevent printing when verbose > 0.
rho_bar, solver_state = ifelse_cond(
jnp.mod(state.iter_num, self.stepsize_updates_frequency) == 0,
lambda _: self._update_stepsize(rho_bar, solver_state, primal_residuals, dual_residuals, Q, c, A, x, y),
lambda _: (rho_bar, solver_state),
operand=None, jit=jit)
rho_bar, solver_state = cond(
jnp.mod(state.iter_num, self.stepsize_updates_frequency) == 0,
lambda _: self._update_stepsize(rho_bar, solver_state, primal_residuals, dual_residuals, Q, c, A, x, y),
lambda _: (rho_bar, solver_state),
None,
jit=jit
)

sol = BoxOSQP._get_full_KKT_solution(primal=(x, z), y=y)

# Same remark as above for ifelse_cond.
error, status = ifelse_cond(
jnp.mod(state.iter_num, self.termination_check_frequency) == 0,
lambda _: self._check_termination_conditions(primal_residuals, dual_residuals,
params, sol, Q, c, A, l, u),
lambda s: (state.error, s),
operand=(state.status), jit=jit)
# Same remark as above for ifelse cond.
error, status = cond(
jnp.mod(state.iter_num, self.termination_check_frequency) == 0,
lambda _: self._check_termination_conditions(primal_residuals, dual_residuals,
params, sol, Q, c, A, l, u),
lambda s: (state.error, s),
state.status,
jit=jit
)

if not jit:
if status == BoxOSQP.PRIMAL_INFEASIBLE:
Expand Down
19 changes: 12 additions & 7 deletions jaxopt/_src/zoom_linesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import jax.numpy as jnp
from jaxopt._src import base
from jaxopt._src.base import _make_funs_with_aux
from jaxopt._src.cond import cond
from jaxopt._src.tree_util import tree_single_dtype
from jaxopt.tree_util import tree_add_scalar_mul
from jaxopt.tree_util import tree_scalar_mul
Expand Down Expand Up @@ -713,25 +714,29 @@ def update(
del value
del grad
del descent_direction

jit, _ = self._get_loop_options()

best_stepsize, new_state_ = lax.cond(
state.interval_found,
self._zoom_into_interval,
self._search_interval,
best_stepsize_, new_state_ = cond(
state.interval_found,
self._zoom_into_interval,
self._search_interval,
init_stepsize,
state,
args,
kwargs,
jit=jit
)

best_stepsize, new_state = lax.cond(
(new_state_.failed) & (new_state_.iter_num == self.maxiter),
best_stepsize, new_state = cond(
(new_state_.failed) & (new_state_.iter_num == self.maxiter),
self._make_safe_step,
self._keep_step,
best_stepsize,
best_stepsize_,
new_state_,
args,
kwargs,
jit=jit
)

if self.verbose:
Expand Down
15 changes: 15 additions & 0 deletions jaxopt/cond.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from jaxopt._src.cond import cond
45 changes: 45 additions & 0 deletions tests/cond_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import parameterized

import jax
import jax.numpy as jnp
import numpy as onp

from jaxopt._src.cond import cond
from jaxopt._src import test_util


class CondTest(test_util.JaxoptTestCase):

@parameterized.product(jit=[False, True])
def test_cond(self, jit):
def true_fun(x):
return x
def false_fun(x):
return jnp.zeros_like(x)

def my_relu(x):
return cond(jnp.sum(x)>0, true_fun, false_fun, x, jit=jit)

if jit:
x = onp.array([1.])
else:
x = jnp.array([1.])
self.assertEqual(jax.nn.relu(x), my_relu(x))

if __name__ == '__main__':
absltest.main()
10 changes: 10 additions & 0 deletions tests/zoom_linesearch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,16 @@ def fun(x):
):
self.assertEqual(getattr(state, name).dtype, out_dtype)

@parameterized.product(jit=[False, True])
def test_non_jittable(self, jit):
def fun(x):
return -onp.sin(10 * x), -10 * onp.cos(10 * x)
x = 1.
def run_ls():
ls = ZoomLineSearch(fun, value_and_grad=True, jit=jit)
ls.run(init_stepsize=1.0, params=x)
if jit:
self.assertRaises(jax.errors.TracerArrayConversionError, run_ls)

if __name__ == "__main__":
# Uncomment the line below in order to run in float64.
Expand Down

0 comments on commit 308329a

Please sign in to comment.