Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
308 changes: 296 additions & 12 deletions src/sparseml/pytorch/optim/modifier_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import math
import sys
from typing import Dict, List, Union
from typing import Dict, List, Optional, Tuple, Union

from torch.nn import Module
from torch.optim.lr_scheduler import (
Expand All @@ -39,13 +39,17 @@
)
from sparseml.pytorch.utils import (
BaseLogger,
get_optim_learning_rate,
get_optim_groups_learning_rates,
set_optim_learning_rate,
)
from sparseml.utils import ALL_TOKEN, convert_to_bool


__all__ = ["SetLearningRateModifier", "LearningRateModifier"]
__all__ = [
"SetLearningRateModifier",
"LearningRateFunctionModifier",
"LearningRateModifier",
]


CONSTRUCTORS = {
Expand All @@ -57,12 +61,16 @@


def _log_lr(
cur_lr: float, loggers: List[BaseLogger], epoch: float, steps_per_epoch: int
group_lrs: List[Tuple[str, float]],
loggers: List[BaseLogger],
epoch: float,
steps_per_epoch: int,
):
step = round(epoch) if steps_per_epoch <= 0 else round(epoch * steps_per_epoch)

for logger in loggers:
logger.log_scalar("Modifier LR", cur_lr, step)
for (group_name, group_lr) in group_lrs:
logger.log_scalar(f"LearningRateModifier/{group_name}", group_lr, step)


@PyTorchModifierYAML()
Expand Down Expand Up @@ -93,6 +101,7 @@ class SetLearningRateModifier(ScheduledModifier, SetLearningRate):
def __init__(
self,
learning_rate: Union[float, None],
param_groups: Optional[List[int]] = None,
start_epoch: float = -1.0,
end_epoch: float = -1.0,
log_types: Union[str, List[str]] = ALL_TOKEN,
Expand All @@ -105,12 +114,29 @@ def __init__(
end_epoch=-1,
end_comparator=None,
)
self._param_groups = param_groups
self._lr_set = False
self._applied = -1.0
self._constant_logging = convert_to_bool(constant_logging)
self._last_logged_lr = None
self._last_logged_epoch = None

@ModifierProp()
def param_groups(self) -> Optional[List[int]]:
"""
:return: The param group indices to set the lr for within the optimizer,
if not set will set the lr for all param groups
"""
return self._param_groups

@param_groups.setter
def param_groups(self, value: Optional[List[int]]):
"""
:param value: The param group indices to set the lr for within the optimizer,
if not set will set the lr for all param groups
"""
self._param_groups = value

@ModifierProp()
def constant_logging(self) -> bool:
"""
Expand Down Expand Up @@ -165,16 +191,28 @@ def log_update(
(calculate batch number using this and epoch)
"""
super().log_update(module, optimizer, epoch, steps_per_epoch)
current_lr = get_optim_learning_rate(optimizer)
group_lrs = [
(f"ParamGroup{index}", lr)
for (index, lr) in enumerate(get_optim_groups_learning_rates(optimizer))
if not self.param_groups or index in self.param_groups
]

if not group_lrs:
raise ValueError(
"Could not find param groups in the optimizer "
f"for given param_groups {self.param_groups}"
)

current_lr = group_lrs[-1][1]

if (
self._constant_logging
or current_lr != self._last_logged_lr
or self._last_logged_lr != current_lr
or math.floor(epoch) != self._last_logged_epoch
):
self._last_logged_lr = current_lr
self._last_logged_epoch = math.floor(epoch)
_log_lr(current_lr, self.loggers, epoch, steps_per_epoch)
_log_lr(group_lrs, self.loggers, epoch, steps_per_epoch)

def _check_set_lr(self, optimizer: Optimizer, epoch: float):
if (
Expand All @@ -185,11 +223,249 @@ def _check_set_lr(self, optimizer: Optimizer, epoch: float):
and not self._lr_set
and self._learning_rate is not None
):
set_optim_learning_rate(optimizer, self.learning_rate)
self._applied = self._learning_rate
for (index, group) in enumerate(optimizer.param_groups):
if not self.param_groups or index in self.param_groups:
group["lr"] = self.learning_rate
self._applied = self.learning_rate
self._lr_set = True


@PyTorchModifierYAML()
class LearningRateFunctionModifier(ScheduledUpdateModifier):
"""
Modifier to set the learning rate based on supported math functions scaling between
an init_lr and a final_lr.
Any time an update point is reached, the LR is updated for the parameters groups
in the optimizer.
Specific parameter groups can be targeted for the optimizer as well.

| Sample yaml:
| !LearningRateFunctionModifier
| start_epoch: 0.0
| end_epoch: 10.0
| lr_func: linear
| init_lr: 0.1
| final_lr: 0.001

:param lr_func: The name of the lr function to use: [linear, cosine]
:param init_lr: The initial learning rate to use once this modifier starts
:param init_lr: The final learning rate to use once this modifier starts
:param start_epoch: The epoch to start the modifier at
(set to -1.0 so it starts immediately)
:param end_epoch: The epoch to end the modifier at,
(set to -1.0 so it doesn't end)
:param_groups: The param group indices to set the lr for within the optimizer,
if not set will set the lr for all param groups
:param update_frequency: unused and should not be set
:param log_types: The loggers to allow the learning rate to be logged to,
default is __ALL__
:param constant_logging: True to constantly log on every step,
False to only log on an LR change and min once per epoch, default False
"""

def __init__(
self,
lr_func: str,
init_lr: float,
final_lr: float,
start_epoch: float,
end_epoch: float,
param_groups: Optional[List[int]] = None,
update_frequency: float = -1.0,
log_types: Union[str, List[str]] = ALL_TOKEN,
):
super().__init__(
log_types=log_types,
start_epoch=start_epoch,
end_epoch=end_epoch,
update_frequency=-1.0,
end_comparator=1,
)
self._lr_func = lr_func
self._init_lr = init_lr
self._final_lr = final_lr
self._param_groups = param_groups
self._learning_rate = None
self._last_applied_lr = None
self._last_logged_lr = None
self._last_logged_epoch = None
self.validate()

@ModifierProp()
def lr_func(self) -> str:
"""
:return: The name of the lr function to use: [linear, cosine]
"""
return self._lr_func

@lr_func.setter
def lr_func(self, value: str):
"""
:param value: The name of the lr function to use: [linear, cosine]
"""
self._lr_func = value
self.validate()

@ModifierProp()
def init_lr(self) -> float:
"""
:return: The initial learning rate to use once this modifier starts
"""
return self._init_lr

@init_lr.setter
def init_lr(self, value: float):
"""
:param value: The initial learning rate to use once this modifier starts
"""
self._init_lr = value
self.validate()

@ModifierProp()
def final_lr(self) -> float:
"""
:return: The final learning rate to use once this modifier starts
"""
return self._final_lr

@final_lr.setter
def final_lr(self, value: float):
"""
:param value: The final learning rate to use once this modifier starts
"""
self._final_lr = value
self.validate()

@ModifierProp()
def param_groups(self) -> Optional[List[int]]:
"""
:return: The param group indices to set the lr for within the optimizer,
if not set will set the lr for all param groups
"""
return self._param_groups

@param_groups.setter
def param_groups(self, value: Optional[List[int]]):
"""
:param value: The param group indices to set the lr for within the optimizer,
if not set will set the lr for all param groups
"""
self._param_groups = value
self.validate()

def update(
self, module: Module, optimizer: Optimizer, epoch: float, steps_per_epoch: int
):
"""
Updates the LR based on the given epoch for the optimizer

:param module: module to modify
:param optimizer: optimizer to modify
:param epoch: current epoch and progress within the current epoch
:param steps_per_epoch: number of steps taken within each epoch
(calculate batch number using this and epoch)
"""
super().update(module, optimizer, epoch, steps_per_epoch)
lambad_func = getattr(LearningRateFunctionModifier, f"_{self._lr_func}")
self._learning_rate = lambad_func(self, epoch)
set_optim_learning_rate(optimizer, self._learning_rate, self.param_groups)

def log_update(
self, module: Module, optimizer: Optimizer, epoch: float, steps_per_epoch: int
):
"""
Check whether to log an update for the learning rate of the modifier.
Checks for a change in the LR or epoch before logging

:param module: module to modify
:param optimizer: optimizer to modify
:param epoch: current epoch and progress within the current epoch
:param steps_per_epoch: number of steps taken within each epoch
(calculate batch number using this and epoch)
"""
super().log_update(module, optimizer, epoch, steps_per_epoch)
group_lrs = [
(f"ParamGroup{index}", lr)
for (index, lr) in enumerate(get_optim_groups_learning_rates(optimizer))
if not self.param_groups or index in self.param_groups
]

if not group_lrs:
raise ValueError(
"Could not find param groups in the optimizer "
f"for given param_groups {self.param_groups}"
)

current_lr = group_lrs[-1][1]

if (
current_lr != self._last_logged_lr
or math.floor(epoch) != self._last_logged_epoch
):
_log_lr(group_lrs, self.loggers, epoch, steps_per_epoch)
self._last_logged_lr = current_lr
self._last_logged_epoch = math.floor(epoch)

def validate(self):
"""
Validate the values of the params for the current instance are valid
"""
lr_funcs = ["linear", "cosine"]
if self.lr_func not in lr_funcs:
raise ValueError(f"lr_func must be one of {lr_funcs}")

if (
(not self.init_lr and self.init_lr != 0)
or self.init_lr < 0.0
or self.init_lr > 1.0
):
raise ValueError(
f"init_lr must be within range [0.0, 1.0], given {self.init_lr}"
)

if (
(not self.final_lr and self.final_lr != 0)
or self.final_lr < 0.0
or self.final_lr > 1.0
):
raise ValueError(
f"final_lr must be within range [0.0, 1.0], given {self.final_lr}"
)

if self.update_frequency != -1.0:
raise ValueError("update_frequency must be kept at -1.0")

def _linear(self, epoch: float) -> float:
# y = y1 + ((x – x1) / (x2 – x1)) * (y2 – y1)
start = self.start_epoch if self.start_epoch > 0 else 0.0
end = self.end_epoch

return self.init_lr + ((epoch - start) / (end - start)) * (
self.final_lr - self.init_lr
)

def _cosine(self, epoch: float) -> float:
start = self.start_epoch if self.start_epoch > 0 else 0.0
end = self.end_epoch

# scale x to [0-1] for use with cosine
x_norm = (epoch - start) / (end - start)

# conditional to support cosine down to a value and up to a value
if self.final_lr < self.init_lr:
y_range = self.init_lr - self.final_lr
y_shift = self.final_lr
x_shift = 0
else:
y_range = self.final_lr - self.init_lr
y_shift = self.init_lr
x_shift = math.pi

return (
math.cos(x_norm * math.pi + x_shift) * y_range / 2 + y_range / 2 + y_shift
)


@PyTorchModifierYAML()
class LearningRateModifier(ScheduledUpdateModifier, LearningRate):
"""
Expand Down Expand Up @@ -337,7 +613,15 @@ def log_update(
(calculate batch number using this and epoch)
"""
super().log_update(module, optimizer, epoch, steps_per_epoch)
current_lr = get_optim_learning_rate(optimizer)
group_lrs = [
(f"ParamGroup{index}", lr)
for (index, lr) in enumerate(get_optim_groups_learning_rates(optimizer))
]

if not group_lrs:
raise ValueError("Could not find any param groups in the optimizer")

current_lr = group_lrs[-1][1]

if (
self._constant_logging
Expand All @@ -346,7 +630,7 @@ def log_update(
):
self._last_logged_lr = current_lr
self._last_logged_epoch = math.floor(epoch)
_log_lr(current_lr, self.loggers, epoch, steps_per_epoch)
_log_lr(group_lrs, self.loggers, epoch, steps_per_epoch)

def validate(self):
"""
Expand Down
Loading