In [6]:
#abc

In [7]:
from torchvision.datasets import Omniglot
from torchvision import transforms
import matplotlib.pyplot as plt

In [8]:
image_size = 32


train_set = Omniglot(
    root="./data",
    background=True,
    transform=transforms.Compose(
        [
            transforms.Grayscale(num_output_channels=3),
            transforms.RandomResizedCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    ),
    download=True,
)
test_set = Omniglot(
    root="./data",
    background=False,
    transform=transforms.Compose(
        [
            # Omniglot images have 1 channel, but our model will expect 3-channel images
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize([int(image_size * 1.15), int(image_size * 1.15)]),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
        ]
    ),
    download=True,
)


Files already downloaded and verified
Files already downloaded and verified


In [9]:
import tensorflow as tf
from tensorflow.keras.layers import Flatten
from tensorflow.keras.applications import VGG16

class PrototypicalNetworks(tf.keras.Model):
    def __init__(self, backbone):
        super(PrototypicalNetworks, self).__init__()
        self.backbone = backbone

    def call(
        self,
        support_images,
        support_labels,
        query_images,
    ):
        """
        Predict query labels using labeled support images.
        """
        # Extract the features of support and query images
        z_support = self.backbone(support_images, training=False)
        z_query = self.backbone(query_images, training=False)

        # Infer the number of different classes from the labels of the support set
        n_way = tf.shape(tf.unique(support_labels)[0])[0]
        # Prototype i is the mean of all instances of features corresponding to labels == i
        z_proto = tf.concat(
            [
                tf.reduce_mean(tf.boolean_mask(z_support, support_labels == label), axis=0)
                for label in tf.range(n_way)
            ],
            axis=0,
        )

        # Compute the euclidean distance from queries to prototypes
        dists = tf.norm(tf.expand_dims(z_query, 1) - tf.expand_dims(z_proto, 0), axis=-1)

        # And here is the super complicated operation to transform those distances into classification scores!
        scores = -dists
        return scores

# Replace `resnet18` with the appropriate TensorFlow backbone model or implement it using TensorFlow if it doesn't exist.
# The preprocessing might also differ for the TensorFlow backbone, depending on the model you use.
# You can use a pre-trained model or a custom implementation.
backbone_model = VGG16(weights='imagenet', include_top=False, input_shape=(32, 32, 3))
backbone_model = Flatten()((backbone_model.output))

# Create the Prototypical Networks model with the TensorFlow backbone
model = PrototypicalNetworks(backbone_model)


In [10]:
import tensorflow as tf

N_WAY = 5  # Number of classes in a task
N_SHOT = 5  # Number of images per class in the support set
N_QUERY = 10  # Number of images per class in the query set
N_EVALUATION_TASKS = 100

# Assuming 'test_set' contains the Omniglot dataset

# Extract images and corresponding labels from the test_set
test_images = [instance[0] for instance in test_set]
test_labels = [instance[1] for instance in test_set]

# Convert the lists to TensorFlow tensors
test_images = tf.convert_to_tensor(test_images)
test_labels = tf.convert_to_tensor(test_labels)

# Create a task sampler for episodic evaluation
def task_sampler(labels):
    selected_labels = tf.random.shuffle(labels)[:N_WAY]
    return selected_labels

# Create a function to sample support and query sets for each task
def sample_support_query(task_labels):
    selected_classes = task_sampler(task_labels)
    support_set_images = []
    query_set_images = []

    for class_label in selected_classes:
        class_indices = tf.where(tf.equal(task_labels, class_label))[:, 0]
        selected_indices = tf.random.shuffle(class_indices)[:N_SHOT + N_QUERY]
        support_set_images.extend(test_images[selected_indices[:N_SHOT]])
        query_set_images.extend(test_images[selected_indices[N_SHOT:]])

    return support_set_images, query_set_images

# Create support and query sets for each task in the evaluation
test_support_images, test_query_images = [], []
for _ in range(N_EVALUATION_TASKS):
    support_set, query_set = sample_support_query(test_labels)
    test_support_images.append(support_set)
    test_query_images.append(query_set)

# Convert the support and query sets to TensorFlow tensors
test_support_images = tf.convert_to_tensor(test_support_images)
test_query_images = tf.convert_to_tensor(test_query_images)

# Create TensorFlow datasets for support and query sets
test_support_loader = tf.data.Dataset.from_tensor_slices(test_support_images).batch(N_WAY * N_SHOT)
test_query_loader = tf.data.Dataset.from_tensor_slices(test_query_images).batch(N_WAY * N_QUERY)

# Prefetch the data to improve performance
test_support_loader = test_support_loader.prefetch(buffer_size=tf.data.AUTOTUNE)
test_query_loader = test_query_loader.prefetch(buffer_size=tf.data.AUTOTUNE)


ValueError: ignored