In [1]:
import os
import zipfile
import tempfile

import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot

In [2]:
input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes = 20)

def setup_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Dense(20, input_shape = input_shape),
      tf.keras.layers.Flatten()
  ])
  return model

def setup_pretrained_weights():
  model = setup_model()

  model.compile(
      loss = tf.keras.losses.categorical_crossentropy,
      optimizer = 'adam',
      metrics = ['accuracy']
  )

  model.fit(x_train, y_train)

  _, pretrained_weights = tempfile.mkstemp('.tf')

  model.save_weights(pretrained_weights)

  return pretrained_weights

def get_gzipped_model_size(model):
  # Returns size of gzipped model, in bytes.
  _, keras_file = tempfile.mkstemp('.h5')
  model.save(keras_file, include_optimizer = False)

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression = zipfile.ZIP_DEFLATED) as f:
    f.write(keras_file)

  return os.path.getsize(zipped_file)

setup_model()
pretrained_weights = setup_pretrained_weights()



### Prune whole model (Sequential and Functional)

Tips for better model accuracy:
- Try "Prune some layers" to skip pruning the layers that reduce accuracy the most.
- It's generally better to finetune with pruning as opposed to training from scratch.

To make the whole model train with pruning, apply tfmot.sparsity.keras.prune_low_magnitude to the model.

In [3]:
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended.

model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

model_for_pruning.summary()

Instructions for updating:
Please use `layer.add_weight` method instead.
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_dense_2  (None, 20)                822       
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 20)                1         
Total params: 823
Trainable params: 420
Non-trainable params: 403
_________________________________________________________________


### Prune some layers (Sequential and Functional)

Pruning a model can have a negative effect on accuracy. You can selectively prune layers of a model to explore the trade-off between accuracy, speed, and model size.

Tips for better model accuracy:
- It's generally better to finetune with pruning as opposed to training from scratch.
- Try pruning the later layers instead of the first layers.
- Avoid pruning critical layers (e.g. attention mechanism).

In [4]:
# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy

# Helper function uses `prune_low_magnitude` to make only the 
# Dense layers train with pruning.
def apply_pruning_to_dense(layer):
  if isinstance(layer, tf.keras.layers.Dense):
    return tfmot.sparsity.keras.prune_low_magnitude(layer)
  return layer

# Use `tf.keras.models.clone_model` to apply `apply_pruning_to_dense` 
# to the layers of the model.
model_for_pruning = tf.keras.models.clone_model(
    base_model,
    clone_function = apply_pruning_to_dense,
)

model_for_pruning.summary()

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_dense_3  (None, 20)                822       
_________________________________________________________________
flatten_3 (Flatten)          (None, 20)                0         
Total params: 822
Trainable params: 420
Non-trainable params: 402
_________________________________________________________________


In [5]:
# Functional Example

i = tf.keras.Input(shape = (20,))
x = tfmot.sparsity.keras.prune_low_magnitude(tf.keras.layers.Dense(10))(i)
o = tf.keras.layers.Flatten()(x)
model_for_pruning = tf.keras.Model(inputs = i, outputs = o)

model_for_pruning.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 20)]              0         
_________________________________________________________________
prune_low_magnitude_dense_4  (None, 10)                412       
_________________________________________________________________
flatten_4 (Flatten)          (None, 10)                0         
Total params: 412
Trainable params: 210
Non-trainable params: 202
_________________________________________________________________


In [6]:
# Sequential Example

model_for_pruning = tf.keras.Sequential([
  tfmot.sparsity.keras.prune_low_magnitude(tf.keras.layers.Dense(20, input_shape = input_shape)),
  tf.keras.layers.Flatten()
])

model_for_pruning.summary()

Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_dense_5  (None, 20)                822       
_________________________________________________________________
flatten_5 (Flatten)          (None, 20)                0         
Total params: 822
Trainable params: 420
Non-trainable params: 402
_________________________________________________________________


### Prune custom Keras layer or modify parts of layer to prune

Common mistake: pruning the bias usually harms model accuracy too much.

tfmot.sparsity.keras.PrunableLayer serves two use cases:
- Prune a custom Keras layer
- Modify parts of a built-in Keras layer to prune.

For an example, the API defaults to only pruning the kernel of the Dense layer. The example below prunes the bias also.

