In [None]:
%matplotlib inline

In [None]:
import json
import random
from math import floor
from pathlib import Path
from typing import Callable
from PIL import Image

import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset
from torch.utils.data.distributed import DistributedSampler
from torchvision import transforms
from torchvision.transforms.functional import crop

mpl.rcParams['figure.dpi']= 200

# Parent Class for a VG Dataset

In [None]:
def pil_loader(path: str) -> Image.Image:
    with open(path, "rb") as f:
        img = Image.open(f)
        return img.convert("RGB")


class VisualGenomeDataset:
    def __init__(
        self,
        image_dir: str,
        metadata_dir: str,
        classes_path: str = "/private/home/rdessi/EGG/egg/zoo/referential_language/utils/classes_1600.txt",
        split: str = "train",
        transform: Callable = transforms.ToTensor(),
        max_objects=20,
        image_size=64,
    ):
        path_images = Path(image_dir)
        path_metadata = Path(metadata_dir) / f"{split}_objects.json"
        path_image_data = Path(metadata_dir) / f"{split}_image_data.json"

        with open(path_image_data) as img_in, open(path_metadata) as metadata_in:
            img_data, img_metadata = json.load(img_in), json.load(metadata_in)
        assert len(img_data) == len(img_metadata)

        get_name = lambda line: line.strip().split(",")[0]
        with open(classes_path) as fin:
            self.class2id = {get_name(line): idx for idx, line in enumerate(fin)}

        self.samples = []
        for img, objs_data in zip(img_data, img_metadata):
            assert img["image_id"] == objs_data["image_id"]
            img_path = path_images / "/".join(img["url"].split("/")[-2:])

            objs = self._filter_objs(img, objs_data["objects"])
            if len(objs) > 2:
                self.samples.append((img_path, objs))

        self.id2class = {v: k for k, v in self.class2id.items()}
        self.transform = transform
        self.max_objects = max_objects
        self.resizer = transforms.Resize(size=(image_size, image_size))

    def _filter_objs(self, img, objs):
        filtered_objs = []
        for obj in objs:
            o_name = next(filter(lambda x: x in self.class2id, obj["names"]), None)
            if o_name is None:
                continue
            obj["names"] = [o_name]

            x, y, h, w = obj["x"], obj["y"], obj["h"], obj["w"]
            img_area = img["width"] * img["height"]
            obj_area = (x + w) * (y + h)
            is_big = obj_area / img_area > 0.01 and w > 1 and h > 1
            if is_big:
                filtered_objs.append(obj)
        return filtered_objs

    def _extract_object(self, image, obj_data):
        label = self.class2id[obj_data["names"][0]]
        y, x, h, w = obj_data["y"], obj_data["x"], obj_data["h"], obj_data["w"]
        obj = self.resizer(crop(image, y, x, h, w))
        return obj, label

    def __len__(self):
        return len(self.samples)

    def _load_and_transform(self, img_path):
        image = pil_loader(img_path)
        if self.transform:
            image = self.transform(image)
        return image

# Datasets for random distractors

In [None]:
class TrainVisualGenomeDatasetRandomDistractors(VisualGenomeDataset):
    def __init__(self, *args, **kwargs):
        super(TrainVisualGenomeDatasetRandomDistractors, self).__init__(*args, **kwargs)

    def __getitem__(self, index):
        img_path, bboxes = self.samples[index]
        image = self._load_and_transform(img_path)

        cropped_obj, label = self._extract_object(image, bboxes[0])

        cropped_objs, labels = [cropped_obj], [label]
        distractors = random.sample(self.samples, k=self.max_objects - 1)
        for img_path, bboxes in distractors:
            image = self._load_and_transform(img_path)

            cropped_obj, label = self._extract_object(image, bboxes[0])
            labels.append(label)
            cropped_objs.append(cropped_obj)

        game_input = torch.stack(cropped_objs)
        labels = torch.Tensor(labels)

        mask = torch.ones(self.max_objects).bool()
        game_labels = torch.arange(self.max_objects)
        baseline = torch.Tensor([1 / self.max_objects])
        aux_input = {"mask": mask, "game_labels": game_labels, "baseline": baseline}
        return game_input, labels, torch.zeros(1), aux_input


