In [1]:
import torch

In [2]:
# Learning pytorch cdist
a = torch.tensor([[3, 4], [-0.3108, -2.4423], [-0.4821,  1.059]])
b = torch.tensor([[3, 4], [-0.6986,  1.3702]])
dists = torch.cdist(a, b, p=2) # p=2 means norm 2, or euclidian distance
print(dists)

# row i of dists is the distances of element i in a
# col j of dists is the distances of element i in a compared to element j of b

tensor([[0.0000, 4.5382],
        [7.2432, 3.8322],
        [4.5579, 0.3791]])


# Prototypical Network

## Sizes in original version:

support_features: [25, 512]

query_features: [25, 512]

prototypes: [5, 512]

dists: [50, 5]

In [3]:
# Prototypical Network
support_features = torch.rand((3,10))
query_features = torch.rand((5,10))

print("support:", support_features)

n_way = 3

support_labels = torch.tensor([0,1,2])

# Prototype i is the mean of all instances of features corresponding to labels == i
prototypes = torch.cat([support_features[torch.nonzero(support_labels==i)].mean(0) 
                        for i in range(n_way)])

print("prototypes:",prototypes)

dists = torch.cdist(query_features, prototypes)
# print(dists)

scores = -dists
# print(scores)

# scores are negative of dists because the largest (least negative)
# score value is the answer, and the smallest distance is the answer

support: tensor([[0.0747, 0.8336, 0.3018, 0.1586, 0.0862, 0.6559, 0.0182, 0.6178, 0.2983,
         0.6495],
        [0.9465, 0.0946, 0.6247, 0.2405, 0.1311, 0.0223, 0.2917, 0.3001, 0.5003,
         0.0394],
        [0.0486, 0.1500, 0.6379, 0.0133, 0.7661, 0.4694, 0.5304, 0.4993, 0.2591,
         0.3628]])
prototypes: tensor([[0.0747, 0.8336, 0.3018, 0.1586, 0.0862, 0.6559, 0.0182, 0.6178, 0.2983,
         0.6495],
        [0.9465, 0.0946, 0.6247, 0.2405, 0.1311, 0.0223, 0.2917, 0.3001, 0.5003,
         0.0394],
        [0.0486, 0.1500, 0.6379, 0.0133, 0.7661, 0.4694, 0.5304, 0.4993, 0.2591,
         0.3628]])


# Dataset (wrapper) Class

Wrap a dataset in a FewShotDataset.

Args:

dataset: dataset to wrap

image_position_in_get_item_output: position of the image in the tuple returned
    by dataset.__getitem__(). Default: 0
    
label_position_in_get_item_output: position of the label in the tuple returned
    by dataset.__getitem__(). Default: 1

In [4]:
from typing import Tuple, List, Union, Iterator
import torch
import random
from torch import Tensor
from torchvision import datasets, transforms
from torch.utils.data import Dataset
from few_shot_dataset import FewShotDataset
from pathlib import Path

class Dummy_dataset(FewShotDataset):
    def __init__(
        self,
        dataset: Dataset,
        image_position_in_get_item_output: int = 0,
        label_position_in_get_item_output: int = 1,
    ):
        """
        Wrap a dataset in a FewShotDataset.
        Args:
            dataset: dataset to wrap
            image_position_in_get_item_output: position of the image in the tuple returned
                by dataset.__getitem__(). Default: 0
            label_position_in_get_item_output: position of the label in the tuple returned
                by dataset.__getitem__(). Default: 1
        """

        self.source_dataset = dataset
        self.image_position_in_get_item_output = image_position_in_get_item_output
        self.label_position_in_get_item_output = label_position_in_get_item_output
        self.n_shot = 5
        self.n_way = 5
        self.n_query = 10
        self.n_tasks = 100

        self.labels = []

        for [_,label] in dataset:
            self.labels.append(label)


    def __getitem__(self, item: int) -> Tuple[Tensor, int]:
        return (
            self.source_dataset[item][self.image_position_in_get_item_output],
            self.source_dataset[item][self.label_position_in_get_item_output],
        )

    def get_labels(self) -> List[int]:
        return self.labels

