Skip to content

Commit

Permalink
Avoid scalar conversion of non-scalar arrays
Browse files Browse the repository at this point in the history
Scalar conversion of non-scalar arrays has been deprecated and has raised a warning since JAX v0.4.16, and will soon result in an error (see google/jax#19181).

PiperOrigin-RevId: 595815388
  • Loading branch information
Jake VanderPlas authored and OptaxDev committed Jan 4, 2024
1 parent bc22961 commit f12cd49
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions optax/_src/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,18 +175,18 @@ def fn_update(params, opt_state, x):
for step in range(2):
params, opt_state = fn_update(params, opt_state, two)
self.assertFalse(bool(opt_state.last_finite))
self.assertEqual(step + 1, int(opt_state.notfinite_count))
self.assertEqual(step + 1, opt_state.notfinite_count.item())
# Next successful param update
params, opt_state = fn_update(params, opt_state, half)
self.assertTrue(bool(opt_state.last_finite))
# Again 2 rejected param updates
for step in range(2):
params, opt_state = fn_update(params, opt_state, two)
self.assertFalse(bool(opt_state.last_finite))
self.assertEqual(step + 1, int(opt_state.notfinite_count))
self.assertEqual(step + 1, opt_state.notfinite_count.item())
# Next param update with NaN is accepted since we reached maximum
_, opt_state = fn_update(params, opt_state, two)
self.assertEqual(5, int(opt_state.total_notfinite))
self.assertEqual(5, opt_state.total_notfinite.item())

@chex.variants(with_jit=True, without_jit=True, with_pmap=True)
def test_multi_steps(self):
Expand Down

0 comments on commit f12cd49

Please sign in to comment.