diff --git a/optax/_src/wrappers_test.py b/optax/_src/wrappers_test.py index d55fe389..c8fee120 100644 --- a/optax/_src/wrappers_test.py +++ b/optax/_src/wrappers_test.py @@ -175,7 +175,7 @@ def fn_update(params, opt_state, x): for step in range(2): params, opt_state = fn_update(params, opt_state, two) self.assertFalse(bool(opt_state.last_finite)) - self.assertEqual(step + 1, int(opt_state.notfinite_count)) + self.assertEqual(step + 1, opt_state.notfinite_count.item()) # Next successful param update params, opt_state = fn_update(params, opt_state, half) self.assertTrue(bool(opt_state.last_finite)) @@ -183,10 +183,10 @@ def fn_update(params, opt_state, x): for step in range(2): params, opt_state = fn_update(params, opt_state, two) self.assertFalse(bool(opt_state.last_finite)) - self.assertEqual(step + 1, int(opt_state.notfinite_count)) + self.assertEqual(step + 1, opt_state.notfinite_count.item()) # Next param update with NaN is accepted since we reached maximum _, opt_state = fn_update(params, opt_state, two) - self.assertEqual(5, int(opt_state.total_notfinite)) + self.assertEqual(5, opt_state.total_notfinite.item()) @chex.variants(with_jit=True, without_jit=True, with_pmap=True) def test_multi_steps(self):