# Pruining - how to compress neural networks

<img src="images/pruning.png" height="50%" width="50%">

## Background
Modern state-of-the-art neural network architectures are HUGE. For instance, you have probably heard about GPT-3, OpenAI’s newest revolutionary NLP model, capable of writing poetry and interactive storytelling.

Well, GPT-3 has around 175 billion parameters.

According to GPT-3 original paper, it required 3.14E+23 flops of training time, and the computing cost itself is in the millions of dollars.

Beside that importent note more issues arise with model size when trying to run inference on edge computer, since its low on resources some of the models would not fit in the edge device memory and some wont even be able to optimize on the edge device.

<img src="images/gpt3.png">


To optimize these costs by compressing the models, three main methods have emerged:
* weight pruning
* quantization
* knowledge distillation

## Pruning
We can think of pruning as the **"optimal brain-damage"**, because thats essentialy what it is, all of the pruning methods trying to shed the least significant weights from the model without degregation in its performance.

This subject is under heavy research and most of it is premature, I'll go over one trending pruning solution backed up with an article called "The winning lottory ticket hypotesis".

> A randomly-initialized, dense neural network contains a subnetwork that is initialized such that — when trained in isolation — it can match the test accuracy of the original network after training for at most the same number of iterations.

That means that if we train a big model initialized with random weights, see what weights are getting toward zero and removing them from the model, we can potentialy win a lottory ticket in a form of a subnetwork that correspond to the article hypotesis.

The authors suggest the following algorithm to do so:
* Randomly initialize the network and store the initial weights for later reference.
* Train the network for a given number of steps.
* Remove a percentage of the weights with the lowest magnitude.
* Restore the remaining weights to the value that was given during the first initialization.
* Go to Step 2. and iterate the pruning.

## Quantization
Quantization is using smaller data types to express the different weights in the model, thus instead of using FP32 multiplication we can use FP16, INT8 and tensor cores.

This can be acheived using dynamic range remapping, the algorithm learns the range in which the weights multiplications values are changing, and remap these values to a smaller data type.

## Knowlege Distilation
Knowlege distilation is an approach developed by  Geoffrey Hinton, Oriol Vinyals, and Jeff Dean in their paper Distilling the Knowledge in a Neural Network.

The method says that we can train a super network with lots of parameters, it will serve as the teacher, which in turn teach different smaller architecture and try to acheive high accuracy while using smaller more production-ready networks.


#### In this article we will go over pruning technichue with TensorFlow framework, we will start with a Dense network and create a Sparse subnetwork out of it.

## Setup

In [1]:
! pip install -q tensorflow-model-optimization

