diff --git a/src/sparseml/keras/optim/mask_pruning.py b/src/sparseml/keras/optim/mask_pruning.py index 07e45f1b346..95e6f07ba58 100644 --- a/src/sparseml/keras/optim/mask_pruning.py +++ b/src/sparseml/keras/optim/mask_pruning.py @@ -15,14 +15,21 @@ import abc import collections import inspect -from typing import List +from typing import List, Union import tensorflow as tf -from sparseml.keras.optim.mask_pruning_creator import PruningMaskCreator +from sparseml.keras.optim.mask_pruning_creator import ( + PruningMaskCreator, + load_mask_creator, +) -__all__ = ["MaskedLayer", "PruningScheduler", "remove_pruning_masks"] +__all__ = [ + "MaskedLayer", + "PruningScheduler", + "remove_pruning_masks", +] class PruningScheduler(abc.ABC): @@ -30,6 +37,12 @@ class PruningScheduler(abc.ABC): Abstract pruning scheduler """ + _REGISTRY = {} + + def __init_subclass__(cls): + super().__init_subclass__() + PruningScheduler._register_class(cls) + @abc.abstractmethod def should_prune(self, step: int) -> bool: """ @@ -51,6 +64,32 @@ def target_sparsity(self, step: int, **kwargs) -> float: """ raise NotImplementedError("Not implemented") + @abc.abstractmethod + def get_config(self): + raise NotImplementedError("Not implemented") + + @classmethod + def deserialize(cls, config): + """ + Deserialize a pruning scheduler from config returned by scheduler's + get_config method + + :param config: a pruning scheduler's config + :return: a pruning scheduler instance + """ + if "class_name" not in config: + raise ValueError("The 'class_name' not found in config: {}".format(config)) + class_name = config["class_name"] + return tf.keras.utils.deserialize_keras_object( + config, + module_objects=globals(), + custom_objects={class_name: PruningScheduler._REGISTRY[class_name]}, + ) + + @classmethod + def _register_class(cls, target_cls): + PruningScheduler._REGISTRY[target_cls.__name__] = target_cls + MaskedParamInfo = collections.namedtuple( "MaskedParamInfo", ["name", "param", "mask", "sparsity"] @@ -192,7 +231,7 @@ def __init__( self, layer: tf.keras.layers.Layer, pruning_scheduler: PruningScheduler, - mask_creator: PruningMaskCreator, + mask_type: Union[str, List[int]] = "unstructured", **kwargs, ): if not isinstance(layer, MaskedLayer) and not isinstance( @@ -205,7 +244,16 @@ def __init__( super(MaskedLayer, self).__init__(layer, **kwargs) self._layer = layer self._pruning_scheduler = pruning_scheduler - self._mask_creator = mask_creator + self._mask_type = mask_type + self._mask_creator = None + self._pruning_vars = [] + self._global_step = None + self._mask_updater = None + + def build(self, input_shape): + super(MaskedLayer, self).build(input_shape) + self._mask_creator = load_mask_creator(self._mask_type) + self._pruning_vars = self._reuse_or_create_pruning_vars() self._global_step = self.add_weight( "global_step", shape=[], @@ -213,7 +261,6 @@ def __init__( dtype=tf.int64, trainable=False, ) - self._pruning_vars = self._reuse_or_create_pruning_vars() self._mask_updater = MaskAndWeightUpdater( self._pruning_vars, self._pruning_scheduler, @@ -276,6 +323,44 @@ def _no_apply_masks_to_weights(): else: return self._layer.call(inputs) + def get_config(self): + """ + Get layer config + Serialization and deserialization should be done using + tf.keras.serialize/deserialize, which create and retrieve the "class_name" + field automatically. + The resulting config below therefore does not contain the field. + """ + config = super(MaskedLayer, self).get_config() + if "layer" not in config: + raise RuntimeError("Expected 'layer' field not found in config") + config.update( + { + "pruning_scheduler": self._pruning_scheduler.get_config(), + "mask_type": self._mask_type, + } + ) + return config + + @classmethod + def from_config(cls, config): + config = config.copy() + layer = tf.keras.layers.deserialize( + config.pop("layer"), custom_objects={"MaskedLayer": MaskedLayer} + ) + if not isinstance(layer, MaskedLayer) and not isinstance( + layer, tf.keras.layers.Layer + ): + raise RuntimeError("Unexpected layer created from config") + pruning_scheduler = PruningScheduler.deserialize( + config.pop("pruning_scheduler") + ) + if not isinstance(pruning_scheduler, PruningScheduler): + raise RuntimeError("Unexpected pruning scheduler type created from config") + mask_type = config.pop("mask_type") + masked_layer = MaskedLayer(layer, pruning_scheduler, mask_type, **config) + return masked_layer + def compute_output_shape(self, input_shape): return self._layer.compute_output_shape(input_shape) @@ -304,6 +389,10 @@ def pruned_layer(self): else: raise RuntimeError("Unrecognized layer") + @property + def masked_layer(self): + return self._layer + def remove_pruning_masks(model: tf.keras.Model): """ diff --git a/src/sparseml/keras/optim/modifier_pruning.py b/src/sparseml/keras/optim/modifier_pruning.py index ced625003ab..996d235fb62 100644 --- a/src/sparseml/keras/optim/modifier_pruning.py +++ b/src/sparseml/keras/optim/modifier_pruning.py @@ -26,10 +26,6 @@ PruningScheduler, remove_pruning_masks, ) -from sparseml.keras.optim.mask_pruning_creator import ( - PruningMaskCreator, - load_mask_creator, -) from sparseml.keras.optim.modifier import ( KerasModifierYAML, ModifierProp, @@ -73,6 +69,14 @@ def __init__( self._update_frequency_steps = update_frequency_steps self._inter_func = inter_func + @property + def init_sparsity(self): + return self._init_sparsity + + @property + def final_sparsity(self): + return self._final_sparsity + @property def start_step(self): return self._start_step @@ -85,6 +89,10 @@ def end_step(self): def update_frequency_steps(self): return self._update_frequency_steps + @property + def inter_func(self): + return self._inter_func + @property def exponent(self) -> float: """ @@ -154,6 +162,20 @@ def target_sparsity(self, step: int, **kwargs): sparsity = self._final_sparsity return sparsity + def get_config(self): + config = { + "class_name": self.__class__.__name__, + "config": { + "init_sparsity": self.init_sparsity, + "final_sparsity": self.final_sparsity, + "start_step": self.start_step, + "end_step": self.end_step, + "update_frequency_steps": self.update_frequency_steps, + "inter_func": self.inter_func, + }, + } + return config + class SparsityFreezer(PruningScheduler): """ @@ -172,6 +194,14 @@ def __init__( self._start_step = start_step self._end_step = end_step + @property + def start_step(self): + return self._start_step + + @property + def end_step(self): + return self._ends_step + def should_prune(self, step: int) -> bool: """ Check if the given step is a right time for pruning @@ -203,6 +233,14 @@ def target_sparsity(self, step: int, tensor=None) -> float: sparsity = None return sparsity + def get_config(self): + config = { + "class_name": self.__class__.__name__, + "start_step": self.start_step, + "end_step": self.end_step, + } + return config + class PruningModifierCallback(tensorflow.keras.callbacks.Callback): """ @@ -345,7 +383,7 @@ def _log(self, logger: KerasLogger, log_data: Dict): @KerasModifierYAML() -class ConstantPruningModifier(ScheduledModifier, PruningScheduler): +class ConstantPruningModifier(ScheduledModifier): """ Holds the sparsity level and shape for a given param constant while training. Useful for transfer learning use cases. @@ -387,7 +425,7 @@ def __init__( self._masked_layers = [] self._sparsity_scheduler = None - self._mask_creator = load_mask_creator("unstructured") + self._mask_type = "unstructured" @ModifierProp() def params(self) -> Union[str, List[str]]: @@ -456,7 +494,7 @@ def _clone_layer(self, layer: tensorflow.keras.layers.Layer): cloned_layer = layer if layer.name in self.layer_names: # TODO: handle regex params cloned_layer = MaskedLayer( - layer, self._sparsity_scheduler, self._mask_creator, name=layer.name + layer, self._sparsity_scheduler, self._mask_type, name=layer.name ) self._masked_layers.append(cloned_layer) return cloned_layer @@ -553,7 +591,7 @@ class GMPruningModifier(ScheduledUpdateModifier): default is __ALL__ :param mask_type: String to define type of sparsity (options: ['unstructured', 'channel', 'filter']), List to define block shape of a parameter's in and out - channels, or a PruningMaskCreator object. default is 'unstructured' + channels. default is 'unstructured' :param leave_enabled: True to continue masking the weights after end_epoch, False to stop masking. Should be set to False if exporting the result immediately after or doing some other prune @@ -569,7 +607,7 @@ def __init__( update_frequency: float, inter_func: str = "cubic", log_types: Union[str, List[str]] = ALL_TOKEN, - mask_type: Union[str, List[int], PruningMaskCreator] = "unstructured", + mask_type: Union[str, List[int]] = "unstructured", leave_enabled: bool = True, ): super(GMPruningModifier, self).__init__( @@ -591,10 +629,7 @@ def __init__( self._leave_enabled = convert_to_bool(leave_enabled) self._inter_func = inter_func self._mask_type = mask_type - self._mask_creator = mask_type self._leave_enabled = convert_to_bool(leave_enabled) - if not isinstance(mask_type, PruningMaskCreator): - self._mask_creator = load_mask_creator(mask_type) self._prune_op_vars = None self._update_ready = None self._sparsity = None @@ -694,21 +729,18 @@ def inter_func(self, value: str): self.validate() @ModifierProp() - def mask_type(self) -> Union[str, List[int], PruningMaskCreator]: + def mask_type(self) -> Union[str, List[int]]: """ - :return: the PruningMaskCreator object used + :return: the mask type used """ return self._mask_type @mask_type.setter - def mask_type(self, value: Union[str, List[int], PruningMaskCreator]): + def mask_type(self, value: Union[str, List[int]]): """ - :param value: the PruningMaskCreator object to use + :param value: the mask type to use """ self._mask_type = value - self._mask_creator = value - if not isinstance(value, PruningMaskCreator): - self._mask_creator = load_mask_creator(value) @ModifierProp() def leave_enabled(self) -> bool: @@ -834,7 +866,7 @@ def _clone_layer(self, layer: tensorflow.keras.layers.Layer): layer.name in self.layer_names ): # TODO: handle regex params --- see create_ops in TF version cloned_layer = MaskedLayer( - layer, self._sparsity_scheduler, self._mask_creator, name=layer.name + layer, self._sparsity_scheduler, self._mask_type, name=layer.name ) self._masked_layers.append(cloned_layer) return cloned_layer diff --git a/tests/sparseml/keras/optim/mock.py b/tests/sparseml/keras/optim/mock.py index 36521b34360..cbfde246af7 100644 --- a/tests/sparseml/keras/optim/mock.py +++ b/tests/sparseml/keras/optim/mock.py @@ -26,11 +26,13 @@ "SequentialModelCreator", "MockPruningScheduler", "model_01", + "mnist_model", ] class MockPruningScheduler(PruningScheduler): def __init__(self, step_and_sparsity_pairs: List[Tuple]): + self._org_pairs = step_and_sparsity_pairs self.step_and_sparsity_pairs = { step: sparsity for (step, sparsity) in step_and_sparsity_pairs } @@ -43,6 +45,12 @@ def target_sparsity(self, step: int): sparsity = self.step_and_sparsity_pairs[step] if update_ready else None return sparsity + def get_config(self): + return { + "class_name": self.__class__.__name__, + "step_and_sparsity_pairs": self._org_pairs, + } + class DenseLayer(tf.keras.layers.Dense): def __init__(self, weight: np.ndarray): @@ -107,3 +115,33 @@ def model_01(): outputs = tf.keras.layers.Dense(10, name="dense_02")(x) model = Model(inputs=inputs, outputs=outputs) return model + + +def mnist_model(): + inputs = tf.keras.Input(shape=(28, 28, 1), name="inputs") + + # Block 1 + x = tf.keras.layers.Conv2D(16, 5, strides=1)(inputs) + x = tf.keras.layers.BatchNormalization()(x) + x = tf.keras.layers.ReLU()(x) + + # Block 2 + x = tf.keras.layers.Conv2D(32, 5, strides=2)(x) + x = tf.keras.layers.BatchNormalization()(x) + x = tf.keras.layers.ReLU()(x) + + # Block 3 + x = tf.keras.layers.Conv2D(64, 5, strides=1)(x) + x = tf.keras.layers.BatchNormalization()(x) + x = tf.keras.layers.ReLU()(x) + + # Block 4 + x = tf.keras.layers.Conv2D(128, 5, strides=2)(x) + x = tf.keras.layers.BatchNormalization()(x) + x = tf.keras.layers.ReLU()(x) + + x = tf.keras.layers.AveragePooling2D(pool_size=1)(x) + x = tf.keras.layers.Flatten()(x) + outputs = tf.keras.layers.Dense(10, activation="softmax", name="outputs")(x) + + return tf.keras.Model(inputs=inputs, outputs=outputs) diff --git a/tests/sparseml/keras/optim/test_mask_pruning.py b/tests/sparseml/keras/optim/test_mask_pruning.py index 9a3de232fc4..62f110a1af4 100644 --- a/tests/sparseml/keras/optim/test_mask_pruning.py +++ b/tests/sparseml/keras/optim/test_mask_pruning.py @@ -12,16 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Union + import numpy as np import pytest import tensorflow as tf -from sparseml.keras.optim import MaskedLayer, UnstructuredPruningMaskCreator -from tests.sparseml.keras.optim.mock import DenseLayerCreator, MockPruningScheduler +from sparseml.keras.optim import ( + GMPruningModifier, + MaskedLayer, + ScheduledModifierManager, +) +from tests.sparseml.keras.optim.mock import ( + DenseLayerCreator, + MockPruningScheduler, + mnist_model, +) @pytest.mark.parametrize( - "layer_lambda, pruning_scheduler, mask_creator, expected_mask", + "layer_lambda, pruning_scheduler, mask_type, expected_mask", [ ( # Weight of a dense layer of shape (3, 4) @@ -32,7 +42,7 @@ ), ), MockPruningScheduler([(1, 0.25), (2, 0.5)]), - UnstructuredPruningMaskCreator(), + "unstructured", # List of expected mask, each corresponding to one of the # above update step in the MockPruningScheduler [ @@ -57,15 +67,128 @@ ], ) def test_mask_update_explicit( - layer_lambda, pruning_scheduler, mask_creator, expected_mask + layer_lambda, pruning_scheduler, mask_type, expected_mask ): if tf.__version__ < "2": pytest.skip("Test needs to be fixed to run with tensorflow_v1 1.x") layer = layer_lambda() - masked_layer = MaskedLayer(layer, pruning_scheduler, mask_creator) + masked_layer = MaskedLayer(layer, pruning_scheduler, mask_type) + masked_layer.build(input_shape=None) update_steps = list(pruning_scheduler.step_and_sparsity_pairs.keys()) for idx, update_step in enumerate(update_steps): tf.keras.backend.batch_set_value([(masked_layer.global_step, update_step)]) masked_layer.mask_updater.conditional_update(training=True) mask = tf.keras.backend.get_value(masked_layer.masks[0]) assert np.allclose(mask, expected_mask[idx]) + + +@pytest.mark.parametrize( + "modifier_lambdas", + [ + ( + lambda: GMPruningModifier( + params=["conv2d/kernel:0"], + init_sparsity=0.25, + final_sparsity=0.75, + start_epoch=0.0, + end_epoch=2.0, + update_frequency=1.0, + ), + lambda: GMPruningModifier( + params=["conv2d_1/kernel:0"], + init_sparsity=0.25, + final_sparsity=0.75, + start_epoch=0.0, + end_epoch=2.0, + update_frequency=1.0, + ), + ), + ( + lambda: GMPruningModifier( + params=["conv2d/kernel:0", "conv2d_2/kernel:0"], + init_sparsity=0.25, + final_sparsity=0.75, + start_epoch=0.0, + end_epoch=2.0, + update_frequency=1.0, + ), + lambda: GMPruningModifier( + params=["conv2d_1/kernel:0", "conv2d/kernel:0", "outputs/kernel:0"], + init_sparsity=0.25, + final_sparsity=0.75, + start_epoch=2.0, + end_epoch=3.0, + update_frequency=1.0, + ), + lambda: GMPruningModifier( + params=["conv2d_2/kernel:0", "outputs/kernel:0"], + init_sparsity=0.25, + final_sparsity=0.75, + start_epoch=3.0, + end_epoch=4.0, + update_frequency=1.0, + ), + ), + ], + scope="function", +) +@pytest.mark.parametrize("steps_per_epoch", [10], scope="function") +def test_nested_layer_structure(modifier_lambdas, steps_per_epoch): + model = mnist_model() + modifiers = [mod() for mod in modifier_lambdas] + manager = ScheduledModifierManager(modifiers) + optimizer = tf.keras.optimizers.Adam() + model, optimizer, callbacks = manager.modify(model, optimizer, steps_per_epoch) + + model.build(input_shape=(1, 28, 28, 1)) + + # Verify number of (outer-most) masked layers + modifier_masked_layer_names = [ + layer_name for mod in modifiers for layer_name in mod.layer_names + ] + model_masked_layer_names = [ + layer.name for layer in model.layers if isinstance(layer, MaskedLayer) + ] + assert len(model_masked_layer_names) == len(set(modifier_masked_layer_names)) + + # Verify that if a layer is modified by N modifiers, then the corresponding + # MaskedLayer will have N-1 number of MaskedLayer nested inside it + for layer in model.layers: + if isinstance(layer, MaskedLayer): + expected_layers = len( + [name for name in modifier_masked_layer_names if name == layer.name] + ) + assert _count_nested_masked_layers(layer) == expected_layers + + # Verify the returned config dict has expected nested structures + model_config = model.get_config() + for layer_config in model_config["layers"]: + if layer_config["class_name"] == "MaskedLayer": + layer_name = layer_config["config"]["name"] + expected_layers = len( + [name for name in modifier_masked_layer_names if name == layer_name] + ) + assert ( + _count_nested_masked_layers_in_config(layer_config) == expected_layers + ) + + # Verify model serialization and deserialization working for (nested) masked layer + model_config = model.get_config() + new_model = model.__class__.from_config( + model_config, custom_objects={"MaskedLayer": MaskedLayer} + ) + assert model_config == new_model.get_config() + + tf.keras.backend.clear_session() + + +def _count_nested_masked_layers_in_config(layer_config: Dict): + if layer_config["class_name"] != "MaskedLayer": + return 0 + return 1 + _count_nested_masked_layers_in_config(layer_config["config"]["layer"]) + + +def _count_nested_masked_layers(layer: Union[MaskedLayer, tf.keras.layers.Layer]): + if not isinstance(layer, MaskedLayer): + return 0 + return 1 + _count_nested_masked_layers(layer.masked_layer) diff --git a/tests/sparseml/keras/optim/test_modifier_pruning.py b/tests/sparseml/keras/optim/test_modifier_pruning.py index 47954f02e01..1f3b30604df 100644 --- a/tests/sparseml/keras/optim/test_modifier_pruning.py +++ b/tests/sparseml/keras/optim/test_modifier_pruning.py @@ -60,7 +60,7 @@ scope="function", ) @pytest.mark.parametrize("steps_per_epoch", [10], scope="function") -class TestGradualKSModifier: +class TestGMPruningModifier: def test_lifecycle(self, model_lambda, modifier_lambda, steps_per_epoch): model = model_lambda() modifier = modifier_lambda()