Skip to content

Commit

Permalink
Compression support for tf.keras.Sequential (#2887)
Browse files Browse the repository at this point in the history
* compression support tf sequential

* minor fix

* move helper functions to Compressor method

Co-authored-by: liuzhe <zhe.liu@microsoft.com>
  • Loading branch information
liuzhe-lz and liuzhe committed Sep 15, 2020
1 parent 98a49b1 commit 10d7ece
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 163 deletions.
40 changes: 15 additions & 25 deletions examples/model_compress/model_prune_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,21 @@ def get_dataset(dataset_name='mnist'):

def create_model(model_name='naive'):
assert model_name == 'naive'
return NaiveModel()

class NaiveModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.seq_layers = [
tf.keras.layers.Conv2D(filters=20, kernel_size=5),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.MaxPool2D(pool_size=2),
tf.keras.layers.Conv2D(filters=20, kernel_size=5),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.MaxPool2D(pool_size=2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=500),
tf.keras.layers.ReLU(),
tf.keras.layers.Dense(units=10),
tf.keras.layers.Softmax()
]

def call(self, x):
for layer in self.seq_layers:
x = layer(x)
return x
return tf.keras.Sequential([
tf.keras.layers.Conv2D(filters=20, kernel_size=5),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.MaxPool2D(pool_size=2),
tf.keras.layers.Conv2D(filters=20, kernel_size=5),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.MaxPool2D(pool_size=2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=500),
tf.keras.layers.ReLU(),
tf.keras.layers.Dense(units=10),
tf.keras.layers.Softmax()
])


def create_pruner(model, pruner_name):
Expand Down
219 changes: 84 additions & 135 deletions src/sdk/pynni/nni/compression/tensorflow/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,90 +15,119 @@
_logger = logging.getLogger(__name__)


class LayerInfo:
"""
This structure contains all infomation needed to compress a TensorFlow ``Layer``.
Attributes
----------
layer : tf.keras.layers.Layer
The layer.
name : str
The layer's name. Note that it's local to sub-model and may differ from its attribute name.
type : str
Name of the layer's class.
path : list of str or tuple of (str, int)
The layer object's and its parents' attribute name / list index.
For example, if the path is `[('cells', 2), 'conv']`, then the layer can be accessed as `model.cells[2].conv`.
config : JSON object
Selected configuration for this layer. The format is detailed in tutorial.
Parameters
----------
layer : tf.keras.layers.Layer
See attributes section.
path : list of str or tuple of (str, int)
See attributes section.
"""

def __init__(self, layer, path=None):
self.layer = layer
self.name = layer.name
self.type = type(layer).__name__
self.path = path
self.config = None


class Compressor:
"""
Common base class for all compressors.
This class is designed for other base classes.
Algorithms should inherit ``Pruner`` or ``Quantizer`` instead.
Attributes
----------
bound_model : tf.keras.Model
compressed_model : tf.keras.Model
Compressed user model.
wrappers : list of tf.keras.Model
A wrapper is an instrumented TF ``Layer``, in ``Model`` format.
The list is ordered by preorder traversal.
Parameters
----------
LayerWrapperClass : a class derive from Model
The class used to instrument layers.
model : tf.keras.Model
The user model to be compressed.
config_list : list of JSON object
User configuration. The format is detailed in tutorial.
LayerWrapperClass : a class derive from Model
The class used to instrument layers.
"""

def __init__(self, LayerWrapperClass, model, config_list):
def __init__(self, model, config_list, LayerWrapperClass):
assert isinstance(model, tf.keras.Model)
if isinstance(model, tf.keras.Sequential):
raise ValueError('NNI model compression does not support `Sequential` model for now')
self.validate_config(model, config_list)

self.bound_model = model
self.wrappers = []
self._original_model = model
self._config_list = config_list
self._wrapper_class = LayerWrapperClass
self._wrappers = {} # key: id(layer) , value: Wrapper(layer)

self.compressed_model = self._instrument(model)
self.wrappers = list(self._wrappers.values())

for layer_info in _detect_layers_to_compress(model, config_list):
self.wrappers.append(LayerWrapperClass(layer_info, self))
if not self.wrappers:
_logger.warning('Nothing is configured to compress, please check your model and config list')

_instrument_model(model, self.wrappers)

def set_wrappers_attribute(self, name, value):
"""
Call ``setattr`` on all wrappers.
"""
for wrapper in self.wrappers:
setattr(wrapper, name, value)

def validate_config(self, model, config_list):
"""
Compression algorithm should overload this function to validate configuration.
"""
pass


def _instrument(self, layer):
if isinstance(layer, tf.keras.Sequential):
return self._instrument_sequential(layer)
if isinstance(layer, tf.keras.Model):
return self._instrument_model(layer)

# a layer can be referenced in multiple attributes of a model,
# but should only be instrumented once
if id(layer) in self._wrappers:
return self._wrappers[id(layer)]

config = self._select_config(layer)
if config is not None:
wrapper = self._wrapper_class(layer, config, self)
self._wrappers[id(layer)] = wrapper
return wrapper

return layer

def _instrument_sequential(self, seq):
layers = list(seq.layers) # seq.layers is read-only property
need_rebuild = False
for i, layer in enumerate(layers):
new_layer = self._instrument(layer)
if new_layer is not layer:
layers[i] = new_layer
need_rebuild = True
return tf.keras.Sequential(layers) if need_rebuild else seq

def _instrument_model(self, model):
for key, value in list(model.__dict__.items()): # avoid "dictionary keys changed during iteration"
if isinstance(value, tf.keras.layers.Layer):
new_layer = self._instrument(value)
if new_layer is not value:
setattr(model, key, new_layer)
elif isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, tf.keras.layers.Layer):
value[i] = self._instrument(item)
return model


