In [1]:
# Import libraries
import tensorflow as tf
import numpy as np
import random

from tensorflow.keras.applications import MobileNetV3Small
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.layers import Rescaling

import fewshot_functions as fs

2025-03-30 18:02:47.819568: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-30 18:02:47.840961: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

In [2]:
# Define dataset parameters
batch_size = 32
img_size = (224, 224)  # Resize images to 224x224 (adjust to your model's input size)

# Load dataset from 'data' folder with train-validation split
train_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    'data',                   # Root directory for your dataset
    image_size=img_size,       # Resize images to 224x224 (or your model's input size)
    batch_size=batch_size,     # Batch size for training
    label_mode='int',          # Labels are returned as integer indices based on folder names
    validation_split=0.2,      # Split 20% for validation
    subset="training",         # This is the training set
    seed=123                   # Seed for reproducibility
)

validation_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    'data',
    image_size=img_size,
    batch_size=batch_size,
    label_mode='int',
    validation_split=0.2,      # 20% validation
    subset="validation",       # This is the validation set
    seed=123
)

# Normalize the pixel values to [0, 1] range (from 0-255)
normalization_layer = tf.keras.layers.Rescaling(1./255)
train_dataset = train_dataset.map(lambda x, y: (normalization_layer(x), y))
validation_dataset = validation_dataset.map(lambda x, y: (normalization_layer(x), y))

# Prefetch data for better performance
AUTOTUNE = tf.data.AUTOTUNE
train_dataset = train_dataset.cache().prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.cache().prefetch(buffer_size=AUTOTUNE)

Found 9695 files belonging to 97 classes.
Using 7756 files for training.
Found 9695 files belonging to 97 classes.
Using 1939 files for validation.


MobileNetV3Small documentation:
https://www.tensorflow.org/api_docs/python/tf/keras/applications/MobileNetV3Small

In [3]:
# Load pre-trained model

# Preloading with imagenet weights, excluding top layer as it the one we will be retraining
base = MobileNetV3Small(input_shape=(224,224,3), 
                        weights="imagenet", 
                        include_top=False,
                        include_preprocessing=False
                       )

# Prevent imported weights from being retrained
base.trainable = False

Custom layers

The embeddeding layer to be utilized by Fewshot will be set up here

Prototypical Fewshot paper: https://arxiv.org/pdf/1703.05175

In [4]:
# Define input tensor (224x224 image with 3 color channels; RGB)
inputs = tf.keras.Input(shape=(224,224,3))

# Pass inputs through base model
x = base(inputs, training=False)

# Convert feature maps to single feature vector per image; alternative, flatten(), is prone to overfitting
x = GlobalAveragePooling2D()(x)

# Define number of classes in dataset; adjust as necessary
outputs = Dense(128, activation='linear')(x)

# Create customized model
embedding_model = Model(inputs,outputs)

Reached a roadblock here, but this is what the fewshot training will look like. I do not want to proceed until we have the data imported and preprocessed.

In [None]:
# Training loop

optimizer = Adam(learning_rate=0.001)

episodes = 1000
for episode in range(episodes):  # Train for 1000 episodes
    support_images, support_labels, query_images, query_labels = fs.sample_episode(train_dataset)

    with tf.GradientTape() as tape:
        # Embed the support and query images
        support_embeddings = embedding_model(support_images, training=True)
        query_embeddings = embedding_model(query_images, training=True)
        # Compute the loss
        loss = fs.prototypical_loss(support_embeddings, support_labels, query_embeddings, query_labels)

    # Update the model
    gradients = tape.gradient(loss, embedding_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, embedding_model.trainable_variables))

    # Log 1 episode in every 50
    if episode % 50 == 0:
        print(f"Episode {episode}: Loss = {loss.numpy():.4f}")

    # Store for use later
    embedding_model.save("mobilenetv3_fewshot.keras")

2025-03-30 18:02:50.182561: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Episode 0: Loss = 4.4548


2025-03-30 18:02:51.506732: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2025-03-30 18:02:53.684643: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2025-03-30 18:02:58.009385: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2025-03-30 18:03:06.725574: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


KeyboardInterrupt: 