In [1]:
import numpy as np
import tensorflow as tf
from sklearn.metrics import classification_report

# Load MNIST dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0  # Normalize to [0,1]

# Reshape to (28,28,1)
x_train = np.expand_dims(x_train, axis=-1)
x_test = np.expand_dims(x_test, axis=-1)

# Group images by class
num_classes = 10
train_classes = {i: x_train[y_train == i] for i in range(num_classes)}
test_classes = {i: x_test[y_test == i] for i in range(num_classes)}

# Sample few-shot episode
def sample_episode(classes, n_classes=5, k_shot=5, q_query=10):
    selected_classes = np.random.choice(list(classes.keys()), n_classes, replace=False)
    support_set, query_set, labels = [], [], []
    
    for label, cls in enumerate(selected_classes):
        samples = np.random.choice(len(classes[cls]), k_shot + q_query, replace=False)
        support, query = samples[:k_shot], samples[k_shot:]
        support_set.append(classes[cls][support])
        query_set.append(classes[cls][query])
        labels.append([label] * q_query)

    support_set = np.array(support_set)  # Shape: (N, k, 28, 28, 1)
    query_set = np.array(query_set)      # Shape: (N, q, 28, 28, 1)
    labels = np.array(labels).flatten()  # Shape: (N * q,)

    return support_set, query_set, labels

# Compute class prototypes (mean embedding)
def compute_prototypes(support_set):
    return np.mean(support_set, axis=1)  # Shape: (N, 28, 28, 1)

# Classify query samples based on Euclidean distance
def classify_query(query_set, prototypes):
    N, Q, H, W, C = query_set.shape  # N = num classes, Q = query samples/class

    # Reshape for broadcasting
    query_set = query_set.reshape(N * Q, H * W * C)  # (N*Q, 28*28*1)
    prototypes = prototypes.reshape(N, H * W * C)    # (N, 28*28*1)

    # Compute Euclidean distances
    distances = np.linalg.norm(query_set[:, np.newaxis, :] - prototypes, axis=2)  # (N*Q, N)
    
    return np.argmin(distances, axis=1)  # (N*Q,)

# Sample an episode
support, query, true_labels = sample_episode(train_classes)

# Compute prototypes & classify
prototypes = compute_prototypes(support)
pred_labels = classify_query(query, prototypes)

# Evaluate performance
print("Prototypical Network Results:")
print(classification_report(true_labels, pred_labels))


Prototypical Network Results:
              precision    recall  f1-score   support

           0       0.80      0.80      0.80        10
           1       0.80      0.80      0.80        10
           2       0.89      0.80      0.84        10
           3       0.83      1.00      0.91        10
           4       1.00      0.90      0.95        10

    accuracy                           0.86        50
   macro avg       0.86      0.86      0.86        50
weighted avg       0.86      0.86      0.86        50

