In [None]:
%matplotlib inline

from typing import List

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision

from egg.zoo.referential_language.dataloaders import get_dataloader

import torchvision.transforms.functional as F
from torchvision.io import read_image

plt.rcParams["savefig.bbox"] = 'tight'


from IPython.core.debugger import set_trace

mpl.rcParams['figure.dpi']= 200

In [None]:
def resize_boxes(boxes: torch.Tensor, original_size: List[int], new_size: List[int]) -> torch.Tensor:
    ratios = [
        torch.tensor(s, dtype=torch.float32, device=boxes.device)
        / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
        for s, s_orig in zip(new_size, original_size)
    ]
    ratio_height, ratio_width = ratios
    xmin, ymin, xmax, ymax = boxes.unbind(1)

    xmin = xmin * ratio_width
    xmax = xmax * ratio_width
    ymin = ymin * ratio_height
    ymax = ymax * ratio_height
    return torch.stack((xmin, ymin, xmax, ymax), dim=1)

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [None]:
dataset_name = "flickr"

if dataset_name == "vg":
    image_dir = "/private/home/rdessi/visual_genome"
    metadata_dir = "/private/home/rdessi/visual_genome/filtered_splits"
else:
    image_dir = "/private/home/rdessi/flickr30k/Images"
    metadata_dir = "/private/home/rdessi/flickr30k/Annotations"
    
    
data_kwargs = {
    "dataset_name": "flickr",                                                                                 
    "image_dir": image_dir,
    "metadata_dir": metadata_dir,
    "batch_size": 32,
    "split": "val",
    "image_size": 64,
    "max_objects": 9,
    "use_augmentation": False,
    "seed": 111,
}

dl = get_dataloader(**data_kwargs)


In [None]:
resizer = torchvision.transforms.Resize(size=(64, 64))
for batch_id, batch in enumerate(dl):
    sender_input, labels, recv_input, aux_input = batch
    
    bboxes = aux_input["bboxes"][0]
    for idx, elem in enumerate(bboxes):
        if elem[0] == -1:
            break
    stop_idx = idx
    
    bboxes = bboxes[:stop_idx]
    n_bboxes = resize_boxes(bboxes, original_size=aux_input["image_sizes"][0][1:].tolist(), new_size=(128, 128))
    
    img_list = []
    for b in n_bboxes[:6]:
        x, y, w, h = [int(x) for x in b.tolist()]
        obj = resizer(torchvision.transforms.functional.crop(aux_input["sender_image"][0], y, x, h, w))
        
        img_list.append(obj)

    n_bboxes[:, 2] = n_bboxes[:, 0] + n_bboxes[:, 2]
    n_bboxes[:, 3] = n_bboxes[:, 1] + n_bboxes[:, 3]
    
    image = (aux_input["sender_image"][0] * 255).byte()
    
    result = torchvision.utils.draw_bounding_boxes(image, n_bboxes, width=1)

    result = result.permute(1, 2, 0)
    plt.xticks([], [])
    plt.yticks([], [])
    
    #captions = '\n'.join([s for s in sents[0]])
    #plt.xlabel(f"""
    # 
    #{captions}""", color="blue")
    plt.imshow(result.numpy())
    plt.show()

    show(torchvision.utils.make_grid(img_list))

    if batch_id == 0:
        break