In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_model_optimization as tfmot
import tensorflow_datasets as tfds

%load_ext tensorboard

from os import path
import pathlib
import tempfile

In [2]:
# normalizing the images to [0, 1]
def normalize(image, label):
    return tf.cast(image, tf.float32) / 255., label

def random_crop(image):
    cropped_image = tf.image.random_crop(
        image, size=[256, 256, 3])

    return cropped_image

def random_jitter(image):
    # resizing to 286 x 286 x 3
    image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # randomly cropping to 256 x 256 x 3
    image = random_crop(image)

    # random mirroring
    image = tf.image.random_flip_left_right(image)

    return image

def preprocess_flowers_train(image, label):
    image = random_jitter(image)
    return image, label

# -------------------------------

def preprocess_flowers(image, label):
    image = tf.image.resize(image, [256, 256],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    return image, label

def load_flowers_dataset():  
    (ds_train, ds_validation, ds_test), ds_info = tfds.load(name="tf_flowers", 
                                             with_info=True,
                                             split=['train[:70%]', 'train[70%:85%]', 'train[85%:]'],  #70/15/15 split
                                             as_supervised=True)

    ds_train = ds_train.map(normalize)    
    ds_train = ds_train.map(preprocess_flowers)
    ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
    
    ds_validation = ds_validation.map(normalize)
    ds_validation = ds_validation.map(preprocess_flowers)
    
    ds_test = ds_test.map(normalize)
    ds_test = ds_test.map(preprocess_flowers)
    
    return ds_train, ds_validation, ds_test

def load_beans_datasets():
    (ds_train, ds_validation, ds_test), ds_info = tfds.load(
        'beans',
        split=['train', 'validation', 'test'],
        shuffle_files=True,
        as_supervised=True,
        with_info=True,
    )
    
    ds_train = ds_train.map(normalize)
    ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
    
    ds_validation = ds_validation.map(normalize)
    
    ds_test = ds_test.map(normalize)
    
    return ds_train, ds_validation, ds_test

# Optimization after training

## Quantization after training

1. Dynamic range quantization
2. Full integer quantization
3. Float16 quantization

### 1. Dynamic range quantization

Only the weights are converted from float to 8 bit int. 

### 2. Full integer quantization

Weights and activation outputs are quantizated. Good for microcontrolers and TPUs

In [3]:
def model_quantization(model_path, ds):
    # check if model was already in optimization folder
    if (len(model_path.split('/')[0].split('_')) == 3):
        optimized_dir = pathlib.Path(model_path.split('/')[0] + '/')
    else:
        optimized_dir = pathlib.Path(model_path.split('/')[0] + '_optimized/')
    model_name = model_path.split('/')[-1].split('.')[0]
    
    model = tf.keras.models.load_model(model_path)
    
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    tflite_model = converter.convert()
    
    # save converted tflite model
    tf_model_path = optimized_dir/(model_name + '.tflite')
    size = tf_model_path.write_bytes(tflite_model)
    print('Converted TFLite model ('+ str(size) +' Bytes) saved to: ' + str(tf_model_path))
    
    # 1. optimize model using dynamic range quantization
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    tflite_model_quant = converter.convert()
    
    tf_quant_model_path = optimized_dir/(model_name + '_dynamic_rage_quantization.tflite')
    size = tf_quant_model_path.write_bytes(tflite_model_quant)
    print('Dynamic range quantizatized TFLite model ('+ str(size) +' Bytes) saved to: ' + str(tf_quant_model_path))
    
    # 2. Full integer quantization
    def representative_data_gen():
        for input_value, _ in ds.batch(1).take(100):
            # Model has only one input so each data point has one element.
            yield [input_value]

    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.representative_dataset = representative_data_gen

    tflite_model_quant = converter.convert()
    tf_quant_model_path = optimized_dir/(model_name + '_full_integer_quantization.tflite')
    size = tf_quant_model_path.write_bytes(tflite_model_quant)
    print('Full integer quantizatized TFLite model ('+ str(size) +' Bytes) saved to: ' + str(tf_quant_model_path))
    
    # 2.1 Full integer quantization with input and output in integer too
    try:
        converter = tf.lite.TFLiteConverter.from_keras_model(model)
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        converter.representative_dataset = representative_data_gen
        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
        converter.inference_input_type = tf.uint8
        converter.inference_output_type = tf.uint8

        tflite_model_quant = converter.convert()
        tf_quant_model_path = optimized_dir/(model_name + '_full_integer_quantization_integer_io.tflite')
        size = tf_quant_model_path.write_bytes(tflite_model_quant)
        print('Full integer quantizatized with integer io TFLite model ('+ str(size) +' Bytes) saved to: ' + str(tf_quant_model_path))
    except:
        print('ERROR: Failed Full integer quantizatized with integer io TFLite model')
        
    # 3. float16 quantization
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_types = [tf.float16]
    
    tflite_model_quant = converter.convert()
    tf_quant_model_path = optimized_dir/(model_name + '_float16_quantization.tflite')
    size = tf_quant_model_path.write_bytes(tflite_model_quant)
    print('float16 quantizatized TFLite model ('+ str(size) +' Bytes) saved to: ' + str(tf_quant_model_path))


## Pruning and fine-tuning

1. Prune model to different sparsity  ( tf uses magnitude-based pruning )
    1. ConstantSparsity - sparsity is kept constant during training.
    2. PolynomialDecay - the degree of sparsity is changed during training.
2. Fine-tune model

https://www.machinecurve.com/index.php/2020/09/29/tensorflow-pruning-schedules-constantsparsity-and-polynomialdecay/

In [4]:
BATCH_SIZE = 64
# todo change this
PRUNING_EPOCHS = 1
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

In [5]:
# Some layers cannot be pruned

def prune_prunable_layers(model, pruning_params):
    """returns model for pruning with avoided non prunable layers"""
    
    # Rescaling layer cannot be pruned
    def apply_pruning_to_prunable(layer):
        if isinstance(layer, tf.keras.layers.experimental.preprocessing.Rescaling) or isinstance(layer, tf.keras.layers.experimental.preprocessing.Normalization):
            return layer
        return prune_low_magnitude(layer, **pruning_params)
    model_for_pruning = tf.keras.models.clone_model(
                                model,
                                clone_function=apply_pruning_to_prunable,
                            )
    return model_for_pruning

In [6]:
def prune_model(model_path, batch_size, pruning_epochs, ds_train, ds_validation):
     # check if model was already in optimization folder
    if (len(model_path.split('/')[0].split('_')) == 3):
        optimized_dir = pathlib.Path(model_path.split('/')[0] + '/')
    else:
        optimized_dir = pathlib.Path(model_path.split('/')[0] + '_optimized/')
    model_name = model_path.split('/')[-1].split('.')[0]
    
    model = tf.keras.models.load_model(model_path)
    
    sparsities = [0.5, 0.75, 0.9]
    
    ds_train = ds_train.batch(batch_size)
    ds_validation = ds_validation.batch(batch_size)
    
    ds_train.cache()
    ds_validation.cache()
    
    # get number of images
    num_images = 0
    for i in ds_train.as_numpy_iterator():
        num_images+=1

    end_step = np.ceil(num_images / batch_size).astype(np.int32) * pruning_epochs
        
    for sparsity in sparsities:
        # Define pruning configuration
        pruning_params_constant = {
            'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(target_sparsity=sparsity,
                                                                        begin_step=0,
                                                                        end_step=end_step)
        }
        
        pruning_params_polynomial = {
            'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0,
                                                                    final_sparsity=sparsity,
                                                                    begin_step=0,
                                                                    end_step=end_step)
        }
        
        # Rescaling layer cannot be wrapped in prune low magnitude
        try:
            model_for_pruning_constant = prune_low_magnitude(model, **pruning_params_constant)
            model_for_pruning_polynomial = prune_low_magnitude(model, **pruning_params_polynomial)
        except:
            model_for_pruning_constant = prune_prunable_layers(model, pruning_params_constant) 
            model_for_pruning_polynomial = prune_prunable_layers(model, pruning_params_polynomial)
            

        # Compile models for pruning
        model_for_pruning_constant.compile(optimizer='adam',
                          loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                          metrics=['accuracy'])
        model_for_pruning_polynomial.compile(optimizer='adam',
                          loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                          metrics=['accuracy'])


        constant_log_dir = pathlib.Path("./tmp/" + model_name + '_ConstantSparsity' + str(int(sparsity*100)))
        constant_log_dir.mkdir(parents=True, exist_ok=True)
        polynomial_log_dir = pathlib.Path("./tmp/" + model_name + '_PolynomialDecay' + str(int(sparsity*100)))
        polynomial_log_dir.mkdir(parents=True, exist_ok=True)
        
        # Model callbacks
        constant_callbacks = [
            tfmot.sparsity.keras.UpdatePruningStep(),
            tfmot.sparsity.keras.PruningSummaries(log_dir=str(constant_log_dir))
        ]
        polynomial_callbacks = [
            tfmot.sparsity.keras.UpdatePruningStep(),
            tfmot.sparsity.keras.PruningSummaries(log_dir=str(polynomial_log_dir))
        ]

        # Fitting data
        model_for_pruning_constant.fit(ds_train,
                                  validation_data=ds_validation,
                                  epochs=pruning_epochs,
                                  callbacks=constant_callbacks)
        
        model_for_pruning_polynomial.fit(ds_train,
                                  validation_data=ds_validation,
                                  epochs=pruning_epochs,
                                  callbacks=polynomial_callbacks)
        
        # Save pruned models
        model_constant_path = optimized_dir/(model_name + '_ConstantSparsity' + str(int(sparsity*100)) + '.h5')
        model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning_constant)
        model_for_export.save(str(model_constant_path))
        print('saved ' + str(model_constant_path))
        
        model_polynomial_path = optimized_dir/(model_name + '_PolynomialDecay' + str(int(sparsity*100)) + '.h5')
        model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning_polynomial)
        model_for_export.save(str(model_polynomial_path))
        print('saved ' + str(model_polynomial_path))

