Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions src/sparseml/pytorch/optim/modifier_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,9 @@ class GMPruningModifier(_PruningParamsModifier):
immediately after or doing some other prune
:param inter_func: the type of interpolation function to use:
[linear, cubic, inverse_cubic]
:param phased: True to enable a phased approach where pruning will
turn on and off with the update_frequency. Starts with pruning on
at start_epoch, off at start_epoch + update_frequency, and so on.
:param log_types: The loggers to allow the learning rate to be logged to,
default is __ALL__
:param mask_type: String to define type of sparsity (options: ['unstructured',
Expand All @@ -514,6 +517,7 @@ def __init__(
params: Union[str, List[str]],
leave_enabled: bool = True,
inter_func: str = "cubic",
phased: bool = False,
log_types: Union[str, List[str]] = ALL_TOKEN,
mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured",
global_sparsity: bool = False,
Expand All @@ -531,6 +535,7 @@ def __init__(
self._final_sparsity = final_sparsity
self._leave_enabled = convert_to_bool(leave_enabled)
self._inter_func = inter_func
self._phased = phased
self._mask_type = mask_type
self._mask_creator = (
mask_type
Expand Down Expand Up @@ -612,6 +617,24 @@ def inter_func(self, value: str):
self._inter_func = value
self.validate()

@ModifierProp()
def phased(self) -> bool:
"""
:return: True to enable a phased approach where pruning will
turn on and off with the update_frequency. Starts with pruning on
at start_epoch, off at start_epoch + update_frequency, and so on.
"""
return self._phased

@phased.setter
def phased(self, value: bool):
"""
:param value: the type of interpolation function to use:
[linear, cubic, inverse_cubic]
"""
self._phased = value
self.validate()

@ModifierProp()
def mask_type(self) -> Union[str, List[int], PruningMaskCreator]:
"""
Expand Down Expand Up @@ -763,6 +786,16 @@ def _check_mask_update(self, module: Module, epoch: float, steps_per_epoch: int)
self._final_sparsity,
self._inter_func,
)

# make sure if phased that the phases end at the final sparsity
# if it doesn't divide evenly
if self.phased and not self.end_pending(epoch, steps_per_epoch):
# adjust for phased pruning: start=on, start+update=off
phase = math.floor((epoch - self.start_epoch) / self.update_frequency)
if phase % 2 != 0:
# odd update phase, turn sparsity off
self._applied_sparsity = 0.0

self._module_masks.set_param_masks_from_sparsity(self._applied_sparsity)

if self.end_pending(epoch, steps_per_epoch):
Expand Down Expand Up @@ -843,6 +876,9 @@ class MagnitudePruningModifier(GMPruningModifier):
immediately after or doing some other prune
:param inter_func: the type of interpolation function to use:
[linear, cubic, inverse_cubic]
:param phased: True to enable a phased approach where pruning will
turn on and off with the update_frequency. Starts with pruning on
at start_epoch, off at start_epoch + update_frequency, and so on.
:param log_types: The loggers to allow the learning rate to be logged to,
default is __ALL__
:param mask_type: String to define type of sparsity (options: ['unstructured',
Expand All @@ -860,6 +896,7 @@ def __init__(
params: Union[str, List[str]],
leave_enabled: bool = True,
inter_func: str = "cubic",
phased: bool = False,
log_types: Union[str, List[str]] = ALL_TOKEN,
mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured",
):
Expand All @@ -872,6 +909,7 @@ def __init__(
params=params,
leave_enabled=leave_enabled,
inter_func=inter_func,
phased=phased,
log_types=log_types,
mask_type=mask_type,
global_sparsity=False,
Expand Down Expand Up @@ -933,6 +971,9 @@ class MovementPruningModifier(GMPruningModifier):
immediately after or doing some other prune
:param inter_func: the type of interpolation function to use:
[linear, cubic, inverse_cubic]
:param phased: True to enable a phased approach where pruning will
turn on and off with the update_frequency. Starts with pruning on
at start_epoch, off at start_epoch + update_frequency, and so on.
:param log_types: The loggers to allow the learning rate to be logged to,
default is __ALL__
:param mask_type: String to define type of sparsity (options: ['unstructured',
Expand All @@ -950,6 +991,7 @@ def __init__(
params: Union[str, List[str]],
leave_enabled: bool = True,
inter_func: str = "cubic",
phased: bool = False,
log_types: Union[str, List[str]] = ALL_TOKEN,
mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured",
):
Expand All @@ -962,6 +1004,7 @@ def __init__(
params=params,
leave_enabled=leave_enabled,
inter_func=inter_func,
phased=phased,
log_types=log_types,
mask_type=mask_type,
global_sparsity=False,
Expand Down Expand Up @@ -1024,6 +1067,9 @@ class GlobalMagnitudePruningModifier(GMPruningModifier):
immediately after or doing some other prune
:param inter_func: the type of interpolation function to use:
[linear, cubic, inverse_cubic]
:param phased: True to enable a phased approach where pruning will
turn on and off with the update_frequency. Starts with pruning on
at start_epoch, off at start_epoch + update_frequency, and so on.
:param log_types: The loggers to allow the learning rate to be logged to,
default is __ALL__
:param mask_type: String to define type of sparsity (options: ['unstructured',
Expand All @@ -1043,6 +1089,7 @@ def __init__(
params: Union[str, List[str]] = ALL_PRUNABLE_TOKEN,
leave_enabled: bool = True,
inter_func: str = "cubic",
phased: bool = False,
log_types: Union[str, List[str]] = ALL_TOKEN,
mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured",
score_type: Union[str, MFACOptions] = "magnitude",
Expand All @@ -1056,6 +1103,7 @@ def __init__(
params=params,
leave_enabled=leave_enabled,
inter_func=inter_func,
phased=phased,
log_types=log_types,
mask_type=mask_type,
global_sparsity=True,
Expand Down Expand Up @@ -1115,6 +1163,9 @@ class MFACPruningModifier(GMPruningModifier):
immediately after or doing some other prune
:param inter_func: the type of interpolation function to use:
[linear, cubic, inverse_cubic]
:param phased: True to enable a phased approach where pruning will
turn on and off with the update_frequency. Starts with pruning on
at start_epoch, off at start_epoch + update_frequency, and so on.
:param log_types: The loggers to allow the learning rate to be logged to,
default is __ALL__
:param mask_type: String to define type of sparsity (options: ['unstructured',
Expand All @@ -1139,6 +1190,7 @@ def __init__(
params: Union[str, List[str]],
leave_enabled: bool = True,
inter_func: str = "cubic",
phased: bool = False,
log_types: Union[str, List[str]] = ALL_TOKEN,
mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured",
mfac_options: Dict[str, Any] = None,
Expand All @@ -1152,6 +1204,7 @@ def __init__(
params=params,
leave_enabled=leave_enabled,
inter_func=inter_func,
phased=phased,
log_types=log_types,
mask_type=mask_type,
global_sparsity=True,
Expand Down
28 changes: 27 additions & 1 deletion tests/sparseml/pytorch/optim/test_modifier_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os

import pytest
Expand Down Expand Up @@ -246,6 +247,16 @@ def test_constant_pruning_yaml():
inter_func="cubic",
mask_type=[1, 4],
),
lambda: GMPruningModifier(
params=["__ALL_PRUNABLE__"],
init_sparsity=0.9,
final_sparsity=0.9,
start_epoch=10.0,
end_epoch=25.0,
update_frequency=2.0,
inter_func="cubic",
phased=True,
),
],
scope="function",
)
Expand Down Expand Up @@ -294,7 +305,22 @@ def test_lifecycle(
epoch += modifier.update_frequency
assert modifier.update_ready(epoch, test_steps_per_epoch)
modifier.scheduled_update(model, optimizer, epoch, test_steps_per_epoch)
assert modifier.applied_sparsity > last_sparsity

if not modifier.phased:
assert modifier.applied_sparsity > last_sparsity
else:
pruned_on = (
math.floor(
(epoch - modifier.start_epoch) / modifier.update_frequency
)
% 2
== 0
)
if pruned_on:
assert modifier.applied_sparsity >= last_sparsity
else:
assert modifier.applied_sparsity == 0

last_sparsity = modifier.applied_sparsity

_ = model(test_batch) # check forward pass
Expand Down