Skip to content

Commit

Permalink
Revert "DRAFT: Revert "Add Cosine Annealing LR scheduler as a decay m…
Browse files Browse the repository at this point in the history
…ethod (#3507)" (#3545)"

This reverts commit feec8a6.
  • Loading branch information
justinxzhao committed Aug 28, 2023
1 parent f34c272 commit 03657a0
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 14 deletions.
71 changes: 59 additions & 12 deletions ludwig/modules/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
import math
from typing import Any, Dict
from typing import Any, Callable, Dict

from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR, ReduceLROnPlateau, SequentialLR

from ludwig.constants import MINIMIZE, TRAINING, VALIDATION
from ludwig.modules.metric_registry import get_metric_objective
Expand Down Expand Up @@ -166,22 +166,41 @@ def get_schedule_with_warmup(
step_info: StepInfo,
) -> LambdaLR:
"""Creates a learning rate scheduler that updates each training step."""
decay_fn = decay_registry[config.decay]
schedulers = []

def lr_lambda(current_step: int):
if current_step < step_info.num_warmup_steps:
return float(current_step) / float(max(1, step_info.num_warmup_steps))
return decay_fn(current_step, step_info.num_training_steps, step_info.num_warmup_steps, config)
# Warmup scheduler
if step_info.num_warmup_steps > 0:
warmup_scheduler = LambdaLR(
optimizer,
lambda current_step: float(current_step) / float(max(1, step_info.num_warmup_steps)),
last_epoch=-1,
)
schedulers.append(warmup_scheduler)

return LambdaLR(optimizer, lr_lambda, last_epoch=-1)
# Decay scheduler
decay = config.decay
decay_scheduler = decay_registry[decay](config, optimizer, step_info)
schedulers.append(decay_scheduler)

if len(schedulers) == 1:
# Only one scheduler, no need to wrap in a SequentialLR
return schedulers[0]

# Return a SequentialLR that applies the warmup and decay schedulers in order
# with the warmup scheduler only applied for the first num_warmup_steps steps.
return SequentialLR(optimizer, schedulers=schedulers, milestones=[step_info.num_warmup_steps], last_epoch=-1)


def no_decay(current_step: int, num_training_steps: int, num_warmup_steps: int, config: LRSchedulerConfig):
return 1.0


def linear_decay(current_step: int, num_training_steps: int, num_warmup_steps: int, config: LRSchedulerConfig):
return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
return max(
0.0,
float(num_training_steps - num_warmup_steps - current_step)
/ float(max(1, num_training_steps - num_warmup_steps)),
)


def exponential_decay(current_step: int, num_training_steps: int, num_warmup_steps: int, config: LRSchedulerConfig):
Expand All @@ -194,8 +213,36 @@ def exponential_decay(current_step: int, num_training_steps: int, num_warmup_ste
return math.pow(decay_rate, exponent)


def wrap_decay_fn(decay_fn: Callable) -> Callable:
def init_fn(config: LRSchedulerConfig, optimizer: Optimizer, step_info: StepInfo) -> LambdaLR:
return LambdaLR(
optimizer,
lambda current_step: decay_fn(
current_step, step_info.num_training_steps, step_info.num_warmup_steps, config
),
last_epoch=-1,
)

return init_fn


def init_cosine_decay(
config: LRSchedulerConfig,
optimizer: Optimizer,
step_info: StepInfo,
) -> CosineAnnealingWarmRestarts:
return CosineAnnealingWarmRestarts(
optimizer,
T_0=config.t_0 or step_info.steps_per_checkpoint,
T_mult=config.t_mult or 1,
eta_min=config.eta_min or 0,
last_epoch=-1,
)


decay_registry = {
None: no_decay,
"linear": linear_decay,
"exponential": exponential_decay,
None: wrap_decay_fn(no_decay),
"linear": wrap_decay_fn(linear_decay),
"exponential": wrap_decay_fn(exponential_decay),
"cosine": init_cosine_decay,
}
28 changes: 27 additions & 1 deletion ludwig/schema/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class LRSchedulerConfig(schema_utils.BaseMarshmallowConfig, ABC):
"""Configuration for learning rate scheduler parameters."""

decay: str = schema_utils.StringOptions(
options=["linear", "exponential"],
options=["linear", "exponential", "cosine"],
default=None,
allow_none=True,
description="Turn on decay of the learning rate.",
Expand Down Expand Up @@ -99,6 +99,32 @@ class LRSchedulerConfig(schema_utils.BaseMarshmallowConfig, ABC):
parameter_metadata=TRAINER_METADATA[MODEL_ECD]["learning_rate_scheduler"]["reduce_eval_split"],
)

# Parameters for CosineAnnealingWarmRestarts scheduler

t_0: int = schema_utils.PositiveInteger(
default=None,
allow_none=True,
description="Number of steps before the first restart for cosine annealing decay. If not specified, it"
" will be set to `steps_per_checkpoint`.",
parameter_metadata=TRAINER_METADATA[MODEL_ECD]["learning_rate_scheduler"]["t_0"],
)

t_mult: int = schema_utils.PositiveInteger(
default=1,
description="Period multiplier after each restart for cosine annealing decay. Defaults to 1, i.e.,"
" restart every `t_0` steps. If set to a larger value, the period between restarts increases by that"
" multiplier. For e.g., if t_mult is 2, then the periods would be: t_0, 2*t_0, 2^2*t_0, 2^3*t_0, etc.",
parameter_metadata=TRAINER_METADATA[MODEL_ECD]["learning_rate_scheduler"]["t_mult"],
)

eta_min: float = schema_utils.FloatRange(
default=0,
min=0,
max=1,
description="Minimum learning rate allowed for cosine annealing decay. Default: 0.",
parameter_metadata=TRAINER_METADATA[MODEL_ECD]["learning_rate_scheduler"]["eta_min"],
)


# TODO(travis): too much boilerplate here, we should find a way to abstract all this and only require specifying the
# minimal amount needed for the new config object.
Expand Down
14 changes: 13 additions & 1 deletion ludwig/schema/metadata/configs/trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,10 @@ ecd:
suggested_values_reasoning:
Starting with exponential decay is a safe place to start, as it is a "softer" decrease in the learning
rate over time, as compared with linear, which is more steep after the initial drop. Linear decay is
most useful when the risk of catastrophic forgetting is very high (e.g, for fine-tuning pretrained models).
most useful when the risk of catastrophic forgetting is very high (e.g, for fine-tuning pretrained
models). Cosine annealing is a type of learning rate schedule that has the effect of starting with a
large learning rate that is relatively rapidly decreased to a minimum value before being increased
rapidly again. The resetting of the learning rate acts like a simulated restart of the learning process.
If you observe your loss curves shooting up (even on the training set) in later epochs, increasing the
decay rate may help mitigate this effect.
ui_display_name: Decay
Expand Down Expand Up @@ -600,6 +603,15 @@ ecd:
reduce_eval_split:
expected_impact: 1
ui_display_name: Reduce Eval Split
t_0:
expected_impact: 1
ui_display_name: T_0
t_mult:
expected_impact: 1
ui_display_name: T_mult
eta_min:
expected_impact: 1
ui_display_name: Eta Min
gbm:
learning_rate:
commonly_used: true
Expand Down
84 changes: 84 additions & 0 deletions tests/ludwig/modules/test_lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import numpy as np
from torch.optim import SGD

Expand Down Expand Up @@ -33,6 +35,11 @@ def test_lr_scheduler_warmup_decay():
exp_scheduler = LRScheduler(config=exp_config, optimizer=exp_optimizer)
exp_scheduler.reset(steps_per_checkpoint, total_steps)

cosine_optimizer = SGD(module.parameters(), lr=base_lr)
cosine_config = LRSchedulerConfig(warmup_fraction=warmup_fraction, decay="cosine", t_0=steps_per_checkpoint)
cosine_scheduler = LRScheduler(config=cosine_config, optimizer=cosine_optimizer)
cosine_scheduler.reset(steps_per_checkpoint, total_steps)

warmup_steps = total_steps * warmup_fraction
for i in range(total_steps):
# Offset by 1
Expand All @@ -48,17 +55,25 @@ def test_lr_scheduler_warmup_decay():
exp_scheduler.step()
exp_lr = exp_optimizer.param_groups[0]["lr"]

cosine_scheduler.step()
cosine_lr = cosine_optimizer.param_groups[0]["lr"]

if step < warmup_steps:
assert linear_lr == exp_lr, f"step: {step}"
assert linear_lr == cosine_lr, f"step: {step}"
assert linear_lr < base_lr, f"step: {step}"
elif step == warmup_steps:
assert linear_lr == base_lr, f"step: {step}"
assert cosine_lr == base_lr, f"step: {step}"
assert exp_lr < base_lr, f"step: {step}"
else:
assert linear_lr < base_lr, f"step: {step}"
assert exp_lr < base_lr, f"step: {step}"
assert cosine_lr <= base_lr, f"step: {step}"

assert linear_lr < exp_lr
assert exp_lr < cosine_lr
assert cosine_lr == base_lr


def test_lr_scheduler_reduce_on_plateau():
Expand Down Expand Up @@ -119,6 +134,75 @@ def test_lr_scheduler_reduce_on_plateau():
assert np.isclose(lr, 0.001)


def test_lr_scheduler_cosine_decay_fixed_period():
total_steps = 10000
steps_per_checkpoint = 1000
base_lr = 1.0

module = NumberInputFeature(NumberInputFeatureConfig(name="num1", encoder=DenseEncoderConfig()))

optimizer = SGD(module.parameters(), lr=base_lr)
config = LRSchedulerConfig(decay="cosine", t_0=steps_per_checkpoint, decay_rate=0, reduce_on_plateau=0)
scheduler = LRScheduler(config=config, optimizer=optimizer)
scheduler.reset(steps_per_checkpoint, total_steps)

curr_lr = base_lr
prev_lr = base_lr
num_restarts = 0
for step in range(total_steps + 1):
# Cosine annealing formula
expected_lr = base_lr * 0.5 * (1 + math.cos(math.pi * (step % steps_per_checkpoint) / steps_per_checkpoint))
assert np.isclose(curr_lr, expected_lr), f"step: {step}"

if prev_lr < curr_lr:
# Since Cosine decay is periodic, we should see the learning rate
# decrease and then increase again.
num_restarts += 1

prev_lr = curr_lr
scheduler.step()

curr_lr = optimizer.param_groups[0]["lr"]

assert num_restarts == 10, f"num_restarts: {num_restarts}"


def test_lr_scheduler_cosine_decay_increasing_period():
total_steps = 20000
steps_per_checkpoint = 1000
base_lr = 1.0

module = NumberInputFeature(NumberInputFeatureConfig(name="num1", encoder=DenseEncoderConfig()))

optimizer = SGD(module.parameters(), lr=base_lr)
config = LRSchedulerConfig(
decay="cosine",
t_0=steps_per_checkpoint,
t_mult=2,
decay_rate=0,
reduce_on_plateau=0,
)
scheduler = LRScheduler(config=config, optimizer=optimizer)
scheduler.reset(steps_per_checkpoint, total_steps)

curr_lr = base_lr
prev_lr = base_lr
num_restarts = 0
for _ in range(total_steps + 1):
if prev_lr < curr_lr:
# Since Cosine decay is periodic, we should see the learning rate
# decrease and then increase again.
num_restarts += 1

prev_lr = curr_lr
scheduler.step()

curr_lr = optimizer.param_groups[0]["lr"]

# 1000, 3000, 6000, 12000, 24000 (but we stop at 20000)
assert num_restarts == 4, f"num_restarts: {num_restarts}"


def test_lr_scheduler_save_load():
steps_per_checkpoint = 10
total_steps = 100
Expand Down

0 comments on commit 03657a0

Please sign in to comment.