<a href="https://colab.research.google.com/github/hongqin/Python-CoLab-bootcamp/blob/master/Few_Shot_Learning_with_Prototypical_Networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Few-Shot Learning with Prototypical Networks on Omniglot Dataset

## Step 1: Import Libraries and Prepare Data
In this step, we will import the necessary libraries and prepare the Omniglot dataset for our few-shot learning task.

In [1]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
from collections import defaultdict

# Ensure reproducibility
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

# Transformations for the Omniglot dataset
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load Omniglot dataset
train_dataset = datasets.Omniglot(root='./data', background=True, download=True, transform=transform)
test_dataset = datasets.Omniglot(root='./data', background=False, download=True, transform=transform)

# Extract characters and their classes
train_characters = defaultdict(list)
test_characters = defaultdict(list)

for idx, (image, target) in enumerate(train_dataset):
    train_characters[target].append(idx)

for idx, (image, target) in enumerate(test_dataset):
    test_characters[target].append(idx)

train_classes = list(train_characters.keys())
test_classes = list(test_characters.keys())


Downloading https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_background.zip to ./data/omniglot-py/images_background.zip


100%|██████████| 9464212/9464212 [00:00<00:00, 74254768.15it/s]


Extracting ./data/omniglot-py/images_background.zip to ./data/omniglot-py
Downloading https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_evaluation.zip to ./data/omniglot-py/images_evaluation.zip


100%|██████████| 6462886/6462886 [00:00<00:00, 74539228.69it/s]

Extracting ./data/omniglot-py/images_evaluation.zip to ./data/omniglot-py





## Step 2: Define the Prototypical Network
Here, we define the Prototypical Network model, which consists of a simple convolutional neural network (CNN) used for embedding the input images.

In [2]:

class PrototypicalNetwork(nn.Module):
    def __init__(self, input_channels, hidden_dim, output_dim):
        super(PrototypicalNetwork, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, hidden_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(hidden_dim, output_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(output_dim),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )

    def forward(self, x):
        x = self.encoder(x)
        return x.view(x.size(0), -1)


## Step 3: Helper Functions for Few-Shot Learning
We create helper functions to generate few-shot learning episodes and to compute Euclidean distances between embeddings.

In [3]:

def create_episode(dataset, characters, num_classes, num_support, num_query):
    selected_classes = random.sample(characters.keys(), num_classes)
    support_x = []
    query_x = []
    support_y = []
    query_y = []

    for i, cls in enumerate(selected_classes):
        indices = characters[cls]
        support_indices = random.sample(indices, num_support)
        query_indices = random.sample(list(set(indices) - set(support_indices)), num_query)

        support_x.append(torch.stack([dataset[idx][0] for idx in support_indices]))
        query_x.append(torch.stack([dataset[idx][0] for idx in query_indices]))
        support_y.append(torch.ones(num_support) * i)
        query_y.append(torch.ones(num_query) * i)

    return (torch.cat(support_x), torch.cat(support_y)), (torch.cat(query_x), torch.cat(query_y))

def euclidean_distance(x, y):
    return torch.sqrt(torch.sum((x - y) ** 2, dim=-1))


## Step 4: Train the Model
We define a function to train the Prototypical Network using the few-shot learning episodes generated in the previous step.

In [None]:

def train_protonet(model, optimizer, train_dataset, train_characters, num_classes, num_support, num_query, num_episodes):
    model.train()
    for episode in range(num_episodes):
        (support_x, support_y), (query_x, query_y) = create_episode(train_dataset, train_characters, num_classes, num_support, num_query)

        optimizer.zero_grad()
        support_embeddings = model(support_x)
        query_embeddings = model(query_x)

        prototypes = []
        for i in range(num_classes):
            prototypes.append(support_embeddings[support_y == i].mean(dim=0))
        prototypes = torch.stack(prototypes)

        distances = torch.stack([euclidean_distance(query_embeddings, prototype) for prototype in prototypes])
        predictions = torch.argmin(distances, dim=0)

        loss = nn.CrossEntropyLoss()(distances.T, query_y.long())
        loss.backward()
        optimizer.step()

        if (episode + 1) % 100 == 0:
            print(f'Episode [{episode + 1}/{num_episodes}], Loss: {loss.item():.4f}')

input_channels = 1
hidden_dim = 64
output_dim = 64
num_classes = 5
num_support = 5
num_query = 15
num_episodes = 1000

model = PrototypicalNetwork(input_channels, hidden_dim, output_dim)
optimizer = optim.Adam(model.parameters(), lr=0.001)

train_protonet(model, optimizer, train_dataset, train_characters, num_classes, num_support, num_query, num_episodes)


since Python 3.9 and will be removed in a subsequent version.
  selected_classes = random.sample(characters.keys(), num_classes)


Episode [100/1000], Loss: 1.6305
Episode [200/1000], Loss: 1.6172
Episode [300/1000], Loss: 1.6211


## Step 5: Evaluate the Model
Finally, we evaluate the trained Prototypical Network on the test dataset to see how well it performs on few-shot learning tasks.

In [None]:

def evaluate_protonet(model, test_dataset, test_characters, num_classes, num_support, num_query, num_episodes):
    model.eval()
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for _ in range(num_episodes):
            (support_x, support_y), (query_x, query_y) = create_episode(test_dataset, test_characters, num_classes, num_support, num_query)

            support_embeddings = model(support_x)
            query_embeddings = model(query_x)

            prototypes = []
            for i in range(num_classes):
                prototypes.append(support_embeddings[support_y == i].mean(dim=0))
            prototypes = torch.stack(prototypes)

            distances = torch.stack([euclidean_distance(query_embeddings, prototype) for prototype in prototypes])
            predictions = torch.argmin(distances, dim=0)

            total_correct += (predictions == query_y).sum().item()
            total_samples += len(query_y)

    accuracy = total_correct / total_samples
    print(f'Test Accuracy: {accuracy:.4f}')

evaluate_protonet(model, test_dataset, test_characters, num_classes, num_support, num_query, num_episodes=100)
