<a href="https://colab.research.google.com/github/danijak/testing/blob/master/Pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
import numpy as np
from tensorflow import keras

In [3]:
# loading the MNIST dataset from keras
data = keras.datasets.mnist 

In [4]:
# training & testing split
(train_images, train_labels), (test_images, test_labels) = data.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [6]:
# Input Normalize to get it into 0-1
train_images = train_images/255.0
test_images = test_images/255.0

In [30]:
# model
model = keras.models.Sequential()
model.add(keras.layers.InputLayer(input_shape=(28,28)))
model.add(keras.layers.Reshape(target_shape=(28,28,1)))
model.add(keras.layers.Conv2D(12,kernel_size=(3,3),activation='relu'))
model.add(keras.layers.MaxPool2D(pool_size=(2,2)))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(10))
model.summary()

Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
reshape_2 (Reshape)          (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 26, 26, 12)        120       
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 13, 13, 12)        0         
_________________________________________________________________
flatten (Flatten)            (None, 2028)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 10)                20290     
Total params: 20,410
Trainable params: 20,410
Non-trainable params: 0
_________________________________________________________________


In [34]:
#Train the model
model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
model.fit(x=train_images, y=train_labels, batch_size=32, epochs=5, validation_split=0.1)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7fc06619cf98>

In [36]:
# Evaluate
print(model.evaluate(x=test_images, y=test_labels))


[0.26367828249931335, 0.9244999885559082]


In [39]:
! pip install tensorflow_model_optimization

Collecting tensorflow_model_optimization
[?25l  Downloading https://files.pythonhosted.org/packages/55/38/4fd48ea1bfcb0b6e36d949025200426fe9c3a8bfae029f0973d85518fa5a/tensorflow_model_optimization-0.5.0-py2.py3-none-any.whl (172kB)
[K     |██                              | 10kB 13.0MB/s eta 0:00:01[K     |███▉                            | 20kB 1.7MB/s eta 0:00:01[K     |█████▊                          | 30kB 2.1MB/s eta 0:00:01[K     |███████▋                        | 40kB 2.5MB/s eta 0:00:01[K     |█████████▌                      | 51kB 2.0MB/s eta 0:00:01[K     |███████████▍                    | 61kB 2.2MB/s eta 0:00:01[K     |█████████████▎                  | 71kB 2.5MB/s eta 0:00:01[K     |███████████████▏                | 81kB 2.7MB/s eta 0:00:01[K     |█████████████████               | 92kB 2.9MB/s eta 0:00:01[K     |███████████████████             | 102kB 2.8MB/s eta 0:00:01[K     |████████████████████▉           | 112kB 2.8MB/s eta 0:00:01[K     |██████

In [40]:
# Pre trained model with pruning
import tensorflow_model_optimization as tfmot
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set. 

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define model for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}

model_for_pruning = prune_low_magnitude(model, **pruning_params)

# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model_for_pruning.summary()


Instructions for updating:
Please use `layer.add_weight` method instead.
Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_reshape_ (None, 28, 28, 1)         1         
_________________________________________________________________
prune_low_magnitude_conv2d_4 (None, 26, 26, 12)        230       
_________________________________________________________________
prune_low_magnitude_max_pool (None, 13, 13, 12)        1         
_________________________________________________________________
prune_low_magnitude_flatten  (None, 2028)              1         
_________________________________________________________________
prune_low_magnitude_dense_2  (None, 10)                40572     
Total params: 40,805
Trainable params: 20,410
Non-trainable params: 20,395
_________________________________________________________________


In [42]:
import tempfile

In [43]:
logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]
  
model_for_pruning.fit(train_images, train_labels,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)

Epoch 1/2
Instructions for updating:
use `tf.profiler.experimental.stop` instead.
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x7fc0641c1780>

In [45]:
print(model_for_pruning.evaluate(x=test_images, y=test_labels))

[0.3080199956893921, 0.9182999730110168]
