In [1]:
import tensorflow as tf
from tensorflow import keras
import tensorflow_model_optimization as tfmot

import numpy as np
import tempfile
import zipfile
import os

In [2]:
mnist= keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

In [3]:
train_labels.shape

(60000,)

In [4]:
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

In [5]:
model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(28, 28)),
    keras.layers.Reshape(target_shape=(28, 28, 1)),
    keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu),
    keras.layers.MaxPooling2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
])

In [6]:
# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [7]:
model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x19c4a581190>

In [8]:
_,baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0)

In [9]:
print(baseline_model_accuracy)

0.9817000031471252


In [10]:
#_,keras_file = tempfile.mkstemp('.h5')
#print('Saving model to: ', keras_file)
#tf.keras.models.save_model(model, keras_file, include_optimizer=False)

In [11]:
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
  'number_of_clusters': 16,
  'cluster_centroids_init': CentroidInitialization.LINEAR
}

# Cluster a whole model
clustered_model = cluster_weights(model, **clustering_params)

In [12]:
clustered_model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 cluster_reshape (ClusterWei  (None, 28, 28, 1)        0         
 ghts)                                                           
                                                                 
 cluster_conv2d (ClusterWeig  (None, 26, 26, 12)       244       
 hts)                                                            
                                                                 
 cluster_max_pooling2d (Clus  (None, 13, 13, 12)       0         
 terWeights)                                                     
                                                                 
 cluster_flatten (ClusterWei  (None, 2028)             0         
 ghts)                                                           
                                                                 
 cluster_dense (ClusterWeigh  (None, 10)               4

In [13]:
# Use smaller learning rate for fine-tuning clustered model
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)

clustered_model.compile(
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

In [14]:
clustered_model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 cluster_reshape (ClusterWei  (None, 28, 28, 1)        0         
 ghts)                                                           
                                                                 
 cluster_conv2d (ClusterWeig  (None, 26, 26, 12)       244       
 hts)                                                            
                                                                 
 cluster_max_pooling2d (Clus  (None, 13, 13, 12)       0         
 terWeights)                                                     
                                                                 
 cluster_flatten (ClusterWei  (None, 2028)             0         
 ghts)                                                           
                                                                 
 cluster_dense (ClusterWeigh  (None, 10)               4

In [16]:
# Fine-tune model
clustered_model.fit(
  train_images,
  train_labels,
  batch_size=500,
  epochs=10,
  validation_split=0.1)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x19c4beb6670>

In [17]:
_, clustered_model_accuracy = clustered_model.evaluate(
  test_images, test_labels, verbose=0)

print(baseline_model_accuracy)
print(clustered_model_accuracy)

0.9817000031471252
0.9811999797821045


In [18]:
final_model = tfmot.clustering.keras.strip_clustering(clustered_model)

_, clustered_keras_file = tempfile.mkstemp('.h5')
print('Saving clustered model to: ', clustered_keras_file)
tf.keras.models.save_model(final_model, clustered_keras_file, 
                           include_optimizer=False)

Saving clustered model to:  C:\Users\alexa\AppData\Local\Temp\tmpk32ujy6s.h5




INFO:tensorflow:Assets written to: C:\Users\alexa\AppData\Local\Temp\tmpj2vseuk5\assets


INFO:tensorflow:Assets written to: C:\Users\alexa\AppData\Local\Temp\tmpj2vseuk5\assets


FileNotFoundError: [Errno 2] No such file or directory: '/tmp/clustered_mnist.tflite'