# Metrics

https://towardsdatascience.com/the-5-classification-evaluation-metrics-you-must-know-aa97784ff226

https://towardsdatascience.com/20-popular-machine-learning-metrics-part-1-classification-regression-evaluation-metrics-1ca3e282a2ce

In [26]:
import numpy as np
import json
import copy
from json import JSONEncoder
import os

In [2]:
class EvaluatedModel:
    y_pred = []
    y_true = []
    model_name = ''
    avg_time = 0.0
    gzip_size = 0
    
    accuracy = None
    precision = None
    recall = None
    f1_score = None
    

    def __init__(self, dict1=None, y_pred=[], y_true=[], model_name='', avg_time=0.0, gzip_size=0):
        if (dict1 == None):
            self.y_pred = copy.deepcopy(y_pred)
            self.y_true = copy.deepcopy(y_true)
            self.model_name = model_name
            self.avg_time = avg_time
            self.gzip_size = gzip_size
        else:
            self.__dict__.update(dict1)


## 1. Compute metrics for models

In [97]:
from sklearn.metrics import accuracy_score 
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score 
from sklearn.metrics import f1_score

def annote_models(path):
    with open(path) as f:
        evaluated_models = json.load(f, object_hook=EvaluatedModel)

    for evaluated_model in evaluated_models:
        evaluated_model.accuracy = accuracy_score(y_pred=evaluated_model.y_pred, y_true=evaluated_model.y_true)
        evaluated_model.precision = precision_score(y_pred=evaluated_model.y_pred, y_true=evaluated_model.y_true, average='macro')
        evaluated_model.recall = recall_score(y_pred=evaluated_model.y_pred, y_true=evaluated_model.y_true, average='macro')
        evaluated_model.f1_score = f1_score(y_pred=evaluated_model.y_pred, y_true=evaluated_model.y_true, average='macro')
        # change model name for table
        evaluated_model.model_name = evaluated_model.model_name.replace("Efficent", "Efficient")
        evaluated_model.model_name = evaluated_model.model_name.replace("KMeansPlusPlus32", "WQ32")
        evaluated_model.model_name = evaluated_model.model_name.replace("KMeansPlusPlus128", "WQ128")
        evaluated_model.model_name = evaluated_model.model_name.replace("full_integer_quantization", "FIQ")
        evaluated_model.model_name = evaluated_model.model_name.replace("dynamic_rage_quantization", "DRQ")
        evaluated_model.model_name = evaluated_model.model_name.replace("float16_quantization", "F16Q")
        evaluated_model.model_name = evaluated_model.model_name.replace("_integer_io", "-Iio")
        evaluated_model.model_name = evaluated_model.model_name.replace("PolynomialDecay90", "PD90")
        evaluated_model.model_name = evaluated_model.model_name.replace("PolynomialDecay75", "PD75")
        evaluated_model.model_name = evaluated_model.model_name.replace("PolynomialDecay50", "PD50")
        evaluated_model.model_name = evaluated_model.model_name.replace("flowers_model_", "")
        evaluated_model.model_name = evaluated_model.model_name.replace("beans_model_", "")
        evaluated_model.model_name = evaluated_model.model_name.replace("_flowers_model", "")
        evaluated_model.model_name = evaluated_model.model_name.replace("_beans_model", "")
        evaluated_model.model_name = evaluated_model.model_name.replace("ConstantSparsity90", "CS90")
        evaluated_model.model_name = evaluated_model.model_name.replace("ConstantSparsity75", "CS75")
        evaluated_model.model_name = evaluated_model.model_name.replace("ConstantSparsity50", "CS50")
        evaluated_model.model_name = evaluated_model.model_name.replace(".tflite", "")
        evaluated_model.model_name = evaluated_model.model_name.replace("_", " -> ")
    evaluated_models.sort(key=lambda x: x.model_name, reverse=False)
    return evaluated_models

