Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 95 additions & 6 deletions src/sparseml/keras/optim/mask_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,34 @@
import abc
import collections
import inspect
from typing import List
from typing import List, Union

import tensorflow as tf

from sparseml.keras.optim.mask_pruning_creator import PruningMaskCreator
from sparseml.keras.optim.mask_pruning_creator import (
PruningMaskCreator,
load_mask_creator,
)


__all__ = ["MaskedLayer", "PruningScheduler", "remove_pruning_masks"]
__all__ = [
"MaskedLayer",
"PruningScheduler",
"remove_pruning_masks",
]


class PruningScheduler(abc.ABC):
"""
Abstract pruning scheduler
"""

_REGISTRY = {}

def __init_subclass__(cls):
super().__init_subclass__()
PruningScheduler._register_class(cls)

@abc.abstractmethod
def should_prune(self, step: int) -> bool:
"""
Expand All @@ -51,6 +64,32 @@ def target_sparsity(self, step: int, **kwargs) -> float:
"""
raise NotImplementedError("Not implemented")

@abc.abstractmethod
def get_config(self):
raise NotImplementedError("Not implemented")

@classmethod
def deserialize(cls, config):
"""
Deserialize a pruning scheduler from config returned by scheduler's
get_config method

:param config: a pruning scheduler's config
:return: a pruning scheduler instance
"""
if "class_name" not in config:
raise ValueError("The 'class_name' not found in config: {}".format(config))
class_name = config["class_name"]
return tf.keras.utils.deserialize_keras_object(
config,
module_objects=globals(),
custom_objects={class_name: PruningScheduler._REGISTRY[class_name]},
)

@classmethod
def _register_class(cls, target_cls):
PruningScheduler._REGISTRY[target_cls.__name__] = target_cls