## Weight clustering

In [7]:
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

In [8]:
# Some layers cannot be weight clustered

def cluster_clustred_layers(model, cluster_params):    
    # Rescaling layer cannot be pruned
    def apply_clustering_to_clusterable(layer):
        if model.layers[0] == layer or  model.layers[-1] == layer:
            return layer
        try:
            x = cluster_weights(layer, **cluster_params)
            return x
        except:
            return layer

    model_for_clustering = tf.keras.models.clone_model(
                                model,
                                clone_function=apply_clustering_to_clusterable,
                            )
    return model_for_clustering

In [9]:
def weight_cluster_model(model_path, batch_size, epochs, ds_train, ds_validation, number_of_clusters):
    """ Weight clustering on given moodel 
    note: cannot use for cycle in this function to do different number of clusters because of compatibility issues"""
     # check if model was already in optimization folder
    if (len(model_path.split('/')[0].split('_')) == 3):
        optimized_dir = pathlib.Path(model_path.split('/')[0] + '/')
    else:
        optimized_dir = pathlib.Path(model_path.split('/')[0] + '_optimized/')
    model_name = model_path.split('/')[-1].split('.')[0]
    
    model = tf.keras.models.load_model(model_path)
    
    ds_train = ds_train.batch(batch_size)
    ds_validation = ds_validation.batch(batch_size)
    
    ds_train.cache()
    ds_validation.cache()
    
    # get number of images
    num_images = 0
    for i in ds_train.as_numpy_iterator():
        num_images+=1

    end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

    # Define weight clustering configuration
    cluster_params_kmeans = {
                                  'number_of_clusters': number_of_clusters,
                                  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS
                                }

    # Rescaling layer cannot be wrapped in prune low magnitude
    try:
        model_for_clustering_kmeans = cluster_weights(model, **cluster_params_kmeans)
    except:
        model_for_clustering_kmeans = cluster_clustred_layers(model, cluster_params_kmeans)

    # Compile models for clustering
    model_for_clustering_kmeans.compile(optimizer='adam',
                      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                      metrics=['accuracy'])

    # Fitting data        
    model_for_clustering_kmeans.fit(ds_train,
                              validation_data=ds_validation,
                              epochs=epochs)

    # Save pruned models
    model_kmeans_path = optimized_dir/(model_name + '_KMeansPlusPlus' + str(number_of_clusters) + '.h5')
    model_for_export = tfmot.clustering.keras.strip_clustering(model_for_clustering_kmeans)
    model_for_export.save(str(model_kmeans_path))
    print('saved ' + str(model_kmeans_path))