def print_table(models):
    print("|Model name|Accuarcy|Precision|Recall|F1 score|Compressed size|Average time|")
    print("|-----|-----|-----|-----|-----|-----|-----|")
    for model in models:
        print("|"+model.model_name+"|"+str(round(model.accuracy * 100, 2))+"|"+str(round(model.precision * 100, 2))
              +"|"+str(round(model.recall * 100, 2))+"|"+str(round(model.f1_score * 100, 2))+"|"
              +str(round(model.gzip_size / 1024))+" kB|"+str(round(model.avg_time, 2))+" ms|")


In [98]:
flowers_models = annote_models('results/evaluated_beans_models_on_MobileNetV2.json')
print_table(flowers_models)


  _warn_prf(average, modifier, msg_start, len(result))


|Model name|Accuarcy|Precision|Recall|F1 score|Compressed size|Average time|
|-----|-----|-----|-----|-----|-----|-----|
|MobileNetV2|70.31|72.89|70.41|70.38|7993 kB|230.08 ms|
|MobileNetV2 -> CS50|52.34|68.81|52.68|49.75|8106 kB|230.85 ms|
|MobileNetV2 -> CS50 -> DRQ|53.12|69.51|53.45|50.76|2220 kB|335.23 ms|
|MobileNetV2 -> CS50 -> F16Q|53.12|69.51|53.45|50.76|4052 kB|270.1 ms|
|MobileNetV2 -> CS50 -> FIQ|52.34|68.4|52.69|49.19|2252 kB|253.78 ms|
|MobileNetV2 -> CS50 -> FIQ-Iio|53.12|68.45|53.45|50.53|2252 kB|256.27 ms|
|MobileNetV2 -> CS50 -> WQ128|33.59|11.2|33.33|16.76|4894 kB|241.25 ms|
|MobileNetV2 -> CS50 -> WQ128 -> DRQ|33.59|11.2|33.33|16.76|2142 kB|326.78 ms|
|MobileNetV2 -> CS50 -> WQ128 -> F16Q|33.59|11.2|33.33|16.76|3734 kB|241.66 ms|
|MobileNetV2 -> CS50 -> WQ128 -> FIQ|33.59|11.2|33.33|16.76|2177 kB|251.11 ms|
|MobileNetV2 -> CS50 -> WQ128 -> FIQ-Iio|33.59|11.2|33.33|16.76|2177 kB|167.69 ms|
|MobileNetV2 -> CS50 -> WQ32|33.59|11.2|33.33|16.76|3252 kB|270.94 ms|
|MobileN

