Skip to content

Commit

Permalink
Forward fix for learning rate schedulers.
Browse files Browse the repository at this point in the history
  • Loading branch information
justinxzhao committed Aug 29, 2023
1 parent 03657a0 commit 0141421
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 33 deletions.
53 changes: 26 additions & 27 deletions ludwig/modules/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from ludwig.utils.metric_utils import TrainerMetric
from ludwig.utils.trainer_utils import ProgressTracker

logger = logging.getLogger(__name__)


class ReduceLROnPLateauCappedDecreases(ReduceLROnPlateau):
def __init__(self, optimizer: Optimizer, mode: str, reduce_limit: int, factor: float, patience: int):
Expand All @@ -29,11 +31,11 @@ def step(self, metrics):
def num_reduce_lr(self) -> int:
return self._num_reduce_lr

def _reduce_lr(self, epoch):
def _reduce_lr(self):
self._num_reduce_lr += 1
self.apply_lr(epoch)
self.apply_lr()

def apply_lr(self, epoch=None):
def apply_lr(self):
if self._num_reduce_lr == 0:
return

Expand All @@ -43,24 +45,23 @@ def apply_lr(self, epoch=None):
if old_lr - new_lr > self.eps:
param_group["lr"] = new_lr
if self.verbose:
epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch
print("Epoch {}: reducing learning rate" " of group {} to {:.4e}.".format(epoch_str, i, new_lr))
logger.info(f"From ReduceLROnPLateauCappedDecreases, reducing learning rate to {new_lr}")


class LRScheduler:
def __init__(
self,
config: LRSchedulerConfig,
optimizer: Optimizer,
steps_per_checkpoint: int = 1000,
total_steps: int = 10000,
steps_per_checkpoint: int,
total_steps: int,
):
self.config = config
self.optimizer = optimizer

# Scheduler updated each training step
self.step_info = StepInfo(steps_per_checkpoint, total_steps, self.config)
self._train_scheduler = get_schedule_with_warmup(self.config, self.optimizer, self.step_info)
self._train_scheduler = get_schedule_with_warmup_and_decay(self.config, self.optimizer, self.step_info)

# Scheduler updated each eval step
self._eval_scheduler = None
Expand All @@ -74,13 +75,8 @@ def __init__(
patience=self.config.reduce_on_plateau_patience,
)

self.reset(steps_per_checkpoint, total_steps)

def reset(self, steps_per_checkpoint: int, total_steps: int):
# Retain state but update number of steps for training
self.step_info.reset(steps_per_checkpoint, total_steps)

def step(self):
"""Called every step of training."""
self._train_scheduler.step()

if self._eval_scheduler is not None:
Expand All @@ -90,6 +86,7 @@ def step(self):
self._eval_scheduler.apply_lr()

def eval_step(self, progress_tracker: ProgressTracker, validation_field: str):
"""Called every checkpoint evaluation step."""
if self._eval_scheduler is None:
# No reduce on plateau
return
Expand Down Expand Up @@ -140,14 +137,11 @@ class StepInfo:

def __init__(self, steps_per_checkpoint: int, total_steps: int, config: LRSchedulerConfig):
self.config = config
self.reset(steps_per_checkpoint, total_steps)

def reset(self, steps_per_checkpoint: int, total_steps: int):
self.steps_per_checkpoint = steps_per_checkpoint
self.num_training_steps = total_steps