# Dummy Sampler

In [5]:
from typing import Iterator, Optional, Sized, Dict
from torch.utils.data import Sampler

class DummySampler(Sampler):
    
    def __init__(self, dataset, n_way, n_shot, n_query, n_tasks, data_source: Sized | None = None) -> None:
        super().__init__(data_source)
        self.n_way = n_way # number of classes in one task
        self.n_shot = n_shot # number of support images for each class in one task
        self.n_query = n_query # number of query images for each class in one task
        self.n_tasks = n_tasks # number of tasks to sample (each bundle of support and query)

        self.items_per_label: Dict[int,List[int]] = {}

        for item, label in enumerate(dataset.get_labels()):
            if label not in self.items_per_label:
                self.items_per_label[label] = [item]
            else:
                self.items_per_label[label].append(item)

    def __len__(self):
        return self.n_tasks
    
    def __iter__(self) -> Iterator[List[int]]:
        yield torch.cat([
            torch.tensor(
                    random.sample(
                        self.items_per_label[label], self.n_shot + self.n_query # getting n_shot + n_query random images from selected label
                    )
                )
                for label in random.sample(
                    sorted(self.items_per_label.keys()), self.n_way # samples n_way classes
                )
        ]).tolist() # returns a list of n_way*(n_shot+n_query) numbers representing image indexes within entire dataset

    def episodic_collate_fn(
        self, input_data: List[Tuple[Tensor, Union[Tensor, int]]]
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor, List[int]]:
        
        # cast label of each data image to an integer
        input_data_with_int_labels = self._cast_input_data_to_tensor_int_tuple(input_data)

        # true_class_ids is a list from [0-max_class_label] ex: [0,1,2,3,4,5]
        true_class_ids = list({x[1] for x in input_data_with_int_labels})

        # unsqueeze adds additional dimension, so go from torch.Size([3, 28, 28]) to torch.Size([1, 3, 28, 28]) for each image
        # this way, the concatenated result dimensions are torch.Size([420, 3, 28, 28]) rather than torch.Size([1260, 28, 28])
        all_images = torch.cat([x[0].unsqueeze(0) for x in input_data_with_int_labels])

        # reshape(*var) unpacks the dimensions of var, so reshape(*all_images.shape[1:]) where each image is size (1,2,3) is the same as reshape(1,2,3)
        # this code puts each image in a matrix where each row is a different class (n_way) with each row containing n_shot+n_query images from the class
        all_images = all_images.reshape(
            (self.n_way, self.n_shot + self.n_query, *all_images.shape[1:])
        )

        # puts list of labels into matrix form matching dimensions above matrix
        # ex: [2,2,1,1] => [[2,2],[1,1]] (all rows are the same number/class)
        all_labels = torch.tensor(
            [true_class_ids.index(x[1]) for x in input_data_with_int_labels]
        ).reshape((self.n_way, self.n_shot + self.n_query))

        # take the first n_shot images and combine them into a group (support)
        support_images = all_images[:, : self.n_shot].reshape(
            (-1, *all_images.shape[2:])
        )

        # take the rest of the images and combine them into a group (query)
        query_images = all_images[:, self.n_shot :].reshape((-1, *all_images.shape[2:]))

        # group the support and query labels and flatten so they are a list
        # classes in the rows are the same between them, the difference is the length of each row (but the rows go away since they are flattened)
        support_labels = all_labels[:, : self.n_shot].flatten()
        query_labels = all_labels[:, self.n_shot :].flatten()
        
        return (
            support_images,
            support_labels,
            query_images,
            query_labels,
            true_class_ids,
        )

    @staticmethod
    def _cast_input_data_to_tensor_int_tuple(
        input_data: List[Tuple[Tensor, Union[Tensor, int]]]
    ) -> List[Tuple[Tensor, int]]:
        return [(image, int(label)) for (image, label) in input_data]

In [6]:
image_size = 28

