-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix cond in zoom linesearch for non-jittable case, moved if_else_cond…
… in loop.py
- Loading branch information
Showing
6 changed files
with
124 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Branching utilities.""" | ||
|
||
import jax | ||
|
||
def cond(cond, if_fun, else_fun, *operands, jit=True): | ||
"""Wrapper to avoid having the condition to be compiled if not wanted.""" | ||
if not jit: | ||
with jax.disable_jit(): | ||
return jax.lax.cond(cond, if_fun, else_fun, *operands) | ||
return jax.lax.cond(cond, if_fun, else_fun, *operands) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from jaxopt._src.cond import cond |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from absl.testing import absltest | ||
from absl.testing import parameterized | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import numpy as onp | ||
|
||
from jaxopt._src.cond import cond | ||
from jaxopt._src import test_util | ||
|
||
|
||
class CondTest(test_util.JaxoptTestCase): | ||
|
||
@parameterized.product(jit=[False, True]) | ||
def test_cond(self, jit): | ||
def true_fun(x): | ||
return x | ||
def false_fun(x): | ||
return jnp.zeros_like(x) | ||
|
||
def my_relu(x): | ||
return cond(jnp.sum(x)>0, true_fun, false_fun, x, jit=jit) | ||
|
||
if jit: | ||
x = onp.array([1.]) | ||
else: | ||
x = jnp.array([1.]) | ||
self.assertEqual(jax.nn.relu(x), my_relu(x)) | ||
|
||
if __name__ == '__main__': | ||
absltest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters