In [12]:
from datasets import *
from torch.utils.data import DataLoader
from random import randint


In [19]:
import random
from typing import List, Tuple

import torch
from torch.utils.data import Sampler, Dataset
# torch.multiprocessing.set_start_method('spawn')# good solution !!!!

import clip
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, clip_preprocess = clip.load("ViT-B/32", device)

class CLIPTaskSampler(Sampler):
    """
    Samples batches in the shape of few-shot classification tasks. At each iteration, it will sample
    n_way classes, and then sample support and query images from these classes.
    """

    def __init__(
        self, dataset: Dataset, n_way: int, n_shot: int, n_query: int, n_tasks: int, phrase="This is a photo of a {}"
    ):
        """
        Args:
            dataset: dataset from which to sample classification tasks. Must have a field 'label': a
                list of length len(dataset) containing containing the labels of all images.
            n_way: number of classes in one task
            n_shot: number of support images for each class in one task
            n_query: number of query images for each class in one task
            n_tasks: number of tasks to sample
        """
        super().__init__(data_source=None)
        self.n_way = randint(1,n_way+1)
        self.n_shot = n_shot
        self.n_query = n_query
        self.n_tasks = n_tasks
        self.classes = dataset.classes
        self.phrase = phrase

        self.items_per_label = {}
        self.n_way_hist = []
        assert hasattr(
            dataset, "labels"
        ), "TaskSampler needs a dataset with a field 'label' containing the labels of all images."
        for item, label in enumerate(dataset.labels):
            if label in self.items_per_label.keys():
                self.items_per_label[label].append(item)
            else:
                self.items_per_label[label] = [item]

    def __len__(self):
        return self.n_tasks

    def __iter__(self):
        for _ in range(self.n_tasks):
            n_way = randint(1,self.n_way+1)
            print(n_way)
            self.n_way_hist.append(n_way)
            yield torch.cat(
                [
                    # pylint: disable=not-callable
                    torch.tensor(
                        random.sample(
                            self.items_per_label[label], self.n_shot + self.n_query
                        )
                    )
                    # pylint: enable=not-callable
                    for label in random.sample(self.items_per_label.keys(), self.n_way)
                ]
            )

    def episodic_collate_fn(
        self, input_data: List[Tuple[torch.Tensor, int]]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
        """
        Collate function to be used as argument for the collate_fn parameter of episodic
            data loaders.
        Args:
            input_data: each element is a tuple containing:
                - an image as a torch Tensor
                - the label of this image
        Returns:
            tuple(Tensor, Tensor, Tensor, Tensor, list[int]): respectively:
                - support images,
                - their labels,
                - query images,
                - their labels,
                - the dataset class ids of the class sampled in the episode
        """
        
        print(self.n_way)
        new_input_data  = []

        cache = {}
        
        for image, label in input_data:

            # Not normalizing image
            image_emb = clip_model.encode_image(image.unsqueeze(0).cuda())

            #speeeed up
            if label in cache:
                print("Using cache")
                class_embeddings = cache[label]
            else:
                # Normalizing Text
                class_name = self.classes[label]
                class_name = class_name.replace("_", " ")
                text = clip.tokenize(self.phrase.format(class_name))
                class_embeddings = clip_model.encode_text(text.cuda())
                class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
                cache[label] = class_embeddings

            final_input = torch.cat((image_emb,class_embeddings),dim=1)
            new_input_data.append((final_input,label))



        true_class_ids = list({x[1] for x in new_input_data})

        all_images = torch.cat([x[0].unsqueeze(0) for x in new_input_data])
        all_images = all_images.reshape(
            (self.n_way, self.n_shot + self.n_query, *all_images.shape[1:])
        )
        # pylint: disable=not-callable
        all_labels = torch.tensor(
            [true_class_ids.index(x[1]) for x in new_input_data]
        ).reshape((self.n_way, self.n_shot + self.n_query))
        # pylint: enable=not-callable

        support_images = all_images[:, : self.n_shot].reshape(
            (-1, *all_images.shape[2:])
        )
        query_images = all_images[:, self.n_shot :].reshape((-1, *all_images.shape[2:]))
        support_labels = all_labels[:, : self.n_shot].flatten()
        query_labels = all_labels[:, self.n_shot :].flatten()
        return (
            support_images,
            support_labels,
            query_images,
            query_labels,
            true_class_ids,
        )

In [20]:
import clip
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, clip_preprocess = clip.load("ViT-B/32", device)
phrase = "This is a photo of a {}"

In [21]:
dataset = Cifar100(5, 3)
_,_, train_dataset, _ = dataset.get_train_loaders(transform_fn=clip_preprocess)
train_dataset.labels = train_dataset.targets

Files already downloaded and verified
50000
Files already downloaded and verified


In [25]:
task_sampler = CLIPTaskSampler(train_dataset, n_way=3, n_shot=4, n_query=5, n_tasks=2)

In [26]:
train_loader = DataLoader(
    train_dataset,
    batch_sampler=task_sampler,
    num_workers=0,
    collate_fn=task_sampler.episodic_collate_fn,
)

In [28]:
(
    example_support_images,
    example_support_labels,
    example_query_images,
    example_query_labels,
    example_class_ids,
) = next(iter(train_loader))

4
[1, 4]
Using cache
Using cache
Using cache
Using cache
Using cache
Using cache
Using cache
Using cache
Using cache
Using cache
Using cache
Using cache
Using cache
Using cache
Using cache
Using cache
Using cache
Using cache
Using cache
Using cache
Using cache
Using cache
Using cache
Using cache


In [None]:
example_query_images.shape

In [None]:
example_query_labels

In [11]:
example_query_labels.shape

torch.Size([15])

In [13]:
example_class_ids

[97, 99, 44]