<table class="tfo-notebook-buttons" align="left">

  <td>
    <a target="_blank" href="https://colab.research.google.com/github/rajdeepd/tensorflow_2.0_book_code/blob/master/ch09/weight_clustering_example_2_mnist.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/rajdeepd/tensorflow_2.0_book_code/blob/master/ch09/weight_clustering_example_2_mnist.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View on GitHub</a>
  </td>

</table>

# Weight Clustering Keras example

## Overview

This is an  example showing the usage of the **Weight Clustering** API, part of the TensorFlow Model Optimization Toolkit's collaborative optimization pipeline. We will use 4 and 8 clusters and see the affect on accuracy



## Setup

You can run this Jupyter Notebook in local [virtualenv](https://www.tensorflow.org/install/pip?lang=python3#2.-create-a-virtual-environment-recommended) or [colab](https://colab.sandbox.google.com/).

In [1]:
! pip install -q tensorflow-model-optimization

## Make the necessary import

In [2]:
import tensorflow as tf

import numpy as np
import tempfile
import zipfile
import os

## Train a tf.keras model for MNIST without clustering
1. load the dataset
2. train and test images normalize
3. Create Sequential model
4. Compile the model with following parameters
  
   * Use `adam` optimizer
   * `SparseCategoricalCrossentropy`
   * Optimize for `accuracy` metrics
5. Run model.fit(..) with `train_images` and `train_labels` for 10 epochs and validation split of 0.1

In [3]:
# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

model = tf.keras.Sequential([
  tf.keras.layers.InputLayer(input_shape=(28, 28)),
  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=1
)



<keras.callbacks.History at 0x7fa572f65790>

### Evaluate the baseline model and save it for later usage

In [4]:
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
tf.keras.models.save_model(model, keras_file, include_optimizer=False)

Baseline test accuracy: 0.958299994468689
Saving model to:  /tmp/tmp9gpn1bs4.h5


## Cluster and fine-tune the model with 4 and 8 clusters

Apply the `cluster_weights()` API to cluster the whole pre-trained model to demonstrate and observe its effectiveness in reducing the model size when applying zip, while maintaining accuracy. For more details refer to the  [clustering comprehensive guide](https://www.tensorflow.org/model_optimization/guide/clustering/clustering_comprehensive_guide).

### Define the model and apply the clustering API

The model needs to be pre-trained before using the clustering API. This function wraps a keras model or layer with clustering functionality which clusters the layer's weights during training. For examples, using this with number_of_clusters equals 8 will ensure that each weight tensor has no more than 8 unique values.

In [5]:
import tensorflow_model_optimization as tfmot


In [10]:
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
cluster_weights_4 = tfmot.clustering.keras.cluster_weights
model_4 = tf.keras.Sequential([
  tf.keras.layers.InputLayer(input_shape=(28, 28)),
  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

# Train the digit classification model
model_4.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model_4.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=1
)
cluster_weights_4_ = tfmot.clustering.keras.cluster_weights
clustering_params_4 = {
  'number_of_clusters': 4,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS
}

clustered_model_4 = cluster_weights_4(model_4, **clustering_params_4)

# Use smaller learning rate for fine-tuning
opt = tf.keras.optimizers.Adam(learning_rate=1e-7)

clustered_model_4.compile(
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])



In [11]:
cluster_weights_8 = tfmot.clustering.keras.cluster_weights
model_8 = tf.keras.Sequential([
  tf.keras.layers.InputLayer(input_shape=(28, 28)),
  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

# Train the digit classification model
model_8.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model_8.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=1
)
cluster_weights_8_ = tfmot.clustering.keras.cluster_weights
clustering_params_8 = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS
}

clustered_model_8 = cluster_weights_8(model_8, **clustering_params_4)

# Use smaller learning rate for fine-tuning
opt = tf.keras.optimizers.Adam(learning_rate=1e-7)

clustered_model_8.compile(
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])



### Fine-tune the model and evaluate the accuracy against baseline

Fine-tune the model with clustering for 3 epochs.

In [12]:
# Fine-tune model
clustered_model_4.fit(
  train_images,
  train_labels,
  epochs=3,
  validation_split=0.1)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.callbacks.History at 0x7fa572c16790>

In [13]:
# Fine-tune model
clustered_model_8.fit(
  train_images,
  train_labels,
  epochs=3,
  validation_split=0.1)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.callbacks.History at 0x7fa5727941d0>

Define helper functions to calculate and print the number of clustering in each kernel of the model.

In [14]:
def print_model_weight_clusters(model):

    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            # ignore auxiliary quantization weights
            if "quantize_layer" in weight.name:
                continue
            if "kernel" in weight.name:
                unique_count = len(np.unique(weight))
                print(
                    f"{layer.name}/{weight.name}: {unique_count} clusters "
                )

Check that the model kernels were correctly clustered. We need to strip the clustering wrapper first.

In [15]:
stripped_clustered_model_4 = tfmot.clustering.keras.strip_clustering(clustered_model_4)
print_model_weight_clusters(stripped_clustered_model_4)

stripped_clustered_model_8 = tfmot.clustering.keras.strip_clustering(clustered_model_8)
print_model_weight_clusters(stripped_clustered_model_8)

conv2d_5/kernel:0: 4 clusters 
dense_5/kernel:0: 4 clusters 
conv2d_6/kernel:0: 4 clusters 
dense_6/kernel:0: 4 clusters 


For this example, there is minimal loss in test accuracy after clustering, compared to the baseline.

In [16]:
_, clustered_model_accuracy_4 = clustered_model_4.evaluate(
  test_images, test_labels, verbose=0)
_, clustered_model_accuracy_8 = clustered_model_8.evaluate(
  test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
print('Clustered 4 clusters test accuracy:', clustered_model_accuracy_4)
print('Clustered 8 clusters test accuracy:', clustered_model_accuracy_8)

Baseline test accuracy: 0.958299994468689
Clustered 4 clusters test accuracy: 0.9391999840736389
Clustered 8 clusters test accuracy: 0.9294999837875366


## Conclusion

In this sample, we use`cluster_weights()` API to create two clustered models with 4 and 8 clusters and compared the model accuracy.