In [3]:
import tensorflow as tf
import tensorflow_model_optimization as tfmot

from model_profiler import model_profiler


# Load Dataset


In [16]:
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images / 255.0
test_images  = test_images / 255.0


In [17]:
original_model = tf.keras.Sequential([
  tf.keras.layers.InputLayer(input_shape=(28, 28)),
  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

# for layer in model.layers:
#     if isinstance(layer, tf.keras.layers.Conv2D) or isinstance(layer, tf.keras.layers.Dense):
#         layer.set_weights([tf.constant(w, dtype=tf.float32) for w in layer.get_weights()])

original_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)
original_model.summary()


Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 reshape_3 (Reshape)         (None, 28, 28, 1)         0         
                                                                 
 conv2d_3 (Conv2D)           (None, 26, 26, 12)        120       
                                                                 
 max_pooling2d_3 (MaxPoolin  (None, 13, 13, 12)        0         
 g2D)                                                            
                                                                 
 flatten_3 (Flatten)         (None, 2028)              0         
                                                                 
 dense_3 (Dense)             (None, 10)                20290     
                                                                 
Total params: 20410 (79.73 KB)
Trainable params: 20410 (79.73 KB)
Non-trainable params: 0 (0.00 Byte)
__________________

In [19]:
print(model_profiler(original_model, 12800))


| Model Profile                    | Value         | Unit    |
|----------------------------------|---------------|---------|
| Selected GPUs                    | None Detected | GPU IDs |
| No. of FLOPs                     | 0.0           | BFLOPs  |
| GPU Memory Requirement           | 0.6181        | GB      |
| Model Parameters                 | 0.0204        | Million |
| Memory Required by Model Weights | 0.0779        | MB      |


In [21]:
original_model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=3
);


Epoch 1/3
Epoch 2/3
Epoch 3/3


In [22]:
_, baseline_model_accuracy = original_model.evaluate(
    test_images, test_labels, verbose=0)

print('Original model test accuracy:', baseline_model_accuracy)


Original model test accuracy: 0.9805999994277954


In [23]:
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS
}

clustered_model = cluster_weights(original_model, **clustering_params)


In [24]:
clustered_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)


In [25]:
print(model_profiler(clustered_model, 48))


| Model Profile                    | Value         | Unit    |
|----------------------------------|---------------|---------|
| Selected GPUs                    | None Detected | GPU IDs |
| No. of FLOPs                     | 0.0           | BFLOPs  |
| GPU Memory Requirement           | 0.0024        | GB      |
| Model Parameters                 | 0.0408        | Million |
| Memory Required by Model Weights | 0.1557        | MB      |


In [26]:
_, clustered_model_accuracy = clustered_model.evaluate(
    test_images, test_labels, verbose=0)

print('Clustered test accuracy:', clustered_model_accuracy)


Clustered test accuracy: 0.9553999900817871
