In [155]:
import os
import tensorflow as tf
from tensorflow.keras import applications
from tensorflow.keras.utils import image_dataset_from_directory
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import pandas as pd

plt.rcpara

tf.get_logger().setLevel('ERROR')
IMG_SIZE = 224

# Feature extraction model

In [203]:
def build_feature_extraction_model(base_model_class, input_shape=(IMG_SIZE, IMG_SIZE, 3), preprocess_layer=None, num_classes=101):
    """
    Args:
        base_model_class: A class for the base model (e.g., tf.keras.applications.MobileNet).
        preprocess_layer: Optional preprocessing layer (e.g., tf.keras.layers.Lambda with preprocess function).
        input_shape: The shape of the input data (excluding batch size).
        num_classes: The number of output classes for classification.

    Returns:
        tf.keras.Model: A compiled Keras model.
    """
    if base_model_class == "Xception":
        input_shape=(299, 299, 3)
        
    inputs = layers.Input(shape=input_shape, dtype="float32", name="input_layer")
    x = layers.Lambda(preprocess_layer, name="preprocessing_layer")(inputs)

    base_model = base_model_class(input_tensor=x, include_top=False, weights="imagenet")
    base_model.trainable = False
    
    x = base_model(x)
    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)

    model = tf.keras.Model(inputs=inputs, outputs=outputs, name=base_model_class.__name__)
    model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

    return model


def plot_loss_curves(history, title):
    """Return separate loss/accuracy curves for training and validation metrics."""
    loss = history.history["loss"]
    val_loss = history.history["val_loss"]

    accuracy = history.history["accuracy"]
    val_accuracy = history.history["val_accuracy"]

    epochs = range(len(history.history["loss"]))

    plt.figure(figsize=(14,4))
    plt.suptitle(title)

    plt.subplot(121)
    plt.plot(epochs, loss, label="training loss")
    plt.plot(epochs, val_loss, label="val loss")
    plt.title("loss")
    plt.xlabel("epochs")
    plt.legend()

    plt.subplot(122)
    plt.plot(epochs, accuracy, label="training accuracy")
    plt.plot(epochs, val_accuracy, label="val accuracy")
    plt.title("accuracy")
    plt.xlabel("epochs")
    plt.legend()

    plt.show()
    

def test_models(models_config, train_data, test_data, num_classes=101, epochs=3):
    """
    Test various models with different configurations and save the results.
    
    Args:
        models_config (list): List of dicts containing model, its name and preprocessor.
        train_data (tf.data.Dataset): Training data.
        val_data (tf.data.Dataset): Validation data.
        num_classes (int): Number of output classes.
        epochs (int): Number of epochs to train the model.
    
    Returns:
        pd.DataFrame: DataFrame with model performance results.
    """
    results = []

    for config in models_config:
        model_name = config['model_name']
        model_builder = config['model']
        preprocess_layer = config['preprocessor']

        print(f"Training model: {model_name}...")

        model = build_feature_extraction_model(model_builder, preprocess_layer=preprocess_layer)

        history = model.fit(train_data,
                            validation_data=test_data,
                            validation_steps=int(0.33*(len(test_data))),
                            epochs=epochs,
                            verbose=1)

        test_accuracy = model.evaluate(test_data)[1]
        results.append({"Model": model_name,
                        "Test Accuracy": test_accuracy})

        plot_loss_curves(history, model_name)

    results = pd.DataFrame(results)
    results.to_csv("feature_extraction_results.csv", index=False)

    return results

In [204]:
model = build_feature_extraction_model(EfficientNetV2S, preprocess_layer=applications.efficientnet_v2.preprocess_input)
model.summary()

In [195]:
train_data, validation_data = image_dataset_from_directory("food101_train",
                                                           image_size=(IMG_SIZE, IMG_SIZE),
                                                           batch_size=32,
                                                           seed=42,
                                                           validation_split=0.1,
                                                           subset="both"
                                                           )


train_data = train_data.prefetch(buffer_size=tf.data.AUTOTUNE)
validation_data = validation_data.prefetch(buffer_size=tf.data.AUTOTUNE)

test_data= image_dataset_from_directory("food101_test",
                                          batch_size=32,
                                          image_size=(IMG_SIZE, IMG_SIZE),
                                          shuffle=False
                                        ).prefetch(buffer_size=tf.data.AUTOTUNE)

Found 74843 files belonging to 101 classes.
Using 67359 files for training.
Using 7484 files for validation.
Found 25255 files belonging to 101 classes.


In [206]:
models_config = [
    {
        'model_name': "MobileNetV2",
        'model': applications.MobileNetV2,
        'preprocessor': applications.mobilenet_v2.preprocess_input
    },
    {
        'model_name': "EfficientNetV2S",
        'model': EfficientNetV2S,
        'preprocessor': applications.efficientnet_v2.preprocess_input
    },
    {
        'model_name': "Xception",
        'model': applications.Xception,
        'preprocessor': applications.xception.preprocess_input,
    },
    {
        'model_name': "ConvNextTiny",
        'model': applications.ConvNeXtTiny,
        'preprocessor': applications.convnext.preprocess_input,
    },
    {
        'model_name': "ResNet50V2",
        'model': applications.ResNet50V2,
        'preprocessor': applications.resnet_v2.preprocess_input,
    },
    
]

In [210]:
results_df = test_models(models_config, train_data, validation_data, test_data, num_classes=101, epochs=3)
results_df.round(2)

Unnamed: 0,Model,Test Accuracy
0,MobileNetV2,0.04
1,EfficientNetV2S,0.71
2,Xception,0.06
3,ConvNextTiny,0.7
4,ResNet50V2,0.14


# Fine tuning

In [4]:
from tensorflow.keras import mixed_precision

policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)