MaskedParamInfo = collections.namedtuple(
"MaskedParamInfo", ["name", "param", "mask", "sparsity"]
Expand Down Expand Up @@ -192,7 +231,7 @@ def __init__(
self,
layer: tf.keras.layers.Layer,
pruning_scheduler: PruningScheduler,
mask_creator: PruningMaskCreator,
mask_type: Union[str, List[int]] = "unstructured",
**kwargs,
):
if not isinstance(layer, MaskedLayer) and not isinstance(
Expand All @@ -205,15 +244,23 @@ def __init__(
super(MaskedLayer, self).__init__(layer, **kwargs)
self._layer = layer
self._pruning_scheduler = pruning_scheduler
self._mask_creator = mask_creator
self._mask_type = mask_type
self._mask_creator = None
self._pruning_vars = []
self._global_step = None
self._mask_updater = None

def build(self, input_shape):
super(MaskedLayer, self).build(input_shape)
self._mask_creator = load_mask_creator(self._mask_type)
self._pruning_vars = self._reuse_or_create_pruning_vars()
self._global_step = self.add_weight(
"global_step",
shape=[],
initializer=tf.keras.initializers.Constant(-1),
dtype=tf.int64,
trainable=False,
)
self._pruning_vars = self._reuse_or_create_pruning_vars()
self._mask_updater = MaskAndWeightUpdater(
self._pruning_vars,
self._pruning_scheduler,
Expand Down Expand Up @@ -276,6 +323,44 @@ def _no_apply_masks_to_weights():
else:
return self._layer.call(inputs)

def get_config(self):
"""
Get layer config
Serialization and deserialization should be done using
tf.keras.serialize/deserialize, which create and retrieve the "class_name"
field automatically.
The resulting config below therefore does not contain the field.
"""
config = super(MaskedLayer, self).get_config()
if "layer" not in config:
raise RuntimeError("Expected 'layer' field not found in config")
config.update(
{
"pruning_scheduler": self._pruning_scheduler.get_config(),
"mask_type": self._mask_type,
}
)
return config

@classmethod
def from_config(cls, config):
config = config.copy()
layer = tf.keras.layers.deserialize(
config.pop("layer"), custom_objects={"MaskedLayer": MaskedLayer}
)
if not isinstance(layer, MaskedLayer) and not isinstance(
layer, tf.keras.layers.Layer
):
raise RuntimeError("Unexpected layer created from config")
pruning_scheduler = PruningScheduler.deserialize(
config.pop("pruning_scheduler")
)
if not isinstance(pruning_scheduler, PruningScheduler):
raise RuntimeError("Unexpected pruning scheduler type created from config")
mask_type = config.pop("mask_type")
masked_layer = MaskedLayer(layer, pruning_scheduler, mask_type, **config)
return masked_layer

def compute_output_shape(self, input_shape):
return self._layer.compute_output_shape(input_shape)

Expand Down Expand Up @@ -304,6 +389,10 @@ def pruned_layer(self):
else:
raise RuntimeError("Unrecognized layer")

@property
def masked_layer(self):
return self._layer


def remove_pruning_masks(model: tf.keras.Model):
"""
Expand Down
72 changes: 52 additions & 20 deletions src/sparseml/keras/optim/modifier_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@
PruningScheduler,
remove_pruning_masks,
)
from sparseml.keras.optim.mask_pruning_creator import (
PruningMaskCreator,
load_mask_creator,
)
from sparseml.keras.optim.modifier import (
KerasModifierYAML,
ModifierProp,
Expand Down Expand Up @@ -73,6 +69,14 @@ def __init__(
self._update_frequency_steps = update_frequency_steps
self._inter_func = inter_func

@property
def init_sparsity(self):
return self._init_sparsity

@property
def final_sparsity(self):
return self._final_sparsity

@property
def start_step(self):
return self._start_step
Expand All @@ -85,6 +89,10 @@ def end_step(self):
def update_frequency_steps(self):
return self._update_frequency_steps

@property
def inter_func(self):
return self._inter_func

@property
def exponent(self) -> float:
"""
Expand Down Expand Up @@ -154,6 +162,20 @@ def target_sparsity(self, step: int, **kwargs):
sparsity = self._final_sparsity
return sparsity

def get_config(self):
config = {
"class_name": self.__class__.__name__,
"config": {
"init_sparsity": self.init_sparsity,
"final_sparsity": self.final_sparsity,
"start_step": self.start_step,
"end_step": self.end_step,
"update_frequency_steps": self.update_frequency_steps,
"inter_func": self.inter_func,
},
}
return config


class SparsityFreezer(PruningScheduler):
"""
Expand All @@ -172,6 +194,14 @@ def __init__(
self._start_step = start_step
self._end_step = end_step

@property
def start_step(self):
return self._start_step

@property
def end_step(self):
return self._ends_step

def should_prune(self, step: int) -> bool:
"""
Check if the given step is a right time for pruning
Expand Down Expand Up @@ -203,6 +233,14 @@ def target_sparsity(self, step: int, tensor=None) -> float:
sparsity = None
return sparsity

def get_config(self):
config = {
"class_name": self.__class__.__name__,
"start_step": self.start_step,
"end_step": self.end_step,
}
return config


class PruningModifierCallback(tensorflow.keras.callbacks.Callback):
"""
Expand Down Expand Up @@ -345,7 +383,7 @@ def _log(self, logger: KerasLogger, log_data: Dict):


@KerasModifierYAML()
class ConstantPruningModifier(ScheduledModifier, PruningScheduler):
class ConstantPruningModifier(ScheduledModifier):
"""
Holds the sparsity level and shape for a given param constant while training.
Useful for transfer learning use cases.
Expand Down Expand Up @@ -387,7 +425,7 @@ def __init__(
self._masked_layers = []

self._sparsity_scheduler = None
self._mask_creator = load_mask_creator("unstructured")
self._mask_type = "unstructured"

@ModifierProp()
def params(self) -> Union[str, List[str]]:
Expand Down Expand Up @@ -456,7 +494,7 @@ def _clone_layer(self, layer: tensorflow.keras.layers.Layer):
cloned_layer = layer
if layer.name in self.layer_names: # TODO: handle regex params
cloned_layer = MaskedLayer(
layer, self._sparsity_scheduler, self._mask_creator, name=layer.name
layer, self._sparsity_scheduler, self._mask_type, name=layer.name
)
self._masked_layers.append(cloned_layer)
return cloned_layer
Expand Down Expand Up @@ -553,7 +591,7 @@ class GMPruningModifier(ScheduledUpdateModifier):
default is __ALL__
:param mask_type: String to define type of sparsity (options: ['unstructured',
'channel', 'filter']), List to define block shape of a parameter's in and out
channels, or a PruningMaskCreator object. default is 'unstructured'
channels. default is 'unstructured'
:param leave_enabled: True to continue masking the weights after end_epoch,
False to stop masking. Should be set to False if exporting the result
immediately after or doing some other prune
Expand All @@ -569,7 +607,7 @@ def __init__(
update_frequency: float,
inter_func: str = "cubic",
log_types: Union[str, List[str]] = ALL_TOKEN,
mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured",
mask_type: Union[str, List[int]] = "unstructured",
leave_enabled: bool = True,
):
super(GMPruningModifier, self).__init__(
Expand All @@ -591,10 +629,7 @@ def __init__(
self._leave_enabled = convert_to_bool(leave_enabled)
self._inter_func = inter_func
self._mask_type = mask_type
self._mask_creator = mask_type
self._leave_enabled = convert_to_bool(leave_enabled)
if not isinstance(mask_type, PruningMaskCreator):
self._mask_creator = load_mask_creator(mask_type)
self._prune_op_vars = None
self._update_ready = None
self._sparsity = None
Expand Down Expand Up @@ -694,21 +729,18 @@ def inter_func(self, value: str):
self.validate()

@ModifierProp()
def mask_type(self) -> Union[str, List[int], PruningMaskCreator]:
def mask_type(self) -> Union[str, List[int]]:
"""
:return: the PruningMaskCreator object used
:return: the mask type used
"""
return self._mask_type

@mask_type.setter
def mask_type(self, value: Union[str, List[int], PruningMaskCreator]):
def mask_type(self, value: Union[str, List[int]]):
"""
:param value: the PruningMaskCreator object to use
:param value: the mask type to use
"""
self._mask_type = value
self._mask_creator = value
if not isinstance(value, PruningMaskCreator):
self._mask_creator = load_mask_creator(value)

@ModifierProp()
def leave_enabled(self) -> bool:
Expand Down Expand Up @@ -834,7 +866,7 @@ def _clone_layer(self, layer: tensorflow.keras.layers.Layer):
layer.name in self.layer_names
): # TODO: handle regex params --- see create_ops in TF version
cloned_layer = MaskedLayer(
layer, self._sparsity_scheduler, self._mask_creator, name=layer.name
layer, self._sparsity_scheduler, self._mask_type, name=layer.name
)
self._masked_layers.append(cloned_layer)
return cloned_layer
Expand Down
Loading