In [98]:
import typing
from torch.utils.data import Dataset, DataLoader
import torch
import torch.cuda
import torchvision as tv
from torch import nn
from tqdm import tqdm

In [99]:
tv.datasets.Omniglot(root="./data/raw", background=True, download=True)

Files already downloaded and verified


Dataset Omniglot
    Number of datapoints: 18480
    Root location: ./data/raw\omniglot-py

In [100]:
class FewShotDataset(Dataset):
  """ A custom Dataset class for Few-Shot Learning tasks.
    This dataset can operate in two modes: "support" (for prototype calculation) and "query" (for evaluation). """
  def __init__(self, dataset, indices, transform, mode="support"):
    """ Args:
        dataset (list): List of (feature, label) pairs.
        indices (list): List of indices to be used for the dataset.
        transform (callable): Transform to be applied to the features.
        mode (str): Mode of operation, either "support" or "query". Default is "support". """
    self.dataset, self.indices = [] if not dataset else dataset, indices # Initialize dataset (empty if not provided)
    self.mode, self.transform = mode, transform
    self.classes = dataset.classes
  # __init__():

  def __getitem__(self, index):
    if index >= len(self.indices): raise IndexError("Index out of bounds")
    feature, label = self.dataset[self.indices[index]]
    if self.mode == "query":
      one_hot_vector = torch.zeros(len(self.classes))
      one_hot_vector[label] = 1.
      label = one_hot_vector.requires_grad_(True)
    return self.transform(feature), label
  # __getitem__():

  def __len__(self): return len(self.indices)
# FSLDataset()

In [101]:
class FewShotEpisoder:
  """ A class to generate episodes for Few-Shot Learning.
  Each episode consists of a support set and a query set. """
  def __init__(self, dataset: tv.datasets.ImageFolder, k_shot: int, n_query: int, transform: typing.Callable):
    """ Args:
        dataset (Dataset): The base dataset to generate episodes from.
        k_shot (int): Number of support samples per class.
        n_query (int): Number of query samples per class.
        transform (callable): Transform to be applied to the features. """
    self.k_shot, self.n_query = k_shot, n_query  # define n-way/k-hot framework parameters
    self.dataset, self.transform = dataset, transform  # init dataset and apply transformer
    self.indices_c = self.get_indices()
  # __init__()

  def get_indices(self):
    """ Initialize the class indices for the dataset.
        * Returns: tuple of Number of classes and a list of indices grouped by class. """
    indices_c = [[] for _ in range(len(self.dataset.classes))]
    for index, (feature, label) in enumerate(self.dataset): indices_c[label].append(index)
    return indices_c
  # get_indices():

  def get_episode(self):  # select classes using list of chosen indexes
    """ Generate an episode consisting of a support set and a query set.
        Returns: tuple of A FewShotDataset for the support set and a FewShotDataset for the query set. """
    buffer_indices_c = self.indices_c.copy()
    support_examples, query_examples = [], []
    # select support examples
    for index_indices, indices in enumerate(buffer_indices_c):
      for index, x_index in enumerate(indices):
        support_examples.append(x_index)
        buffer_indices_c[index_indices].pop(index)
    # select query examples
    query_examples = support_examples.copy()
    for index_indices, indices in enumerate(buffer_indices_c):
      for index, x_index in enumerate(indices):
        support_examples.append(x_index)
        buffer_indices_c[index_indices].pop(index)
    # init datasets
    support_set = FewShotDataset(self.dataset, support_examples, mode="support", transform=self.transform)
    query_set = FewShotDataset(self.dataset, query_examples, mode="query", transform=self.transform)

    return support_set, query_set
  # get_episode()
# Episoder()

