<a href="https://colab.research.google.com/github/rajdeepd/tensorflow_2.0_book_code/blob/master/ch09/prunings_samples_rd.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pruning Examples for TensorFlow 2.x
In magnitude based weight pruning model sparsity is achieved by gradually zeroing out model weights without compromising accuracy. It helps improve model compression by factor of 3.

## Install and Setup

In [None]:
! pip install -q tensorflow-model-optimization

import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot


## Training Data

In [None]:

%load_ext tensorboard

import tempfile


## Model

In [None]:
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 and 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model_mnist = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model_mnist.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model_mnist.fit(
  train_images,
  train_labels,
  epochs=4, validation_split=0.1,
)

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<keras.callbacks.History at 0x7fb27f9ac6d0>

## Prune all the Weights

In [None]:

_, pretrained_weights = tempfile.mkstemp('.tf')
model_mnist.save_weights(pretrained_weights)

model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model_mnist)

model_for_pruning.summary()



Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_reshape_ (None, 28, 28, 1)         1         
_________________________________________________________________
prune_low_magnitude_conv2d_3 (None, 26, 26, 12)        230       
_________________________________________________________________
prune_low_magnitude_max_pool (None, 13, 13, 12)        1         
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 2028)              1         
_________________________________________________________________
prune_low_magnitude_dense_3  (None, 10)                40572     
Total params: 40,805
Trainable params: 20,410
Non-trainable params: 20,395
_________________________________________________________________


## Prune Some Layers 

Functional and Sequential example.
Tips for better model accuracy:

It's generally better to finetune with pruning as opposed to training from scratch.
Try pruning the later layers instead of the first layers.
Avoid pruning critical layers (e.g. attention mechanism).
More:

The tfmot.sparsity.keras.prune_low_magnitude API docs provide details on how to vary the pruning configuration per layer.

In [None]:

# Helper function uses `prune_low_magnitude` to make only the 
# Dense layers train with pruning.
def apply_pruning_to_dense(layer):
  if isinstance(layer, tf.keras.layers.Dense):
    return tfmot.sparsity.keras.prune_low_magnitude(layer)
  return layer

# Use `tf.keras.models.clone_model` to apply `apply_pruning_to_dense` 
# to the layers of the model.
model_for_pruning = tf.keras.models.clone_model(
    model_mnist,
    clone_function=apply_pruning_to_dense,
)

model_for_pruning.summary()

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
reshape_3 (Reshape)          (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 26, 26, 12)        120       
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 13, 13, 12)        0         
_________________________________________________________________
flatten_3 (Flatten)          (None, 2028)              0         
_________________________________________________________________
prune_low_magnitude_dense_3  (None, 10)                40572     
Total params: 40,692
Trainable params: 20,410
Non-trainable params: 20,282
_________________________________________________________________




### Sequential Example

In [None]:
# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.
input_shape = [10]

model_for_pruning = tf.keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  tfmot.sparsity.keras.prune_low_magnitude(keras.layers.Dense(10))
])

model_for_pruning.summary()

Model: "sequential_7"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
reshape_4 (Reshape)          (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 26, 26, 12)        120       
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 13, 13, 12)        0         
_________________________________________________________________
flatten_7 (Flatten)          (None, 2028)              0         
_________________________________________________________________
prune_low_magnitude_dense_7  (None, 10)                40572     
Total params: 40,692
Trainable params: 20,410
Non-trainable params: 20,282
_________________________________________________________________