|Model name|Accuarcy|Precision|Recall|F1 score|Compressed size|Average time|
|-----|-----|-----|-----|-----|-----|-----|
|MobileNetV2|70.31|72.89|70.41|70.38|7993 kB|230.08 ms|
|MobileNetV2 -> CS50|52.34|68.81|52.68|49.75|8106 kB|230.85 ms|
|MobileNetV2 -> CS50 -> DRQ|53.12|69.51|53.45|50.76|2220 kB|335.23 ms|
|MobileNetV2 -> CS50 -> F16Q|53.12|69.51|53.45|50.76|4052 kB|270.1 ms|
|MobileNetV2 -> CS50 -> FIQ|52.34|68.4|52.69|49.19|2252 kB|253.78 ms|
|MobileNetV2 -> CS50 -> FIQ-Iio|53.12|68.45|53.45|50.53|2252 kB|256.27 ms|
|MobileNetV2 -> CS50 -> WQ128|33.59|11.2|33.33|16.76|4894 kB|241.25 ms|
|MobileNetV2 -> CS50 -> WQ128 -> DRQ|33.59|11.2|33.33|16.76|2142 kB|326.78 ms|
|MobileNetV2 -> CS50 -> WQ128 -> F16Q|33.59|11.2|33.33|16.76|3734 kB|241.66 ms|
|MobileNetV2 -> CS50 -> WQ128 -> FIQ|33.59|11.2|33.33|16.76|2177 kB|251.11 ms|
|MobileNetV2 -> CS50 -> WQ128 -> FIQ-Iio|33.59|11.2|33.33|16.76|2177 kB|167.69 ms|
|MobileNetV2 -> CS50 -> WQ32|33.59|11.2|33.33|16.76|3252 kB|270.94 ms|
|MobileNetV2 -> CS50 -> WQ32 -> DRQ|33.59|11.2|33.33|16.76|1844 kB|328.9 ms|
|MobileNetV2 -> CS50 -> WQ32 -> F16Q|33.59|11.2|33.33|16.76|2951 kB|243.48 ms|
|MobileNetV2 -> CS50 -> WQ32 -> FIQ|33.59|11.2|33.33|16.76|1881 kB|220.0 ms|
|MobileNetV2 -> CS50 -> WQ32 -> FIQ-Iio|33.59|11.2|33.33|16.76|1881 kB|250.93 ms|
|MobileNetV2 -> CS75|33.59|11.2|33.33|16.76|8137 kB|242.03 ms|
|MobileNetV2 -> CS75 -> DRQ|33.59|11.2|33.33|16.76|2026 kB|294.05 ms|
|MobileNetV2 -> CS75 -> F16Q|33.59|11.2|33.33|16.76|4071 kB|229.96 ms|
|MobileNetV2 -> CS75 -> FIQ|33.59|11.2|33.33|16.76|2061 kB|167.37 ms|
|MobileNetV2 -> CS75 -> FIQ-Iio|33.59|11.2|33.33|16.76|2061 kB|167.61 ms|
|MobileNetV2 -> CS75 -> WQ128|60.94|62.99|61.06|61.01|4560 kB|231.82 ms|
|MobileNetV2 -> CS75 -> WQ128 -> DRQ|60.94|62.99|61.06|61.01|1965 kB|370.42 ms|
|MobileNetV2 -> CS75 -> WQ128 -> F16Q|60.94|62.99|61.06|61.01|3645 kB|229.83 ms|
|MobileNetV2 -> CS75 -> WQ128 -> FIQ|62.5|63.91|62.62|62.42|2000 kB|167.14 ms|
|MobileNetV2 -> CS75 -> WQ128 -> FIQ-Iio|60.94|62.74|61.06|60.91|2000 kB|173.05 ms|
|MobileNetV2 -> CS75 -> WQ32|61.72|65.5|61.79|61.04|2983 kB|294.12 ms|
|MobileNetV2 -> CS75 -> WQ32 -> DRQ|60.16|65.23|60.21|59.31|1706 kB|375.53 ms|
|MobileNetV2 -> CS75 -> WQ32 -> F16Q|61.72|65.5|61.79|61.04|2641 kB|229.02 ms|
|MobileNetV2 -> CS75 -> WQ32 -> FIQ|61.72|66.12|61.79|60.74|1742 kB|166.85 ms|
|MobileNetV2 -> CS75 -> WQ32 -> FIQ-Iio|61.72|66.12|61.79|60.74|1741 kB|167.81 ms|
|MobileNetV2 -> CS90|33.59|11.2|33.33|16.76|8093 kB|292.85 ms|
|MobileNetV2 -> CS90 -> DRQ|33.59|11.2|33.33|16.76|1907 kB|370.93 ms|
|MobileNetV2 -> CS90 -> F16Q|33.59|11.2|33.33|16.76|4053 kB|233.16 ms|
|MobileNetV2 -> CS90 -> FIQ|33.59|11.2|33.33|16.76|1946 kB|166.54 ms|
|MobileNetV2 -> CS90 -> FIQ-Iio|33.59|11.2|33.33|16.76|1946 kB|167.71 ms|
|MobileNetV2 -> CS90 -> WQ128|32.81|10.94|33.33|16.47|4357 kB|292.03 ms|
|MobileNetV2 -> CS90 -> WQ128 -> DRQ|32.81|10.94|33.33|16.47|1812 kB|314.93 ms|
|MobileNetV2 -> CS90 -> WQ128 -> F16Q|32.81|10.94|33.33|16.47|3566 kB|291.92 ms|
|MobileNetV2 -> CS90 -> WQ128 -> FIQ|32.81|10.94|33.33|16.47|1853 kB|166.81 ms|
|MobileNetV2 -> CS90 -> WQ128 -> FIQ-Iio|32.81|10.94|33.33|16.47|1853 kB|167.38 ms|
|MobileNetV2 -> CS90 -> WQ32|32.81|10.94|33.33|16.47|2875 kB|288.91 ms|
|MobileNetV2 -> CS90 -> WQ32 -> DRQ|32.81|10.94|33.33|16.47|1612 kB|452.12 ms|
|MobileNetV2 -> CS90 -> WQ32 -> F16Q|32.81|10.94|33.33|16.47|2557 kB|259.03 ms|
|MobileNetV2 -> CS90 -> WQ32 -> FIQ|32.81|10.94|33.33|16.47|1652 kB|166.51 ms|
|MobileNetV2 -> CS90 -> WQ32 -> FIQ-Iio|32.81|10.94|33.33|16.47|1651 kB|167.29 ms|
|MobileNetV2 -> DRQ|69.53|72.23|69.64|69.59|2346 kB|299.25 ms|
|MobileNetV2 -> F16Q|70.31|72.89|70.41|70.38|3961 kB|229.3 ms|
|MobileNetV2 -> FIQ|71.09|73.6|71.19|71.12|2380 kB|218.98 ms|
|MobileNetV2 -> FIQ-Iio|70.31|72.75|70.41|70.43|2380 kB|167.91 ms|
|MobileNetV2 -> PD50|52.34|68.81|52.68|49.75|8106 kB|230.93 ms|
|MobileNetV2 -> PD50 -> DRQ|53.91|71.13|54.25|51.34|2220 kB|353.04 ms|
|MobileNetV2 -> PD50 -> F16Q|53.12|69.51|53.45|50.76|4052 kB|229.28 ms|
|MobileNetV2 -> PD50 -> FIQ|51.56|69.44|51.9|48.78|2253 kB|167.04 ms|
|MobileNetV2 -> PD50 -> FIQ-Iio|52.34|71.39|52.69|49.34|2252 kB|167.1 ms|
|MobileNetV2 -> PD50 -> WQ128|33.59|27.78|33.33|18.05|4894 kB|282.94 ms|
|MobileNetV2 -> PD50 -> WQ128 -> DRQ|36.72|28.18|36.43|26.39|2135 kB|345.72 ms|
|MobileNetV2 -> PD50 -> WQ128 -> F16Q|33.59|27.78|33.33|18.05|3736 kB|229.65 ms|
|MobileNetV2 -> PD50 -> WQ128 -> FIQ|33.59|11.2|33.33|16.76|2169 kB|167.08 ms|
|MobileNetV2 -> PD50 -> WQ128 -> FIQ-Iio|33.59|11.2|33.33|16.76|2169 kB|167.51 ms|
|MobileNetV2 -> PD50 -> WQ32|49.22|43.84|49.32|44.4|3252 kB|279.6 ms|
|MobileNetV2 -> PD50 -> WQ32 -> DRQ|48.44|44.12|48.5|44.67|1832 kB|302.42 ms|
|MobileNetV2 -> PD50 -> WQ32 -> F16Q|49.22|43.84|49.32|44.4|2963 kB|246.56 ms|
|MobileNetV2 -> PD50 -> WQ32 -> FIQ|48.44|41.93|48.54|43.15|1869 kB|239.35 ms|
|MobileNetV2 -> PD50 -> WQ32 -> FIQ-Iio|50.78|45.56|50.89|45.71|1869 kB|167.76 ms|
|MobileNetV2 -> PD75|33.59|11.2|33.33|16.76|8137 kB|276.78 ms|
|MobileNetV2 -> PD75 -> DRQ|33.59|11.2|33.33|16.76|2026 kB|362.46 ms|
|MobileNetV2 -> PD75 -> F16Q|33.59|11.2|33.33|16.76|4071 kB|229.15 ms|
|MobileNetV2 -> PD75 -> FIQ|33.59|11.2|33.33|16.76|2061 kB|203.59 ms|
|MobileNetV2 -> PD75 -> FIQ-Iio|33.59|11.2|33.33|16.76|2061 kB|249.3 ms|
|MobileNetV2 -> PD75 -> WQ128|64.84|74.12|65.01|61.97|4564 kB|229.89 ms|
|MobileNetV2 -> PD75 -> WQ128 -> DRQ|64.84|74.12|65.01|61.97|1968 kB|350.95 ms|
|MobileNetV2 -> PD75 -> WQ128 -> F16Q|64.84|74.12|65.01|61.97|3651 kB|289.63 ms|
|MobileNetV2 -> PD75 -> WQ128 -> FIQ|64.84|74.12|65.01|61.97|2001 kB|206.75 ms|
|MobileNetV2 -> PD75 -> WQ128 -> FIQ-Iio|64.84|74.12|65.01|61.97|2001 kB|251.34 ms|
|MobileNetV2 -> PD75 -> WQ32|42.19|30.05|41.86|33.03|2983 kB|229.94 ms|
|MobileNetV2 -> PD75 -> WQ32 -> DRQ|42.19|63.38|41.88|33.71|1699 kB|366.34 ms|
|MobileNetV2 -> PD75 -> WQ32 -> F16Q|42.19|30.05|41.86|33.03|2645 kB|291.23 ms|
|MobileNetV2 -> PD75 -> WQ32 -> FIQ|42.97|64.65|42.65|34.33|1735 kB|171.77 ms|
|MobileNetV2 -> PD75 -> WQ32 -> FIQ-Iio|42.19|64.31|41.88|33.86|1735 kB|167.78 ms|
|MobileNetV2 -> PD90|33.59|11.2|33.33|16.76|8093 kB|291.39 ms|
|MobileNetV2 -> PD90 -> DRQ|33.59|11.2|33.33|16.76|1907 kB|315.33 ms|
|MobileNetV2 -> PD90 -> F16Q|33.59|11.2|33.33|16.76|4053 kB|292.41 ms|
|MobileNetV2 -> PD90 -> FIQ|33.59|11.2|33.33|16.76|1946 kB|167.08 ms|
|MobileNetV2 -> PD90 -> FIQ-Iio|33.59|11.2|33.33|16.76|1946 kB|167.82 ms|
|MobileNetV2 -> PD90 -> WQ128|33.59|11.2|33.33|16.76|4358 kB|230.82 ms|
|MobileNetV2 -> PD90 -> WQ128 -> DRQ|33.59|11.2|33.33|16.76|1833 kB|311.15 ms|
|MobileNetV2 -> PD90 -> WQ128 -> F16Q|33.59|11.2|33.33|16.76|3577 kB|290.67 ms|
|MobileNetV2 -> PD90 -> WQ128 -> FIQ|33.59|11.2|33.33|16.76|1871 kB|167.14 ms|
|MobileNetV2 -> PD90 -> WQ128 -> FIQ-Iio|33.59|11.2|33.33|16.76|1871 kB|167.75 ms|
|MobileNetV2 -> PD90 -> WQ32|32.81|10.94|33.33|16.47|2877 kB|263.46 ms|
|MobileNetV2 -> PD90 -> WQ32 -> DRQ|32.81|10.94|33.33|16.47|1616 kB|392.8 ms|
|MobileNetV2 -> PD90 -> WQ32 -> F16Q|32.81|10.94|33.33|16.47|2554 kB|259.43 ms|
|MobileNetV2 -> PD90 -> WQ32 -> FIQ|32.81|10.94|33.33|16.47|1658 kB|167.02 ms|
|MobileNetV2 -> PD90 -> WQ32 -> FIQ-Iio|32.81|10.94|33.33|16.47|1658 kB|169.16 ms|
|MobileNetV2 -> WQ128|63.28|62.79|63.4|62.52|5079 kB|244.69 ms|
|MobileNetV2 -> WQ128 -> DRQ|61.72|61.28|61.83|61.27|2285 kB|326.55 ms|
|MobileNetV2 -> WQ128 -> F16Q|63.28|62.79|63.4|62.52|3819 kB|291.02 ms|
|MobileNetV2 -> WQ128 -> FIQ|62.5|61.98|62.62|61.85|2318 kB|226.11 ms|
|MobileNetV2 -> WQ128 -> FIQ-Iio|62.5|62.77|62.59|62.2|2318 kB|167.32 ms|
|MobileNetV2 -> WQ32|68.75|69.39|68.81|68.97|3406 kB|291.11 ms|
|MobileNetV2 -> WQ32 -> DRQ|68.75|69.04|68.81|68.83|1893 kB|374.83 ms|
|MobileNetV2 -> WQ32 -> F16Q|68.75|69.39|68.81|68.97|3194 kB|291.86 ms|
|MobileNetV2 -> WQ32 -> FIQ|68.75|69.32|68.77|68.92|1923 kB|235.43 ms|
|MobileNetV2 -> WQ32 -> FIQ-Iio|68.75|69.49|68.79|68.98|1923 kB|167.7 ms|

