Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions src/sparseml/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,24 @@
try:
import tensorflow

version = [int(v) for v in tensorflow.__version__.split(".")]
if version[0] != 2 or version[1] < 2:
raise Exception
if tensorflow.__version__ < "2.1.0":
raise RuntimeError("TensorFlow >= 2.1.0 is required, found {}".format(version))
except:
raise RuntimeError(
"Unable to import tensorflow. tensorflow>=2.2 is required"
"Unable to import tensorflow. TensorFlow>=2.1.0 is required"
" to use sparseml.keras."
)


try:
import keras

v = keras.__version__
if v < "2.4.3":
raise RuntimeError(
"Native keras is found and will be used, but required >= 2.4.3, found {}".format(
v
)
)
except:
pass
13 changes: 7 additions & 6 deletions src/sparseml/keras/optim/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@

from typing import List, Union

import tensorflow as tf
from tensorflow import Tensor

from sparseml.keras.optim.modifier import Modifier, ScheduledModifier
from sparseml.keras.utils.compat import keras
from sparseml.keras.utils.logger import KerasLogger
from sparseml.optim import BaseManager
from sparseml.utils import load_recipe_yaml_str
Expand Down Expand Up @@ -71,11 +72,11 @@ def __init__(self, modifiers: List[ScheduledModifier]):

def modify(
self,
model: Union[tf.keras.Model, tf.keras.Sequential],
optimizer: tf.keras.optimizers.Optimizer,
model: Union[keras.Model, keras.Sequential],
optimizer: keras.optimizers.Optimizer,
steps_per_epoch: int,
loggers: Union[KerasLogger, List[KerasLogger]] = None,
input_tensors: tf.Tensor = None,
input_tensors: Tensor = None,
):
"""
Modify the model and optimizer based on the requirements of modifiers
Expand Down Expand Up @@ -106,14 +107,14 @@ def modify(
continue
if isinstance(callback, list):
callbacks = callbacks + callback
elif isinstance(callback, tf.keras.callbacks.Callback):
elif isinstance(callback, keras.callbacks.Callback):
callbacks.append(callback)
else:
raise RuntimeError("Invalid callback type")
self._optimizer = optimizer
return model, optimizer, callbacks

def finalize(self, model: tf.keras.Model):
def finalize(self, model: keras.Model):
"""
Remove extra information related to the modifier from the model that is
not necessary for exporting
Expand Down
105 changes: 54 additions & 51 deletions src/sparseml/keras/optim/mask_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
import inspect
from typing import List, Union

import tensorflow as tf
import tensorflow

from sparseml.keras.optim.mask_pruning_creator import (
PruningMaskCreator,
load_mask_creator,
)
from sparseml.keras.utils import keras


