From cb4a473d8b9eda01b71968e92dd66f3edc93f949 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 12 Jan 2024 06:05:15 -0800 Subject: [PATCH] Avoid scalar conversion of non-scalar arrays Scalar conversion of non-scalar arrays has been deprecated and has raised a warning since JAX v0.4.16, and will soon result in an error (see https://github.com/google/jax/pull/19181). PiperOrigin-RevId: 597821246 --- optax/_src/wrappers_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/optax/_src/wrappers_test.py b/optax/_src/wrappers_test.py index d55fe389..c8fee120 100644 --- a/optax/_src/wrappers_test.py +++ b/optax/_src/wrappers_test.py @@ -175,7 +175,7 @@ def fn_update(params, opt_state, x): for step in range(2): params, opt_state = fn_update(params, opt_state, two) self.assertFalse(bool(opt_state.last_finite)) - self.assertEqual(step + 1, int(opt_state.notfinite_count)) + self.assertEqual(step + 1, opt_state.notfinite_count.item()) # Next successful param update params, opt_state = fn_update(params, opt_state, half) self.assertTrue(bool(opt_state.last_finite)) @@ -183,10 +183,10 @@ def fn_update(params, opt_state, x): for step in range(2): params, opt_state = fn_update(params, opt_state, two) self.assertFalse(bool(opt_state.last_finite)) - self.assertEqual(step + 1, int(opt_state.notfinite_count)) + self.assertEqual(step + 1, opt_state.notfinite_count.item()) # Next param update with NaN is accepted since we reached maximum _, opt_state = fn_update(params, opt_state, two) - self.assertEqual(5, int(opt_state.total_notfinite)) + self.assertEqual(5, opt_state.total_notfinite.item()) @chex.variants(with_jit=True, without_jit=True, with_pmap=True) def test_multi_steps(self):