Skip to content

Commit

Permalink
Updates warmup_cosine_decay_schedule to allow 0 as peak_value. Curren…
Browse files Browse the repository at this point in the history
…tly it errors out as divide by 0.

The need for peak_value to be 0 is when we want to turn off training for certain portions of the network by setting the learning rate to be 0.

PiperOrigin-RevId: 604272302
  • Loading branch information
OptaxDev committed Feb 5, 2024
1 parent c86b9a9 commit c902205
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
10 changes: 7 additions & 3 deletions optax/schedules/_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,16 +403,20 @@ def warmup_cosine_decay_schedule(
Returns:
schedule: A function that maps step counts to values.
"""
alpha = 0. if peak_value == 0. else end_value / peak_value
schedules = [
linear_schedule(
init_value=init_value,
end_value=peak_value,
transition_steps=warmup_steps),
transition_steps=warmup_steps,
),
cosine_decay_schedule(
init_value=peak_value,
decay_steps=decay_steps - warmup_steps,
alpha=end_value/peak_value,
exponent=exponent)]
alpha=alpha,
exponent=exponent,
),
]
return _join.join_schedules(schedules, [warmup_steps])


Expand Down
18 changes: 18 additions & 0 deletions optax/schedules/_schedule_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,24 @@ def test_with_exponent(self):
rtol=1e-6, atol=1e-8
)

@chex.all_variants
def test_zero_peak_value(self):
"""Check that we get correct results when running with zero peak value."""
schedule_fn = self.variant(
_schedule.warmup_cosine_decay_schedule(
init_value=0.2,
peak_value=0,
end_value=-3.0,
warmup_steps=50,
decay_steps=100,
exponent=2,
)
)
output = schedule_fn(np.array([0, 10, 50, 75, 100]))
np.testing.assert_allclose(
output, np.array([0.2, 0.16, 0.0, 0.0, 0.0]), rtol=1e-6, atol=1e-8
)


class SGDRTest(chex.TestCase):

Expand Down

0 comments on commit c902205

Please sign in to comment.