In [2]:
import torch

In [9]:
# Learning pytorch cdist
a = torch.tensor([[3, 4], [-0.3108, -2.4423], [-0.4821,  1.059]])
b = torch.tensor([[3, 4], [-0.6986,  1.3702]])
dists = torch.cdist(a, b, p=2) # p=2 means norm 2, or euclidian distance
print(dists)

# row i of dists is the distances of element i in a
# col j of dists is the distances of element i in a compared to element j of b

tensor([[0.0000, 4.5382],
        [7.2432, 3.8322],
        [4.5579, 0.3791]])


# Prototypical Network

## Sizes in original version:

support_features: [25, 512]

query_features: [25, 512]

prototypes: [5, 512]

dists: [50, 5]

In [31]:
# Prototypical Network
support_features = torch.rand((3,10))
query_features = torch.rand((5,10))

print("support:", support_features)

n_way = 3

support_labels = torch.tensor([0,1,2])

# Prototype i is the mean of all instances of features corresponding to labels == i
prototypes = torch.cat([support_features[torch.nonzero(support_labels==i)].mean(0) 
                        for i in range(n_way)])

print("prototypes:",prototypes)

dists = torch.cdist(query_features, prototypes)
# print(dists)

scores = -dists
# print(scores)

# scores are negative of dists because the largest (least negative)
# score value is the answer, and the smallest distance is the answer

support: tensor([[0.7511, 0.7107, 0.5242, 0.4270, 0.9622, 0.1226, 0.5837, 0.5515, 0.6753,
         0.2185],
        [0.5593, 0.3586, 0.8593, 0.5987, 0.7437, 0.3188, 0.4466, 0.5877, 0.4988,
         0.9259],
        [0.1332, 0.8960, 0.2055, 0.8800, 0.7347, 0.6509, 0.8275, 0.7082, 0.2440,
         0.2309]])
prototypes: tensor([[0.7511, 0.7107, 0.5242, 0.4270, 0.9622, 0.1226, 0.5837, 0.5515, 0.6753,
         0.2185],
        [0.5593, 0.3586, 0.8593, 0.5987, 0.7437, 0.3188, 0.4466, 0.5877, 0.4988,
         0.9259],
        [0.1332, 0.8960, 0.2055, 0.8800, 0.7347, 0.6509, 0.8275, 0.7082, 0.2440,
         0.2309]])
tensor([[1.0895, 1.0045, 1.0592],
        [1.0871, 0.8147, 1.7394],
        [1.3743, 1.2679, 1.2579],
        [1.1231, 1.0116, 1.3322],
        [1.2640, 1.1765, 1.3550]])
tensor([[-1.0895, -1.0045, -1.0592],
        [-1.0871, -0.8147, -1.7394],
        [-1.3743, -1.2679, -1.2579],
        [-1.1231, -1.0116, -1.3322],
        [-1.2640, -1.1765, -1.3550]])


# Dataset (wrapper) Class

Wrap a dataset in a FewShotDataset.

Args:

dataset: dataset to wrap

image_position_in_get_item_output: position of the image in the tuple returned
    by dataset.__getitem__(). Default: 0
    
label_position_in_get_item_output: position of the label in the tuple returned
    by dataset.__getitem__(). Default: 1

In [77]:
from typing import Tuple, List, Union, Iterator
import torch
import random
from torch import Tensor
from torchvision import datasets, transforms
from torch.utils.data import Dataset
from few_shot_dataset import FewShotDataset
from pathlib import Path