def _select_config(self, layer):
# Find the last matching config block for given layer.
# Returns None if the layer should not be compressed.
layer_type = type(layer).__name__
last_match = None
for config in self._config_list:
if 'op_types' in config:
match = layer_type in config['op_types']
match_default = 'default' in config['op_types'] and layer_type in default_layers.weighted_modules
if not match and not match_default:
continue
if 'op_names' in config and layer.name not in config['op_names']:
continue
last_match = config
if last_match is None or 'exclude' in last_match:
return None
return last_match


class Pruner(Compressor):
"""
Expand All @@ -121,7 +150,7 @@ class Pruner(Compressor):
User configuration. The format is detailed in tutorial.
"""
def __init__(self, model, config_list):
super().__init__(PrunerLayerWrapper, model, config_list)
super().__init__(model, config_list, PrunerLayerWrapper)
#self.callback = PrunerCallback(self)

def compress(self):
Expand All @@ -133,10 +162,10 @@ def compress(self):
Returns
-------
tf.keras.Model
The compressed model, for convenience. This is exactly the same object to constructor argument.
The compressed model.
"""
self._update_mask()
return self.bound_model
return self.compressed_model

def calc_masks(self, wrapper, **kwargs):
"""
Expand Down Expand Up @@ -195,11 +224,10 @@ class PrunerLayerWrapper(tf.keras.Model):
Afterwards, `masks` is the last return value of ``Pruner.calc_masks``.
See ``Pruner.calc_masks`` for details.
"""
def __init__(self, layer_info, pruner):
def __init__(self, layer, config, pruner):
super().__init__()
self.layer_info = layer_info
self.layer = layer_info.layer
self.config = layer_info.config
self.layer = layer
self.config = config
self.pruner = pruner
self.masks = {}
_logger.info('Layer detected to compress: %s', self.layer.name)
Expand All @@ -226,82 +254,3 @@ def call(self, *inputs):
#
# def on_train_batch_end(self, batch, logs=None):
# self._pruner.update_mask()


def _detect_layers_to_compress(model, config_list):
# Returns list of LayerInfo.
located_layers = _locate_layers(model)
ret = []
for layer in model.layers:
config = _select_config(LayerInfo(layer), config_list)
if config is not None:
if id(layer) not in located_layers:
_logger.error('Failed to locate layer %s in model. The layer will not be compressed. '
'This is a bug in NNI, feel free to fire an issue.', layer.name)
continue
layer_info = located_layers[id(layer)]
layer_info.config = config
ret.append(layer_info)
return ret

def _locate_layers(model, cur_path=[]):
# Find out how to access layers from model object.
# Returns dict of (layer's object ID, LayerInfo).
# This function is required because TF framework does not track layer's attribute name,
# and to my knowledge `Layer.name` is only useful for read-only access.
# `cur_path`s format is documented in `LayerInfo.path`.
# TODO: it can only find layers in `Model` and `list` for now.
assert isinstance(model, tf.keras.Model)
if isinstance(model, tf.keras.Sequential):
_logger.warning('`Sequential` model is not supported yet, ignored.')
ret = {}
for key, value in model.__dict__.items():
if isinstance(value, tf.keras.Model):
ret.update(_locate_layers(value, cur_path + [key]))
elif isinstance(value, tf.keras.layers.Layer):
ret[id(value)] = LayerInfo(value, cur_path + [key])
elif isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, tf.keras.Model):
ret.update(_locate_layers(item, cur_path + [(key, i)]))
elif isinstance(item, tf.keras.layers.Layer):
ret[id(item)] = LayerInfo(item, cur_path + [(key, i)])
return ret

def _select_config(layer_info, config_list):
# Find the last matching config block for given layer.
# Returns None if the layer should not be compressed.
ret = None
for config in config_list:
if 'op_types' in config:
match = layer_info.type in config['op_types']
match_default = 'default' in config['op_types'] and layer_info.type in default_layers.weighted_modules
if not match and not match_default:
continue
if 'op_names' in config and layer_info.name not in config['op_names']:
continue
ret = config
if ret is None or 'exclude' in ret:
return None
return ret


def _instrument_model(model, wrappers):
# Replace layers to wrappers
for wrapper in reversed(wrappers):
cur = model
for key in wrapper.layer_info.path[:-1]:
if isinstance(key, str):
cur = getattr(cur, key)
else:
name, index = key
cur = getattr(cur, name)[index]
key = wrapper.layer_info.path[-1]
if isinstance(key, str):
setattr(cur, key, wrapper)
else:
name, index = key
getattr(cur, name)[index] = wrapper
#if isinstance(cur, tf.keras.Sequential):
# cur._graph_initialized = False
# cur._layer_call_argspecs[wrapper] = cur._layer_call_argspecs[wrapper.layer]
19 changes: 16 additions & 3 deletions src/sdk/pynni/tests/test_compressor_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@

####
#
# This file tests pruners on 2 models: a classic CNN model, and a naive model with one linear layer
# This file tests pruners on 3 models:
# A classic CNN model built by inheriting `Model`;
# The same CNN model built with `Sequential`;
# A naive model with only one linear layer.
#
# The CNN model is used to test layer detecting and instrumenting.
# The CNN models are used to test layer detecting and instrumenting.
#
# The naive model is used to test mask calculation.
# It has a single 10x10 linear layer without bias, and `reduce_sum` its result.
Expand All @@ -31,11 +34,12 @@ def test_layer_detection(self):
# Conv and dense layers should be compressed, pool and flatten should not.
# This also tests instrumenting functionality.
self._test_layer_detection_on_model(CnnModel())
self._test_layer_detection_on_model(build_sequential_model())

def _test_layer_detection_on_model(self, model):
pruner = pruners['level'](model)
pruner.compress()
layer_types = sorted(wrapper.layer_info.type for wrapper in pruner.wrappers)
layer_types = sorted(type(wrapper.layer).__name__ for wrapper in pruner.wrappers)
assert layer_types == ['Conv2D', 'Dense', 'Dense'], layer_types

def test_level_pruner(self):
Expand Down Expand Up @@ -73,6 +77,15 @@ def call(self, x):
x = self.fc2(x)
return x

def build_sequential_model():
return Sequential([
Conv2D(filters=10, kernel_size=3, activation='relu'),
MaxPool2D(pool_size=2),
Flatten(),
Dense(units=10, activation='relu'),
Dense(units=5, activation='softmax'),
])

class NaiveModel(Model):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 10d7ece

Please sign in to comment.