In [2]:
# Import libraries
import tensorflow as tf
import numpy as np
import random

from tensorflow.keras.applications import MobileNetV3Small
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.layers import Rescaling

import fewshot_functions as fs

In [None]:
# LOAD AND PROCESS DATA SET
# ...

MobileNetV3Small documentation:
https://www.tensorflow.org/api_docs/python/tf/keras/applications/MobileNetV3Small

In [25]:
# Load pre-trained model

# Preloading with imagenet weights, excluding top layer as it the one we will be retraining
base = MobileNetV3Small(input_shape=(224,224,3), 
                        weights="imagenet", 
                        include_top=False,
                        include_preprocessing=False
                       )

# Prevent imported weights from being retrained
base.trainable = False

Custom layers

The embeddeding layer to be utilized by Fewshot will be set up here

Prototypical Fewshot paper: https://arxiv.org/pdf/1703.05175

In [None]:
# Define input tensor (224x224 image with 3 color channels; RGB)
inputs = tf.keras.Input(shape=(224,224,3))

# Pass inputs through base model
x = base(inputs, training=False)

# Convert feature maps to single feature vector per image; alternative, flatten(), is prone to overfitting
x = GlobalAveragePooling2D()(x)

# Define number of classes in dataset; adjust as necessary
outputs = Dense(128, activation='linear')(x)

# Create customized model
embedding_model = Model(inputs,outputs)

Reached a roadblock here, but this is what the fewshot training will look like. I do not want to proceed until we have the data imported and preprocessed.

In [None]:
# Training loop

optimizer = Adam(learning_rate=0.001)

episodes = 1000
for episode in range(episodes):  # Train for 1000 episodes
    support_images, support_labels, query_images, query_labels = fs.sample_episode(dataset)

    with tf.GradientTape() as tape:
        # Embed the support and query images
        support_embeddings = embedding_model(support_images, training=True)
        query_embeddings = embedding_model(query_images, training=True)
        # Compute the loss
        loss = fs.prototypical_loss(support_embeddings, support_labels, query_embeddings, query_labels)

    # Optimize each variable and update the model
    gradients = tape.gradient(loss, embedding_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, embedding_model.trainable_variables))

    # Log 1 episode of every 50 in console
    if episode % 50 == 0:
        print(f"Episode {episode}: Loss = {loss.numpy():.4f}")

# Store for use later
embedding_model.save("mobilenetv3_fewshot.h5")