In [None]:
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Author: spopov@google.com (Stefan Popov)

WORKDIR = '/path/to/cad_estate/'

import asyncio
import json
import logging
import os
import sys
sys.path.append(os.path.join(WORKDIR, 'src'))

import ipywidgets
import nest_asyncio
import torch as t
import torchvision.transforms.functional as tvtF

from cad_estate import debug_helpers
from cad_estate import file_system as fs
from cad_estate import frames as frame_lib
from cad_estate import objects as obj_lib
from cad_estate.notebooks import objects_notebook_helper as helper_lib

debug_helpers.better_jupyter_display()
debug_helpers.better_tensor_display()
nest_asyncio.apply()
disp = debug_helpers.display_images

logging.basicConfig(level=logging.INFO)

# Set to the full path of the annotations for scene you would like to visualize
ANNOTATIONS_DIR = os.path.join(WORKDIR, "data/annotations")

EXAMPLE_SCENES = [
    "-4qd6olV30Y_57424033", "-4czxXKNrnI_179946000", "5LToC4KInNM_84017000",
    "5RD3EAlBS9w_180880000", "5TJZHqkmTlo_86921000", "-4czxXKNrnI_199665000",
    "-23esP--xK8_30497133", "-AplGqzOF5Y_147046900", "5RD3EAlBS9w_112012000"
]  # yapf: ignore

# If set to None, frames will be downloaded on the fly
# FRAMES_DIR = None
FRAMES_DIR = os.path.join(WORKDIR, "data/frames")

# Path to the processed ShapeNet dataset
SHAPE_NET_DIR = os.path.join(WORKDIR, "data/shape_net_npz")

# How many frames to show (ignored when also showing tracks for
# objects without 3D)
NUM_FRAMES = 6

# The width of the displayed images
IMAGE_WIDTH = 800

# One-time setup
ANNOTATIONS_DIR = fs.abspath(ANNOTATIONS_DIR)
SHAPE_NET_DIR = fs.abspath(SHAPE_NET_DIR)
shapenet_meta = obj_lib.load_shapenet_metadata(SHAPE_NET_DIR)

In [None]:
# Interactive widgets setup
w_scene_name, w_frame_index, w_show_tracks = (
    helper_lib.create_interactive_widgets(EXAMPLE_SCENES, NUM_FRAMES,
                                          ANNOTATIONS_DIR))

# Some objects in the videos of CAD-Estate don't have a 3D shape aligned to
# them, as automatic tracking (sec 3.1 of paper), object selection (sec. 3.2),
# and pose estimation (sec. 3.5) can fail.
# This visualization shows the 2D box tracks of such objects, when `show_tracks`
# is true.


@ipywidgets.interact
def visualize_scene(scene_name: str = w_scene_name,
                    frame_index: int = w_frame_index,
                    show_tracks: bool = w_show_tracks):
  ann_dir = fs.join(ANNOTATIONS_DIR, scene_name)
  obj_json, frames_json = asyncio.run(
      fs.read_all_bytes_async(
          [fs.join(ann_dir, v) for v in ["objects.json", "frames.json"]]))
  obj_json, frames_json = [json.loads(v) for v in (obj_json, frames_json)]
  frames = frame_lib.load_metadata(frames_json)
  objects = asyncio.run(obj_lib.load_objects(obj_json, shapenet_meta))

  frames_dir = helper_lib.download_frames(frames, FRAMES_DIR)

  if show_tracks:
    num_frames = int(frames.manual_track_annotations.to(t.int32).sum())
    if num_frames <= 0:
      raise ValueError("The video has no manual track annotations.")
    frames = frame_lib.filter(frames, frames.manual_track_annotations)
  else:
    mask = frame_lib.sample_regular(frames, NUM_FRAMES)
    frames = frame_lib.filter(frames, mask)
    assert frames.frame_timestamps.shape[0] > 0

  frames = asyncio.run(frame_lib.load_images(frames, frames_dir))

  synth, rgb = helper_lib.render_objects(objects, frames, frame_index)

  if show_tracks:
    track_boxes = obj_lib.load_track_boxes(obj_json, frames)
    synth = helper_lib.render_tracks(synth, objects, track_boxes, frame_index)

  mask = (synth == 0).all(dim=0)
  synth[:, mask] = rgb[:, mask]

  _, sh, sw = synth.shape
  h, w = sh * IMAGE_WIDTH // sw, IMAGE_WIDTH
  synth = tvtF.resize(synth, (h, w), antialias=True)

  disp(synth)
