In [None]:
import json
import os
import random
from operator import itemgetter
from collections import defaultdict

import numpy as np
import pandas as pd

from PIL import Image
from tqdm.auto import tqdm

from ego4d.research.util.masks import (
    decode_mask,
    blend_mask,
)
from ego4d.research.readers import TorchAudioStreamReader, PyAvReader
VideoReader = TorchAudioStreamReader

In [None]:
RELEASE_DIR = "/placeholder/path"  # NOTE: changeme
# RELEASE_DIR = "/large_experiments/egoexo/v2/"

egoexo = {
    "takes": os.path.join(RELEASE_DIR, "takes.json"),
    "captures": os.path.join(RELEASE_DIR, "captures.json"),
    "physical_setting": os.path.join(RELEASE_DIR, "physical_setting.json"),
    "participants": os.path.join(RELEASE_DIR, "participants.json"),
    "visual_objects": os.path.join(RELEASE_DIR, "visual_objects.json"),
}

for k, v in egoexo.items():
    egoexo[k] = json.load(open(v))

takes = egoexo["takes"]
captures = egoexo["captures"]
takes_by_uid = {x["take_uid"]: x for x in takes}

In [None]:
annotation_dir = os.path.join(RELEASE_DIR, "annotations/")
relation_ann = json.load(open(os.path.join(annotation_dir, "relations_train.json")))
relation_objs = relation_ann["annotations"]
relation_takes = set({k for k, ann in relation_objs.items() if len(ann["object_masks"]) > 0})
len(relation_takes)

In [None]:
take_uid = random.sample(relation_takes, 1)[0]
take_uid

In [None]:
annotation = relation_objs[take_uid]

object_masks = annotation['object_masks']
object_names = [(x, "".join(x.split("_")[0])) for x in object_masks.keys()]
object_names

In [None]:
# sample an object & camera/viewpoint
object_name, object_annotations = random.sample(list(object_masks.items()), 1)[0]
camera_name, mask_annotations = random.sample(list(object_annotations.items()), 1)[0]

cam_id_sid = camera_name.split("_")
stream_id = "0"
cam_id = cam_id_sid[0]
if len(cam_id_sid) > 1:
    cam_id, stream_id = cam_id_sid
    if stream_id == "214-1":  # TODO(suyog, miguel): fix inconsitency
        stream_id = "rgb"

rel_path = takes_by_uid[take_uid]["frame_aligned_videos"][cam_id][stream_id]["relative_path"]
video_path = os.path.join(RELEASE_DIR, takes_by_uid[take_uid]["root_dir"], rel_path)
assert os.path.exists(video_path)

reader = VideoReader(
    path=video_path,
    frame_window_size=1,
    stride=1,
    gpu_idx=-1,
    resize=None,
    mean=None,
    crop=None,
    std=None,
    axis_order="thwc",
    uint8_scale=True,
)
object_name, camera_name

In [None]:
# sample a frame for the above object + camera
frame_number, annotation_obj = random.sample(list(mask_annotations['annotation'].items()), 1)[0]
take_uid, object_name, camera_name, frame_number

In [None]:
frame = reader[int(frame_number)]
mask = decode_mask(annotation_obj)
input_img = frame["video"][0].numpy()
pil_img = Image.fromarray(blend_mask(input_img, mask, alpha=0.7))
pil_img