In [74]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_model_optimization as tfmot
import tensorflow_datasets as tfds

%load_ext tensorboard

from os import path
import pathlib
import tempfile

# normalizing the images to [0, 1]
def normalize(image, label):
    return tf.cast(image, tf.float32) / 255., label

def random_crop(image):
    cropped_image = tf.image.random_crop(
        image, size=[256, 256, 3])

    return cropped_image

def random_jitter(image):
    # resizing to 286 x 286 x 3
    image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # randomly cropping to 256 x 256 x 3
    image = random_crop(image)

    # random mirroring
    image = tf.image.random_flip_left_right(image)

    return image

def preprocess_flowers_train(image, label):
    image = random_jitter(image)
    return image, label

# -------------------------------

def preprocess_flowers(image, label):
    image = tf.image.resize(image, [256, 256],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    return image, label

def load_flowers_dataset():  
    ds_test, ds_info = tfds.load(name="tf_flowers", 
                                with_info=True,
                                split=['test'],
                                as_supervised=True)

#     ds_train = ds_train.map(normalize)    
#     ds_train = ds_train.map(preprocess_flowers)
#     ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
    
#     ds_validation = ds_validation.map(normalize)
#     ds_validation = ds_validation.map(preprocess_flowers)
    ds_train = None
    ds_validation= None
    ds_test = ds_test.map(normalize)
    ds_test = ds_test.map(preprocess_flowers)
    
    return ds_train, ds_validation, ds_test

def load_beans_datasets():
    (ds_train, ds_validation, ds_test), ds_info = tfds.load(
        'beans',
        split=['train', 'validation', 'test'],
        shuffle_files=True,
        as_supervised=True,
        with_info=True,
    )
    
    ds_train = ds_train.map(normalize)
    ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
    
    ds_validation = ds_validation.map(normalize)
    
    ds_test = ds_test.map(normalize)
    
    return ds_train, ds_validation, ds_test

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [75]:
new_model = tf.keras.models.load_model('flowers_models_optimized/EfficentNetB0_flowers_model_ConstantSparsity50.h5')





In [76]:
new_model.compile(optimizer='adam',
                      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                      metrics=['accuracy'])

In [77]:
flowers_test = load_flowers_dataset()[2]
flowers_test = flowers_test.batch(1)
new_model.evaluate(flowers_test)

ValueError: Unknown split "test". Should be one of ['train'].