## Pruning

Before converting the network we could use [pruning](https://www.tensorflow.org/model_optimization/guide/pruning).


Documentation: [tfmot.sparsity.keras.prune_low_magnitude](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/prune_low_magnitude)


In [13]:
%%script false --no-raise-error

import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

batch_size = 128
epochs = 2
validation_split = 0.1

num_images = x_train_normalized.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.5,
                                                            final_sparsity=0.8,
                                                            begin_step=0,
                                                            end_step=end_step)
}

model_for_pruning = prune_low_magnitude(tf_model, **pruning_params)

model_for_pruning.compile(optimizer='adam',
                         loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
                         metrics=['accuracy'])

model_for_pruning.summary()


Instructions for updating:
Please use `layer.add_weight` method instead.
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_conv2d_4 (None, 30, 30, 6)         116       
_________________________________________________________________
prune_low_magnitude_average_ (None, 15, 15, 6)         1         
_________________________________________________________________
prune_low_magnitude_conv2d_5 (None, 13, 13, 16)        1746      
_________________________________________________________________
prune_low_magnitude_average_ (None, 6, 6, 16)          1         
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 576)               1         
_________________________________________________________________
prune_low_magnitude_dense_6  (None, 120)               138362    
_______________________________________________

In [14]:
%%script false --no-raise-error

logdir = './logs'

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
#  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]
  
model_for_pruning.fit(x_train_normalized, y_train,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)

Epoch 1/2
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x162fafca0>

In [92]:
%%script false --no-raise-error

!tensorboard --logdir={logdir}

W0628 16:06:07.765544 123145418182656 plugin_event_accumulator.py:332] Found more than one graph event per run, or there was a metagraph containing a graph_def, as well as one or more graph events.  Overwriting the graph with the newest event.
W0628 16:06:07.770737 123145418182656 plugin_event_accumulator.py:332] Found more than one graph event per run, or there was a metagraph containing a graph_def, as well as one or more graph events.  Overwriting the graph with the newest event.
W0628 16:06:07.776291 123145418182656 plugin_event_accumulator.py:332] Found more than one graph event per run, or there was a metagraph containing a graph_def, as well as one or more graph events.  Overwriting the graph with the newest event.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.2.1 at http://localhost:6006/ (Press CTRL+C to quit)
^C


#### Save the pruned model

Both `tfmot.sparsity.keras.strip_pruning` and applying a standard compression algorithm (e.g. via gzip) are necessary to see the compression benefits of pruning.

> Once a model has been pruned to required sparsity, this method can be used to restore the original model with the sparse weights.

In [15]:
%%script false --no-raise-error

model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

tf_model_pruned_file = './keras-model/LeNet-MNIST_pruned.h5'

tf.keras.models.save_model(model_for_export, tf_model_pruned_file, include_optimizer=False)
print('Saved pruned Keras model to:', tf_model_pruned_file)
model_for_export.summary()

Saved pruned Keras model to: ./keras-model/LeNet-MNIST_pruned.h5
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_4 (Conv2D)            (None, 30, 30, 6)         60        
_________________________________________________________________
average_pooling2d_4 (Average (None, 15, 15, 6)         0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 13, 13, 16)        880       
_________________________________________________________________
average_pooling2d_5 (Average (None, 6, 6, 16)          0         
_________________________________________________________________
flatten_2 (Flatten)          (None, 576)               0         
_________________________________________________________________
dense_6 (Dense)              (None, 120)               69240     
_______________________________________________________

In [16]:
%%script false --no-raise-error

score = model_for_pruning.evaluate(x=x_test_normalized, y=y_test, verbose=0)

tf_model_pruned_loss = score[0]
tf_model_pruned_accuracy = score[1]

print('Test accuracy:\t\t', tf_model_accuracy)
print('Test accuracy (pruned):\t', tf_model_pruned_accuracy)
print('Test loss:\t\t', tf_model_loss)
print('Test loss (pruned):\t', tf_model_pruned_loss)

Test accuracy:		 0.9879000186920166
Test accuracy (pruned):	 0.98089998960495
Test loss:		 0.04126233980059624
Test loss (pruned):	 1.4850536584854126


In [17]:
%%script false --no-raise-error

prediction = model_for_pruning.predict(x_test_normalized)
prediction_delta = np.mean((prediction - y_test) ** 2)

print('Test squared loss (manually):\t', prediction_delta)

Test squared loss (manually):	 0.0029721695


In [18]:
%%script false --no-raise-error

loss_fn = tf.keras.losses.MeanSquaredError()
squared_loss = loss_fn(y_test, prediction)
print('Test squared loss:\t', squared_loss.numpy())

Test squared loss:	 0.002972168


This should match our previous loss from the `model.evaluate()` function.

In [19]:
%%script false --no-raise-error

loss_fn = tf.keras.losses.CategoricalCrossentropy()
crossentropy_loss = loss_fn(y_test, prediction)
print('Test crossentropy loss:\t', crossentropy_loss.numpy())

Test crossentropy loss:	 0.062741384


In [35]:
def get_gzipped_model_size(file):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile

  zipped_file = './keras-model/pruned.zip'
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)

