## Specify dataset name

In [None]:
# Specify dataset name. Available options:
# dataset_name = "paco_lvis_v1_test"
# dataset_name = "paco_ego4d_v1_test"
dataset_name = "paco_lvis_v1_test"


## Load dataset and extract maps

In [None]:
import json
from paco.data.datasets.builtin import _PREDEFINED_PACO

# Derived parameters.
dataset_file_name, image_root_dir = _PREDEFINED_PACO[dataset_name]

# Load dataset.
with open(dataset_file_name) as f:
    dataset = json.load(f)


In [None]:
import os
from collections import defaultdict

# Extract maps from the dataset.
image_id_to_image_file_name = {d["id"]: os.path.join(image_root_dir, d["file_name"]) for d in dataset["images"]}
ann_id_to_ann = {d["id"]: d for d in dataset["annotations"]}
query_id_to_query_dict = {}
for q in dataset["queries"]:
    query_id_to_query_dict[q["id"]] = q
cat_to_query_ids = defaultdict(list)
for d in dataset["queries"]:
    cat_to_query_ids[d["structured_query"][0]].append(d["id"])
cat_to_query_ids = dict(cat_to_query_ids)



## Visualize

In [None]:
import cv2
import numpy as np
from PIL import Image, ImageDraw
from copy import deepcopy

def resize_to_height(im, new_h):
    h, w = im.shape[:2]
    new_w = int(round(w * new_h / h))
    return cv2.resize(im, (new_w, new_h))

def add_border(im, num_px, num_px_left=None, value=(255, 255, 255)):
    if num_px_left is None:
        num_px_left = num_px
    im[:num_px, :] = value
    im[-num_px:, :] = value
    im[:, :num_px_left] = value
    im[:, -num_px:] = value

def gen_pos_neg_im(query_dict, num_im_per_row, out_im_h, num_border_px):
    """
    Generates a row of one positive and N negative images for
    provided query dict. Uses global image_id_to_image_file_name
    and ann_id_to_ann maps.
    """
    imgs = []
    # Get positive image and draw positive/negative boxes.
    im_id = ann_id_to_ann[query_dict["pos_ann_ids"][0]]["image_id"]
    pos_im = Image.open(image_id_to_image_file_name[im_id])
    draw = ImageDraw.Draw(pos_im)
    for ann_id in query_dict["pos_ann_ids"]:
        ann = ann_id_to_ann[ann_id]
        bbox = np.array(ann["bbox"])
        bbox[2:] += bbox[:2]
        draw.rectangle(bbox.tolist(), outline="green", width=10)
    for ann_id in query_dict["neg_ann_ids"]:
        ann = ann_id_to_ann[ann_id]
        bbox = np.array(ann["bbox"])
        bbox[2:] += bbox[:2]
        draw.rectangle(bbox.tolist(), outline="red", width=8)
    pos_im = resize_to_height(np.asarray(pos_im), out_im_h)
    add_border(pos_im, num_border_px, None, (119, 172, 48))  # Green
    imgs.append(pos_im)
    imgs.append(255*np.ones((pos_im.shape[0], 2*num_border_px, pos_im.shape[2]), dtype="uint8"))
    # Get negative images.
    neg_im_ids = query_dict["neg_im_ids"][:num_im_per_row-1]
    for idx, im_id in enumerate(neg_im_ids):
        neg_im = Image.open(image_id_to_image_file_name[im_id])
        neg_im = resize_to_height(np.asarray(neg_im), out_im_h)
        add_border(neg_im, num_border_px, (idx == 0) * num_border_px, (217, 83, 25))  # Red
        imgs.append(neg_im)
    im = np.concatenate(imgs, axis=1)
    return im

# Parameters.
# vis_cats = sorted({d["structured_query"][0] for d in dataset["queries"]})
# vis_cats = ["basket", "bench", "bottle", "chair", "mug", "scissors", "trash_can", "vase", "book", "dog"]
vis_cats = ["dog"]      # List of categories for which to show queries
vis_num_queries = 4     # Number of queries per category to show
vis_num_im_per_row = 4  # Total number of images per query to show (including the one positive image)
vis_border_px = 10      # Number of border pixels around each image
vis_im_height = 480     # Visualization image height
random_seed = 93028477  # Random seed, set to None to disable

# Set the seed for reproducibility.
if random_seed is not None:
    np.random.seed(random_seed)

# Visualize.
for cat in vis_cats:
    query_ids = deepcopy(cat_to_query_ids[cat])
    np.random.shuffle(query_ids)
    for query_id in sorted(query_ids[:vis_num_queries]):
        query_dict = query_id_to_query_dict[query_id]
        im = gen_pos_neg_im(query_dict, vis_num_im_per_row, vis_im_height, vis_border_px)
        print(query_dict["query_string"], "(query ID:", query_id, ")")
        display(Image.fromarray(im))