# optimze models

In [10]:
# datasets
flowers_datasets = load_flowers_dataset()
beans_datasets = load_beans_datasets()

In [11]:
"""Due to bad tensorflow optimization of calling fit function in a loop, calling prune_model in a loop is unusable.
Solution: don't use for loops
Issue: https://github.com/tensorflow/tensorflow/issues/34025"""

"Due to bad tensorflow optimization of calling fit function in a loop, calling prune_model in a loop is unusable.\nSolution: don't use for loops\nIssue: https://github.com/tensorflow/tensorflow/issues/34025"

## 1. Prune base models

In [12]:
prune_model(model_path='flowers_models/MobileNetV2_flowers_model.h5', 
            ds_train=flowers_datasets[0], 
            ds_validation=flowers_datasets[1],
            batch_size=BATCH_SIZE, 
            pruning_epochs=PRUNING_EPOCHS)
prune_model(model_path='flowers_models/EfficentNetB0_flowers_model.h5', 
            ds_train=flowers_datasets[0], 
            ds_validation=flowers_datasets[1], 
            batch_size=BATCH_SIZE, 
            pruning_epochs=PRUNING_EPOCHS)



saved flowers_models_optimized/MobileNetV2_flowers_model_ConstantSparsity50.h5
saved flowers_models_optimized/MobileNetV2_flowers_model_PolynomialDecay50.h5
saved flowers_models_optimized/MobileNetV2_flowers_model_ConstantSparsity75.h5
saved flowers_models_optimized/MobileNetV2_flowers_model_PolynomialDecay75.h5

