In [15]:
import os
import json
from torchvision.datasets import ImageFolder
from dataset.caption_dataset import pre_caption

class ImageNet100ValDataset(ImageFolder):
    def __init__(self, root, transform=None, max_words=30):
        image_path = os.path.join(root, "val")
        super().__init__(image_path, transform)

        label_path = os.path.join(root, "Labels.json")
        with open(label_path, "r") as f:
            self.labels = json.load(f)

        self.text = []
        self.image = []

        self.txt2img = {}
        self.img2txt = {}

        self.template = "a photo of a {}."
        self.max_words = max_words

        index = 0
        for path, label_idx in self.samples:
            self.image.append(path)
            self.img2txt[index] = []

            class_name = self.labels[self.classes[label_idx]]
            text = self.template.format(class_name)
            text = pre_caption(text, self.max_words)
            self.text.append(text)

            self.img2txt[index].append(index)
            self.txt2img[index] = index
            index += 1

    def __getitem__(self, index):
        path, _ = self.samples[index]
        original_image = self.loader(path)

        if self.transform is not None:
            image = self.transform(original_image)

        return {
            "image": image,
            "index": index,
        }

In [19]:
import torchvision.transforms as transforms
from PIL import Image

# Init the dataset
normalize = transforms.Normalize(
        (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
    )

normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

test_transform = transforms.Compose(
    [
        transforms.Resize(
            (224, 224), interpolation=Image.BICUBIC
        ),
        transforms.ToTensor(),
        normalize,
    ]
)
dataset = ImageNet100ValDataset(root="/BS/dduka/work/data/imagenet100/", transform=test_transform, max_words=30)

In [20]:
from torch.utils.data import DataLoader

loader = DataLoader(
        dataset,
        batch_size=8,
        num_workers=1,
        pin_memory=True,
        shuffle=False,
        drop_last=False,
        prefetch_factor=4,
    )
    
for batch in loader:
    print(batch)
    break

{'image': tensor([[[[-0.0287,  0.1939,  0.7419,  ...,  0.9988,  1.3242,  1.1358],
          [-0.5938,  1.1015,  1.2899,  ...,  0.8961,  1.2899,  1.2043],
          [ 1.1015,  0.0056,  0.8618,  ...,  0.6221,  0.6906,  0.4851],
          ...,
          [ 0.1768, -0.0629, -0.0458,  ...,  0.7248,  0.2282, -0.2684],
          [-0.2684, -0.3027, -0.1486,  ...,  0.5193, -0.1486, -0.3712],
          [-0.5082, -0.5424, -0.2513,  ..., -0.0458, -0.3712, -0.4568]],

         [[ 0.2927,  0.2402,  0.8704,  ...,  1.2556,  1.6408,  1.4832],
          [-0.4251,  1.2906,  1.7283,  ...,  1.1331,  1.6232,  1.5532],
          [ 1.3081,  0.5028,  1.6758,  ...,  0.8880,  1.0630,  0.8704],
          ...,
          [ 1.0980,  0.9230,  0.6429,  ...,  1.4832,  0.5203, -0.1275],
          [ 0.7129,  0.5028,  0.3627,  ...,  0.8529, -0.0049, -0.1625],
          [ 0.2577,  0.2402,  0.5728,  ...,  0.1001, -0.1800, -0.2150]],

         [[-0.5670, -0.3404,  1.0365,  ...,  1.3677,  1.7337,  1.5245],
          [-0.4798, 

In [18]:
print(len(loader.dataset.text))

5000