# Setup path to data folder
data_path = Path("data")
image_path = data_path / "UCMerced-Test"

# Check if image folder exists
if image_path.is_dir():
    print(f"{image_path} directory exists.")
else:
    print(f"Did not find {image_path} directory")
    exit()

# Setup train and testing paths
test_dir = image_path / "Test"

# Write transform for image
test_transform = transforms.Compose([
    transforms.Resize([int(image_size * 1.15), int(image_size * 1.15)]),
    transforms.CenterCrop(image_size),
    transforms.ToTensor()
])

test_set = datasets.ImageFolder(
    root=test_dir,
    transform=test_transform,
)

test_set = Dummy_dataset(test_set)

data/UCMerced-Test directory exists.


# Episodic Collate Function

Collate function to be used as argument for the collate_fn parameter of episodic data loaders.

Args:

input_data: each element is a tuple containing:
- an image as a torch Tensor of shape (n_channels, height, width)
- the label of this image as an int or a 0-dim tensor
        
Returns:

tuple(Tensor, Tensor, Tensor, Tensor, list[int]): respectively:
- support images of shape (n_way * n_shot, n_channels, height, width),
- their labels of shape (n_way * n_shot),
- query images of shape (n_way * n_query, n_channels, height, width)
- their labels of shape (n_way * n_query),
- the dataset class ids of the class sampled in the episode

In [7]:
from torch.utils.data import DataLoader

test_sampler = DummySampler(
    test_set, n_way=5, n_shot=5, n_query=10, n_tasks=100
)

test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=8,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

(
    example_support_images,
    example_support_labels,
    example_query_images,
    example_query_labels,
    example_class_ids,
) = next(iter(test_loader))

# Evaluate Model

In [22]:
from src.Prototypical_networks import PrototypicalNetworks
from torchvision.models import resnet18
from torch import nn

convolutional_network = resnet18(pretrained=True)
convolutional_network.fc = nn.Flatten()

model = PrototypicalNetworks(convolutional_network)

def evaluate_on_one_task(
    support_images: torch.Tensor,
    support_labels: torch.Tensor,
    query_images: torch.Tensor,
    query_labels: torch.Tensor,
) -> [int, int]:
    """
    Returns the number of correct predictions of query labels, and the total number of predictions.
    """
    print(query_labels)
    print(torch.max(model(support_images, support_labels, query_images).detach().data,1)[1])
    print(torch.max(model(support_images, support_labels, query_images).detach().data,1)[1]==query_labels)
    print(type((torch.max(model(support_images, support_labels, query_images).detach().data,1)[1]==query_labels).sum()))
    print(type((torch.max(model(support_images, support_labels, query_images).detach().data,1)[1]==query_labels).sum().item()))

    return (
        torch.max(
            model(support_images, support_labels, query_images)
            .detach()
            .data,
            1,
        )[1] # this returns a tensor/list of indexes (labels) of highest scores across each row in matrix
        == query_labels # this compares tensor with real labels to tensor of predicted labels and returns a tensor of True and False
    ).sum().item(), len(query_labels)
    # sum adds up number of "True"s, item turns it from tensor to int

print(evaluate_on_one_task(
    example_support_images,
    example_support_labels,
    example_query_images,
    example_query_labels
))



tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3])
tensor([3, 2, 2, 2, 4, 2, 3, 2, 2, 2, 0, 4, 0, 4, 0, 0, 0, 0, 4, 0, 4, 4, 0, 4,
        4, 0, 0, 1, 4, 3, 3, 0, 1, 4, 1, 2, 1, 1, 0, 1, 4, 3, 3, 0, 0, 2, 2, 0,
        4, 2])
tensor([False,  True,  True,  True, False,  True, False,  True,  True,  True,
         True, False,  True, False,  True,  True,  True,  True, False,  True,
         True,  True, False,  True,  True, False, False, False,  True, False,
        False, False,  True, False,  True, False,  True,  True, False,  True,
        False,  True,  True, False, False, False, False, False, False, False])
<class 'torch.Tensor'>
<class 'int'>
(26, 50)
