Skip to content

Commit

Permalink
Fix pytype failures related to teaching pytype about NumPy scalar types.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 519695398
  • Loading branch information
hawkinsp authored and OptaxDev committed Apr 20, 2023
1 parent 04768d2 commit ceebb9c
Show file tree
Hide file tree
Showing 11 changed files with 43 additions and 43 deletions.
2 changes: 1 addition & 1 deletion examples/lookahead_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def train_step(params, optimizer_state, batch):
test_dataset.as_numpy_iterator())
print(f'Epoch {epoch+1}: test acc: {test_acc:.2f}')

return test_acc
return test_acc # pytype: disable=bad-return-type # numpy-scalars


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions examples/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def model_accuracy(model: Callable[[chex.Array], chex.Array],
accuracy_sum += _single_batch_accuracy(logits, batch['label']) * batch_size
dataset_size += batch_size

return accuracy_sum / dataset_size
return accuracy_sum / dataset_size # pytype: disable=bad-return-type # numpy-scalars # pylint: disable=line-too-long


# Optax is agnostic to which (if any) neural network library is used. Below we
Expand Down Expand Up @@ -111,7 +111,7 @@ def train_step(params, optimizer_state, batch):
test_acc = model_accuracy(eval_model, test_dataset.as_numpy_iterator())
print(f'Epoch {epoch+1}: test acc: {test_acc:.2f}')

return test_acc
return test_acc # pytype: disable=bad-return-type # numpy-scalars


def main(unused_argv):
Expand Down
6 changes: 3 additions & 3 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,13 @@ def test_explicit_dtype(self, dtype):
expected_dtype = jax.dtypes.canonicalize_dtype(dtype) # None -> float32
tx = alias.sgd(0.1, momentum=0.9, accumulator_dtype=dtype)
trace_state, _ = tx.init(jnp.array([0.0, 0.0]))
self.assertEqual(expected_dtype, trace_state.trace.dtype)
self.assertEqual(expected_dtype, trace_state.trace.dtype) # pytype: disable=attribute-error # numpy-scalars
tx = alias.adam(0.1, mu_dtype=dtype)
adam_state, _ = tx.init(jnp.array([0.0, 0.0]))
self.assertEqual(expected_dtype, adam_state.mu.dtype)
self.assertEqual(expected_dtype, adam_state.mu.dtype) # pytype: disable=attribute-error # numpy-scalars
tx = alias.adamw(0.1, mu_dtype=dtype)
adam_state, _, _ = tx.init(jnp.array([0.0, 0.0]))
self.assertEqual(expected_dtype, adam_state.mu.dtype)
self.assertEqual(expected_dtype, adam_state.mu.dtype) # pytype: disable=attribute-error # numpy-scalars


if __name__ == '__main__':
Expand Down
8 changes: 4 additions & 4 deletions optax/_src/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_stateless_no_params(self):
def opt(g, _):
return jax.tree_util.tree_map(lambda g_: g_ * 2, g)

state = opt.init(None)
state = opt.init(None) # pytype: disable=wrong-arg-types # numpy-scalars
update_fn = self.variant(opt.update)
new_updates, _ = update_fn(updates, state)
expected_updates = {'linear': jnp.full((5, 3), 6.0)}
Expand All @@ -159,7 +159,7 @@ def weight_decay(g, p):
return jax.tree_util.tree_map(lambda g_, p_: g_ + 0.1 * p_, g, p)

opt = base.stateless(weight_decay)
state = opt.init(None)
state = opt.init(None) # pytype: disable=wrong-arg-types # numpy-scalars
self.assertIsInstance(state, base.EmptyState)


Expand All @@ -183,15 +183,15 @@ def test_stateless_with_tree_map_no_params(self):
updates = {'linear': jnp.full((5, 3), 3.0)}

opt = base.stateless_with_tree_map(lambda g, _: g * 2.0)
state = opt.init(None)
state = opt.init(None) # pytype: disable=wrong-arg-types # numpy-scalars
update_fn = self.variant(opt.update)
new_updates, _ = update_fn(updates, state)
expected_updates = {'linear': jnp.full((5, 3), 6.0)}
chex.assert_trees_all_close(new_updates, expected_updates)

def test_init_returns_emptystate(self):
opt = base.stateless_with_tree_map(lambda g, p: g + 0.1 * p)
state = opt.init(None)
state = opt.init(None) # pytype: disable=wrong-arg-types # numpy-scalars
self.assertIsInstance(state, base.EmptyState)


Expand Down
4 changes: 2 additions & 2 deletions optax/_src/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def init_fn(params):
group: wrappers.masked(tx, make_mask(labels, group)).init(params)
for group, tx in transforms.items()
}
return MultiTransformState(inner_states)
return MultiTransformState(inner_states) # pytype: disable=wrong-arg-types # numpy-scalars # pylint: disable=line-too-long

