In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, models
import numpy as np

In [2]:
class ResnetFSLModel(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.model = models.resnet50(weights='DEFAULT')
        in_features = self.model.fc.in_features
        self.model.fc = nn.Linear(in_features, num_classes)

    def forward(self, support_images, query_images):
        support_embeddings = self.model(support_images)
        prototypes = support_embeddings.reshape(self.num_classes, -1).mean(dim=1)

        query_embeddings = self.model(query_images)
        distances = torch.cdist(query_embeddings, prototypes)
        log_probabilities = F.log_softmax(-distances, dim=1)
        return log_probabilities

In [3]:
def get_episode_data(images, labels, num_classes_per_episode, num_support, num_query):
    # Randomly select classes for the episode
    selected_classes = np.random.choice(np.unique(labels), num_classes_per_episode, replace=False)

    # Prepare support and query sets
    support_set = []
    query_set = []
    query_labels = []
    for class_label in selected_classes:
        class_indices = np.where(labels == class_label)[0]

        # Randomly select non-overlapping support and query indices
        support_indices = np.random.choice(class_indices, num_support, replace=False)
        query_indices = np.random.choice([i for i in class_indices if i not in support_indices], num_query, replace=False)

        support_set.extend(images[support_indices])
        query_set.extend(images[query_indices])
        query_labels.extend([class_label] * num_query)

    # Convert to tensors
    # support_set = torch.stack(support_set)
    # query_set = torch.stack(query_set)
    # query_labels = torch.tensor(query_labels)
    support_set = torch.stack([images[idx] for idx in support_indices])
    query_set = torch.stack([images[idx] for idx in query_indices])
    query_labels = torch.tensor(query_labels)

    return support_set, query_set, query_labels

In [4]:
# TODO: Add dataset here

# Hyperparameters
num_classes_per_episode = 5
num_support = 5  # Support examples per class
num_query = 15  # Query examples per class
num_episodes = 100
learning_rate = 0.001

In [5]:
model = ResnetFSLModel(num_classes=10)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /home/davidroot/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 59.2MB/s]


In [None]:
for episode in range(num_episodes):
    train_images, train_labels = trainset.data, torch.tensor(trainset.targets)
    support_set, query_set, query_labels = get_episode_data(train_images, train_labels, num_classes_per_episode, num_support, num_query)

    # Training Step
    optimizer.zero_grad()
    outputs = model(support_set, query_set)
    loss = F.cross_entropy(outputs, query_labels)
    loss.backward()
    optimizer.step()

    if (episode + 1) % 10 == 0:
        print(f"Episode {episode + 1}: Training Loss {loss.item():.4f}")

        with torch.no_grad():
            correct = 0
            total = 0
            for _ in range(10):  # Evaluate over multiple validation episodes
                test_images, test_labels = testset.data, torch.tensor(testset.targets)
                support_set, query_set, query_labels = get_episode_data(test_images, test_labels, num_classes_per_episode, num_support, num_query)

                outputs = model(support_set, query_set)
                _, predicted = torch.max(outputs, dim=1)  # Get predictions

                total += query_labels.size(0)
                correct += (predicted == query_labels).sum().item()

            print(f"Validation Accuracy: {100 * correct / total:.2f}%") 