In [139]:
%%script false --no-raise-error

print("Size of baseline Keras model:\t %.2f bytes" % os.path.getsize(tf_model_file))
print("Size of pruned Keras model:\t %.2f bytes" % os.path.getsize(tf_model_pruned_file))
print("\ngzipped:")
print("Size of gzipped baseline Keras model:\t %.2f bytes" % (get_gzipped_model_size(tf_model_file)))
print("Size of gzipped pruned Keras model:\t %.2f bytes" % (get_gzipped_model_size(tf_model_pruned_file)))


Size of baseline Keras model:	 1021632.00 bytes
Size of pruned Keras model:	 349032.00 bytes

gzipped:
Size of gzipped baseline Keras model:	 914806.00 bytes
Size of gzipped pruned Keras model:	 99547.00 bytes


https://www.tensorflow.org/model_optimization/guide/pruning/comprehensive_guide.md#hardware-specific_optimizations

In [20]:
tf_model_pruned = tf.keras.models.load_model('./keras-model/LeNet-MNIST_pruned.h5')
tf_model_pruned.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_4 (Conv2D)            (None, 30, 30, 6)         60        
_________________________________________________________________
average_pooling2d_4 (Average (None, 15, 15, 6)         0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 13, 13, 16)        880       
_________________________________________________________________
average_pooling2d_5 (Average (None, 6, 6, 16)          0         
_________________________________________________________________
flatten_2 (Flatten)          (None, 576)               0         
_________________________________________________________________
dense_6 (Dense)              (None, 120)               69240     
_________________________________________________________________
dense_7 (Dense)              (None, 84)               

#### Explore sparisty and check zeros for each layer

In [21]:
%%script false --no-raise-error

from tensorflow.keras.models import load_model
for file in [tf_model_file, tf_model_pruned_file]:
    model = load_model(file)
    import numpy as np

    for i, w in enumerate(model.get_weights()):
        print(
            "{} -- \tTotal:{}, Zeros: {:.2f}%".format(
                model.weights[i].name, w.size, np.sum(w == 0) / w.size * 100
            )
        )
    print('---')

conv2d_4_4/kernel:0 -- 	Total:54, Zeros: 0.00%
conv2d_4_4/bias:0 -- 	Total:6, Zeros: 0.00%
conv2d_5_4/kernel:0 -- 	Total:864, Zeros: 0.00%
conv2d_5_4/bias:0 -- 	Total:16, Zeros: 0.00%
dense_6_4/kernel:0 -- 	Total:69120, Zeros: 0.00%
dense_6_4/bias:0 -- 	Total:120, Zeros: 0.00%
dense_7_4/kernel:0 -- 	Total:10080, Zeros: 0.00%
dense_7_4/bias:0 -- 	Total:84, Zeros: 0.00%
dense_8_4/kernel:0 -- 	Total:840, Zeros: 0.00%
dense_8_4/bias:0 -- 	Total:10, Zeros: 0.00%
---
conv2d_4_5/kernel:0 -- 	Total:54, Zeros: 79.63%
conv2d_4_5/bias:0 -- 	Total:6, Zeros: 0.00%
conv2d_5_5/kernel:0 -- 	Total:864, Zeros: 79.98%
conv2d_5_5/bias:0 -- 	Total:16, Zeros: 0.00%
dense_6_5/kernel:0 -- 	Total:69120, Zeros: 80.00%
dense_6_5/bias:0 -- 	Total:120, Zeros: 0.00%
dense_7_5/kernel:0 -- 	Total:10080, Zeros: 80.00%
dense_7_5/bias:0 -- 	Total:84, Zeros: 0.00%
dense_8_5/kernel:0 -- 	Total:840, Zeros: 80.00%
dense_8_5/bias:0 -- 	Total:10, Zeros: 0.00%
---