class TestVisualGenomeDatasetRandomDistractors(VisualGenomeDataset, IterableDataset):
    def __init__(self, *args, **kwargs):
        super(TestVisualGenomeDatasetRandomDistractors, self).__init__(*args, **kwargs)

    def __iter__(self):
        self.curr_idx = 0
        self.curr_obj_idx = 0

        world_size = dist.get_world_size() if dist.is_initialized() else 1
        rank = dist.get_rank() if dist.is_initialized() else 0
        per_gpu = int(floor(len(self.samples) / float(world_size)))

        iter_start = per_gpu * rank
        iter_end = iter_start + per_gpu

        worker_info = torch.utils.data.get_worker_info()
        if worker_info:  # num_workers is > 0
            per_worker = int(floor(per_gpu / worker_info.num_workers))
            iter_start = iter_start + worker_info.id * per_worker
            iter_end = iter_start + per_worker

        self.samples = self.samples[iter_start:iter_end]
        return self

    def _load_new_sample(self):
        img_path, obj_data = self.samples[self.curr_idx]
        self.curr_img = self._load_and_transform(img_path)
        self.curr_obj_data = obj_data

    def __next__(self):
        self._load_new_sample()
        max_obj_idx = min(self.max_objects, len(self.curr_obj_data))

        if self.curr_obj_idx >= max_obj_idx:
            self.curr_obj_idx = 0
            self.curr_idx += 1
            if self.curr_idx >= len(self.samples):
                raise StopIteration
            self._load_new_sample()

        img = self.curr_img
        obj_data = self.curr_obj_data[self.curr_obj_idx]
        obj, label = self._extract_object(img, obj_data)
        self.curr_obj_idx += 1

        cropped_objs, labels = [obj], [label]
        distractors = random.sample(self.samples, k=self.max_objects - 1)
        for img_path, bboxes in distractors:
            image = self._load_and_transform(img_path)

            cropped_obj, label = self._extract_object(image, bboxes[0])
            labels.append(label)
            cropped_objs.append(cropped_obj)

        game_input = torch.stack(cropped_objs)
        labels = torch.Tensor(labels)

        mask = torch.ones(self.max_objects).bool()
        game_labels = torch.arange(self.max_objects)
        baseline = torch.Tensor([1 / self.max_objects])
        aux_input = {"mask": mask, "game_labels": game_labels, "baseline": baseline}
        return game_input, label, torch.zeros(1), aux_input

# Datasets for contextual distractors

In [None]:
class VisualGenomeDatasetCtxDistractors(VisualGenomeDataset, torch.utils.data.Dataset):
    def __init__(self, *args, **kwargs):
        super(VisualGenomeDatasetCtxDistractors, self).__init__(*args, **kwargs)

    def __getitem__(self, index):
        img_path, bboxes = self.samples[index]
        image = self._load_and_transform(img_path)

        cropped_objs, labels = [], []
        for obj in bboxes[: min(self.max_objects, len(bboxes))]:
            cropped_obj, label = self._extract_object(image, obj)
            labels.append(label)
            cropped_objs.append(cropped_obj)

        agent_input = torch.stack(cropped_objs)
        labels = torch.Tensor(labels)
        return agent_input, labels


def collate(batch):
    inp, lab = [], []
    for x, l in batch:
        inp.append(x)
        lab.append(l)

    inp = torch.nn.utils.rnn.pad_sequence(inp, batch_first=True, padding_value=-1)
    lab = torch.nn.utils.rnn.pad_sequence(lab, batch_first=True, padding_value=-1)

    mask = inp[:, :, 0, 0, 0] != -1
    baseline = 1 / mask.int().sum(-1)
    bsz, max_objs = inp.shape[:2]
    game_labels = torch.arange(max_objs).repeat(bsz, 1)

    aux_input = {"mask": mask, "game_labels": game_labels, "baseline": baseline}
    return inp, lab, None, aux_input

In [None]:
def get_dataloader(
    image_dir: str = "/datasets01/VisualGenome1.2/061517/",
    metadata_dir: str = "/private/home/rdessi/visual_genome/train_val_test_split_clean",
    batch_size: int = 32,
    split: str = "train",
    image_size: int = 32,
    max_objects: int = 20,
    random_distractors: bool = False,
    seed: int = 111,
):
    collate_fn = None

    kwargs = {
        "image_dir": image_dir,
        "metadata_dir": metadata_dir,
        "split": split,
        "max_objects": max_objects,
        "image_size": image_size,
    }
    if random_distractors:
        if split == "test":
            dataset = TestVisualGenomeDatasetRandomDistractors(**kwargs)
        else:
            dataset = TrainVisualGenomeDatasetRandomDistractors(**kwargs)

    else:
        collate_fn = collate
        dataset = VisualGenomeDatasetCtxDistractors(**kwargs)

    is_iterable_dataset = random_distractors and split == "test"
    sampler = None
    if dist.is_initialized():
        if not is_iterable_dataset:
            shuffle = split != "test"
            sampler = DistributedSampler(
                dataset, shuffle=shuffle, drop_last=True, seed=seed
            )
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=6,
        sampler=sampler,
        collate_fn=collate_fn,
        shuffle=(not is_iterable_dataset and sampler is None),
        pin_memory=True,
        drop_last=True,
    )


In [None]:
kwargs = {
    batch_size: 32,
    split: "train",
    image_size: 32,
    max_objects: 20,
    random_distractors: False,
    seed: 111,
}

dloader = get_dataloader(**kwargs)

In [None]:
dl = torch.utils.data.DataLoader(
    ds,
    batch_size=4,
    num_workers=6,
    collate_fn=collate,
    shuffle=True,
    pin_memory=True,
    drop_last=True,
)

id2class = dl.dataset.id2class

idx = 0
last_obj = 8
for inp, labels, _, aux_input in dl:
    img = inp[0][0]
    all_objs = torch.cat([img.permute(1, 2, 0) for img in inp[0][:last_obj]], dim=1)

    lab = labels[0].tolist()
    title = " ".join([id2class[elem] for elem in lab[:last_obj] if elem >= 0])

    plt.title(title)
    plt.imshow(all_objs.numpy())
    plt.show()

    idx += 1
    if idx == 9:
        break