def update_fn(updates, state, params=None):
labels = param_labels(updates) if callable(param_labels) else param_labels
Expand All @@ -155,6 +155,6 @@ def update_fn(updates, state, params=None):
masked_tx = wrappers.masked(tx, make_mask(labels, group))
updates, new_inner_state[group] = masked_tx.update(
updates, state.inner_states[group], params)
return updates, MultiTransformState(new_inner_state)
return updates, MultiTransformState(new_inner_state) # pytype: disable=wrong-arg-types # numpy-scalars # pylint: disable=line-too-long

return base.GradientTransformation(init_fn, update_fn)
4 changes: 2 additions & 2 deletions optax/_src/control_variates.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def control_variates_jacobians(
"""
control_variate = control_variate_from_function(function)
stochastic_cv, expected_value_cv, update_state_cv = control_variate
data_dim = params[0].shape[0]
data_dim = params[0].shape[0] # pytype: disable=attribute-error # numpy-scalars # pylint: disable=line-too-long
if estimate_cv_coeffs:
cv_coeffs = estimate_control_variate_coefficients(
function, control_variate_from_function, grad_estimator, params,
Expand Down Expand Up @@ -330,7 +330,7 @@ def param_fn(x):
# \nabla_{\theta} E_{p(x; \theta)}]
param_jacobians += cv_coeff * expected_value_grads[param_index]

chex.assert_shape(param_jacobians, (num_samples,) + param.shape)
chex.assert_shape(param_jacobians, (num_samples,) + param.shape) # pytype: disable=attribute-error # numpy-scalars # pylint: disable=line-too-long
jacobians.append(param_jacobians)

return jacobians, control_variate_state
Expand Down
2 changes: 1 addition & 1 deletion optax/_src/lookahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def lookahead(

def init_fn(params: base.Params) -> LookaheadState:
try:
fast_params = params.fast
fast_params = params.fast # pytype: disable=attribute-error # numpy-scalars # pylint: disable=line-too-long
except AttributeError:
# Allowing init_fn to be called with fast parameters reduces the
# modifications necessary to adapt code to use lookahead in some cases.
Expand Down
4 changes: 2 additions & 2 deletions optax/_src/state_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def tree_map_params(
"""

if isinstance(initable, Initable):
state_with_placeholders = initable.init(_ParamsPlaceholder())
state_with_placeholders = initable.init(_ParamsPlaceholder()) # type: ignore # numpy-scalars # pylint: disable=line-too-long
else:
state_with_placeholders = initable(_ParamsPlaceholder())
state_with_placeholders = initable(_ParamsPlaceholder()) # pytype: disable=wrong-arg-types # numpy-scalars # pylint: disable=line-too-long

def map_params(maybe_placeholder_value, value):
if isinstance(maybe_placeholder_value, _ParamsPlaceholder):
Expand Down
16 changes: 8 additions & 8 deletions optax/_src/state_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def init(params):
state = t.init(params)

return ScaleByAdamStateDict(
count=state.count,
params={'mu': state.mu, 'nu': state.nu},
count=state.count, # pytype: disable=attribute-error # numpy-scalars
params={'mu': state.mu, 'nu': state.nu}, # pytype: disable=attribute-error # numpy-scalars
)

def update(updates, state, params=None):
Expand All @@ -63,8 +63,8 @@ def update(updates, state, params=None):

updates, state = t.update(updates, state, params)
return ScaleByAdamStateDict(
count=state.count,
params={'mu': state.mu, 'nu': state.nu},
count=state.count, # pytype: disable=attribute-error # numpy-scalars
params={'mu': state.mu, 'nu': state.nu}, # pytype: disable=attribute-error # numpy-scalars
)

return base.GradientTransformation(init, update)
Expand Down Expand Up @@ -139,8 +139,8 @@ def init(params):
state = init(params)
state = state_utils.tree_map_params(init, lambda v: v+1, state)

self.assertEqual(state.count, 0)
self.assertEqual(state.v, {'w': 1})
self.assertEqual(state.count, 0) # pytype: disable=attribute-error # numpy-scalars
self.assertEqual(state.v, {'w': 1}) # pytype: disable=attribute-error # numpy-scalars

def test_adam(self):
params = _fake_params()
Expand Down Expand Up @@ -193,9 +193,9 @@ def test_inject_hparams(self):
state = opt.init(params)
state = state_utils.tree_map_params(opt, lambda v: v+1, state)

self.assertEqual(1e-3, state.hyperparams['learning_rate'])
self.assertEqual(1e-3, state.hyperparams['learning_rate']) # pytype: disable=attribute-error # numpy-scalars
params_plus_one = jax.tree_map(lambda v: v+1, params)
chex.assert_trees_all_close(state.inner_state[0].mu, params_plus_one)
chex.assert_trees_all_close(state.inner_state[0].mu, params_plus_one) # pytype: disable=attribute-error # numpy-scalars

def test_map_params_to_none(self):
opt = alias.adagrad(1e-4)
Expand Down
8 changes: 4 additions & 4 deletions optax/_src/transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def test_scale(self):
factor = 0.1 ** i
rescaler = transform.scale(factor)
# Apply rescaling.
scaled_updates, _ = rescaler.update(updates, None)
scaled_updates, _ = rescaler.update(updates, None) # pytype: disable=wrong-arg-types # numpy-scalars
# Manually scale updates.
def rescale(t):
return t * factor # pylint:disable=cell-var-from-loop
Expand All @@ -240,7 +240,7 @@ def test_centralize(self, inputs, outputs):
inputs = jnp.asarray(inputs)
outputs = jnp.asarray(outputs)
centralizer = transform.centralize()
centralized_inputs, _ = centralizer.update(inputs, None)
centralized_inputs, _ = centralizer.update(inputs, None) # pytype: disable=wrong-arg-types # numpy-scalars
chex.assert_trees_all_close(centralized_inputs, outputs)

@chex.all_variants
Expand Down Expand Up @@ -282,10 +282,10 @@ def f(params: jnp.ndarray) -> jnp.ndarray:
og = transform.scale_by_optimistic_gradient()
og_state = og.init(initial_params)
# Provide some arbitrary previous gradient.
og_state.trace['x'] = 1.5
og_state.trace['x'] = 1.5 # type: ignore # numpy-scalars

g = jax.grad(f)(initial_params)
og_true = 2 * g['x'] - og_state.trace['x']
og_true = 2 * g['x'] - og_state.trace['x'] # pytype: disable=attribute-error # numpy-scalars
og, og_state = og.update(g, og_state)

# Compare transformation output with manually computed optimistic gradient.
Expand Down
28 changes: 14 additions & 14 deletions optax/_src/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,35 +111,35 @@ def fn(x):
# We know exactly what should be the value of params since we are
# effectively using sgd in all cases.
self.assertEqual(-1., float(jax.tree_util.tree_flatten(params)[0][0]))
self.assertTrue(bool(state.last_finite))
self.assertTrue(bool(state.last_finite)) # pytype: disable=attribute-error # numpy-scalars
# Check 2 rejected param updates
for step in range(2):
grads = grads_fn(params, nan)
updates, state = opt.update(grads, state, params)
params = update.apply_updates(params, updates)
self.assertEqual(-1., float(jax.tree_util.tree_flatten(params)[0][0]))
self.assertFalse(bool(state.last_finite))
self.assertEqual(step + 1, int(state.notfinite_count))
self.assertFalse(bool(state.last_finite)) # pytype: disable=attribute-error # numpy-scalars
self.assertEqual(step + 1, int(state.notfinite_count)) # pytype: disable=attribute-error # numpy-scalars
# Next successful param update
grads = grads_fn(params, one)
updates, state = opt.update(grads, state, params)
params = update.apply_updates(params, updates)
self.assertEqual(-2., float(jax.tree_util.tree_flatten(params)[0][0]))
self.assertTrue(bool(state.last_finite))
self.assertTrue(bool(state.last_finite)) # pytype: disable=attribute-error # numpy-scalars
# Again 2 rejected param updates
for step in range(2):
grads = grads_fn(params, nan)
updates, state = opt.update(grads, state, params)
params = update.apply_updates(params, updates)
self.assertEqual(-2., float(jax.tree_util.tree_flatten(params)[0][0]))
self.assertFalse(bool(state.last_finite))
self.assertEqual(step + 1, int(state.notfinite_count))
self.assertFalse(bool(state.last_finite)) # pytype: disable=attribute-error # numpy-scalars
self.assertEqual(step + 1, int(state.notfinite_count)) # pytype: disable=attribute-error # numpy-scalars
# Next param update with NaN is accepted since we reached maximum
grads = grads_fn(params, nan)
updates, state = opt.update(grads, state, params)
params = update.apply_updates(params, updates)
self.assertTrue(bool(jnp.isnan(jax.tree_util.tree_flatten(params)[0][0])))
self.assertEqual(5, int(state.total_notfinite))
self.assertEqual(5, int(state.total_notfinite)) # pytype: disable=attribute-error # numpy-scalars

def test_apply_if_finite_pmap(self):
# Unlike in `test_apply_if_finite`:
Expand Down Expand Up @@ -247,18 +247,18 @@ def test_multi_steps_every_k_schedule(self):
params = dict(a=jnp.zeros([]))
opt_state = opt_init(params)
grad = dict(a=jnp.zeros([]))
self.assertFalse(ms_opt.has_updated(opt_state))
self.assertFalse(ms_opt.has_updated(opt_state)) # pytype: disable=wrong-arg-types # numpy-scalars
# First two steps have 1 mini-step per update.
for _ in range(2):
_, opt_state = opt_update(grad, opt_state, params)
self.assertTrue(ms_opt.has_updated(opt_state))
self.assertTrue(ms_opt.has_updated(opt_state)) # pytype: disable=wrong-arg-types # numpy-scalars
# Subsequently, mini-steps should have 3 mini-steps per update.
for _ in range(5):
for _ in range(2):
_, opt_state = opt_update(grad, opt_state, params)
self.assertFalse(ms_opt.has_updated(opt_state))
self.assertFalse(ms_opt.has_updated(opt_state)) # pytype: disable=wrong-arg-types # numpy-scalars
_, opt_state = opt_update(grad, opt_state, params)
self.assertTrue(ms_opt.has_updated(opt_state))
self.assertTrue(ms_opt.has_updated(opt_state)) # pytype: disable=wrong-arg-types # numpy-scalars

def test_multi_steps_computes_mean(self):
k_steps = 4
Expand All @@ -268,16 +268,16 @@ def test_multi_steps_computes_mean(self):
params = dict(a=jnp.zeros([]))
opt_state = opt_init(params)
grads = [dict(a=jnp.ones([]) * i) for i in [1, 2, 3, 4]]
self.assertFalse(ms_opt.has_updated(opt_state))
self.assertFalse(ms_opt.has_updated(opt_state)) # pytype: disable=wrong-arg-types # numpy-scalars

# First 3 steps don't update.
for grad in grads[:-1]:
_, opt_state = opt_update(grad, opt_state, params)
self.assertFalse(ms_opt.has_updated(opt_state))
self.assertFalse(ms_opt.has_updated(opt_state)) # pytype: disable=wrong-arg-types # numpy-scalars

# Actual update.
new_params, opt_state = opt_update(grads[-1], opt_state, params)
self.assertTrue(ms_opt.has_updated(opt_state))
self.assertTrue(ms_opt.has_updated(opt_state)) # pytype: disable=wrong-arg-types # numpy-scalars
np.testing.assert_array_equal(new_params['a'], 2.5)

def test_skip_not_finite(self):
Expand Down

0 comments on commit ceebb9c

Please sign in to comment.