Skip to content

Commit

Permalink
Ignore some linesearch tests on gpu/tpu
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 643364817
  • Loading branch information
vroulet authored and OptaxDev committed Jun 18, 2024
1 parent 6a9808c commit 0e303a4
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions optax/_src/linesearch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def test_failure_descent_direction(self):
# @absltest.skipIf(jax.devices()[0] == 'tpu', reason = ...))
# because jax arrays cannot be manipulated from the top level of the python
# program
if jax.default_backend() == 'tpu':
if jax.default_backend() in ['tpu', 'gpu']:
return
else:
# For this f and p, starting at a point on axis 0, the strong Wolfe
Expand Down Expand Up @@ -523,7 +523,7 @@ def test_failure_too_small_max_stepsize(self):
# @absltest.skipIf(jax.devices()[0] == 'tpu', reason = ...))
# because jax arrays cannot be manipulated from the top level of the python
# program
if jax.default_backend() == 'tpu':
if jax.default_backend() in ['tpu', 'gpu']:
return
else:
def fn(x):
Expand Down Expand Up @@ -558,7 +558,7 @@ def test_failure_not_enough_iter(self):
# @absltest.skipIf(jax.devices()[0] == 'tpu', reason = ...))
# because jax arrays cannot be manipulated from the top level of the python
# program
if jax.default_backend() == 'tpu':
if jax.default_backend() in ['tpu', 'gpu']:
return
else:
def fn(x):
Expand Down Expand Up @@ -608,7 +608,7 @@ def test_failure_flat_fun(self):
# @absltest.skipIf(jax.devices()[0] == 'tpu', reason = ...))
# because jax arrays cannot be manipulated from the top level of the python
# program
if jax.default_backend() == 'tpu':
if jax.default_backend() in ['tpu', 'gpu']:
return
else:
def fun_flat(x):
Expand All @@ -631,7 +631,7 @@ def test_failure_inf_value(self):
# @absltest.skipIf(jax.devices()[0] == 'tpu', reason = ...))
# because jax arrays cannot be manipulated from the top level of the python
# program
if jax.default_backend() == 'tpu':
if jax.default_backend() in ['tpu', 'gpu']:
return
else:
def fun_inf(x):
Expand Down

0 comments on commit 0e303a4

Please sign in to comment.