Skip to content

Commit

Permalink
Start inject_hyperparams step count at 0.
Browse files Browse the repository at this point in the history
This fixes issue #415. Previously, `inject_hyperparams` started with
`step_count=1` in the first update when using schedules (it incremented before
passing it to the schedule) whereas `scale_by_schedule` started with
`step_count=0`. To make this consistent, this PR changes `inject_hyperparams`
to also start at 0, i.e. increment the count only after passing it to the
schedule.

The PR comes with a test that breaks without the change. Furthermore, the step
counts in the existing tests of `inject_hyperparams` had to be decremented by
one in order for the tests to pass.

PiperOrigin-RevId: 474015093
  • Loading branch information
mkunesch authored and OptaxDev committed Sep 13, 2022
1 parent 44df918 commit 7532b60
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
4 changes: 2 additions & 2 deletions optax/_src/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,14 +593,14 @@ def init_fn(params):
count, hparams, inner_factory(**other_hps, **hparams).init(params))

def update_fn(updates, state, params=None):
count_inc = numerics.safe_int32_increment(state.count)
dtype = getattr(next(iter(
jax.tree_util.tree_leaves(updates)), None), 'dtype', None)
hparams = {k: _convert_floats(v, dtype)
for k, v in state.hyperparams.items()}
hparams.update(schedule_fn(count_inc, dtype))
hparams.update(schedule_fn(state.count, dtype))
updates, inner_state = inner_factory(**other_hps, **hparams).update(
updates, state.inner_state, params)
count_inc = numerics.safe_int32_increment(state.count)

# pylint:disable=too-many-function-args
return updates, InjectHyperparamsState(count_inc, hparams, inner_state)
Expand Down
14 changes: 12 additions & 2 deletions optax/_src/schedule_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ class InjectHyperparamsTest(chex.TestCase):
def test_updates(self):
optim = schedule.inject_hyperparams(transform.scale)( # stateless
step_size=schedule.piecewise_constant_schedule(
3.0, {2: 5, 8: 2, 13: 1.5}))
3.0, {1: 5, 7: 2, 12: 1.5}))

params = [jnp.zeros([], dtype=jnp.float32)]
state = self.variant(optim.init)(params)
Expand All @@ -518,7 +518,7 @@ def test_updates(self):
def test_hyperparams_state(self):
optim = schedule.inject_hyperparams(transform.trace)( # stateful
decay=schedule.piecewise_constant_schedule(
0.8, {4: 0.5, 10: 1.25}),
0.8, {3: 0.5, 9: 1.25}),
nesterov=True)

params = [jnp.zeros([2, 3]) for _ in range(3)]
Expand Down Expand Up @@ -603,6 +603,16 @@ def test_static_args_error(self, static_args):
with self.assertRaises(ValueError):
schedule.inject_hyperparams(transform.scale, static_args=static_args)

@chex.all_variants
def test_inject_hyperparams_starts_with_step_count_zero(self):
"""Checks that inject_hyperparams uses step count 0 in the first update."""
# See also: https://github.com/deepmind/optax/issues/415.
opt = schedule.inject_hyperparams(transform.scale)(lambda count: count)
params = jnp.zeros(3)
grads = jnp.array([-1, 0, 1])
updates, _ = self.variant(opt.update)(grads, opt.init(params))
np.testing.assert_array_equal(updates, np.zeros(3))


if __name__ == '__main__':
absltest.main()

0 comments on commit 7532b60

Please sign in to comment.