distutils: /usr/local/lib/python3.8/dist-packages
sysconfig: /usr/lib/python3.8/site-packages[0m
distutils: /usr/local/lib/python3.8/dist-packages
sysconfig: /usr/lib/python3.8/site-packages[0m
distutils: /usr/local/include/python3.8/UNKNOWN
sysconfig: /usr/include/python3.8[0m
distutils: /usr/local/bin
sysconfig: /usr/bin[0m
distutils: /usr/local
sysconfig: /usr[0m
user = False
home = None
root = None
prefix = None[0m
  distutils: /usr/local/lib/python3.8/dist-packages
  sysconfig: /usr/lib/python3.8/site-packages[0m
  distutils: /usr/local/lib/python3.8/dist-packages
  sysconfig: /usr/lib/python3.8/site-packages[0m
  distutils: /usr/local/include/python3.8/dm-tree
  sysconfig: /usr/include/python3.8/dm-tree[0m
  distutils: /usr/local/bin
  sysconfig: /usr/bin[0m
  distutils: /usr/local
  sysconfig: /usr[0m
  user = False
  home = None
  root = None
  prefix = None[0m
  distutils: /usr/local/lib/python3.8/dist-packages
  sysconfig: /usr/lib/python3.8/site-packages[0m
  di

In [2]:
import tempfile
import os

import tensorflow as tf
import numpy as np

from tensorflow import keras

%load_ext tensorboard

## Train a model for MNIST without pruning

In [3]:
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 and 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs=4,
  validation_split=0.1,
)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


<tensorflow.python.keras.callbacks.History at 0x7ff30c71c940>

Evaluate baseline test accuracy and save the model for later usage.

In [4]:
model.summary()

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
print('Saved baseline model to:', keras_file)

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
reshape (Reshape)            (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d (Conv2D)              (None, 26, 26, 12)        120       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 12)        0         
_________________________________________________________________
flatten (Flatten)            (None, 2028)              0         
_________________________________________________________________
dense (Dense)                (None, 10)                20290     
Total params: 20,410
Trainable params: 20,410
Non-trainable params: 0
_________________________________________________________________
Baseline test accuracy: 0.9785000085830688
Saved baseline model to: /tmp/tmpnmotx8yj.h5


## Fine-tune pre-trained model with pruning


### Define the model

You will apply pruning to the whole model and see this in the model summary.

In this example, you start the model with 50% sparsity (50% zeros in weights)
and end with 80% sparsity.

https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/PolynomialDecay

In [5]:
import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set. 

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define model for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}

model_for_pruning = prune_low_magnitude(model, **pruning_params)

# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model_for_pruning.summary()



Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_reshape  (None, 28, 28, 1)         1         
_________________________________________________________________
prune_low_magnitude_conv2d ( (None, 26, 26, 12)        230       
_________________________________________________________________
prune_low_magnitude_max_pool (None, 13, 13, 12)        1         
_________________________________________________________________
prune_low_magnitude_flatten  (None, 2028)              1         
_________________________________________________________________
prune_low_magnitude_dense (P (None, 10)                40572     
Total params: 40,805
Trainable params: 20,410
Non-trainable params: 20,395
_________________________________________________________________


### Train and evaluate the model against baseline

Fine tune with pruning for two epochs.

`tfmot.sparsity.keras.UpdatePruningStep` is required during training, and `tfmot.sparsity.keras.PruningSummaries` provides logs for tracking progress and debugging.

In [6]:
logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]
  
model_for_pruning.fit(train_images, train_labels,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)

Epoch 1/2
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x7ff2e9249430>

For this example, there is minimal loss in test accuracy after pruning, compared to the baseline.

In [7]:
_, model_for_pruning_accuracy = model_for_pruning.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy) 
print('Pruned test accuracy:', model_for_pruning_accuracy)

Baseline test accuracy: 0.9785000085830688
Pruned test accuracy: 0.9721999764442444


The logs show the progression of sparsity on a per-layer basis.

In [8]:
#docs_infra: no_execute
%tensorboard --logdir={logdir} --port 8890 --host 0.0.0.0

## Create 3x smaller models from pruning

Both `tfmot.sparsity.keras.strip_pruning` and applying a standard compression algorithm (e.g. via gzip) are necessary to see the compression
benefits of pruning.

*   `strip_pruning` is necessary since it removes every tf.Variable that pruning only needs during training, which would otherwise add to model size during inference
*   Applying a standard compression algorithm is necessary since the serialized weight matrices are the same size as they were before pruning. However, pruning makes most of the weights zeros, which is
added redundancy that algorithms can utilize to further compress the model.

First, create a compressible model for TensorFlow.

In [9]:
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

_, pruned_keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
print('Saved pruned Keras model to:', pruned_keras_file)

Saved pruned Keras model to: /tmp/tmpyfebzim1.h5


Define a helper function to actually compress the models via gzip and measure the zipped size.

In [10]:
def get_gzipped_model_size(file):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)

Compare and see that the models are 3x smaller from pruning.

In [11]:
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped pruned Keras model: %.2f bytes" % (get_gzipped_model_size(pruned_keras_file)))

Size of gzipped baseline Keras model: 78169.00 bytes
Size of gzipped pruned Keras model: 25795.00 bytes
