In [1]:
import yaml 
from torchvision import transforms, datasets
import random
import os
import numpy as np

### Creating the tgt_classes


In [None]:



with open('data/synset_closest_idx.yaml', 'r') as file:
    synset_closest_idx = yaml.safe_load(file)


data_path = '/misc/scratchSSD2/datasets/ILSVRC2012/val'
out_size = 256
transform_list = [
    transforms.Resize((out_size,out_size)),
    transforms.ToTensor()
]
transform = transforms.Compose(transform_list)
dataset = datasets.ImageFolder(data_path,  transform=transform)

idx_image_to_tgt_class = {}
for i in range(len(dataset)):
    img, label = dataset[i]
    #print(synset_closest_idx[label], random.choice(synset_closest_idx[label]))
    idx_image_to_tgt_class[i] = random.choice(synset_closest_idx[label])
    if i%100==0:
        print(f"current image index: {i}")

In [None]:
with open('data/image_idx_to_tgt_class_closest_5.yaml', 'w') as file:
    documents = yaml.dump(dict(idx_image_to_tgt_class), file)

### Creating the new Imagenet wrapper 

In [2]:
with open('data/image_idx_to_tgt_class_closest_5.yaml', 'r') as file:
    image_idx_to_tgt_class_closest_5 = yaml.safe_load(file)

In [None]:
image_idx_to_tgt_class_closest_5

In [4]:
from imagenet_classnames import name_map, folder_label_map

In [56]:
class ImageNet(datasets.ImageFolder):
    classes = [name_map[i] for i in range(1000)]
    name_map = name_map

    def __init__(
            self, 
            root:str, 
            split:str="val", 
            transform=None, 
            target_transform=None, 
            class_idcs=None, 
            start_sample: int = 0, 
            end_sample: int = 50000//1000,
            return_tgt_cls: bool = False,
            idx_to_tgt_cls_path = None, 
            **kwargs
    ):
        _ = kwargs  # Just for consistency with other datasets.
        assert split in ["train", "val"]
        assert start_sample < end_sample and start_sample >= 0 and end_sample <= 50000//1000
        path = root if root[-3:] == "val" or root[-5:] == "train" else os.path.join(root, split)
        super().__init__(path, transform=transform, target_transform=target_transform)
        
        with open(idx_to_tgt_cls_path, 'r') as file:
            idx_to_tgt_cls = yaml.safe_load(file)
            if isinstance(idx_to_tgt_cls, dict):
                idx_to_tgt_cls = [idx_to_tgt_cls[i] for i in range(len(idx_to_tgt_cls))]
        self.idx_to_tgt_cls = idx_to_tgt_cls

        self.return_tgt_cls = return_tgt_cls

        if class_idcs is not None:
            class_idcs = list(sorted(class_idcs))
            tgt_to_tgt_map = {c: i for i, c in enumerate(class_idcs)}
            self.classes = [self.classes[c] for c in class_idcs]
            samples = []
            idx_to_tgt_cls = []
            for i, (p, t) in enumerate(self.samples):
                if t in tgt_to_tgt_map:
                    samples.append((p, tgt_to_tgt_map[t]))
                    idx_to_tgt_cls.append(self.idx_to_tgt_cls[i])
            
            self.idx_to_tgt_cls = idx_to_tgt_cls
            #self.samples = [(p, tgt_to_tgt_map[t]) for i, (p, t) in enumerate(self.samples) if t in tgt_to_tgt_map]
            self.class_to_idx = {k: tgt_to_tgt_map[v] for k, v in self.class_to_idx.items() if v in tgt_to_tgt_map}

        if "val" == split: # reorder
            new_samples = []
            idx_to_tgt_cls = []
            for idx in range(50000//1000):
                new_samples.extend(self.samples[idx::50000//1000])
                idx_to_tgt_cls.extend(self.idx_to_tgt_cls[idx::50000//1000])
            self.samples = new_samples[start_sample*1000:end_sample*1000]
            self.idx_to_tgt_cls = idx_to_tgt_cls[start_sample*1000:end_sample*1000]

        else:
            raise NotImplementedError

        self.class_labels = {i: folder_label_map[folder] for i, folder in enumerate(self.classes)}
        self.targets = np.array(self.samples)[:, 1]
    
    def __getitem__(self, index):
        sample = super().__getitem__(index)
        if self.return_tgt_cls:
            return *sample, self.idx_to_tgt_cls[index]
        else:
            return sample

In [37]:
#convert dict to list 
image_idx_to_tgt_class_closest_5_list = [ image_idx_to_tgt_class_closest_5[i] for i in range(len(image_idx_to_tgt_class_closest_5))]

In [None]:
image_idx_to_tgt_class_closest_5_list

In [57]:
ds = ImageNet('/misc/scratchSSD2/datasets/ILSVRC2012', idx_to_tgt_cls_path='data/image_idx_to_tgt_class_closest_5.yaml', return_tgt_cls = True)

In [59]:
ds[3]

(<PIL.Image.Image image mode=RGB size=500x332>, 3, 389)

In [48]:
out_size = 256
transform_list = [
    transforms.Resize((out_size, out_size)),
    transforms.ToTensor()
]
transform = transforms.Compose(transform_list)

In [52]:
ds = ImageNet('/misc/scratchSSD2/datasets/ILSVRC2012', split="val", return_tgt_cls = True, idx_to_tgt_cls=image_idx_to_tgt_class_closest_5_list, transform=transform)

In [53]:
ds[100]

(tensor([[[0.6000, 0.5961, 0.5922,  ..., 0.5412, 0.5294, 0.5176],
          [0.5686, 0.5647, 0.5608,  ..., 0.5569, 0.5529, 0.5412],
          [0.5333, 0.5255, 0.5294,  ..., 0.5529, 0.5529, 0.5490],
          ...,
          [0.4980, 0.5333, 0.5804,  ..., 0.1529, 0.2588, 0.2784],
          [0.5176, 0.5294, 0.5686,  ..., 0.0353, 0.0275, 0.0431],
          [0.5098, 0.5216, 0.5294,  ..., 0.0510, 0.0118, 0.0078]],
 
         [[0.6863, 0.6824, 0.6784,  ..., 0.6627, 0.6510, 0.6392],
          [0.6549, 0.6510, 0.6471,  ..., 0.6824, 0.6784, 0.6667],
          [0.6196, 0.6118, 0.6157,  ..., 0.6824, 0.6863, 0.6824],
          ...,
          [0.4549, 0.4902, 0.5373,  ..., 0.1255, 0.2196, 0.2157],
          [0.4588, 0.4745, 0.5176,  ..., 0.0235, 0.0235, 0.0353],
          [0.4510, 0.4627, 0.4706,  ..., 0.0353, 0.0118, 0.0039]],
 
         [[0.7765, 0.7725, 0.7686,  ..., 0.7333, 0.7255, 0.7137],
          [0.7451, 0.7412, 0.7373,  ..., 0.7569, 0.7529, 0.7451],
          [0.7098, 0.7020, 0.7059,  ...,