# Setup

In [1]:
!pip install tensorflow tensorflow-model-optimization

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorflow-model-optimization
  Downloading tensorflow_model_optimization-0.7.3-py2.py3-none-any.whl (238 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m238.9/238.9 KB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tensorflow-model-optimization
Successfully installed tensorflow-model-optimization-0.7.3


In [2]:
from tensorflow import keras
import tensorflow as tf
import tensorflow_model_optimization as tfmot
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
import numpy as np


In [3]:
from google.colab import drive
drive.mount("/content/gdrive")

Mounted at /content/gdrive


# Train model without pruning



Load MNIST dataset

In [4]:
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


Normalize the input image so that each pixel value is between 0 and 1.

In [5]:
train_images = train_images / 255.0
test_images = test_images / 255.0

Define the model architecture


In [6]:
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 reshape (Reshape)           (None, 28, 28, 1)         0         
                                                                 
 conv2d (Conv2D)             (None, 26, 26, 12)        120       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 13, 13, 12)       0         
 )                                                               
                                                                 
 flatten (Flatten)           (None, 2028)              0         
                                                                 
 dense (Dense)               (None, 10)                20290     
                                                                 
Total params: 20,410
Trainable params: 20,410
Non-trainable params: 0
____________________________________________________

Train the digit classification model

In [7]:
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.fit(
  train_images,
  train_labels,
  epochs=2,
  validation_split=0.1,
)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7f943988c0d0>

Evaluate baseline test accuracy (loss and accuracy) and save the model

In [8]:
model.evaluate(
    test_images, test_labels, verbose=0)

[0.08592481911182404, 0.974399983882904]

In [9]:
tf.keras.models.save_model(model, "/content/gdrive/MyDrive/Curso-Jetson/models/mnist_model.h5", include_optimizer=False)

# Fine-tune model with pruning

Compute end step to finish pruning after 2 epochs.

In [10]:
batch_size = 128
epochs = 2
validation_split = 0.1

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

print("Num_images: " + str(num_images))
print("End step: " + str(end_step))

Num_images: 54000.0
End step: 844


Define model for pruning.


In [11]:
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}

model_for_pruning = prune_low_magnitude(model, **pruning_params)

model_for_pruning.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model_for_pruning.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_reshape  (None, 28, 28, 1)        1         
  (PruneLowMagnitude)                                            
                                                                 
 prune_low_magnitude_conv2d   (None, 26, 26, 12)       230       
 (PruneLowMagnitude)                                             
                                                                 
 prune_low_magnitude_max_poo  (None, 13, 13, 12)       1         
 ling2d (PruneLowMagnitude)                                      
                                                                 
 prune_low_magnitude_flatten  (None, 2028)             1         
  (PruneLowMagnitude)                                            
                                                                 
 prune_low_magnitude_dense (  (None, 10)               4

Train with pruning

In [12]:
callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
]

model_for_pruning.fit(train_images, train_labels,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7f9436d76190>

Remove pruning layers

In [13]:
model_for_pruning = tfmot.sparsity.keras.strip_pruning(
    model_for_pruning
)

Evaluate baseline test accuracy (loss and accuracy) and save the model

In [14]:
model_for_pruning.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model_for_pruning.evaluate(
   test_images, test_labels, verbose=0)

[0.10427340120077133, 0.9696999788284302]

In [15]:
tf.keras.models.save_model(model_for_pruning, "/content/gdrive/MyDrive/Curso-Jetson/models/mnist_model_0.8sparsity.h5", include_optimizer=False)