diff --git a/src/sparseml/pytorch/optim/modifier_pruning.py b/src/sparseml/pytorch/optim/modifier_pruning.py index 0b1bd58684e..6fb55a6467c 100644 --- a/src/sparseml/pytorch/optim/modifier_pruning.py +++ b/src/sparseml/pytorch/optim/modifier_pruning.py @@ -493,6 +493,9 @@ class GMPruningModifier(_PruningParamsModifier): immediately after or doing some other prune :param inter_func: the type of interpolation function to use: [linear, cubic, inverse_cubic] + :param phased: True to enable a phased approach where pruning will + turn on and off with the update_frequency. Starts with pruning on + at start_epoch, off at start_epoch + update_frequency, and so on. :param log_types: The loggers to allow the learning rate to be logged to, default is __ALL__ :param mask_type: String to define type of sparsity (options: ['unstructured', @@ -514,6 +517,7 @@ def __init__( params: Union[str, List[str]], leave_enabled: bool = True, inter_func: str = "cubic", + phased: bool = False, log_types: Union[str, List[str]] = ALL_TOKEN, mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured", global_sparsity: bool = False, @@ -531,6 +535,7 @@ def __init__( self._final_sparsity = final_sparsity self._leave_enabled = convert_to_bool(leave_enabled) self._inter_func = inter_func + self._phased = phased self._mask_type = mask_type self._mask_creator = ( mask_type @@ -612,6 +617,24 @@ def inter_func(self, value: str): self._inter_func = value self.validate() + @ModifierProp() + def phased(self) -> bool: + """ + :return: True to enable a phased approach where pruning will + turn on and off with the update_frequency. Starts with pruning on + at start_epoch, off at start_epoch + update_frequency, and so on. + """ + return self._phased + + @phased.setter + def phased(self, value: bool): + """ + :param value: the type of interpolation function to use: + [linear, cubic, inverse_cubic] + """ + self._phased = value + self.validate() + @ModifierProp() def mask_type(self) -> Union[str, List[int], PruningMaskCreator]: """ @@ -763,6 +786,16 @@ def _check_mask_update(self, module: Module, epoch: float, steps_per_epoch: int) self._final_sparsity, self._inter_func, ) + + # make sure if phased that the phases end at the final sparsity + # if it doesn't divide evenly + if self.phased and not self.end_pending(epoch, steps_per_epoch): + # adjust for phased pruning: start=on, start+update=off + phase = math.floor((epoch - self.start_epoch) / self.update_frequency) + if phase % 2 != 0: + # odd update phase, turn sparsity off + self._applied_sparsity = 0.0 + self._module_masks.set_param_masks_from_sparsity(self._applied_sparsity) if self.end_pending(epoch, steps_per_epoch): @@ -843,6 +876,9 @@ class MagnitudePruningModifier(GMPruningModifier): immediately after or doing some other prune :param inter_func: the type of interpolation function to use: [linear, cubic, inverse_cubic] + :param phased: True to enable a phased approach where pruning will + turn on and off with the update_frequency. Starts with pruning on + at start_epoch, off at start_epoch + update_frequency, and so on. :param log_types: The loggers to allow the learning rate to be logged to, default is __ALL__ :param mask_type: String to define type of sparsity (options: ['unstructured', @@ -860,6 +896,7 @@ def __init__( params: Union[str, List[str]], leave_enabled: bool = True, inter_func: str = "cubic", + phased: bool = False, log_types: Union[str, List[str]] = ALL_TOKEN, mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured", ): @@ -872,6 +909,7 @@ def __init__( params=params, leave_enabled=leave_enabled, inter_func=inter_func, + phased=phased, log_types=log_types, mask_type=mask_type, global_sparsity=False, @@ -933,6 +971,9 @@ class MovementPruningModifier(GMPruningModifier): immediately after or doing some other prune :param inter_func: the type of interpolation function to use: [linear, cubic, inverse_cubic] + :param phased: True to enable a phased approach where pruning will + turn on and off with the update_frequency. Starts with pruning on + at start_epoch, off at start_epoch + update_frequency, and so on. :param log_types: The loggers to allow the learning rate to be logged to, default is __ALL__ :param mask_type: String to define type of sparsity (options: ['unstructured', @@ -950,6 +991,7 @@ def __init__( params: Union[str, List[str]], leave_enabled: bool = True, inter_func: str = "cubic", + phased: bool = False, log_types: Union[str, List[str]] = ALL_TOKEN, mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured", ): @@ -962,6 +1004,7 @@ def __init__( params=params, leave_enabled=leave_enabled, inter_func=inter_func, + phased=phased, log_types=log_types, mask_type=mask_type, global_sparsity=False, @@ -1024,6 +1067,9 @@ class GlobalMagnitudePruningModifier(GMPruningModifier): immediately after or doing some other prune :param inter_func: the type of interpolation function to use: [linear, cubic, inverse_cubic] + :param phased: True to enable a phased approach where pruning will + turn on and off with the update_frequency. Starts with pruning on + at start_epoch, off at start_epoch + update_frequency, and so on. :param log_types: The loggers to allow the learning rate to be logged to, default is __ALL__ :param mask_type: String to define type of sparsity (options: ['unstructured', @@ -1043,6 +1089,7 @@ def __init__( params: Union[str, List[str]] = ALL_PRUNABLE_TOKEN, leave_enabled: bool = True, inter_func: str = "cubic", + phased: bool = False, log_types: Union[str, List[str]] = ALL_TOKEN, mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured", score_type: Union[str, MFACOptions] = "magnitude", @@ -1056,6 +1103,7 @@ def __init__( params=params, leave_enabled=leave_enabled, inter_func=inter_func, + phased=phased, log_types=log_types, mask_type=mask_type, global_sparsity=True, @@ -1115,6 +1163,9 @@ class MFACPruningModifier(GMPruningModifier): immediately after or doing some other prune :param inter_func: the type of interpolation function to use: [linear, cubic, inverse_cubic] + :param phased: True to enable a phased approach where pruning will + turn on and off with the update_frequency. Starts with pruning on + at start_epoch, off at start_epoch + update_frequency, and so on. :param log_types: The loggers to allow the learning rate to be logged to, default is __ALL__ :param mask_type: String to define type of sparsity (options: ['unstructured', @@ -1139,6 +1190,7 @@ def __init__( params: Union[str, List[str]], leave_enabled: bool = True, inter_func: str = "cubic", + phased: bool = False, log_types: Union[str, List[str]] = ALL_TOKEN, mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured", mfac_options: Dict[str, Any] = None, @@ -1152,6 +1204,7 @@ def __init__( params=params, leave_enabled=leave_enabled, inter_func=inter_func, + phased=phased, log_types=log_types, mask_type=mask_type, global_sparsity=True, diff --git a/tests/sparseml/pytorch/optim/test_modifier_pruning.py b/tests/sparseml/pytorch/optim/test_modifier_pruning.py index 9ef2c98614a..9b02d1c651e 100644 --- a/tests/sparseml/pytorch/optim/test_modifier_pruning.py +++ b/tests/sparseml/pytorch/optim/test_modifier_pruning.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import os import pytest @@ -246,6 +247,16 @@ def test_constant_pruning_yaml(): inter_func="cubic", mask_type=[1, 4], ), + lambda: GMPruningModifier( + params=["__ALL_PRUNABLE__"], + init_sparsity=0.9, + final_sparsity=0.9, + start_epoch=10.0, + end_epoch=25.0, + update_frequency=2.0, + inter_func="cubic", + phased=True, + ), ], scope="function", ) @@ -294,7 +305,22 @@ def test_lifecycle( epoch += modifier.update_frequency assert modifier.update_ready(epoch, test_steps_per_epoch) modifier.scheduled_update(model, optimizer, epoch, test_steps_per_epoch) - assert modifier.applied_sparsity > last_sparsity + + if not modifier.phased: + assert modifier.applied_sparsity > last_sparsity + else: + pruned_on = ( + math.floor( + (epoch - modifier.start_epoch) / modifier.update_frequency + ) + % 2 + == 0 + ) + if pruned_on: + assert modifier.applied_sparsity >= last_sparsity + else: + assert modifier.applied_sparsity == 0 + last_sparsity = modifier.applied_sparsity _ = model(test_batch) # check forward pass