In [119]:
class ProtoNet(nn.Module):
  def __init__(self):
    super(ProtoNet, self).__init__()
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
    self.conv2 = nn.Conv2d(in_channels=16, out_channels=3, kernel_size=3, stride=1, padding=1)
    self.relu = nn.ReLU()
    self.flatten = nn.Flatten()
    self.softmax = nn.LogSoftmax(dim=0)
  # __init__():

  def prototyping(self, prototypes): self.prototypes = prototypes

  def cdist(self, x):
    dists = torch.cdist(x, self.prototypes, p=2).squeeze(0)  # Efficient batch-wise L2 distance computation
    return dists
  # cdist()

  def forward(self, x):
    x = self.conv1(x)
    x = self.relu(x)
    x = self.conv2(x)
    x = self.relu(x)
    x = self.flatten(x)
    x = self.cdist(x)
    x = self.softmax(x)
    return x
  # forward()
# ProtoNet()

In [120]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # init device

# create FSL episode generator
transform = tv.transforms.Compose([
  tv.transforms.Resize((224, 224)),
  tv.transforms.ToTensor(),
  tv.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ]) # transform
imageset = tv.datasets.ImageFolder(root="./data/raw/omniglot-py/images_background/Futurama")

# init episoder
k_shot, n_query, iters, epochs = 5, 2, 5, 1
episoder = FewShotEpisoder(imageset, k_shot, n_query, transform)

# init learning
model = ProtoNet().to(device)
optim = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

# start episodes
loss = float()
for _ in tqdm(range(epochs), desc="epochs/episodes"):
  support_set, query_set = episoder.get_episode() # create support set and query set
  # compute prototype from support examples
  prototypes = list()
  embedded_features_list = [[] for _ in range(len(support_set.classes))]
  for embedded_feature, label in support_set: embedded_features_list[label].append(embedded_feature)
  for embedded_features in embedded_features_list:
    sum = torch.zeros_like(embedded_features[0])
    for embedded_feature in embedded_features: sum += embedded_feature
    sum /= len(embedded_features)
    prototypes.append(sum.flatten())
  prototypes = torch.stack(prototypes)
  model.prototyping(prototypes)
  # update loss for given iters
  for _ in tqdm(range(iters), desc="\titerations/queries"):
    for feature, label in DataLoader(query_set, shuffle=True):
      loss = criterion(model.forward(feature), label.squeeze(dim=0))
      optim.zero_grad()
      loss.backward()
      optim.step()
print(f"loss: {loss:.6f}") # print final value of loss

epochs/episodes:   0%|          | 0/1 [00:00<?, ?it/s]
	iterations/queries:   0%|          | 0/5 [00:00<?, ?it/s][A
	iterations/queries:  20%|██        | 1/5 [00:06<00:26,  6.56s/it][A
	iterations/queries:  40%|████      | 2/5 [00:12<00:18,  6.33s/it][A
	iterations/queries:  60%|██████    | 3/5 [00:18<00:12,  6.24s/it][A
	iterations/queries:  80%|████████  | 4/5 [00:24<00:06,  6.20s/it][A
	iterations/queries: 100%|██████████| 5/5 [00:31<00:00,  6.22s/it][A
epochs/episodes: 100%|██████████| 1/1 [00:31<00:00, 31.68s/it]

loss: 0.000082





In [121]:
eval_episoder = FewShotEpisoder(imageset, 4, 4, transform)
eval_support_set, eval_query_set = eval_episoder.get_episode()
correct, n_problem = 0, len(eval_query_set)
# compute prototype from support examples
prototypes = list()
embedded_features_list = [[] for _ in range(len(support_set.classes))]
for embedded_feature, label in support_set: embedded_features_list[label].append(embedded_feature)
for embedded_features in embedded_features_list:
  sum = torch.zeros_like(embedded_features[0])
  for embedded_feature in embedded_features: sum += embedded_feature
  sum /= len(embedded_features)
  prototypes.append(sum.flatten())
eval_prototypes = torch.stack(prototypes)
model.prototyping(eval_prototypes)
for feature, label in DataLoader(eval_query_set, shuffle=True):
  if torch.argmax(model.forward(feature)) == torch.argmax(label): correct += 1
print(f"accuracy: {correct/n_problem:.2f}({correct}/{n_problem})") # print final accuracy

accuracy: 0.72(188/260)
