diff --git a/src/sparseml/pytorch/optim/modifier_lr.py b/src/sparseml/pytorch/optim/modifier_lr.py index 338cfb0e1aa..c2c6fa6b822 100644 --- a/src/sparseml/pytorch/optim/modifier_lr.py +++ b/src/sparseml/pytorch/optim/modifier_lr.py @@ -19,7 +19,7 @@ import math import sys -from typing import Dict, List, Union +from typing import Dict, List, Optional, Tuple, Union from torch.nn import Module from torch.optim.lr_scheduler import ( @@ -39,13 +39,17 @@ ) from sparseml.pytorch.utils import ( BaseLogger, - get_optim_learning_rate, + get_optim_groups_learning_rates, set_optim_learning_rate, ) from sparseml.utils import ALL_TOKEN, convert_to_bool -__all__ = ["SetLearningRateModifier", "LearningRateModifier"] +__all__ = [ + "SetLearningRateModifier", + "LearningRateFunctionModifier", + "LearningRateModifier", +] CONSTRUCTORS = { @@ -57,12 +61,16 @@ def _log_lr( - cur_lr: float, loggers: List[BaseLogger], epoch: float, steps_per_epoch: int + group_lrs: List[Tuple[str, float]], + loggers: List[BaseLogger], + epoch: float, + steps_per_epoch: int, ): step = round(epoch) if steps_per_epoch <= 0 else round(epoch * steps_per_epoch) for logger in loggers: - logger.log_scalar("Modifier LR", cur_lr, step) + for (group_name, group_lr) in group_lrs: + logger.log_scalar(f"LearningRateModifier/{group_name}", group_lr, step) @PyTorchModifierYAML() @@ -93,6 +101,7 @@ class SetLearningRateModifier(ScheduledModifier, SetLearningRate): def __init__( self, learning_rate: Union[float, None], + param_groups: Optional[List[int]] = None, start_epoch: float = -1.0, end_epoch: float = -1.0, log_types: Union[str, List[str]] = ALL_TOKEN, @@ -105,12 +114,29 @@ def __init__( end_epoch=-1, end_comparator=None, ) + self._param_groups = param_groups self._lr_set = False self._applied = -1.0 self._constant_logging = convert_to_bool(constant_logging) self._last_logged_lr = None self._last_logged_epoch = None + @ModifierProp() + def param_groups(self) -> Optional[List[int]]: + """ + :return: The param group indices to set the lr for within the optimizer, + if not set will set the lr for all param groups + """ + return self._param_groups + + @param_groups.setter + def param_groups(self, value: Optional[List[int]]): + """ + :param value: The param group indices to set the lr for within the optimizer, + if not set will set the lr for all param groups + """ + self._param_groups = value + @ModifierProp() def constant_logging(self) -> bool: """ @@ -165,16 +191,28 @@ def log_update( (calculate batch number using this and epoch) """ super().log_update(module, optimizer, epoch, steps_per_epoch) - current_lr = get_optim_learning_rate(optimizer) + group_lrs = [ + (f"ParamGroup{index}", lr) + for (index, lr) in enumerate(get_optim_groups_learning_rates(optimizer)) + if not self.param_groups or index in self.param_groups + ] + + if not group_lrs: + raise ValueError( + "Could not find param groups in the optimizer " + f"for given param_groups {self.param_groups}" + ) + + current_lr = group_lrs[-1][1] if ( self._constant_logging - or current_lr != self._last_logged_lr + or self._last_logged_lr != current_lr or math.floor(epoch) != self._last_logged_epoch ): self._last_logged_lr = current_lr self._last_logged_epoch = math.floor(epoch) - _log_lr(current_lr, self.loggers, epoch, steps_per_epoch) + _log_lr(group_lrs, self.loggers, epoch, steps_per_epoch) def _check_set_lr(self, optimizer: Optimizer, epoch: float): if ( @@ -185,11 +223,249 @@ def _check_set_lr(self, optimizer: Optimizer, epoch: float): and not self._lr_set and self._learning_rate is not None ): - set_optim_learning_rate(optimizer, self.learning_rate) - self._applied = self._learning_rate + for (index, group) in enumerate(optimizer.param_groups): + if not self.param_groups or index in self.param_groups: + group["lr"] = self.learning_rate + self._applied = self.learning_rate self._lr_set = True +@PyTorchModifierYAML() +class LearningRateFunctionModifier(ScheduledUpdateModifier): + """ + Modifier to set the learning rate based on supported math functions scaling between + an init_lr and a final_lr. + Any time an update point is reached, the LR is updated for the parameters groups + in the optimizer. + Specific parameter groups can be targeted for the optimizer as well. + + | Sample yaml: + | !LearningRateFunctionModifier + | start_epoch: 0.0 + | end_epoch: 10.0 + | lr_func: linear + | init_lr: 0.1 + | final_lr: 0.001 + + :param lr_func: The name of the lr function to use: [linear, cosine] + :param init_lr: The initial learning rate to use once this modifier starts + :param init_lr: The final learning rate to use once this modifier starts + :param start_epoch: The epoch to start the modifier at + (set to -1.0 so it starts immediately) + :param end_epoch: The epoch to end the modifier at, + (set to -1.0 so it doesn't end) + :param_groups: The param group indices to set the lr for within the optimizer, + if not set will set the lr for all param groups + :param update_frequency: unused and should not be set + :param log_types: The loggers to allow the learning rate to be logged to, + default is __ALL__ + :param constant_logging: True to constantly log on every step, + False to only log on an LR change and min once per epoch, default False + """ + + def __init__( + self, + lr_func: str, + init_lr: float, + final_lr: float, + start_epoch: float, + end_epoch: float, + param_groups: Optional[List[int]] = None, + update_frequency: float = -1.0, + log_types: Union[str, List[str]] = ALL_TOKEN, + ): + super().__init__( + log_types=log_types, + start_epoch=start_epoch, + end_epoch=end_epoch, + update_frequency=-1.0, + end_comparator=1, + ) + self._lr_func = lr_func + self._init_lr = init_lr + self._final_lr = final_lr + self._param_groups = param_groups + self._learning_rate = None + self._last_applied_lr = None + self._last_logged_lr = None + self._last_logged_epoch = None + self.validate() + + @ModifierProp() + def lr_func(self) -> str: + """ + :return: The name of the lr function to use: [linear, cosine] + """ + return self._lr_func + + @lr_func.setter + def lr_func(self, value: str): + """ + :param value: The name of the lr function to use: [linear, cosine] + """ + self._lr_func = value + self.validate() + + @ModifierProp() + def init_lr(self) -> float: + """ + :return: The initial learning rate to use once this modifier starts + """ + return self._init_lr + + @init_lr.setter + def init_lr(self, value: float): + """ + :param value: The initial learning rate to use once this modifier starts + """ + self._init_lr = value + self.validate() + + @ModifierProp() + def final_lr(self) -> float: + """ + :return: The final learning rate to use once this modifier starts + """ + return self._final_lr + + @final_lr.setter + def final_lr(self, value: float): + """ + :param value: The final learning rate to use once this modifier starts + """ + self._final_lr = value + self.validate() + + @ModifierProp() + def param_groups(self) -> Optional[List[int]]: + """ + :return: The param group indices to set the lr for within the optimizer, + if not set will set the lr for all param groups + """ + return self._param_groups + + @param_groups.setter + def param_groups(self, value: Optional[List[int]]): + """ + :param value: The param group indices to set the lr for within the optimizer, + if not set will set the lr for all param groups + """ + self._param_groups = value + self.validate() + + def update( + self, module: Module, optimizer: Optimizer, epoch: float, steps_per_epoch: int + ): + """ + Updates the LR based on the given epoch for the optimizer + + :param module: module to modify + :param optimizer: optimizer to modify + :param epoch: current epoch and progress within the current epoch + :param steps_per_epoch: number of steps taken within each epoch + (calculate batch number using this and epoch) + """ + super().update(module, optimizer, epoch, steps_per_epoch) + lambad_func = getattr(LearningRateFunctionModifier, f"_{self._lr_func}") + self._learning_rate = lambad_func(self, epoch) + set_optim_learning_rate(optimizer, self._learning_rate, self.param_groups) + + def log_update( + self, module: Module, optimizer: Optimizer, epoch: float, steps_per_epoch: int + ): + """ + Check whether to log an update for the learning rate of the modifier. + Checks for a change in the LR or epoch before logging + + :param module: module to modify + :param optimizer: optimizer to modify + :param epoch: current epoch and progress within the current epoch + :param steps_per_epoch: number of steps taken within each epoch + (calculate batch number using this and epoch) + """ + super().log_update(module, optimizer, epoch, steps_per_epoch) + group_lrs = [ + (f"ParamGroup{index}", lr) + for (index, lr) in enumerate(get_optim_groups_learning_rates(optimizer)) + if not self.param_groups or index in self.param_groups + ] + + if not group_lrs: + raise ValueError( + "Could not find param groups in the optimizer " + f"for given param_groups {self.param_groups}" + ) + + current_lr = group_lrs[-1][1] + + if ( + current_lr != self._last_logged_lr + or math.floor(epoch) != self._last_logged_epoch + ): + _log_lr(group_lrs, self.loggers, epoch, steps_per_epoch) + self._last_logged_lr = current_lr + self._last_logged_epoch = math.floor(epoch) + + def validate(self): + """ + Validate the values of the params for the current instance are valid + """ + lr_funcs = ["linear", "cosine"] + if self.lr_func not in lr_funcs: + raise ValueError(f"lr_func must be one of {lr_funcs}") + + if ( + (not self.init_lr and self.init_lr != 0) + or self.init_lr < 0.0 + or self.init_lr > 1.0 + ): + raise ValueError( + f"init_lr must be within range [0.0, 1.0], given {self.init_lr}" + ) + + if ( + (not self.final_lr and self.final_lr != 0) + or self.final_lr < 0.0 + or self.final_lr > 1.0 + ): + raise ValueError( + f"final_lr must be within range [0.0, 1.0], given {self.final_lr}" + ) + + if self.update_frequency != -1.0: + raise ValueError("update_frequency must be kept at -1.0") + + def _linear(self, epoch: float) -> float: + # y = y1 + ((x – x1) / (x2 – x1)) * (y2 – y1) + start = self.start_epoch if self.start_epoch > 0 else 0.0 + end = self.end_epoch + + return self.init_lr + ((epoch - start) / (end - start)) * ( + self.final_lr - self.init_lr + ) + + def _cosine(self, epoch: float) -> float: + start = self.start_epoch if self.start_epoch > 0 else 0.0 + end = self.end_epoch + + # scale x to [0-1] for use with cosine + x_norm = (epoch - start) / (end - start) + + # conditional to support cosine down to a value and up to a value + if self.final_lr < self.init_lr: + y_range = self.init_lr - self.final_lr + y_shift = self.final_lr + x_shift = 0 + else: + y_range = self.final_lr - self.init_lr + y_shift = self.init_lr + x_shift = math.pi + + return ( + math.cos(x_norm * math.pi + x_shift) * y_range / 2 + y_range / 2 + y_shift + ) + + @PyTorchModifierYAML() class LearningRateModifier(ScheduledUpdateModifier, LearningRate): """ @@ -337,7 +613,15 @@ def log_update( (calculate batch number using this and epoch) """ super().log_update(module, optimizer, epoch, steps_per_epoch) - current_lr = get_optim_learning_rate(optimizer) + group_lrs = [ + (f"ParamGroup{index}", lr) + for (index, lr) in enumerate(get_optim_groups_learning_rates(optimizer)) + ] + + if not group_lrs: + raise ValueError("Could not find any param groups in the optimizer") + + current_lr = group_lrs[-1][1] if ( self._constant_logging @@ -346,7 +630,7 @@ def log_update( ): self._last_logged_lr = current_lr self._last_logged_epoch = math.floor(epoch) - _log_lr(current_lr, self.loggers, epoch, steps_per_epoch) + _log_lr(group_lrs, self.loggers, epoch, steps_per_epoch) def validate(self): """ diff --git a/src/sparseml/pytorch/utils/helpers.py b/src/sparseml/pytorch/utils/helpers.py index fd91fd15c27..705f3e277b5 100644 --- a/src/sparseml/pytorch/utils/helpers.py +++ b/src/sparseml/pytorch/utils/helpers.py @@ -21,7 +21,7 @@ from collections import OrderedDict, namedtuple from contextlib import contextmanager from copy import deepcopy -from typing import Any, Dict, Iterable, List, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy import torch @@ -43,6 +43,7 @@ __all__ = [ "default_device", "get_optim_learning_rate", + "get_optim_groups_learning_rates", "set_optim_learning_rate", "early_stop_data_loader", "infinite_data_loader", @@ -117,14 +118,27 @@ def get_optim_learning_rate(optim: Optimizer) -> float: raise RuntimeError("cannot get learning_rate, no param_groups available") -def set_optim_learning_rate(optim: Optimizer, value: float): +def get_optim_groups_learning_rates(optim: Optimizer) -> List[float]: + """ + :param optim: The optimizer to get the learning rates for + + :return: get a list of tuples corresponding to the learning rates for the + param groups in the optimizer + """ + return [group["lr"] for group in optim.param_groups] + + +def set_optim_learning_rate( + optim: Optimizer, value: float, groups: Optional[List[int]] = None +): """ :param optim: The optimizer to set the learning rate for :param value: the learning rate to set for the optimizer, will set all param groups in the optim to this value """ - for param_group in optim.param_groups: - param_group["lr"] = value + for (index, group) in enumerate(optim.param_groups): + if not groups or index in groups: + group["lr"] = value ############################## diff --git a/tests/sparseml/pytorch/optim/test_modifier_lr.py b/tests/sparseml/pytorch/optim/test_modifier_lr.py index 81b5a798584..200e5561496 100644 --- a/tests/sparseml/pytorch/optim/test_modifier_lr.py +++ b/tests/sparseml/pytorch/optim/test_modifier_lr.py @@ -19,7 +19,11 @@ import pytest from torch.optim import SGD, Adam -from sparseml.pytorch.optim import LearningRateModifier, SetLearningRateModifier +from sparseml.pytorch.optim import ( + LearningRateFunctionModifier, + LearningRateModifier, + SetLearningRateModifier, +) from sparseml.pytorch.utils import get_optim_learning_rate from tests.sparseml.pytorch.helpers import LinearNet from tests.sparseml.pytorch.optim.test_modifier import ( @@ -154,6 +158,223 @@ def test_set_lr_yaml(): ) +############################## +# +# LearningRateFunctionModifier tests +# +############################## + + +@pytest.mark.parametrize( + "modifier_lambda", + [ + lambda: LearningRateFunctionModifier( + lr_func="linear", + init_lr=0.1, + final_lr=0.001, + start_epoch=0, + end_epoch=10, + ), + lambda: LearningRateFunctionModifier( + lr_func="linear", + init_lr=0.1, + final_lr=0.001, + start_epoch=5, + end_epoch=10, + ), + ], + scope="function", +) +@pytest.mark.parametrize("model_lambda", [LinearNet], scope="function") +@pytest.mark.parametrize( + "optim_lambda", + [ + lambda model: SGD(model.parameters(), INIT_LR), + lambda model: Adam(model.parameters(), INIT_LR), + ], + scope="function", +) +class TestLearningRateFunctionModifierLinearImpl(ScheduledUpdateModifierTest): + def test_lifecycle( + self, + modifier_lambda, + model_lambda, + optim_lambda, + test_steps_per_epoch, # noqa: F811 + ): + modifier = modifier_lambda() + model = model_lambda() + optimizer = optim_lambda(model) + self.initialize_helper(modifier, model) + assert get_optim_learning_rate(optimizer) == INIT_LR + last = 1.0 + + for epoch in range(int(modifier.end_epoch) + 5): + for step in range(test_steps_per_epoch): + epoch_test = float(epoch) + float(step) / float(test_steps_per_epoch) + + if epoch_test < modifier.start_epoch: + expected = INIT_LR + assert not modifier.update_ready(epoch_test, test_steps_per_epoch) + elif epoch_test <= modifier.end_epoch: + expected = modifier.init_lr + ( + (epoch_test - modifier.start_epoch) + / (modifier.end_epoch - modifier.start_epoch) + ) * (modifier.final_lr - modifier.init_lr) + assert modifier.update_ready(epoch_test, test_steps_per_epoch) + modifier.scheduled_update( + model, optimizer, epoch_test, test_steps_per_epoch + ) + assert expected <= last, f"Failed at epoch:{epoch} step:{step}" + last = expected + else: + expected = modifier.final_lr + assert not modifier.update_ready(epoch_test, test_steps_per_epoch) + + assert ( + abs(get_optim_learning_rate(optimizer) - expected) < EPSILON + ), f"Failed at epoch:{epoch} step:{step}" + + +@pytest.mark.parametrize( + "modifier_lambda", + [ + lambda: LearningRateFunctionModifier( + lr_func="cosine", + init_lr=0.1, + final_lr=0.001, + start_epoch=0, + end_epoch=10, + ), + lambda: LearningRateFunctionModifier( + lr_func="cosine", + init_lr=0.1, + final_lr=0.001, + start_epoch=5, + end_epoch=10, + ), + ], + scope="function", +) +@pytest.mark.parametrize("model_lambda", [LinearNet], scope="function") +@pytest.mark.parametrize( + "optim_lambda", + [ + lambda model: SGD(model.parameters(), INIT_LR), + lambda model: Adam(model.parameters(), INIT_LR), + ], + scope="function", +) +class TestLearningRateFunctionModifierCosineImpl(ScheduledUpdateModifierTest): + def test_lifecycle( + self, + modifier_lambda, + model_lambda, + optim_lambda, + test_steps_per_epoch, # noqa: F811 + ): + modifier = modifier_lambda() + model = model_lambda() + optimizer = optim_lambda(model) + self.initialize_helper(modifier, model) + assert get_optim_learning_rate(optimizer) == INIT_LR + last = 1.0 + + for epoch in range(int(modifier.end_epoch) + 5): + for step in range(test_steps_per_epoch): + epoch_test = float(epoch) + float(step) / float(test_steps_per_epoch) + + if epoch_test < modifier.start_epoch: + expected = INIT_LR + assert not modifier.update_ready(epoch_test, test_steps_per_epoch) + elif epoch_test <= modifier.end_epoch: + expected = ( + math.cos( + ( + (epoch_test - modifier.start_epoch) + / (modifier.end_epoch - modifier.start_epoch) + ) + * math.pi + ) + * (modifier.init_lr - modifier.final_lr) + / 2 + + (modifier.init_lr - modifier.final_lr) / 2 + + modifier.final_lr + ) + assert modifier.update_ready(epoch_test, test_steps_per_epoch) + modifier.scheduled_update( + model, optimizer, epoch_test, test_steps_per_epoch + ) + assert expected <= last, f"Failed at epoch:{epoch} step:{step}" + last = expected + else: + expected = modifier.final_lr + assert not modifier.update_ready(epoch_test, test_steps_per_epoch) + + assert ( + abs(get_optim_learning_rate(optimizer) - expected) < EPSILON + ), f"Failed at epoch:{epoch} step:{step}" + + +def test_lr_function_modifier_yaml(): + lr_func = "linear" + start_epoch = 10.0 + end_epoch = 20.0 + init_lr = 0.1 + final_lr = 0.001 + param_groups = [0, 1] + yaml_str = f""" + !LearningRateFunctionModifier + start_epoch: {start_epoch} + end_epoch: {end_epoch} + lr_func: {lr_func} + init_lr: {init_lr} + final_lr: {final_lr} + param_groups: {param_groups} + """ + yaml_modifier = LearningRateFunctionModifier.load_obj( + yaml_str + ) # type: LearningRateFunctionModifier + serialized_modifier = LearningRateFunctionModifier.load_obj( + str(yaml_modifier) + ) # type: LearningRateFunctionModifier + obj_modifier = LearningRateFunctionModifier( + start_epoch=start_epoch, + end_epoch=end_epoch, + lr_func=lr_func, + init_lr=init_lr, + final_lr=final_lr, + param_groups=param_groups, + ) + + assert isinstance(yaml_modifier, LearningRateFunctionModifier) + assert ( + yaml_modifier.start_epoch + == serialized_modifier.start_epoch + == obj_modifier.start_epoch + ) + assert ( + yaml_modifier.end_epoch + == serialized_modifier.end_epoch + == obj_modifier.end_epoch + ) + assert ( + yaml_modifier.update_frequency + == serialized_modifier.update_frequency + == obj_modifier.update_frequency + ) + assert yaml_modifier.lr_func == serialized_modifier.lr_func == obj_modifier.lr_func + assert yaml_modifier.init_lr == serialized_modifier.init_lr == obj_modifier.init_lr + assert ( + yaml_modifier.final_lr == serialized_modifier.final_lr == obj_modifier.final_lr + ) + assert ( + yaml_modifier.param_groups + == serialized_modifier.param_groups + == obj_modifier.param_groups + ) + + ############################## # # LearningRateModifier tests