<a href="https://colab.research.google.com/github/leticiatdoliveira/embedded-img-classification/blob/features-cnn/cnn_plant_diseases.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from tensorflow.keras import backend as K
K.clear_session()

# Install some packages

In [None]:
pip install tensorflow-model-optimization

Collecting tensorflow-model-optimization
  Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl.metadata (904 bytes)
Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl (242 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/242.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━[0m [32m235.5/242.5 kB[0m [31m8.4 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tensorflow-model-optimization
Successfully installed tensorflow-model-optimization-0.8.0


# Mount drive

In [None]:
import os
if not os.path.exists('/content/drive'):
  from google.colab import drive
  drive.mount('/content/drive')
else:
  print('Drive already mounted')


Drive already mounted


# Import packages

In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
from pathlib import Path

import tensorflow as tf
from tensorflow_model_optimization.python.core.keras.compat import keras

import shutil

In [None]:
# shutil.rmtree('/content/data/plantvillage')

In [None]:
# shutil.rmtree('/content/results/plant-village')

Explore some runtime ressources

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
print("List of GPUs Available: ", tf.config.list_physical_devices('GPU'))

Num GPUs Available:  1
List of GPUs Available:  [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


# Global variables setting

In [None]:
PLANT_CULTURE = "tomato"

In [None]:
DOWNLOAD_DATA = False

In [None]:
DELETE_SOME_SUBFOLDERS = False

Paths

In [None]:
dataset_name = 'PlantVillage'

In [None]:
project_dir = os.getcwd()
print(f"Project dir: {project_dir}")

Project dir: /content


In [None]:
drive_dir = project_dir +'/drive/MyDrive/IC_VANT/PlantVillage'
if not os.path.exists(drive_dir):
    os.makedirs(drive_dir)
print(f"Drive dir: {drive_dir}")

Drive dir: /content/drive/MyDrive/IC_VANT/PlantVillage


In [None]:
data_dir: str = drive_dir + '/data'
if not os.path.exists(data_dir):
    os.makedirs(data_dir)
print(f"Data dir: {data_dir}")

Data dir: /content/drive/MyDrive/IC_VANT/PlantVillage/data


In [None]:
results_dir = drive_dir + '/results'
if not os.path.exists(results_dir):
    os.makedirs(results_dir)
print(f"Results dir: {results_dir}")

Results dir: /content/drive/MyDrive/IC_VANT/PlantVillage/results


Model hyperparameters

In [None]:
BATCH_SIZE = 4
EPOCHS = 10

Image parameters

In [None]:
IMAGE_HEIGHT = 256

In [None]:
IMAGE_WIDTH = 256

In [None]:
IMAGE_SIZE = (IMAGE_HEIGHT, IMAGE_WIDTH)

# Functions

In [None]:
def check_data_dir(silent_console: bool = True):
    """
    check if the data_dir exists

    :param:
    :return:
    """
    if os.path.exists(data_dir):
        print("Data_dir found !") if not silent_console else None
        return True
    else:
        print("The data_dir not found !") if not silent_console else None
        return False

## Dataset manip

In [None]:
def get_dataset_info(directory: str) -> int:
    """
    get the number of images in the dataset

    :param directory: str
    :return: int
    """
    dir_path = Path(directory)
    image_count = len(list(dir_path.glob('*/*.jpg')))
    image_count += len(list(dir_path.glob('*/*.JPG')))
    return image_count


In [None]:
def check_nb_of_data_in_dataset(dataset: tf.data.Dataset):
    """
    check the number of data in the dataset

    :param dataset: tf.data.Dataset
    :return:
    """
    nb_of_batches = dataset.cardinality().numpy()
    nb_of_data = nb_of_batches * BATCH_SIZE
    print(f"Nb of data: {nb_of_data} | Nb of batches: {nb_of_batches}")
    return None

In [None]:
def check_nb_of_classes_in_dataset(dataset: tf.data.Dataset):
    """
    check the number of classes in the dataset

    :param dataset: tf.data.Dataset
    :return:
    """
    class_names = dataset.class_names
    print(f"Nb of classes: {len(class_names)} | Class names: {class_names}")
    return None


In [None]:
def load_split_dataset(val_split: float, test_split: float, silent_console: bool = True):
    """
    load and split the dataset

    :return: tf.data.Dataset, tf.data.Dataset, tf.data.Dataset
    """
    # get training dataset
    eval_split = val_split + test_split
    train_ds = tf.keras.utils.image_dataset_from_directory(
        data_dir,
        validation_split=eval_split,
        subset="training",
        seed=123,
        image_size=IMAGE_SIZE,
        batch_size=BATCH_SIZE
    )

    # get data to eval (validation and test)
    val_ds = tf.keras.utils.image_dataset_from_directory(
        data_dir,
        validation_split=eval_split,
        subset="validation",
        seed=123,
        image_size=IMAGE_SIZE,
        batch_size=BATCH_SIZE
    )

    return train_ds, val_ds

In [None]:
def get_dataset_classes(dataset: tf.data.Dataset, dataset_type: str, silent_console: bool = True):
    """
    check the dataset classes

    :return:
    """
    class_names = dataset.class_names
    if not silent_console:
        print(f"Dataset: {dataset_type} | Nb of classes: {len(class_names)} | Class names: {class_names}")
    return class_names

In [None]:
def check_batch_size(dataset: tf.data.Dataset, dataset_type: str):
    """
    check the batch size of the dataset

    :return:
    """
    print(f"\n------ Checking batch size of the {dataset_type} dataset...")
    for image_batch, labels_batch in dataset:
        print(f"Image batch shape: {image_batch.shape}")
        print(f"Label batch shape: {labels_batch.shape}\n")
        break

In [None]:
def display_img_sample_of_dataset(dataset: tf.data.Dataset, dataset_type: str):
    """
    display a sample of images from the dataset

    :return:
    """
    plt.figure(figsize=(10, 10))
    class_names = dataset.class_names

    # Take one batch of images and create a subplot
    for images, labels in dataset.take(1):
        for i in range(9):
            ax = plt.subplot(3, 3, i + 1)  # Create the subplot
            ax.imshow(images[i].numpy().astype("uint8"))  # Show the image

            # Set the title on the subplot (not the entire plot)
            ax.set_title(f"{class_names[labels[i]]}", fontsize=8)
            ax.axis("off")  # Remove the axis labels

    # Adjust layout to prevent overlapping titles
    plt.subplots_adjust(top=0.9, bottom=0.1, left=0.1, right=0.9, hspace=0.3, wspace=0.3)

    # Add a title to the entire plot
    plt.suptitle(f"Sample of images from the {dataset_type} dataset", fontsize=16)

    # Save the plot
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    plt.savefig(results_dir + f"sample_images_{dataset_type}.png")

In [None]:
def set_prefetch(dataset: tf.data.Dataset):
    """
    set the prefetch for the dataset

    :param dataset: tf.data.Dataset
    :return: tf.data.Dataset
    """
    AUTOTUNE = tf.data.AUTOTUNE

    return dataset.shuffle(1000).prefetch(buffer_size=AUTOTUNE)

In [None]:
def normalize_dataset(dataset: tf.data.Dataset, silent_console: bool = True):
    """
    Normalize image data.
    RGB values are in the [0, 255] range, so we need to scale them to the [0, 1] range.

    :return:
    """
    print("Normalizing dataset...")
    normalizer_layer = keras.layers.Rescaling(1. / 255)
    normalized_ds = dataset.map(lambda x, y: (normalizer_layer(x), y))
    image_batch, labels_batch = next(iter(normalized_ds))

    # get the first image to check the pixel values
    if not silent_console:
        first_image = image_batch[0]
        print(f"First image pixel values -> Min: {np.min(first_image)} | Max: {np.max(first_image)} | "
              f"Shape: {first_image.shape}")
    return normalized_ds

In [None]:
def show_first_data_in_dataset(dataset: tf.data.Dataset, class_names: list):
    """
    Show the first data in the dataset

    :return:
    """
    print("\n---- Showing the first data in the dataset...")
    # get the first batch of data
    for image, label in dataset.take(1):
        # Show the first image and label
        first_img = image[0]
        first_label = label[0]
        print(f"First image shape: {first_img.shape} | First label: {first_label}")
        print(f"First image pixel values -> Min: {np.min(first_img)} | Max: {np.max(first_img)}")

        # Display the first image
        plt.figure()
        plt.imshow(first_img)
        plt.title(f"First image example | Label: {first_label} | Class name: {class_names[first_label]}")
        plt.grid(False)
        plt.show()

## Model

In [None]:
def data_augmentation():
    """
    Create a data augmentation layer

    :return:
    """
    data_augmentation = tf.keras.Sequential([
        keras.layers.RandomFlip("horizontal_and_vertical"),
        keras.layers.RandomRotation(0.2),
    ])
    return data_augmentation

In [None]:
def create_model(class_names: list, img_height: int, img_width: int):
    """
    Create a CNN model to classify image

    :return:
    """
    num_classes = len(class_names)
    image_shape = (img_height, img_width, 3)
    model = keras.Sequential([
        keras.layers.Input(shape=image_shape),
        #data_augmentation(),
        keras.layers.Reshape((img_height, img_width, 3)),
        keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Conv2D(64, (3, 3), activation='relu'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Conv2D(64, (3, 3), activation='relu'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Conv2D(64, (3, 3), activation='relu'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Flatten(),
        keras.layers.Dense(64, activation='relu'),
        keras.layers.Dense(num_classes, activation='softmax'),
    ])

    return model


In [None]:
def compile_mode(model: tf.keras.Model):
    """
    Compile the model

    :return:
    """
    model.compile(optimizer='adam',
                  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=['accuracy'])

    return model


In [None]:
def show_model_fit(hist: tf.keras.callbacks.History, nb_epochs: int):
    """
    Show the model fit

    :return:
    """
    acc = hist.history['accuracy']
    val_acc = hist.history['val_accuracy']

    loss = hist.history['loss']
    val_loss = hist.history['val_loss']

    epochs_range = range(nb_epochs)

    plt.figure(figsize=(8, 8))
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, acc, label='Training Accuracy')
    plt.plot(epochs_range, val_acc, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Accuracy')

    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='Training Loss')
    plt.plot(epochs_range, val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')

    # Save the plot
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    plt.savefig(results_dir + "model_fit.png")

In [None]:
def save_model(model: tf.keras.Model, model_name: str, file_format: str = "keras"):
    """
    Save the model

    :return:
    """
    if not os.path.exists(results_dir):
        raise ValueError("Results dir not found !")
    model_file = model_name + "." + file_format
    model_path = results_dir + '/' + model_file
    model.save(model_path)

In [None]:
def create_checkpoint_weights_callback(flag_quant : bool):
    """
    Create a checkpoint callback to save trained weights per epoch done.

    :return:
    """
    if flag_quant:
      checkpoint_dir = results_dir + "/training_checkpoints_quant"
    else:
      checkpoint_dir = results_dir + "/training_checkpoints"
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}.weights.h5")
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_prefix,
        save_weights_only=True
    )
    return checkpoint_callback, checkpoint_dir

In [None]:
def weights_files_key(filename):
    """
    Get the epoch number of each trained weight saved.

    Parameters
    ----------
    filename: str

    Returns
    -------
    nb of the epoch
    """
    base_name = filename.split('.')[0]
    epoch_number = base_name.split('_')[1]
    print(f"Epoch number: {epoch_number} | Base name: {base_name}")
    return int(epoch_number)

# Main code

## Download data

In [None]:
if DOWNLOAD_DATA:
  import kagglehub

  # download
  !kaggle datasets download -d emmarex/plantdisease -p /content/data --unzip

  #rename data folder
  !mv content/data/PlantVillage/ /content/drive/MyDrive/IC_VANT/PlantVillage/data

  # delete dir where we have downloaded data
  shutil.rmtree('/content/data')
else:
  print("Data already downloaded !")

Data already downloaded !


Check subfolders in data_dir, delete all subdirectories with a culture different of our set culture

In [None]:
if DELETE_SOME_SUBFOLDERS:

  def delete_subfolders(data_dir, target_culture):
    for subdir in os.listdir(data_dir):
        subdir_path = os.path.join(data_dir, subdir)
        print(f"Checking subdirectory: {subdir_path}")
        if os.path.isdir(subdir_path) and PLANT_CULTURE in subdir.lower():
            print(f"Keeping subdirectory: {subdir_path}")
        else:
            # move subdir to plant-village-others
            path = data_dir + '-others'
            if not os.path.exists(path):
                os.makedirs(path)
            print(f"Moving subdirectory: {subdir_path} to {path}")
            shutil.move(subdir_path, os.path.join(path, 'plant-village-others'))
        print("\n")

  delete_subfolders(data_dir, PLANT_CULTURE)

## Load data

In [None]:
if not check_data_dir():
  raise ValueError("Data dir not found !")
nb_img_data = get_dataset_info(data_dir)
print(f"# {dataset_name.upper()} dataset contains {nb_img_data} images\n")

# PLANTVILLAGE dataset contains 16010 images



In [None]:
train_dataset, val_dataset = load_split_dataset(0.2, 0.2, silent_console=True)

Found 16011 files belonging to 10 classes.
Using 9607 files for training.
Found 16011 files belonging to 10 classes.
Using 6404 files for validation.


### Check dataset classes

In [None]:
train_classes = get_dataset_classes(train_dataset, "train")
val_classes = get_dataset_classes(val_dataset, "validation")
if train_classes != val_classes:
    raise ValueError("The classes in the train and validation datasets are different")
else:
    print(f"Nb of class: {len(train_classes)} | Classes: {train_classes}\n")

Nb of class: 10 | Classes: ['Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy']



### Display sample of images

Train img sample

In [None]:
check_batch_size(train_dataset, "train")
# display_img_sample_of_dataset(train_dataset, "train")


------ Checking batch size of the train dataset...
Image batch shape: (4, 256, 256, 3)
Label batch shape: (4,)



Validation img sample

In [None]:
check_batch_size(val_dataset, "validation")
# display_img_sample_of_dataset(val_dataset, "validation")


------ Checking batch size of the validation dataset...
Image batch shape: (4, 256, 256, 3)
Label batch shape: (4,)



## Data pre-processing and cleaning

### Prefetch data

In [None]:
print("\n---- Setting prefetch for the dataset...")
train_dataset = set_prefetch(train_dataset)
val_dataset = set_prefetch(val_dataset)


---- Setting prefetch for the dataset...


### Normalize data

In [None]:
print("\n---- Normalizing dataset...")
train_dataset_normalized = normalize_dataset(train_dataset)
val_dataset_normalized = normalize_dataset(val_dataset)


---- Normalizing dataset...
Normalizing dataset...
Normalizing dataset...


In [None]:
# show_first_data_in_dataset(train_dataset_normalized, train_classes)

## Create and fit a CNN model

### Visualize some augmented images

### Build model

In [None]:
print("\n---- Building the model...")
model = create_model(train_classes, IMAGE_HEIGHT, IMAGE_WIDTH)
model = compile_mode(model)

In [None]:
model.summary()

### Train model

In [None]:
start_ep = -1
fit_model = True

Create callbacks to save fit by epoch

In [None]:
checkpoint_callback, checkpoint_dir = create_checkpoint_weights_callback()

In [None]:
print(f"Checkpoint_dir: {checkpoint_dir}")

Check entire model fit

In [None]:
if start_ep == -1:
  # check the entire model
  if os.path.exists(results_dir + "cnn_first_model.keras"):
    print("\n---- An entire model is already saved !")
    print("\n---- Loading the model...")
    model = keras.models.load_model(results_dir + '/' + f"cnn_model_ep={EPOCHS}.keras")
    fit_model = False
  else:
    start_ep = 0
    print("\n---- No entire model saved !")

Check checkpoints of epochs fit

In [None]:
if start_ep == 0:
  print(f"\n---- Checking ckpt results")

  if os.path.exists(checkpoint_dir):
    files = os.listdir(checkpoint_dir)
    if len(files) > 0:
      print(f"\n We found some ckpt model files !")

      highest_epoch_file = max(files, key=weights_files_key)
      highest_epoch = weights_files_key(highest_epoch_file)
      print(f"Highest epoch: {highest_epoch}")

      # get the last epoch fitting
      if highest_epoch < EPOCHS:
        start_ep = highest_epoch
        path_to_load_model = os.path.join(checkpoint_dir, "ckpt_{}.weights.h5".format(start_ep))
        print("\n---- Loading weights of: {}".format(path_to_load_model))
        model.load_weights(path_to_load_model)
        fit_model = True
      else:
        print("more or equal weights saved than EPOCHS")
        fit_model = False
        start_ep = -1
        # load checkpoint fitting
        path_to_load_model = os.path.join(checkpoint_dir, "ckpt_{}.weights.h5".format(EPOCHS))
        print("\n---- Loading weights of: {}".format(path_to_load_model))
        model.load_weights(path_to_load_model)
        # save as complete model
        print("\n---- Saving as an entire model...")
        save_model(model, f"cnn_model_ep={EPOCHS}")
        print("\n---- Complete model saved !")
    else:
      print(f"\n No ckpt model files found !")
      fit_model = True


Fit model

In [None]:
print(start_ep)

In [None]:
if fit_model:
  print("\n---- Fitting the model...")
  print(f"Start epoch: {start_ep}")
  print(f"End epoch: {EPOCHS}\n")

  # fit using GPU
  with tf.device('/device:GPU:0'):
    hist = model.fit(train_dataset_normalized,
                    validation_data=val_dataset_normalized,
                    epochs=EPOCHS,
                    initial_epoch=start_ep,
                    callbacks=[checkpoint_callback]
                    )

In [None]:
for layer in model.layers:
    print(layer.name, layer.output.shape)

In [None]:
if fit_model:
  show_model_fit(hist, EPOCHS)
  save_model(model, f"cnn_model_ep={EPOCHS}")

### Evaluate model

In [None]:
print("\n---- Evaluating the model...")

In [None]:
print("### Val dataset evaluation:")
val_score = model.evaluate(val_dataset_normalized)
print(f"Val accuracy: {val_score[1]*100.00:.2f} %")

In [None]:
print("### Train dataset evaluation:")
train_score = model.evaluate(train_dataset_normalized)
print(f"Train accuracy: {train_score[1]*100.00:.2f} %")

## Create a Quantization Aware Model

In [None]:
import tensorflow_model_optimization as tfmot

### Quantize layers

In [None]:
base_model = keras.models.clone_model(model)
base_model.set_weights(model.get_weights())
base_model = compile_mode(base_model)

Test model copy

In [None]:
print("### Train dataset evaluation (COPY MODEL):")
train_score = base_model.evaluate(train_dataset_normalized)
print(f"Train accuracy: {train_score[1]*100.00:.2f} %")

Check model type

In [None]:
print(type(base_model))

Check tensorflow and keras version

In [None]:
print(tf.__version__)
print(keras.__version__)

Create quantize

In [None]:
quantize_model = tfmot.quantization.keras.quantize_model

Apply quantization-aware in the model

In [None]:
quant_aware_model = quantize_model(base_model)

### Compile and fit the quantization model

Compile the quantization model with the same compilation of the model

In [None]:
quant_aware_model = compile_mode(quant_aware_model)

In [None]:
quant_aware_model.summary()

Fit the quantization model

In [None]:
checkpoint_callback_quant, checkpoint_dir_quant = create_checkpoint_weights_callback(flag_quant=True)

In [None]:
print(f"Checkpoint_dir_quant: {checkpoint_dir_quant}")

In [None]:
quant_aware_model.fit(train_dataset_normalized,
                      batch_size=32,
                      validation_data=val_dataset_normalized,
                      epochs=EPOCHS,
                      callbacks=[checkpoint_callback_quant]
                      )

### Save q-model

In [None]:
save_model(quant_aware_model, f"cnn_quant_aware_model_ep={EPOCHS}")

### Evaluate q-model

In [None]:
print("[INFO] Calculating Quant Aware model accuracy")
scores_val = quant_aware_model.evaluate(val_dataset_normalized)
print(f"Val Accuracy: {scores_val[1]*100:.2f}%")

In [None]:
scores_train = quant_aware_model.evaluate(train_dataset_normalized)
print(f"Train Accuracy: {scores_train[1]*100:.2f}%")

## TF to TFlite

Convert model to tensorflowlite model

In [None]:
print(quant_aware_model.input_shape)

In [None]:
input_shape = (1, 256, 256, 3)
def representative_data_gen():
  for _ in range(100):  # Adjust the number of samples as needed
    dummy_input = np.random.rand(*input_shape).astype(np.float32)  # Generate dummy input
    yield [dummy_input]  # Yield as a list containing a single input

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(quant_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen

In [None]:
quantized_tflite_model = converter.convert()

Save TFlite model

In [None]:
file_to_save_TFlite = results_dir + '/' + f"tflite_model_ep={EPOCHS}.tflite"
print(file_to_save_TFlite)

In [None]:
with open(file_to_save_TFlite,"wb") as f:
    f.write(quantized_tflite_model)

Test TF model

In [None]:
def evaluate_tflite_model (dataset, interpreter):
    input_index = interpreter.get_input_details()[0]["index"]
    output_index = interpreter.get_output_details()[0]["index"]

    prediction_digits = []
    test_labels = []
    for image, label in dataset.unbatch().take(dataset.unbatch().cardinality()):

        test_image = np.expand_dims(image, axis=0).astype(np.float32)
        interpreter.set_tensor(input_index, test_image)
        interpreter.invoke()

        output = interpreter.tensor(output_index)
        digit = np.argmax(output()[0])
        prediction_digits.append(digit)
        test_labels.append(label)

    prediction_digits = np.array(prediction_digits)
    accuracy = (prediction_digits == test_labels).mean()
    return accuracy

In [None]:
interpreter = tf.lite.Interpreter(model_path = file_to_save_TFlite)
interpreter.allocate_tensors()

In [None]:
test_accuracy = evaluate_tflite_model(val_dataset_normalized, interpreter)
print('Quant TFLite test_accuracy:', test_accuracy)