class Dummy_dataset(FewShotDataset):
    def __init__(
        self,
        dataset: Dataset,
        image_position_in_get_item_output: int = 0,
        label_position_in_get_item_output: int = 1,
    ):
        """
        Wrap a dataset in a FewShotDataset.
        Args:
            dataset: dataset to wrap
            image_position_in_get_item_output: position of the image in the tuple returned
                by dataset.__getitem__(). Default: 0
            label_position_in_get_item_output: position of the label in the tuple returned
                by dataset.__getitem__(). Default: 1
        """

        self.source_dataset = dataset
        self.image_position_in_get_item_output = image_position_in_get_item_output
        self.label_position_in_get_item_output = label_position_in_get_item_output
        self.n_shot = 5
        self.n_way = 5
        self.n_query = 10
        self.n_tasks = 100

        self.labels = []

        for [_,label] in dataset:
            self.labels.append(label)


    def __getitem__(self, item: int) -> Tuple[Tensor, int]:
        return (
            self.source_dataset[item][self.image_position_in_get_item_output],
            self.source_dataset[item][self.label_position_in_get_item_output],
        )
    
    def __iter__(self) -> Iterator[List[int]]:
        """
        Sample n_way labels uniformly at random,
        and then sample n_shot + n_query items for each label, also uniformly at random.
        Yields:
            a list of indices of length (n_way * (n_shot + n_query))
        """
        for _ in range(self.n_tasks):
            yield torch.cat(
                [
                    torch.tensor(
                        random.sample(
                            self.items_per_label[label], self.n_shot + self.n_query
                        )
                    )
                    for label in random.sample(
                        sorted(self.items_per_label.keys()), self.n_way
                    )
                ]
            ).tolist()
    
    def episodic_collate_fn(
        self, input_data: List[Tuple[Tensor, Union[Tensor, int]]]
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor, List[int]]:
        
        # cast label of each data image to an integer
        input_data_with_int_labels = self._cast_input_data_to_tensor_int_tuple(input_data)

        # true_class_ids is a list from [0-max_class_label] ex: [0,1,2,3,4,5]
        true_class_ids = list({x[1] for x in input_data_with_int_labels})

        # unsqueeze adds additional dimension, so go from torch.Size([3, 28, 28]) to torch.Size([1, 3, 28, 28]) for each image
        # this way, the end result dimensions are torch.Size([420, 3, 28, 28]) rather than torch.Size([1260, 28, 28])
        all_images = torch.cat([x[0].unsqueeze(0) for x in input_data_with_int_labels])
        print(all_images.reshape(
            (self.n_way, self.n_shot + self.n_query, *all_images.shape[1:])
        ))

    def get_labels(self) -> List[int]:
        return self.labels
    
    @staticmethod
    def _cast_input_data_to_tensor_int_tuple(
        input_data: List[Tuple[Tensor, Union[Tensor, int]]]
    ) -> List[Tuple[Tensor, int]]:
        return [(image, int(label)) for (image, label) in input_data]

# Dummy Sampler

In [111]:
from typing import Iterator, Optional, Sized, Dict
from torch.utils.data import Sampler

class DummySampler(Sampler):
    
    def __init__(self, dataset, n_way, n_shot, n_query, n_tasks, data_source: Sized | None = None) -> None:
        super().__init__(data_source)
        self.n_way = n_way # number of classes in one task
        self.n_shot = n_shot # number of support images for each class in one task
        self.n_query = n_query # number of query images for each class in one task
        self.n_tasks = n_tasks # number of tasks to sample (each bundle of support and query)

        self.items_per_label: Dict[int,List[int]] = {}

        for item, label in enumerate(dataset.get_labels()):
            if label not in self.items_per_label:
                self.items_per_label[label] = [item]
            else:
                self.items_per_label[label].append(item)

    def __len__(self):
        return self.n_tasks
    
    def __iter__(self) -> Iterator[List[int]]:
        yield torch.cat([
            torch.tensor(
                    random.sample(
                        self.items_per_label[label], self.n_shot + self.n_query # getting n_shot + n_query random images from selected label
                    )
                )
                for label in random.sample(
                    sorted(self.items_per_label.keys()), self.n_way # samples n_way classes
                )
        ]).tolist() # returns a list of n_way*(n_shot+n_query) numbers representing image indexes within entire dataset

In [113]:
image_size = 28

# Setup path to data folder
data_path = Path("data")
image_path = data_path / "UCMerced-Test"

# Check if image folder exists
if image_path.is_dir():
    print(f"{image_path} directory exists.")
else:
    print(f"Did not find {image_path} directory")
    exit()

# Setup train and testing paths
test_dir = image_path / "Test"

# Write transform for image
test_transform = transforms.Compose([
    transforms.Resize([int(image_size * 1.15), int(image_size * 1.15)]),
    transforms.CenterCrop(image_size),
    transforms.ToTensor()
])

test_set = datasets.ImageFolder(
    root=test_dir,
    transform=test_transform,
)

test_set = Dummy_dataset(test_set)

test_sampler = DummySampler(
    test_set, n_way=5, n_shot=5, n_query=10, n_tasks=100
)

test_set.episodic_collate_fn(test_set)

data/UCMerced-Test directory exists.


AttributeError: 'Dummy_dataset' object has no attribute 'items_per_label'

# Episodic Collate Function

Collate function to be used as argument for the collate_fn parameter of episodic data loaders.

Args:

input_data: each element is a tuple containing:
- an image as a torch Tensor of shape (n_channels, height, width)
- the label of this image as an int or a 0-dim tensor
        
Returns:

tuple(Tensor, Tensor, Tensor, Tensor, list[int]): respectively:
- support images of shape (n_way * n_shot, n_channels, height, width),
- their labels of shape (n_way * n_shot),
- query images of shape (n_way * n_query, n_channels, height, width)
- their labels of shape (n_way * n_query),
- the dataset class ids of the class sampled in the episode