From ceebb9c41e077d7aa68e3f3f3d66893d07dcf336 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 27 Mar 2023 05:41:12 -0700 Subject: [PATCH 1/4] Fix pytype failures related to teaching pytype about NumPy scalar types. PiperOrigin-RevId: 519695398 --- examples/lookahead_mnist.py | 2 +- examples/mnist.py | 4 ++-- optax/_src/alias_test.py | 6 +++--- optax/_src/base_test.py | 8 ++++---- optax/_src/combine.py | 4 ++-- optax/_src/control_variates.py | 4 ++-- optax/_src/lookahead.py | 2 +- optax/_src/state_utils.py | 4 ++-- optax/_src/state_utils_test.py | 16 ++++++++-------- optax/_src/transform_test.py | 8 ++++---- optax/_src/wrappers_test.py | 28 ++++++++++++++-------------- 11 files changed, 43 insertions(+), 43 deletions(-) diff --git a/examples/lookahead_mnist.py b/examples/lookahead_mnist.py index fce49efa..cb07ace9 100644 --- a/examples/lookahead_mnist.py +++ b/examples/lookahead_mnist.py @@ -80,7 +80,7 @@ def train_step(params, optimizer_state, batch): test_dataset.as_numpy_iterator()) print(f'Epoch {epoch+1}: test acc: {test_acc:.2f}') - return test_acc + return test_acc # pytype: disable=bad-return-type # numpy-scalars if __name__ == '__main__': diff --git a/examples/mnist.py b/examples/mnist.py index d79f97af..b006cca6 100644 --- a/examples/mnist.py +++ b/examples/mnist.py @@ -54,7 +54,7 @@ def model_accuracy(model: Callable[[chex.Array], chex.Array], accuracy_sum += _single_batch_accuracy(logits, batch['label']) * batch_size dataset_size += batch_size - return accuracy_sum / dataset_size + return accuracy_sum / dataset_size # pytype: disable=bad-return-type # numpy-scalars # pylint: disable=line-too-long # Optax is agnostic to which (if any) neural network library is used. Below we @@ -111,7 +111,7 @@ def train_step(params, optimizer_state, batch): test_acc = model_accuracy(eval_model, test_dataset.as_numpy_iterator()) print(f'Epoch {epoch+1}: test acc: {test_acc:.2f}') - return test_acc + return test_acc # pytype: disable=bad-return-type # numpy-scalars def main(unused_argv): diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index 46a5729b..da6f4f6a 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -188,13 +188,13 @@ def test_explicit_dtype(self, dtype): expected_dtype = jax.dtypes.canonicalize_dtype(dtype) # None -> float32 tx = alias.sgd(0.1, momentum=0.9, accumulator_dtype=dtype) trace_state, _ = tx.init(jnp.array([0.0, 0.0])) - self.assertEqual(expected_dtype, trace_state.trace.dtype) + self.assertEqual(expected_dtype, trace_state.trace.dtype) # pytype: disable=attribute-error # numpy-scalars tx = alias.adam(0.1, mu_dtype=dtype) adam_state, _ = tx.init(jnp.array([0.0, 0.0])) - self.assertEqual(expected_dtype, adam_state.mu.dtype) + self.assertEqual(expected_dtype, adam_state.mu.dtype) # pytype: disable=attribute-error # numpy-scalars tx = alias.adamw(0.1, mu_dtype=dtype) adam_state, _, _ = tx.init(jnp.array([0.0, 0.0])) - self.assertEqual(expected_dtype, adam_state.mu.dtype) + self.assertEqual(expected_dtype, adam_state.mu.dtype) # pytype: disable=attribute-error # numpy-scalars if __name__ == '__main__': diff --git a/optax/_src/base_test.py b/optax/_src/base_test.py index 651e481d..4461f60c 100644 --- a/optax/_src/base_test.py +++ b/optax/_src/base_test.py @@ -148,7 +148,7 @@ def test_stateless_no_params(self): def opt(g, _): return jax.tree_util.tree_map(lambda g_: g_ * 2, g) - state = opt.init(None) + state = opt.init(None) # pytype: disable=wrong-arg-types # numpy-scalars update_fn = self.variant(opt.update) new_updates, _ = update_fn(updates, state) expected_updates = {'linear': jnp.full((5, 3), 6.0)} @@ -159,7 +159,7 @@ def weight_decay(g, p): return jax.tree_util.tree_map(lambda g_, p_: g_ + 0.1 * p_, g, p) opt = base.stateless(weight_decay) - state = opt.init(None) + state = opt.init(None) # pytype: disable=wrong-arg-types # numpy-scalars self.assertIsInstance(state, base.EmptyState) @@ -183,7 +183,7 @@ def test_stateless_with_tree_map_no_params(self): updates = {'linear': jnp.full((5, 3), 3.0)} opt = base.stateless_with_tree_map(lambda g, _: g * 2.0) - state = opt.init(None) + state = opt.init(None) # pytype: disable=wrong-arg-types # numpy-scalars update_fn = self.variant(opt.update) new_updates, _ = update_fn(updates, state) expected_updates = {'linear': jnp.full((5, 3), 6.0)} @@ -191,7 +191,7 @@ def test_stateless_with_tree_map_no_params(self): def test_init_returns_emptystate(self): opt = base.stateless_with_tree_map(lambda g, p: g + 0.1 * p) - state = opt.init(None) + state = opt.init(None) # pytype: disable=wrong-arg-types # numpy-scalars self.assertIsInstance(state, base.EmptyState) diff --git a/optax/_src/combine.py b/optax/_src/combine.py index 40ab061f..f696850a 100644 --- a/optax/_src/combine.py +++ b/optax/_src/combine.py @@ -146,7 +146,7 @@ def init_fn(params): group: wrappers.masked(tx, make_mask(labels, group)).init(params) for group, tx in transforms.items() } - return MultiTransformState(inner_states) + return MultiTransformState(inner_states) # pytype: disable=wrong-arg-types # numpy-scalars # pylint: disable=line-too-long def update_fn(updates, state, params=None): labels = param_labels(updates) if callable(param_labels) else param_labels @@ -155,6 +155,6 @@ def update_fn(updates, state, params=None): masked_tx = wrappers.masked(tx, make_mask(labels, group)) updates, new_inner_state[group] = masked_tx.update( updates, state.inner_states[group], params) - return updates, MultiTransformState(new_inner_state) + return updates, MultiTransformState(new_inner_state) # pytype: disable=wrong-arg-types # numpy-scalars # pylint: disable=line-too-long return base.GradientTransformation(init_fn, update_fn) diff --git a/optax/_src/control_variates.py b/optax/_src/control_variates.py index 67ca124f..fc836745 100644 --- a/optax/_src/control_variates.py +++ b/optax/_src/control_variates.py @@ -271,7 +271,7 @@ def control_variates_jacobians( """ control_variate = control_variate_from_function(function) stochastic_cv, expected_value_cv, update_state_cv = control_variate - data_dim = params[0].shape[0] + data_dim = params[0].shape[0] # pytype: disable=attribute-error # numpy-scalars # pylint: disable=line-too-long if estimate_cv_coeffs: cv_coeffs = estimate_control_variate_coefficients( function, control_variate_from_function, grad_estimator, params, @@ -330,7 +330,7 @@ def param_fn(x): # \nabla_{\theta} E_{p(x; \theta)}] param_jacobians += cv_coeff * expected_value_grads[param_index] - chex.assert_shape(param_jacobians, (num_samples,) + param.shape) + chex.assert_shape(param_jacobians, (num_samples,) + param.shape) # pytype: disable=attribute-error # numpy-scalars # pylint: disable=line-too-long jacobians.append(param_jacobians) return jacobians, control_variate_state diff --git a/optax/_src/lookahead.py b/optax/_src/lookahead.py index bf33e134..2886de88 100644 --- a/optax/_src/lookahead.py +++ b/optax/_src/lookahead.py @@ -97,7 +97,7 @@ def lookahead( def init_fn(params: base.Params) -> LookaheadState: try: - fast_params = params.fast + fast_params = params.fast # pytype: disable=attribute-error # numpy-scalars # pylint: disable=line-too-long except AttributeError: # Allowing init_fn to be called with fast parameters reduces the # modifications necessary to adapt code to use lookahead in some cases. diff --git a/optax/_src/state_utils.py b/optax/_src/state_utils.py index 4c4c6acf..8d34422f 100644 --- a/optax/_src/state_utils.py +++ b/optax/_src/state_utils.py @@ -79,9 +79,9 @@ def tree_map_params( """ if isinstance(initable, Initable): - state_with_placeholders = initable.init(_ParamsPlaceholder()) + state_with_placeholders = initable.init(_ParamsPlaceholder()) # type: ignore # numpy-scalars # pylint: disable=line-too-long else: - state_with_placeholders = initable(_ParamsPlaceholder()) + state_with_placeholders = initable(_ParamsPlaceholder()) # pytype: disable=wrong-arg-types # numpy-scalars # pylint: disable=line-too-long def map_params(maybe_placeholder_value, value): if isinstance(maybe_placeholder_value, _ParamsPlaceholder): diff --git a/optax/_src/state_utils_test.py b/optax/_src/state_utils_test.py index a7476398..27b3a422 100644 --- a/optax/_src/state_utils_test.py +++ b/optax/_src/state_utils_test.py @@ -50,8 +50,8 @@ def init(params): state = t.init(params) return ScaleByAdamStateDict( - count=state.count, - params={'mu': state.mu, 'nu': state.nu}, + count=state.count, # pytype: disable=attribute-error # numpy-scalars + params={'mu': state.mu, 'nu': state.nu}, # pytype: disable=attribute-error # numpy-scalars ) def update(updates, state, params=None): @@ -63,8 +63,8 @@ def update(updates, state, params=None): updates, state = t.update(updates, state, params) return ScaleByAdamStateDict( - count=state.count, - params={'mu': state.mu, 'nu': state.nu}, + count=state.count, # pytype: disable=attribute-error # numpy-scalars + params={'mu': state.mu, 'nu': state.nu}, # pytype: disable=attribute-error # numpy-scalars ) return base.GradientTransformation(init, update) @@ -139,8 +139,8 @@ def init(params): state = init(params) state = state_utils.tree_map_params(init, lambda v: v+1, state) - self.assertEqual(state.count, 0) - self.assertEqual(state.v, {'w': 1}) + self.assertEqual(state.count, 0) # pytype: disable=attribute-error # numpy-scalars + self.assertEqual(state.v, {'w': 1}) # pytype: disable=attribute-error # numpy-scalars def test_adam(self): params = _fake_params() @@ -193,9 +193,9 @@ def test_inject_hparams(self): state = opt.init(params) state = state_utils.tree_map_params(opt, lambda v: v+1, state) - self.assertEqual(1e-3, state.hyperparams['learning_rate']) + self.assertEqual(1e-3, state.hyperparams['learning_rate']) # pytype: disable=attribute-error # numpy-scalars params_plus_one = jax.tree_map(lambda v: v+1, params) - chex.assert_trees_all_close(state.inner_state[0].mu, params_plus_one) + chex.assert_trees_all_close(state.inner_state[0].mu, params_plus_one) # pytype: disable=attribute-error # numpy-scalars def test_map_params_to_none(self): opt = alias.adagrad(1e-4) diff --git a/optax/_src/transform_test.py b/optax/_src/transform_test.py index e13bb1bc..b9b6b085 100644 --- a/optax/_src/transform_test.py +++ b/optax/_src/transform_test.py @@ -221,7 +221,7 @@ def test_scale(self): factor = 0.1 ** i rescaler = transform.scale(factor) # Apply rescaling. - scaled_updates, _ = rescaler.update(updates, None) + scaled_updates, _ = rescaler.update(updates, None) # pytype: disable=wrong-arg-types # numpy-scalars # Manually scale updates. def rescale(t): return t * factor # pylint:disable=cell-var-from-loop @@ -240,7 +240,7 @@ def test_centralize(self, inputs, outputs): inputs = jnp.asarray(inputs) outputs = jnp.asarray(outputs) centralizer = transform.centralize() - centralized_inputs, _ = centralizer.update(inputs, None) + centralized_inputs, _ = centralizer.update(inputs, None) # pytype: disable=wrong-arg-types # numpy-scalars chex.assert_trees_all_close(centralized_inputs, outputs) @chex.all_variants @@ -282,10 +282,10 @@ def f(params: jnp.ndarray) -> jnp.ndarray: og = transform.scale_by_optimistic_gradient() og_state = og.init(initial_params) # Provide some arbitrary previous gradient. - og_state.trace['x'] = 1.5 + og_state.trace['x'] = 1.5 # type: ignore # numpy-scalars g = jax.grad(f)(initial_params) - og_true = 2 * g['x'] - og_state.trace['x'] + og_true = 2 * g['x'] - og_state.trace['x'] # pytype: disable=attribute-error # numpy-scalars og, og_state = og.update(g, og_state) # Compare transformation output with manually computed optimistic gradient. diff --git a/optax/_src/wrappers_test.py b/optax/_src/wrappers_test.py index b9d8b104..dc05d33d 100644 --- a/optax/_src/wrappers_test.py +++ b/optax/_src/wrappers_test.py @@ -111,35 +111,35 @@ def fn(x): # We know exactly what should be the value of params since we are # effectively using sgd in all cases. self.assertEqual(-1., float(jax.tree_util.tree_flatten(params)[0][0])) - self.assertTrue(bool(state.last_finite)) + self.assertTrue(bool(state.last_finite)) # pytype: disable=attribute-error # numpy-scalars # Check 2 rejected param updates for step in range(2): grads = grads_fn(params, nan) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) self.assertEqual(-1., float(jax.tree_util.tree_flatten(params)[0][0])) - self.assertFalse(bool(state.last_finite)) - self.assertEqual(step + 1, int(state.notfinite_count)) + self.assertFalse(bool(state.last_finite)) # pytype: disable=attribute-error # numpy-scalars + self.assertEqual(step + 1, int(state.notfinite_count)) # pytype: disable=attribute-error # numpy-scalars # Next successful param update grads = grads_fn(params, one) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) self.assertEqual(-2., float(jax.tree_util.tree_flatten(params)[0][0])) - self.assertTrue(bool(state.last_finite)) + self.assertTrue(bool(state.last_finite)) # pytype: disable=attribute-error # numpy-scalars # Again 2 rejected param updates for step in range(2): grads = grads_fn(params, nan) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) self.assertEqual(-2., float(jax.tree_util.tree_flatten(params)[0][0])) - self.assertFalse(bool(state.last_finite)) - self.assertEqual(step + 1, int(state.notfinite_count)) + self.assertFalse(bool(state.last_finite)) # pytype: disable=attribute-error # numpy-scalars + self.assertEqual(step + 1, int(state.notfinite_count)) # pytype: disable=attribute-error # numpy-scalars # Next param update with NaN is accepted since we reached maximum grads = grads_fn(params, nan) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) self.assertTrue(bool(jnp.isnan(jax.tree_util.tree_flatten(params)[0][0]))) - self.assertEqual(5, int(state.total_notfinite)) + self.assertEqual(5, int(state.total_notfinite)) # pytype: disable=attribute-error # numpy-scalars def test_apply_if_finite_pmap(self): # Unlike in `test_apply_if_finite`: @@ -247,18 +247,18 @@ def test_multi_steps_every_k_schedule(self): params = dict(a=jnp.zeros([])) opt_state = opt_init(params) grad = dict(a=jnp.zeros([])) - self.assertFalse(ms_opt.has_updated(opt_state)) + self.assertFalse(ms_opt.has_updated(opt_state)) # pytype: disable=wrong-arg-types # numpy-scalars # First two steps have 1 mini-step per update. for _ in range(2): _, opt_state = opt_update(grad, opt_state, params) - self.assertTrue(ms_opt.has_updated(opt_state)) + self.assertTrue(ms_opt.has_updated(opt_state)) # pytype: disable=wrong-arg-types # numpy-scalars # Subsequently, mini-steps should have 3 mini-steps per update. for _ in range(5): for _ in range(2): _, opt_state = opt_update(grad, opt_state, params) - self.assertFalse(ms_opt.has_updated(opt_state)) + self.assertFalse(ms_opt.has_updated(opt_state)) # pytype: disable=wrong-arg-types # numpy-scalars _, opt_state = opt_update(grad, opt_state, params) - self.assertTrue(ms_opt.has_updated(opt_state)) + self.assertTrue(ms_opt.has_updated(opt_state)) # pytype: disable=wrong-arg-types # numpy-scalars def test_multi_steps_computes_mean(self): k_steps = 4 @@ -268,16 +268,16 @@ def test_multi_steps_computes_mean(self): params = dict(a=jnp.zeros([])) opt_state = opt_init(params) grads = [dict(a=jnp.ones([]) * i) for i in [1, 2, 3, 4]] - self.assertFalse(ms_opt.has_updated(opt_state)) + self.assertFalse(ms_opt.has_updated(opt_state)) # pytype: disable=wrong-arg-types # numpy-scalars # First 3 steps don't update. for grad in grads[:-1]: _, opt_state = opt_update(grad, opt_state, params) - self.assertFalse(ms_opt.has_updated(opt_state)) + self.assertFalse(ms_opt.has_updated(opt_state)) # pytype: disable=wrong-arg-types # numpy-scalars # Actual update. new_params, opt_state = opt_update(grads[-1], opt_state, params) - self.assertTrue(ms_opt.has_updated(opt_state)) + self.assertTrue(ms_opt.has_updated(opt_state)) # pytype: disable=wrong-arg-types # numpy-scalars np.testing.assert_array_equal(new_params['a'], 2.5) def test_skip_not_finite(self): From e188350323c624132756589da86dc6fa4a4e8112 Mon Sep 17 00:00:00 2001 From: Iurii Kemaev Date: Thu, 20 Apr 2023 07:16:08 -0700 Subject: [PATCH 2/4] Resolve pytype errors. PiperOrigin-RevId: 525737739 --- optax/_src/alias_test.py | 6 +++--- optax/_src/combine.py | 6 +++--- optax/_src/control_variates.py | 6 +++--- optax/_src/lookahead.py | 5 ++--- optax/_src/state_utils.py | 10 +++++++--- optax/_src/state_utils_test.py | 25 +++++++++++++++---------- optax/_src/transform_test.py | 8 ++++---- optax/_src/wrappers.py | 7 +++++-- optax/_src/wrappers_test.py | 28 ++++++++++++++-------------- 9 files changed, 56 insertions(+), 45 deletions(-) diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index da6f4f6a..720c14ab 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -188,13 +188,13 @@ def test_explicit_dtype(self, dtype): expected_dtype = jax.dtypes.canonicalize_dtype(dtype) # None -> float32 tx = alias.sgd(0.1, momentum=0.9, accumulator_dtype=dtype) trace_state, _ = tx.init(jnp.array([0.0, 0.0])) - self.assertEqual(expected_dtype, trace_state.trace.dtype) # pytype: disable=attribute-error # numpy-scalars + self.assertEqual(expected_dtype, getattr(trace_state, 'trace').dtype) tx = alias.adam(0.1, mu_dtype=dtype) adam_state, _ = tx.init(jnp.array([0.0, 0.0])) - self.assertEqual(expected_dtype, adam_state.mu.dtype) # pytype: disable=attribute-error # numpy-scalars + self.assertEqual(expected_dtype, getattr(adam_state, 'mu').dtype) tx = alias.adamw(0.1, mu_dtype=dtype) adam_state, _, _ = tx.init(jnp.array([0.0, 0.0])) - self.assertEqual(expected_dtype, adam_state.mu.dtype) # pytype: disable=attribute-error # numpy-scalars + self.assertEqual(expected_dtype, getattr(adam_state, 'mu').dtype) if __name__ == '__main__': diff --git a/optax/_src/combine.py b/optax/_src/combine.py index f696850a..00f81ea7 100644 --- a/optax/_src/combine.py +++ b/optax/_src/combine.py @@ -69,7 +69,7 @@ def update_fn(updates, state, params=None, **extra_args): class MultiTransformState(NamedTuple): - inner_states: Mapping[Hashable, NamedTuple] + inner_states: Mapping[Hashable, base.OptState] def multi_transform( @@ -146,7 +146,7 @@ def init_fn(params): group: wrappers.masked(tx, make_mask(labels, group)).init(params) for group, tx in transforms.items() } - return MultiTransformState(inner_states) # pytype: disable=wrong-arg-types # numpy-scalars # pylint: disable=line-too-long + return MultiTransformState(inner_states) def update_fn(updates, state, params=None): labels = param_labels(updates) if callable(param_labels) else param_labels @@ -155,6 +155,6 @@ def update_fn(updates, state, params=None): masked_tx = wrappers.masked(tx, make_mask(labels, group)) updates, new_inner_state[group] = masked_tx.update( updates, state.inner_states[group], params) - return updates, MultiTransformState(new_inner_state) # pytype: disable=wrong-arg-types # numpy-scalars # pylint: disable=line-too-long + return updates, MultiTransformState(new_inner_state) return base.GradientTransformation(init_fn, update_fn) diff --git a/optax/_src/control_variates.py b/optax/_src/control_variates.py index fc836745..bf7ac2b4 100644 --- a/optax/_src/control_variates.py +++ b/optax/_src/control_variates.py @@ -271,7 +271,7 @@ def control_variates_jacobians( """ control_variate = control_variate_from_function(function) stochastic_cv, expected_value_cv, update_state_cv = control_variate - data_dim = params[0].shape[0] # pytype: disable=attribute-error # numpy-scalars # pylint: disable=line-too-long + data_dim = jax.tree_util.tree_leaves(params)[0].shape[0] if estimate_cv_coeffs: cv_coeffs = estimate_control_variate_coefficients( function, control_variate_from_function, grad_estimator, params, @@ -315,7 +315,7 @@ def param_fn(x): lambda x: expected_value_cv(x, control_variate_state))(params) jacobians = [] - for param_index, param in enumerate(params): + for param_index, param in enumerate(jax.tree_util.tree_leaves(params)): chex.assert_shape(function_jacobians[param_index], (num_samples, data_dim)) chex.assert_shape(cv_jacobians[param_index], (num_samples, data_dim)) chex.assert_shape(cv_param_grads[param_index], (data_dim,)) @@ -330,7 +330,7 @@ def param_fn(x): # \nabla_{\theta} E_{p(x; \theta)}] param_jacobians += cv_coeff * expected_value_grads[param_index] - chex.assert_shape(param_jacobians, (num_samples,) + param.shape) # pytype: disable=attribute-error # numpy-scalars # pylint: disable=line-too-long + chex.assert_shape(param_jacobians, (num_samples,) + param.shape) jacobians.append(param_jacobians) return jacobians, control_variate_state diff --git a/optax/_src/lookahead.py b/optax/_src/lookahead.py index 2886de88..ae77b976 100644 --- a/optax/_src/lookahead.py +++ b/optax/_src/lookahead.py @@ -96,9 +96,8 @@ def lookahead( raise ValueError('Synchronization period must be >= 1.') def init_fn(params: base.Params) -> LookaheadState: - try: - fast_params = params.fast # pytype: disable=attribute-error # numpy-scalars # pylint: disable=line-too-long - except AttributeError: + fast_params = getattr(params, 'fast', None) + if fast_params is None: # Allowing init_fn to be called with fast parameters reduces the # modifications necessary to adapt code to use lookahead in some cases. logging.warning( diff --git a/optax/_src/state_utils.py b/optax/_src/state_utils.py index 8d34422f..4f6a5cd0 100644 --- a/optax/_src/state_utils.py +++ b/optax/_src/state_utils.py @@ -15,7 +15,7 @@ """Tools for mapping over optimizer states.""" import typing -from typing import Any, Callable, Optional, Protocol, Union +from typing import Any, Callable, Optional, Protocol, Union, cast import jax from optax._src import base @@ -78,10 +78,14 @@ def tree_map_params( optional extra arguments. """ + # Cast for pytype checks (no-op for other usages). + placeholder = cast(base.chex.ArrayTree, _ParamsPlaceholder()) + if isinstance(initable, Initable): - state_with_placeholders = initable.init(_ParamsPlaceholder()) # type: ignore # numpy-scalars # pylint: disable=line-too-long + initable = cast(Initable, initable) # for pytype checks + state_with_placeholders = initable.init(placeholder) else: - state_with_placeholders = initable(_ParamsPlaceholder()) # pytype: disable=wrong-arg-types # numpy-scalars # pylint: disable=line-too-long + state_with_placeholders = initable(placeholder) def map_params(maybe_placeholder_value, value): if isinstance(maybe_placeholder_value, _ParamsPlaceholder): diff --git a/optax/_src/state_utils_test.py b/optax/_src/state_utils_test.py index 27b3a422..488df9e0 100644 --- a/optax/_src/state_utils_test.py +++ b/optax/_src/state_utils_test.py @@ -15,7 +15,7 @@ """Tests for state_utils.""" import dataclasses -from typing import Optional, TypedDict +from typing import Optional, TypedDict, cast from absl.testing import absltest import chex @@ -37,7 +37,7 @@ class FakeShardSpec: class ScaleByAdamStateDict(TypedDict): """An opt state that uses dictionaries instead of classes.""" - count: int + count: chex.Array params: TypedDict('Params', {'mu': chex.ArrayTree, 'nu': chex.ArrayTree}) @@ -48,10 +48,11 @@ def _scale_by_adam_with_dicts(): def init(params): state = t.init(params) + state = cast(transform.ScaleByAdamState, state) return ScaleByAdamStateDict( - count=state.count, # pytype: disable=attribute-error # numpy-scalars - params={'mu': state.mu, 'nu': state.nu}, # pytype: disable=attribute-error # numpy-scalars + count=state.count, + params={'mu': state.mu, 'nu': state.nu}, ) def update(updates, state, params=None): @@ -62,9 +63,10 @@ def update(updates, state, params=None): ) updates, state = t.update(updates, state, params) + state = cast(transform.ScaleByAdamState, state) return ScaleByAdamStateDict( - count=state.count, # pytype: disable=attribute-error # numpy-scalars - params={'mu': state.mu, 'nu': state.nu}, # pytype: disable=attribute-error # numpy-scalars + count=state.count, + params={'mu': state.mu, 'nu': state.nu}, ) return base.GradientTransformation(init, update) @@ -138,9 +140,10 @@ def init(params): state = init(params) state = state_utils.tree_map_params(init, lambda v: v+1, state) + state = cast(Foo, state) - self.assertEqual(state.count, 0) # pytype: disable=attribute-error # numpy-scalars - self.assertEqual(state.v, {'w': 1}) # pytype: disable=attribute-error # numpy-scalars + self.assertEqual(int(state.count), 0) + self.assertEqual(state.v, {'w': jnp.array(1)}) def test_adam(self): params = _fake_params() @@ -192,10 +195,12 @@ def test_inject_hparams(self): params = _fake_params() state = opt.init(params) state = state_utils.tree_map_params(opt, lambda v: v+1, state) + state = cast(schedule.InjectHyperparamsState, state) - self.assertEqual(1e-3, state.hyperparams['learning_rate']) # pytype: disable=attribute-error # numpy-scalars + self.assertEqual(1e-3, state.hyperparams['learning_rate']) params_plus_one = jax.tree_map(lambda v: v+1, params) - chex.assert_trees_all_close(state.inner_state[0].mu, params_plus_one) # pytype: disable=attribute-error # numpy-scalars + mu = getattr(state.inner_state[0], 'mu') + chex.assert_trees_all_close(mu, params_plus_one) def test_map_params_to_none(self): opt = alias.adagrad(1e-4) diff --git a/optax/_src/transform_test.py b/optax/_src/transform_test.py index b9b6b085..0ccefe6d 100644 --- a/optax/_src/transform_test.py +++ b/optax/_src/transform_test.py @@ -221,7 +221,7 @@ def test_scale(self): factor = 0.1 ** i rescaler = transform.scale(factor) # Apply rescaling. - scaled_updates, _ = rescaler.update(updates, None) # pytype: disable=wrong-arg-types # numpy-scalars + scaled_updates, _ = rescaler.update(updates, {}) # Manually scale updates. def rescale(t): return t * factor # pylint:disable=cell-var-from-loop @@ -240,7 +240,7 @@ def test_centralize(self, inputs, outputs): inputs = jnp.asarray(inputs) outputs = jnp.asarray(outputs) centralizer = transform.centralize() - centralized_inputs, _ = centralizer.update(inputs, None) # pytype: disable=wrong-arg-types # numpy-scalars + centralized_inputs, _ = centralizer.update(inputs, {}) chex.assert_trees_all_close(centralized_inputs, outputs) @chex.all_variants @@ -282,10 +282,10 @@ def f(params: jnp.ndarray) -> jnp.ndarray: og = transform.scale_by_optimistic_gradient() og_state = og.init(initial_params) # Provide some arbitrary previous gradient. - og_state.trace['x'] = 1.5 # type: ignore # numpy-scalars + getattr(og_state, 'trace')['x'] = 1.5 g = jax.grad(f)(initial_params) - og_true = 2 * g['x'] - og_state.trace['x'] # pytype: disable=attribute-error # numpy-scalars + og_true = 2 * g['x'] - getattr(og_state, 'trace')['x'] og, og_state = og.update(g, og_state) # Compare transformation output with manually computed optimistic gradient. diff --git a/optax/_src/wrappers.py b/optax/_src/wrappers.py index 271d1d5b..95e1e7e1 100644 --- a/optax/_src/wrappers.py +++ b/optax/_src/wrappers.py @@ -427,8 +427,11 @@ def mid_step(args): return new_updates, new_state - def has_updated(self, state: MultiStepsState) -> Array: - return jnp.logical_and(state.mini_step == 0, state.gradient_step > 0) + def has_updated(self, state: Union[MultiStepsState, chex.ArrayTree]) -> Array: + # Use `getattr` to bypass pytype checks. + return jnp.logical_and( + getattr(state, 'mini_step') == 0, getattr(state, 'gradient_step') > 0 + ) def gradient_transformation(self) -> base.GradientTransformation: return base.GradientTransformation(init=self.init, update=self.update) diff --git a/optax/_src/wrappers_test.py b/optax/_src/wrappers_test.py index dc05d33d..ce50bb76 100644 --- a/optax/_src/wrappers_test.py +++ b/optax/_src/wrappers_test.py @@ -111,35 +111,35 @@ def fn(x): # We know exactly what should be the value of params since we are # effectively using sgd in all cases. self.assertEqual(-1., float(jax.tree_util.tree_flatten(params)[0][0])) - self.assertTrue(bool(state.last_finite)) # pytype: disable=attribute-error # numpy-scalars + self.assertTrue(bool(getattr(state, 'last_finite'))) # Check 2 rejected param updates for step in range(2): grads = grads_fn(params, nan) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) self.assertEqual(-1., float(jax.tree_util.tree_flatten(params)[0][0])) - self.assertFalse(bool(state.last_finite)) # pytype: disable=attribute-error # numpy-scalars - self.assertEqual(step + 1, int(state.notfinite_count)) # pytype: disable=attribute-error # numpy-scalars + self.assertFalse(bool(getattr(state, 'last_finite'))) + self.assertEqual(step + 1, int(getattr(state, 'notfinite_count'))) # Next successful param update grads = grads_fn(params, one) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) self.assertEqual(-2., float(jax.tree_util.tree_flatten(params)[0][0])) - self.assertTrue(bool(state.last_finite)) # pytype: disable=attribute-error # numpy-scalars + self.assertTrue(bool(getattr(state, 'last_finite'))) # Again 2 rejected param updates for step in range(2): grads = grads_fn(params, nan) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) self.assertEqual(-2., float(jax.tree_util.tree_flatten(params)[0][0])) - self.assertFalse(bool(state.last_finite)) # pytype: disable=attribute-error # numpy-scalars - self.assertEqual(step + 1, int(state.notfinite_count)) # pytype: disable=attribute-error # numpy-scalars + self.assertFalse(bool(getattr(state, 'last_finite'))) + self.assertEqual(step + 1, int(getattr(state, 'notfinite_count'))) # Next param update with NaN is accepted since we reached maximum grads = grads_fn(params, nan) updates, state = opt.update(grads, state, params) params = update.apply_updates(params, updates) self.assertTrue(bool(jnp.isnan(jax.tree_util.tree_flatten(params)[0][0]))) - self.assertEqual(5, int(state.total_notfinite)) # pytype: disable=attribute-error # numpy-scalars + self.assertEqual(5, int(getattr(state, 'total_notfinite'))) def test_apply_if_finite_pmap(self): # Unlike in `test_apply_if_finite`: @@ -247,18 +247,18 @@ def test_multi_steps_every_k_schedule(self): params = dict(a=jnp.zeros([])) opt_state = opt_init(params) grad = dict(a=jnp.zeros([])) - self.assertFalse(ms_opt.has_updated(opt_state)) # pytype: disable=wrong-arg-types # numpy-scalars + self.assertFalse(ms_opt.has_updated(opt_state)) # First two steps have 1 mini-step per update. for _ in range(2): _, opt_state = opt_update(grad, opt_state, params) - self.assertTrue(ms_opt.has_updated(opt_state)) # pytype: disable=wrong-arg-types # numpy-scalars + self.assertTrue(ms_opt.has_updated(opt_state)) # Subsequently, mini-steps should have 3 mini-steps per update. for _ in range(5): for _ in range(2): _, opt_state = opt_update(grad, opt_state, params) - self.assertFalse(ms_opt.has_updated(opt_state)) # pytype: disable=wrong-arg-types # numpy-scalars + self.assertFalse(ms_opt.has_updated(opt_state)) _, opt_state = opt_update(grad, opt_state, params) - self.assertTrue(ms_opt.has_updated(opt_state)) # pytype: disable=wrong-arg-types # numpy-scalars + self.assertTrue(ms_opt.has_updated(opt_state)) def test_multi_steps_computes_mean(self): k_steps = 4 @@ -268,16 +268,16 @@ def test_multi_steps_computes_mean(self): params = dict(a=jnp.zeros([])) opt_state = opt_init(params) grads = [dict(a=jnp.ones([]) * i) for i in [1, 2, 3, 4]] - self.assertFalse(ms_opt.has_updated(opt_state)) # pytype: disable=wrong-arg-types # numpy-scalars + self.assertFalse(ms_opt.has_updated(opt_state)) # First 3 steps don't update. for grad in grads[:-1]: _, opt_state = opt_update(grad, opt_state, params) - self.assertFalse(ms_opt.has_updated(opt_state)) # pytype: disable=wrong-arg-types # numpy-scalars + self.assertFalse(ms_opt.has_updated(opt_state)) # Actual update. new_params, opt_state = opt_update(grads[-1], opt_state, params) - self.assertTrue(ms_opt.has_updated(opt_state)) # pytype: disable=wrong-arg-types # numpy-scalars + self.assertTrue(ms_opt.has_updated(opt_state)) np.testing.assert_array_equal(new_params['a'], 2.5) def test_skip_not_finite(self): From f9ed8acf69a2a1d0909313ffdcb721b965b0a014 Mon Sep 17 00:00:00 2001 From: Iurii Kemaev Date: Thu, 20 Apr 2023 07:41:11 -0700 Subject: [PATCH 3/4] Release v0.1.5. PiperOrigin-RevId: 525743135 --- optax/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/__init__.py b/optax/__init__.py index 5f5c2a15..5a05af45 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -187,7 +187,7 @@ from optax._src.wrappers import skip_large_updates from optax._src.wrappers import skip_not_finite -__version__ = "0.1.5.dev" +__version__ = "0.1.5" __all__ = ( "adabelief", From e5f910e92534c3b444e806582227389b7bb2b3b0 Mon Sep 17 00:00:00 2001 From: Markus Kunesch Date: Fri, 21 Apr 2023 05:23:45 -0700 Subject: [PATCH 4/4] Add empty pyproject.toml to enable PR #513. PiperOrigin-RevId: 526006691 --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..b1401d1e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +# Empty file in preparation for PR #513. +# See https://github.com/deepmind/optax/pull/513 for more information.