if self.config.warmup_fraction > 0 and self.config.warmup_evaluations > 0:
logging.info(
logger.info(
"Both `learning_rate_scheduler.warmup_fraction` and `learning_rate_scheduler.warmup_evaluations` "
"provided. The larger of the two (as a function of the total training steps) will be used."
)
Expand All @@ -160,35 +154,34 @@ def reset(self, steps_per_checkpoint: int, total_steps: int):
self.num_warmup_steps = num_warmup_steps


def get_schedule_with_warmup(
def get_schedule_with_warmup_and_decay(
config: LRSchedulerConfig,
optimizer: Optimizer,
step_info: StepInfo,
) -> LambdaLR:
"""Creates a learning rate scheduler that updates each training step."""
schedulers = []

# Warmup scheduler
# Warmup scheduler.
if step_info.num_warmup_steps > 0:
warmup_scheduler = LambdaLR(
optimizer,
lambda current_step: float(current_step) / float(max(1, step_info.num_warmup_steps)),
last_epoch=-1,
)
schedulers.append(warmup_scheduler)

# Decay scheduler
# Decay scheduler.
decay = config.decay
decay_scheduler = decay_registry[decay](config, optimizer, step_info)
schedulers.append(decay_scheduler)

if len(schedulers) == 1:
# Only one scheduler, no need to wrap in a SequentialLR
# Only one scheduler, so no need to wrap in a SequentialLR.
return schedulers[0]

# Return a SequentialLR that applies the warmup and decay schedulers in order
# with the warmup scheduler only applied for the first num_warmup_steps steps.
return SequentialLR(optimizer, schedulers=schedulers, milestones=[step_info.num_warmup_steps], last_epoch=-1)
return SequentialLR(optimizer, schedulers=schedulers, milestones=[step_info.num_warmup_steps])


def no_decay(current_step: int, num_training_steps: int, num_warmup_steps: int, config: LRSchedulerConfig):
Expand Down Expand Up @@ -220,7 +213,6 @@ def init_fn(config: LRSchedulerConfig, optimizer: Optimizer, step_info: StepInfo
lambda current_step: decay_fn(
current_step, step_info.num_training_steps, step_info.num_warmup_steps, config
),
last_epoch=-1,
)

return init_fn
Expand All @@ -231,12 +223,19 @@ def init_cosine_decay(
optimizer: Optimizer,
step_info: StepInfo,
) -> CosineAnnealingWarmRestarts:
t_0 = config.t_0
if not t_0:
t_0 = step_info.steps_per_checkpoint
if not t_0:
# A scheduler may be initialized with dummy values like at the start of training.
# Ensure that t_0 != 0, as this causes an error to be raised.
t_0 = 1

return CosineAnnealingWarmRestarts(
optimizer,
T_0=config.t_0 or step_info.steps_per_checkpoint,
T_0=t_0,
T_mult=config.t_mult or 1,
eta_min=config.eta_min or 0,
last_epoch=-1,
)


Expand Down
18 changes: 12 additions & 6 deletions ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ def prepare(self):
self.config,
self.base_learning_rate,
)
self.scheduler = LRScheduler(self.config.learning_rate_scheduler, self.optimizer)

# NOTE: This is a partially configured LRScheduler. It will be updated in the first call to train_step.
self.scheduler = LRScheduler(self.config.learning_rate_scheduler, self.optimizer, 0, 0)

def train_step(
self, inputs: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor], should_step: bool = True
Expand Down Expand Up @@ -762,8 +764,13 @@ def train(
final_steps_per_checkpoint = min(final_steps_per_checkpoint, self.total_steps)
early_stopping_steps = final_steps_per_checkpoint * self.early_stop

# Update learning rate scheduler which depends on number of steps
self.scheduler.reset(final_steps_per_checkpoint, self.total_steps)
# Initialize the learning rate scheduler.
self.scheduler = LRScheduler(
self.config.learning_rate_scheduler,
self.optimizer,
steps_per_checkpoint=final_steps_per_checkpoint,
total_steps=self.total_steps,
)

if self.is_coordinator():
logger.info(
Expand Down Expand Up @@ -944,9 +951,8 @@ def _train_loop(
loss, all_losses = self.train_step(inputs, targets, should_step=should_step)
logger.info(f"Train loss for step {progress_tracker.steps}: {loss:.3f}")

if should_step:
# Update LR schduler here instead of train loop to avoid updating during batch size tuning, etc.
self.scheduler.step()
# Update LR schduler here instead of train loop to avoid updating during batch size tuning, etc.
self.scheduler.step()

if self.is_coordinator() and not self.skip_save_log:
self.write_step_summary(
Expand Down

0 comments on commit 0141421

Please sign in to comment.