Skip to content

Commit

Permalink
epochs -> steps
Browse files Browse the repository at this point in the history
  • Loading branch information
TezRomacH committed May 16, 2019
1 parent a8340e3 commit 69ba1ee
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions catalyst/contrib/scheduler/onecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class OneCycleLR(BaseScheduler):
OneCycle scheduler with warm-up & lr decay stages.
First stage increases lr from ``init_lr`` to ``max_lr``,
and called ``warmup``. Also it decreases momentum
from ``init_momentum`` to ``min_momentum``. Takes ``warmup_epochs`` epochs
from ``init_momentum`` to ``min_momentum``. Takes ``warmup_steps`` steps
Second is ``annealing`` stage. Decrease lr from ``max_lr`` to ``min_lr``,
Increase momentum from ``min_momentum`` to ``max_momentum``.
Expand Down Expand Up @@ -41,11 +41,11 @@ def __init__(
init_lr (float, optional): initial lr
warmup_steps (int): count of steps for warm-up stage
warmup_fraction (float, optional): fraction in [0; 1) to calculate
number of warmup epochs.
number of warmup steps.
Cannot be set together with ``warmup_steps``
decay_steps (int): count of steps for lr decay stage
decay_fraction (float, optional): fraction in [0; 1) to calculate
number of decay epochs.
number of decay steps.
Cannot be set together with ``decay_steps``
momentum_range: tuple with two or three elements
(min_momentum, max_momentum, [final_momentum])
Expand Down Expand Up @@ -104,7 +104,7 @@ def _calculate_warmup(
):
if warmup_fraction is not None:
assert 0.0 <= warmup_fraction < 1.0 and warmup_steps == 0, \
"You should pass either warmup_epochs or " \
"You should pass either warmup_steps or " \
"warmup_fraction in range [0; 1) "
warmup_steps = int(num_steps * warmup_fraction)

Expand All @@ -120,7 +120,7 @@ def _calculate_decay(
):
if decay_fraction is not None:
assert 0.0 <= decay_fraction < 1.0 and decay_steps == 0, \
"You should pass either decay_epochs or " \
"You should pass either decay_steps or " \
"decay_fraction in range [0; 1) "
decay_steps = int(num_steps * decay_fraction)

Expand Down Expand Up @@ -157,7 +157,7 @@ def _calculate_lr_momentum(
momentum_decay, momentum_annealing, momentum_warmup
))

def _get_epoch_lr_momentum(self, step_num: int):
def _get_steps_lr_momentum(self, step_num: int):
if step_num < len(self.learning_rates):
lr = self.learning_rates[step_num]
else:
Expand All @@ -177,7 +177,7 @@ def get_lr(self) -> List[float]:
Returns:
List[float]: calculated lr for every param groups
"""
lr, _ = self._get_epoch_lr_momentum(self.last_epoch)
lr, _ = self._get_steps_lr_momentum(self.last_epoch)
return [lr] * self.total_groups

def get_momentum(self) -> List[float]:
Expand All @@ -186,7 +186,7 @@ def get_momentum(self) -> List[float]:
Returns:
List[float]: calculated momentum for every param groups
"""
_, momentum = self._get_epoch_lr_momentum(self.last_epoch)
_, momentum = self._get_steps_lr_momentum(self.last_epoch)
return [momentum] * self.total_groups

def reset(self):
Expand Down

0 comments on commit 69ba1ee

Please sign in to comment.