In [7]:
class MyDenseLayer(tf.keras.layers.Dense, tfmot.sparsity.keras.PrunableLayer):

  def get_prunable_weights(self):
    # Prune bias also, though that usually harms model accuracy too much.
    return [self.kernel, self.bias]

# Use `prune_low_magnitude` to make the `MyDenseLayer` layer train with pruning.
model_for_pruning = tf.keras.Sequential([
  tfmot.sparsity.keras.prune_low_magnitude(MyDenseLayer(20, input_shape = input_shape)),
  tf.keras.layers.Flatten()
])

model_for_pruning.summary()

Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_my_dense (None, 20)                843       
_________________________________________________________________
flatten_6 (Flatten)          (None, 20)                0         
Total params: 843
Trainable params: 420
Non-trainable params: 423
_________________________________________________________________


In [8]:
%load_ext tensorboard

### Train Model

In [9]:
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

log_dir = tempfile.mkdtemp()
callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    # Log sparsity and other metrics in Tensorboard.
    tfmot.sparsity.keras.PruningSummaries(log_dir = log_dir)
]

model_for_pruning.compile(
      loss = tf.keras.losses.categorical_crossentropy,
      optimizer = 'adam',
      metrics = ['accuracy']
)

model_for_pruning.fit(
    x_train,
    y_train,
    callbacks = callbacks,
    epochs = 2,
)

Epoch 1/2
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x7f9cc85ce350>

### Custom training loop

In [10]:
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

# Boilerplate
loss = tf.keras.losses.categorical_crossentropy
optimizer = tf.keras.optimizers.Adam()
log_dir = tempfile.mkdtemp()
unused_arg = -1
epochs = 2
batches = 1 # example is hardcoded so that the number of batches cannot change.

# Non-boilerplate.
model_for_pruning.optimizer = optimizer
step_callback = tfmot.sparsity.keras.UpdatePruningStep()
step_callback.set_model(model_for_pruning)
log_callback = tfmot.sparsity.keras.PruningSummaries(log_dir = log_dir) # Log sparsity and other metrics in Tensorboard.
log_callback.set_model(model_for_pruning)

step_callback.on_train_begin() # run pruning callback
for _ in range(epochs):
  log_callback.on_epoch_begin(epoch = unused_arg) # run pruning callback
  for _ in range(batches):
    step_callback.on_train_batch_begin(batch=unused_arg) # run pruning callback

    with tf.GradientTape() as tape:
      logits = model_for_pruning(x_train, training = True)
      loss_value = loss(y_train, logits)
      grads = tape.gradient(loss_value, model_for_pruning.trainable_variables)
      optimizer.apply_gradients(zip(grads, model_for_pruning.trainable_variables))

  step_callback.on_epoch_end(batch = unused_arg) # run pruning callback

### Checkpoint and deserialize
You must preserve the optimizer step during checkpointing. This means while you can use Keras HDF5 models for checkpointing, you cannot use Keras HDF5 weights.

In [11]:
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

_, keras_model_file = tempfile.mkstemp('.h5')

# Checkpoint: saving the optimizer is necessary (include_optimizer = True is the default).
model_for_pruning.save(keras_model_file, include_optimizer = True)

In [12]:
# Deserialize model.
with tfmot.sparsity.keras.prune_scope():
  loaded_model = tf.keras.models.load_model(keras_model_file)

loaded_model.summary()

Model: "sequential_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_dense_8  (None, 20)                822       
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 20)                1         
Total params: 823
Trainable params: 420
Non-trainable params: 403
_________________________________________________________________


### Deploy pruned model

In [13]:
# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

print("final model")
model_for_export.summary()

print("\n")
print("Size of gzipped pruned model without stripping: %.2f bytes" % (get_gzipped_model_size(model_for_pruning)))
print("Size of gzipped pruned model with stripping: %.2f bytes" % (get_gzipped_model_size(model_for_export)))

final model
Model: "sequential_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_9 (Dense)              (None, 20)                420       
_________________________________________________________________
flatten_10 (Flatten)         (None, 20)                0         
Total params: 420
Trainable params: 420
Non-trainable params: 0
_________________________________________________________________


Size of gzipped pruned model without stripping: 3256.00 bytes
Size of gzipped pruned model with stripping: 2836.00 bytes