KeyboardInterrupt: 

In [None]:
prune_model(model_path='beans_models/MobileNetV2_beans_model.h5', 
            ds_train=beans_datasets[0], 
            ds_validation=beans_datasets[1], 
            batch_size=BATCH_SIZE, 
            pruning_epochs=PRUNING_EPOCHS)
prune_model(model_path='beans_models/EfficentNetB0_beans_model.h5', 
            ds_train=beans_datasets[0], 
            ds_validation=beans_datasets[1], 
            batch_size=BATCH_SIZE, 
            pruning_epochs=PRUNING_EPOCHS)

## 2. Weight cluster base models

In [13]:
weight_cluster_model(model_path='flowers_models/MobileNetV2_flowers_model.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='flowers_models/MobileNetV2_flowers_model.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='flowers_models/EfficentNetB0_flowers_model.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='flowers_models/EfficentNetB0_flowers_model.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

saved flowers_models_optimized/MobileNetV2_flowers_model_KMeansPlusPlus32.h5
saved flowers_models_optimized/MobileNetV2_flowers_model_KMeansPlusPlus128.h5
saved flowers_models_optimized/EfficentNetB0_flowers_model_KMeansPlusPlus32.h5


KeyboardInterrupt: 

In [None]:
weight_cluster_model(model_path='beans_models/MobileNetV2_beans_model.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='beans_models/MobileNetV2_beans_model.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='beans_models/EfficentNetB0_beans_model.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='beans_models/EfficentNetB0_beans_model.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

## 3. prune weight clustered models

In [None]:
prune_model(model_path='flowers_models_optimized/MobileNetV2_flowers_model_KMeansPlusPlus32.h5', 
            ds_train=flowers_datasets[0], 
            ds_validation=flowers_datasets[1],
            batch_size=BATCH_SIZE, 
            pruning_epochs=PRUNING_EPOCHS)
prune_model(model_path='flowers_models_optimized/MobileNetV2_flowers_model_KMeansPlusPlus128.h5', 
            ds_train=flowers_datasets[0], 
            ds_validation=flowers_datasets[1],
            batch_size=BATCH_SIZE, 
            pruning_epochs=PRUNING_EPOCHS)

prune_model(model_path='flowers_models_optimized/EfficentNetB0_flowers_model_KMeansPlusPlus32.h5', 
            ds_train=flowers_datasets[0], 
            ds_validation=flowers_datasets[1],
            batch_size=BATCH_SIZE, 
            pruning_epochs=PRUNING_EPOCHS)
prune_model(model_path='flowers_models_optimized/EfficentNetB0_flowers_model_KMeansPlusPlus32.h5', 
            ds_train=flowers_datasets[0], 
            ds_validation=flowers_datasets[1],
            batch_size=BATCH_SIZE, 
            pruning_epochs=PRUNING_EPOCHS)

In [None]:
prune_model(model_path='beans_models_optimized/MobileNetV2_beans_model_KMeansPlusPlus32.h5', 
            ds_train=beans_datasets[0], 
            ds_validation=beans_datasets[1],
            batch_size=BATCH_SIZE, 
            pruning_epochs=PRUNING_EPOCHS)
prune_model(model_path='beans_models_optimized/MobileNetV2_beans_model_KMeansPlusPlus128.h5', 
            ds_train=beans_datasets[0], 
            ds_validation=beans_datasets[1],
            batch_size=BATCH_SIZE, 
            pruning_epochs=PRUNING_EPOCHS)

prune_model(model_path='beans_models_optimized/EfficentNetB0_beans_model_KMeansPlusPlus32.h5', 
            ds_train=beans_datasets[0], 
            ds_validation=beans_datasets[1],
            batch_size=BATCH_SIZE, 
            pruning_epochs=PRUNING_EPOCHS)
prune_model(model_path='beans_models_optimized/EfficentNetB0_beans_model_KMeansPlusPlus32.h5', 
            ds_train=beans_datasets[0], 
            ds_validation=beans_datasets[1],
            batch_size=BATCH_SIZE, 
            pruning_epochs=PRUNING_EPOCHS)

## 4. weight cluster prunned models

In [None]:
weight_cluster_model(model_path='flowers_models_optimized/MobileNetV2_flowers_model_ConstantSparsity50.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='flowers_models_optimized/MobileNetV2_flowers_model_ConstantSparsity50.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='flowers_models_optimized/MobileNetV2_flowers_model_PolynomialDecay50.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='flowers_models_optimized/MobileNetV2_flowers_model_PolynomialDecay50.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='flowers_models_optimized/MobileNetV2_flowers_model_ConstantSparsity75.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='flowers_models_optimized/MobileNetV2_flowers_model_ConstantSparsity75.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='flowers_models_optimized/MobileNetV2_flowers_model_PolynomialDecay75.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='flowers_models_optimized/MobileNetV2_flowers_model_PolynomialDecay75.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='flowers_models_optimized/MobileNetV2_flowers_model_ConstantSparsity90.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='flowers_models_optimized/MobileNetV2_flowers_model_ConstantSparsity90.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='flowers_models_optimized/MobileNetV2_flowers_model_PolynomialDecay90.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='flowers_models_optimized/MobileNetV2_flowers_model_PolynomialDecay90.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

In [None]:
weight_cluster_model(model_path='flowers_models_optimized/EfficentNetB0_flowers_model_ConstantSparsity50.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='flowers_models_optimized/EfficentNetB0_flowers_model_ConstantSparsity50.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='flowers_models_optimized/EfficentNetB0_flowers_model_PolynomialDecay50.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='flowers_models_optimized/EfficentNetB0_flowers_model_PolynomialDecay50.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='flowers_models_optimized/EfficentNetB0_flowers_model_ConstantSparsity75.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='flowers_models_optimized/EfficentNetB0_flowers_model_ConstantSparsity75.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='flowers_models_optimized/EfficentNetB0_flowers_model_PolynomialDecay75.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='flowers_models_optimized/EfficentNetB0_flowers_model_PolynomialDecay75.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='flowers_models_optimized/EfficentNetB0_flowers_model_ConstantSparsity90.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='flowers_models_optimized/EfficentNetB0_flowers_model_ConstantSparsity90.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='flowers_models_optimized/EfficentNetB0_flowers_model_PolynomialDecay90.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='flowers_models_optimized/EfficentNetB0_flowers_model_PolynomialDecay90.h5', 
                     ds_train=flowers_datasets[0], 
                     ds_validation=flowers_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

#### beans models

In [None]:
weight_cluster_model(model_path='beans_models_optimized/MobileNetV2_beans_model_ConstantSparsity50.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='beans_models_optimized/MobileNetV2_beans_model_ConstantSparsity50.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='beans_models_optimized/MobileNetV2_beans_model_PolynomialDecay50.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='beans_models_optimized/MobileNetV2_beans_model_PolynomialDecay50.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='beans_models_optimized/MobileNetV2_beans_model_ConstantSparsity75.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='beans_models_optimized/MobileNetV2_beans_model_ConstantSparsity75.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='beans_models_optimized/MobileNetV2_beans_model_PolynomialDecay75.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='beans_models_optimized/MobileNetV2_beans_model_PolynomialDecay75.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='beans_models_optimized/MobileNetV2_beans_model_ConstantSparsity90.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='beans_models_optimized/MobileNetV2_beans_model_ConstantSparsity90.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='beans_models_optimized/MobileNetV2_beans_model_PolynomialDecay90.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='beans_models_optimized/MobileNetV2_beans_model_PolynomialDecay90.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

In [None]:
weight_cluster_model(model_path='beans_models_optimized/EfficentNetB0_beans_model_ConstantSparsity50.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='beans_models_optimized/EfficentNetB0_beans_model_ConstantSparsity50.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='beans_models_optimized/EfficentNetB0_beans_model_PolynomialDecay50.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='beans_models_optimized/EfficentNetB0_beans_model_PolynomialDecay50.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='beans_models_optimized/EfficentNetB0_beans_model_ConstantSparsity75.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='beans_models_optimized/EfficentNetB0_beans_model_ConstantSparsity75.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='beans_models_optimized/EfficentNetB0_beans_model_PolynomialDecay75.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='beans_models_optimized/EfficentNetB0_beans_model_PolynomialDecay75.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='beans_models_optimized/EfficentNetB0_beans_model_ConstantSparsity90.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='beans_models_optimized/EfficentNetB0_beans_model_ConstantSparsity90.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

weight_cluster_model(model_path='beans_models_optimized/EfficentNetB0_beans_model_PolynomialDecay90.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=32)
weight_cluster_model(model_path='beans_models_optimized/EfficentNetB0_beans_model_PolynomialDecay90.h5', 
                     ds_train=beans_datasets[0], 
                     ds_validation=beans_datasets[1], 
                     batch_size=BATCH_SIZE, 
                     epochs=PRUNING_EPOCHS, 
                     number_of_clusters=128)

## Quantization of all models

In [16]:
# get all models paths
model_paths = []
import os
for file in os.listdir("beans_models/"):
    if file.endswith(".h5"):
        model_paths.append(str(os.path.join("beans_models/", file)))
for file in os.listdir("beans_models_optimized/"):
    if file.endswith(".h5"):
        model_paths.append(str(os.path.join("beans_models/", file)))
for file in os.listdir("flowers_models/"):
    if file.endswith(".h5"):
        model_paths.append(str(os.path.join("beans_models/", file)))
for file in os.listdir("flowers_models_optimized/"):
    if file.endswith(".h5"):
        model_paths.append(str(os.path.join("beans_models/", file)))

['beans_models/MobileNetV2_beans_model.h5',
 'beans_models/EfficentNetB0_beans_model.h5',
 'beans_models/MobileNetV2_beans_model_PolynomialDecay75.h5',
 'beans_models/MobileNetV2_beans_model_PolynomialDecay87.h5',
 'beans_models/MobileNetV2_beans_model_PolynomialDecay50.h5',
 'beans_models/MobileNetV2_beans_model_ConstantSparsity87.h5',
 'beans_models/MobileNetV2_beans_model_ConstantSparsity50.h5',
 'beans_models/MobileNetV2_beans_model_ConstantSparsity75.h5',
 'beans_models/MobileNetV2_flowers_model.h5',
 'beans_models/EfficentNetB0_flowers_model.h5',
 'beans_models/EfficentNetB0_flowers_model_KMeansPlusPlus8.h5',
 'beans_models/MobileNetV2_flowers_model_KMeansPlusPlus8.h5',
 'beans_models/MobileNetV2_flowers_model_PolynomialDecay50.h5',
 'beans_models/MobileNetV2_flowers_model_KMeansPlusPlus128.h5',
 'beans_models/MobileNetV2_flowers_model_KMeansPlusPlus16.h5',
 'beans_models/MobileNetV2_flowers_model_KMeansPlusPlus32.h5',
 'beans_models/MobileNetV2_flowers_model_PolynomialDecay75.h5

In [17]:
for model_path in model_paths:
    if "flowers" in model_path:
        model_quantization(model_path=model_path, ds=flowers_datasets[0])
    if "beans" in model_path:
        model_quantization(model_path=model_path, ds=beans_datasets[0])

INFO:tensorflow:Assets written to: /var/folders/_1/7lg8klcj1d55272nzrt0k5740000gn/T/tmp7e8_zb3g/assets


INFO:tensorflow:Assets written to: /var/folders/_1/7lg8klcj1d55272nzrt0k5740000gn/T/tmp7e8_zb3g/assets


Converted TFLite model (8873924 Bytes) saved to: beans_models_optimized/MobileNetV2_beans_model.tflite


KeyboardInterrupt: 

In [None]:
# model_quantization(model_path='flowers_models/EfficentNetB0_flowers_model.h5', ds=load_flowers_dataset()[0])

# ds_train, ds_validation, ds_test = load_flowers_dataset()

# for i in [4, 8, 16]:
#     weight_cluster_model(model_path='flowers_models/EfficentNetB0_flowers_model.h5', ds_train=ds_train, ds_validation=ds_validation, batch_size=BATCH_SIZE, epochs=PRUNING_EPOCHS, number_of_clusters=i)
    
# ds_train, ds_validation, ds_test = load_flowers_dataset()

# prune_model(model_path='flowers_models/MobileNetV2_flowers_model.h5', ds_train=ds_train, ds_validation=ds_validation, batch_size=BATCH_SIZE, pruning_epochs=PRUNING_EPOCHS)