In [2]:
# Import libraries
import tensorflow as tf
import numpy as np
import random

from tensorflow.keras.applications import MobileNetV3Small
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.layers import Rescaling

import fewshot_functions as fs

In [None]:
# Define dataset parameters
batch_size = 32
img_size = (224, 224)  # Resize images to 224x224 (adjust to your model's input size)

# Load dataset from 'data' folder with train-validation split
train_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    'data',                   # Root directory for your dataset
    image_size=img_size,       # Resize images to 224x224 (or your model's input size)
    batch_size=batch_size,     # Batch size for training
    label_mode='int',          # Labels are returned as integer indices based on folder names
    validation_split=0.2,      # Split 20% for validation
    subset="training",         # This is the training set
    seed=123                   # Seed for reproducibility
)

# Normalize the pixel values to [0, 1] range (from 0-255)
normalization_layer = tf.keras.layers.Rescaling(1./255)
train_dataset = train_dataset.map(lambda x, y: (normalization_layer(x), y))

# Prefetch data for better performance
AUTOTUNE = tf.data.AUTOTUNE
train_dataset = train_dataset.cache().prefetch(buffer_size=AUTOTUNE)

Found 9695 files belonging to 97 classes.
Using 7756 files for training.
Found 9695 files belonging to 97 classes.
Using 1939 files for validation.


MobileNetV3Small documentation:
https://www.tensorflow.org/api_docs/python/tf/keras/applications/MobileNetV3Small

In [4]:
# Load pre-trained model

# Preloading with imagenet weights, excluding top layer as it the one we will be retraining
base = MobileNetV3Small(input_shape=(224,224,3), 
                        weights="imagenet", 
                        include_top=False,
                        include_preprocessing=False
                       )

# Prevent imported weights from being retrained
base.trainable = False

Custom layers

The embeddeding layer to be utilized by Fewshot will be set up here

Prototypical Fewshot paper: https://arxiv.org/pdf/1703.05175

In [5]:
# Define input tensor (224x224 image with 3 color channels; RGB)
inputs = tf.keras.Input(shape=(224,224,3))

# Pass inputs through base model
x = base(inputs, training=False)

# Convert feature maps to single feature vector per image; alternative, flatten(), is prone to overfitting
x = GlobalAveragePooling2D()(x)

# Define number of classes in dataset; adjust as necessary
outputs = Dense(128, activation='linear')(x)

# Create customized model
embedding_model = Model(inputs,outputs)

Reached a roadblock here, but this is what the fewshot training will look like. I do not want to proceed until we have the data imported and preprocessed.

In [6]:
# Training loop

optimizer = Adam(learning_rate=0.001)

episodes = 1000
for episode in range(episodes):  # Train for 1000 episodes
    support_images, support_labels, query_images, query_labels = fs.sample_episode(train_dataset)

    with tf.GradientTape() as tape:
        # Embed the support and query images
        support_embeddings = embedding_model(support_images, training=True)
        query_embeddings = embedding_model(query_images, training=True)
        # Compute the loss
        loss = fs.prototypical_loss(support_embeddings, support_labels, query_embeddings, query_labels)

    # Update the model
    gradients = tape.gradient(loss, embedding_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, embedding_model.trainable_variables))

    # Log 1 episode in every 50
    if episode % 50 == 0:
        print(f"Episode {episode}: Loss = {loss.numpy():.4f}")

    # Store for use later
    embedding_model.save("mobilenetv3_fewshot.keras")

2025-03-30 18:05:29.687904: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Episode 0: Loss = 3.7574


2025-03-30 18:05:31.015762: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2025-03-30 18:05:33.242176: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2025-03-30 18:05:37.731198: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2025-03-30 18:05:46.710123: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2025-03-30 18:06:04.586505: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Episode 50: Loss = 1.6773


2025-03-30 18:06:40.470119: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Episode 100: Loss = 1.0439


2025-03-30 18:07:52.572682: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Episode 150: Loss = 1.6694
Episode 200: Loss = 2.1463
Episode 250: Loss = 1.0018


2025-03-30 18:10:16.331443: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Episode 300: Loss = 1.4995
Episode 350: Loss = 0.5849
Episode 400: Loss = 0.5105
Episode 450: Loss = 0.6546
Episode 500: Loss = 0.3345


2025-03-30 18:15:06.018492: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Episode 550: Loss = 0.3323
Episode 600: Loss = 0.5868
Episode 650: Loss = 1.7164
Episode 700: Loss = 0.7739
Episode 750: Loss = 1.2735
Episode 800: Loss = 0.3468
Episode 850: Loss = 0.8548
Episode 900: Loss = 0.9512
Episode 950: Loss = 0.3840