__all__ = [
Expand Down Expand Up @@ -80,7 +81,7 @@ def deserialize(cls, config):
if "class_name" not in config:
raise ValueError("The 'class_name' not found in config: {}".format(config))
class_name = config["class_name"]
return tf.keras.utils.deserialize_keras_object(
return keras.utils.deserialize_keras_object(
config,
module_objects=globals(),
custom_objects={class_name: PruningScheduler._REGISTRY[class_name]},
Expand Down Expand Up @@ -112,7 +113,7 @@ def __init__(
pruning_vars: List[MaskedParamInfo],
pruning_scheduler: PruningScheduler,
mask_creator: PruningMaskCreator,
global_step: tf.Tensor,
global_step: tensorflow.Tensor,
):
self._pruning_vars = pruning_vars
self._pruning_scheduler = pruning_scheduler
Expand All @@ -121,36 +122,36 @@ def __init__(
self._update_ready = None

def _is_pruning_step(self) -> bool:
global_step_val = tf.keras.backend.get_value(self._global_step)
global_step_val = keras.backend.get_value(self._global_step)
assert global_step_val >= 0
update_ready = self._pruning_scheduler.should_prune(global_step_val)
return update_ready

def _conditional_training_update(self):
def _no_update_masks_and_weights():
return tf.no_op("no_update")
return tensorflow.no_op("no_update")

def _update_masks_and_weights():
assignments = []
global_step_val = tf.keras.backend.get_value(self._global_step)
global_step_val = keras.backend.get_value(self._global_step)
for masked_param_info in self._pruning_vars:
new_sparsity = self._pruning_scheduler.target_sparsity(global_step_val)
new_mask = self._mask_creator.create_sparsity_mask(
masked_param_info.param, new_sparsity
)
assignments.append(masked_param_info.mask.assign(new_mask))
assignments.append(masked_param_info.sparsity.assign(new_sparsity))
masked_param = tf.math.multiply(
masked_param = tensorflow.math.multiply(
masked_param_info.param, masked_param_info.mask
)
assignments.append(masked_param_info.param.assign(masked_param))
return tf.group(assignments)
return tensorflow.group(assignments)

update_ready = self._is_pruning_step()

self._update_ready = update_ready
return tf.cond(
tf.cast(update_ready, tf.bool),
return tensorflow.cond(
tensorflow.cast(update_ready, tensorflow.bool),
_update_masks_and_weights,
_no_update_masks_and_weights,
)
Expand All @@ -161,11 +162,11 @@ def apply_masks(self):
"""
assignments = []
for masked_param_info in self._pruning_vars:
masked_param = tf.math.multiply(
masked_param = tensorflow.math.multiply(
masked_param_info.param, masked_param_info.mask
)
assignments.append(masked_param_info.param.assign(masked_param))
return tf.group(assignments)
return tensorflow.group(assignments)

def conditional_update(self, training=None):
"""
Expand All @@ -175,32 +176,34 @@ def conditional_update(self, training=None):
"""

def _update():
with tf.control_dependencies([self._conditional_training_update()]):
return tf.no_op("update")
with tensorflow.control_dependencies([self._conditional_training_update()]):
return tensorflow.no_op("update")

def _no_update():
return tf.no_op("no_update")
return tensorflow.no_op("no_update")

training = tf.keras.backend.learning_phase() if training is None else training
return tf.cond(tf.cast(training, tf.bool), _update, _no_update)
training = keras.backend.learning_phase() if training is None else training
return tensorflow.cond(
tensorflow.cast(training, tensorflow.bool), _update, _no_update
)


_LAYER_PRUNABLE_PARAMS_MAP = {
tf.keras.layers.Conv1D: ["kernel"],
tf.keras.layers.Conv2D: ["kernel"],
tf.keras.layers.Conv2DTranspose: ["kernel"],
tf.keras.layers.Conv3D: ["kernel"],
tf.keras.layers.Conv3DTranspose: ["kernel"],
tf.keras.layers.Dense: ["kernel"],
tf.keras.layers.Embedding: ["embeddings"],
tf.keras.layers.LocallyConnected1D: ["kernel"],
tf.keras.layers.LocallyConnected2D: ["kernel"],
tf.keras.layers.SeparableConv1D: ["pointwise_kernel"],
tf.keras.layers.SeparableConv2D: ["pointwise_kernel"],
keras.layers.Conv1D: ["kernel"],
keras.layers.Conv2D: ["kernel"],
keras.layers.Conv2DTranspose: ["kernel"],
keras.layers.Conv3D: ["kernel"],
keras.layers.Conv3DTranspose: ["kernel"],
keras.layers.Dense: ["kernel"],
keras.layers.Embedding: ["embeddings"],
keras.layers.LocallyConnected1D: ["kernel"],
keras.layers.LocallyConnected2D: ["kernel"],
keras.layers.SeparableConv1D: ["pointwise_kernel"],
keras.layers.SeparableConv2D: ["pointwise_kernel"],
}


def _get_default_prunable_params(layer: tf.keras.layers.Layer):
def _get_default_prunable_params(layer: keras.layers.Layer):
if layer.__class__ in _LAYER_PRUNABLE_PARAMS_MAP:
prunable_param_names = _LAYER_PRUNABLE_PARAMS_MAP[layer.__class__]
return {
Expand All @@ -216,7 +219,7 @@ def _get_default_prunable_params(layer: tf.keras.layers.Layer):
)


class MaskedLayer(tf.keras.layers.Wrapper):
class MaskedLayer(keras.layers.Wrapper):
"""
Masked layer is a layer wrapping around another layer with a mask; the mask however
is shared if the enclosed layer is again of MaskedLayer type
Expand All @@ -229,13 +232,13 @@ class MaskedLayer(tf.keras.layers.Wrapper):

def __init__(
self,
layer: tf.keras.layers.Layer,
layer: keras.layers.Layer,
pruning_scheduler: PruningScheduler,
mask_type: Union[str, List[int]] = "unstructured",
**kwargs,
):
if not isinstance(layer, MaskedLayer) and not isinstance(
layer, tf.keras.layers.Layer
layer, keras.layers.Layer
):
raise ValueError(
"Invalid layer passed in, expected MaskedLayer or a keras Layer, "
Expand All @@ -257,8 +260,8 @@ def build(self, input_shape):
self._global_step = self.add_weight(
"global_step",
shape=[],
initializer=tf.keras.initializers.Constant(-1),
dtype=tf.int64,
initializer=keras.initializers.Constant(-1),
dtype=tensorflow.int64,
trainable=False,
)
self._mask_updater = MaskAndWeightUpdater(
Expand All @@ -276,43 +279,43 @@ def _reuse_or_create_pruning_vars(
# for the "core", inner-most, Keras built-in layer
return self._layer.pruning_vars

assert isinstance(self._layer, tf.keras.layers.Layer)
assert isinstance(self._layer, keras.layers.Layer)
prunable_params = _get_default_prunable_params(self._layer)

pruning_vars = []
for name, param in prunable_params.items():
mask = self.add_weight(
"mask",
shape=param.shape,
initializer=tf.keras.initializers.get("ones"),
initializer=keras.initializers.get("ones"),
dtype=param.dtype,
trainable=False,
)
sparsity = self.add_weight(
"sparsity",
shape=[],
initializer=tf.keras.initializers.get("zeros"),
initializer=keras.initializers.get("zeros"),
dtype=param.dtype,
trainable=False,
)
pruning_vars.append(MaskedParamInfo(name, param, mask, sparsity))
return pruning_vars

def call(self, inputs: tf.Tensor, training=None):
def call(self, inputs: tensorflow.Tensor, training=None):
"""
Forward function for calling layer instance as function
"""
training = tf.keras.backend.learning_phase() if training is None else training
training = keras.backend.learning_phase() if training is None else training

def _apply_masks_to_weights():
with tf.control_dependencies([self._mask_updater.apply_masks()]):
return tf.no_op("update")
with tensorflow.control_dependencies([self._mask_updater.apply_masks()]):
return tensorflow.no_op("update")

def _no_apply_masks_to_weights():
return tf.no_op("no_update_masks")
return tensorflow.no_op("no_update_masks")

tf.cond(
tf.cast(training, tf.bool),
tensorflow.cond(
tensorflow.cast(training, tensorflow.bool),
_apply_masks_to_weights,
_no_apply_masks_to_weights,
)
Expand All @@ -327,7 +330,7 @@ def get_config(self):
"""
Get layer config
Serialization and deserialization should be done using
tf.keras.serialize/deserialize, which create and retrieve the "class_name"
keras.serialize/deserialize, which create and retrieve the "class_name"
field automatically.
The resulting config below therefore does not contain the field.
"""
Expand All @@ -345,11 +348,11 @@ def get_config(self):
@classmethod
def from_config(cls, config):
config = config.copy()
layer = tf.keras.layers.deserialize(
layer = keras.layers.deserialize(
config.pop("layer"), custom_objects={"MaskedLayer": MaskedLayer}
)
if not isinstance(layer, MaskedLayer) and not isinstance(
layer, tf.keras.layers.Layer
layer, keras.layers.Layer
):
raise RuntimeError("Unexpected layer created from config")
pruning_scheduler = PruningScheduler.deserialize(
Expand Down Expand Up @@ -384,7 +387,7 @@ def pruning_vars(self):
def pruned_layer(self):
if isinstance(self._layer, MaskedLayer):
return self._layer.pruned_layer
elif isinstance(self._layer, tf.keras.layers.Layer):
elif isinstance(self._layer, keras.layers.Layer):
return self._layer
else:
raise RuntimeError("Unrecognized layer")
Expand All @@ -394,7 +397,7 @@ def masked_layer(self):
return self._layer


def remove_pruning_masks(model: tf.keras.Model):
def remove_pruning_masks(model: keras.Model):
"""
Remove pruning masks from a model that was pruned using the MaskedLayer logic
:param model: a model that was pruned using MaskedLayer
Expand All @@ -410,7 +413,7 @@ def _get_pruned_layer(layer):
) or layer.__class__.__name__.endswith("MaskedLayer")
if is_masked_layer:
return _get_pruned_layer(layer.layer)
elif isinstance(layer, tf.keras.layers.Layer):
elif isinstance(layer, keras.layers.Layer):
return layer
else:
raise ValueError("Unknown layer type")
Expand All @@ -425,6 +428,6 @@ def _remove_pruning_masks(layer):

# TODO: while the resulting model could be exported to ONNX, its built status
# is removed
return tf.keras.models.clone_model(
return keras.models.clone_model(
model, input_tensors=None, clone_function=_remove_pruning_masks
)
Loading