In [None]:
# @title Install dependencies {form-width: "25%"}

!pip install mediapy

In [None]:
# @title Imports {form-width: "25%"}

import cv2
import einops
import matplotlib
from matplotlib import cm
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.collections import LineCollection
from mpl_toolkits.mplot3d.art3d import Line3DCollection
import mediapy as media
import numpy as np
import plotly.graph_objects as go
import seaborn as sns
from sklearn.decomposition import PCA
import tensorflow_datasets as tfds

In [None]:
# @title Load dataset {form-width: "25%"}

ds = tfds.load('movi_f/512x512', data_dir='gs://kubric-public/tfds', shuffle_files=False)
ds = ds['train']
dataset = tfds.as_numpy(ds)

ds_iter = iter(dataset)

In [None]:
# @title Load next batch {form-width: "25%"}

sample = next(ds_iter)
media.show_video(sample['video'], fps=10)

In [None]:
# @title Geometry functions {form-width: "25%"}

def get_intrinsic(focal_length, sensor_width, width, height):
  f_x = focal_length / sensor_width
  sensor_height = sensor_width * width / height
  f_y = focal_length / sensor_height
  p_x = 1 / 2.
  p_y = 1 / 2.
  return np.array([
      [f_x, 0, p_x],
      [0, f_y, p_y],
      [0, 0, 1],
  ])

def batch_quaternion_to_rotation_matrix(quaternions):
  """Convert a batch of quaternions to rotation matrices."""
  # Normalize the quaternions
  quaternions = quaternions / np.linalg.norm(quaternions, axis=-1, keepdims=True)
  q0, q1, q2, q3 = quaternions[..., 0], quaternions[..., 1], quaternions[..., 2], quaternions[..., 3]
  # Compute rotation matrices using broadcasting
  rot = np.zeros(quaternions.shape[:-1] + (3, 3))
  rot[..., 0, 0] = 1 - 2 * (q2**2 + q3**2)
  rot[..., 0, 1] = 2 * (q1*q2 - q0*q3)
  rot[..., 0, 2] = 2 * (q0*q2 + q1*q3)
  rot[..., 1, 0] = 2 * (q1*q2 + q0*q3)
  rot[..., 1, 1] = 1 - 2 * (q1**2 + q3**2)
  rot[..., 1, 2] = 2 * (q2*q3 - q0*q1)
  rot[..., 2, 0] = 2 * (q1*q3 - q0*q2)
  rot[..., 2, 1] = 2 * (q0*q1 + q2*q3)
  rot[..., 2, 2] = 1 - 2 * (q1**2 + q2**2)
  return rot

def get_matrix_world(rotation, translation):
  """Single transformation matrix."""
  transform = np.eye(4)
  transform[:3, :3] = rotation
  transform[:3, 3] = translation
  return transform

def batch_get_matrix_world(rotations, translations):
  """Batch version of get_matrix_world."""
  transforms = np.zeros(rotations.shape[:-2] + (4, 4), dtype=np.float32)
  transforms[..., :3, :3] = rotations
  transforms[..., :3, 3] = translations
  transforms[..., 3, 3] = 1
  return transforms

def image2camera(image_coords, depth, intrinsic, width, height):
  """Lift 2D image coordinate from [0, height/width] to camera coordinate."""
  normed = image_coords / np.array((width, height))
  hom = np.concatenate([normed, np.ones_like(normed[..., :1])], axis=-1)
  camera_plane = hom @ np.linalg.inv(intrinsic).T
  return camera_plane * depth[..., None]

def camera2world(rotation, translation, points3d):
  """Transform 3D points from camera coordinate to world coordinate."""
  matrix = get_matrix_world(rotation, translation)
  points4d = np.concatenate([points3d, np.ones_like(points3d[..., :1])], axis=-1)
  return (points4d @ np.linalg.inv(matrix).T)[..., :3]

def world2camera(rotation, translation, points3d):
  """Transform 3D points from world coordinate to camera coordinate."""
  matrix = get_matrix_world(rotation, translation)
  points4d = np.concatenate([points3d, np.ones_like(points3d[..., :1])], axis=-1)
  return (points4d @ matrix.T)[..., :3]

def camera2image(point3d, intrinsic):
  """Project 3D point in camera coordinate to [0, 1] image plane."""
  proj = point3d @ intrinsic.T
  image_coords = proj[..., :2] / proj[..., 2:3]
  z = proj[..., 2]
  return image_coords, z

def bilinear_interpolate(im, x, y):
  """Bilinear interpolation for 2D coordinates."""
  x0 = np.floor(x).astype(int)
  x1 = x0 + 1
  y0 = np.floor(y).astype(int)
  y1 = y0 + 1

  x0 = np.clip(x0, 0, im.shape[1] - 1)
  x1 = np.clip(x1, 0, im.shape[1] - 1)
  y0 = np.clip(y0, 0, im.shape[0] - 1)
  y1 = np.clip(y1, 0, im.shape[0] - 1)

  im_a = im[y0, x0]
  im_b = im[y1, x0]
  im_c = im[y0, x1]
  im_d = im[y1, x1]

  wa = (x1 - x) * (y1 - y)
  wb = (x1 - x) * (y - y0)
  wc = (x - x0) * (y1 - y)
  wd = (x - x0) * (y - y0)

  return wa * im_a + wb * im_b + wc * im_c + wd * im_d

def batch_bilinear_interpolate(im, x, y):
  """Bilinear interpolation for batch of images."""
  x0 = np.floor(x).astype(int)
  x1 = x0 + 1
  y0 = np.floor(y).astype(int)
  y1 = y0 + 1

  x0 = np.clip(x0, 0, im.shape[-1] - 1)
  x1 = np.clip(x1, 0, im.shape[-1] - 1)
  y0 = np.clip(y0, 0, im.shape[-2] - 1)
  y1 = np.clip(y1, 0, im.shape[-2] - 1)

  b = np.arange(im.shape[0])[:, None, None]
  im_a = im[b, y0, x0]
  im_b = im[b, y1, x0]
  im_c = im[b, y0, x1]
  im_d = im[b, y1, x1]

  wa = (x1 - x) * (y1 - y)
  wb = (x1 - x) * (y - y0)
  wc = (x - x0) * (y1 - y)
  wd = (x - x0) * (y - y0)

  return wa * im_a + wb * im_b + wc * im_c + wd * im_d

def sample_grid_points(height, width, stride=1):
  """Return [H/stride, W/stride, 2] grid points with x,y order."""
  grid = np.mgrid[stride//2:height:stride, stride//2:width:stride].transpose(1, 2, 0)
  return grid[..., ::-1]  # swap to (x, y)

In [None]:
# @title Prepare groundtruth data {form-width: "25%"}

frames = sample['video']
num_frames, height, width = frames.shape[:3]
depth_range = sample['metadata']['depth_range']
depths = sample['depth'][..., 0] / 65535 * (depth_range[1] - depth_range[0]) + depth_range[0]
intrinsic = get_intrinsic(sample['camera']['focal_length'], sample['camera']['sensor_width'], width, height)
masks = sample['segmentations'][..., 0]
bboxes_3d = einops.rearrange(sample['instances']['bboxes_3d'], 'n t ... -> t n ...')

camera_rotations = batch_quaternion_to_rotation_matrix(sample['camera']['quaternions'])
camera_rotations = camera_rotations @ np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])  # flip y and z axis
camera_positions = sample['camera']['positions']
camera_poses = batch_get_matrix_world(camera_rotations, camera_positions)
camera_poses = np.linalg.inv(camera_poses)  # world to camera extinsics
object_quaternions = einops.rearrange(sample['instances']['quaternions'], 'n t ... -> t n ...')
object_rotations = batch_quaternion_to_rotation_matrix(object_quaternions)
object_positions = einops.rearrange(sample['instances']['positions'], 'n t ... -> t n ...')
object_poses = batch_get_matrix_world(object_rotations, object_positions)
identity = np.tile(np.eye(4)[None, None], (num_frames, 1, 1, 1))
object_poses = np.concatenate((identity, object_poses), axis=1)  # add background to object poses

minv, maxv = sample["metadata"]["forward_flow_range"]
forward_flows = sample['forward_flow'] / 65535 * (maxv - minv) + minv
forward_flows = forward_flows[..., ::-1]  # switch to [x, y]
minv, maxv = sample["metadata"]["backward_flow_range"]
backward_flows = sample['backward_flow'] / 65535 * (maxv - minv) + minv
backward_flows = -backward_flows[..., ::-1]  # switch to [x, y]
surface_normals = sample['normal'] / 65535 * 2.0 - 1

In [None]:
# @title Visualize groundtruth z_buffer {form-width: "25%"}

num_frames, height, width = frames.shape[0:3]
x, y = np.meshgrid(np.arange(width), np.arange(height))
projected_pt = np.stack([x, y, np.ones_like(x)], axis=-1) / np.array([width, height, 1])  # Shape: (height, width, 3)
camera_plane = projected_pt @ np.linalg.inv(intrinsic).T # Shape: (height, width, 3)
camera_ball = camera_plane / np.sqrt(np.sum(np.square(camera_plane), axis=-1, keepdims=True)) # Shape: (height, width, 3)
camera_coords = camera_ball[None] * depths[..., None]  # Shape: (num_frames, height, width, 3)
z_buffers = camera_coords[..., 2]  # Shape: (num_frames, height, width)


fig = go.Figure()

# Add the grayscale image as a heatmap
fig.add_trace(go.Heatmap(
    z=1/z_buffers[0, ::-1],
    colorscale='magma',
    colorbar=dict(title='Value'),
    hovertemplate='X: %{x}<br>Y: %{y}<br>Pixel Value: %{z:.2f}<extra></extra>'
))

# Define layout with correct aspect ratio
fig.update_layout(title='1/z_buffer', width=width, height=height)

fig.show()

In [None]:
# @title Visualize different groundtruth data {form-width: "25%"}

def segmentations_to_video(masks):
  """Converts a sequence of segmentation masks to color code video.

  Args:
    masks: [num_frames, height, width], np.uint8, [0, num_objects]

  Returns:
    video: [num_frames, height, width, 3], np.uint8, [0, 255]
  """
  num_objects = np.max(masks)  # assume consecutive numbering
  palette = [(0, 0, 0)] + sns.color_palette(n_colors=num_objects)
  palette = [(int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)) for c in palette]
  video = np.zeros((masks.shape[0], masks.shape[1], masks.shape[2], 3), dtype=np.uint8)
  for i in range(num_objects + 1):
    video[masks == i] = palette[i]
  return video

def depths_to_video(depths):
  """Converts a sequence of depths to color code video.

  Args:
    depths: [num_frames, height, width], np.float32, [0, inf]

  Returns:
    video: [num_frames, height, width, 3], np.uint8, [0, 255]
  """
  vmax = np.percentile(depths, 95)
  normalizer = matplotlib.colors.Normalize(vmin=depths.min(), vmax=vmax)
  mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
  video = np.zeros((depths.shape[0], depths.shape[1], depths.shape[2], 3),
                   dtype=np.uint8)
  for i in range(depths.shape[0]):
    video[i] = (mapper.to_rgba(depths[i])[:, :, :3] * 255).astype(np.uint8)
  return video

def colored_depthmap(depth, d_min=None, d_max=None, invalid_value=None, colormap='Spectral'):
  """Converts a depth map to a colored image using a plasma colormap.

  Args:
    depth: The depth map (numpy array).
    d_min: Minimum depth value (float, optional). Defaults to None (minimum in
      depth).
    d_max: Maximum depth value (float, optional). Defaults to None (maximum in
      depth).

  Returns:
    The colored depth map as a numpy array of uint8 representing RGB channels.
  """
  if d_min is None:
    d_min = np.min(depth)
  if d_max is None:
    d_max = np.max(depth)
  depth_relative = (depth - d_min) / (d_max - d_min)
  cmap = plt.get_cmap(colormap)
  depth_colored = 255 * cmap(depth_relative)[:, :, :3]  # H, W, C
  depth_colored = depth_colored.astype(np.uint8)
  depth_colored[depth == invalid_value] = 0
  return depth_colored

def depths_to_video(depths):
  """Converts a sequence of depths to color code video.

  Args:
    depths: [num_frames, height, width], np.float32, [0, inf]

  Returns:
    video: [num_frames, height, width, 3], np.uint8, [0, 255]
  """
  vmin, vmax = np.percentile(depths, 5), np.percentile(depths, 95)
  normalizer = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
  mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
  video = np.zeros((depths.shape[0], depths.shape[1], depths.shape[2], 3), dtype=np.uint8)
  for i in range(depths.shape[0]):
    video[i] = (mapper.to_rgba(depths[i])[:, :, :3] * 255).astype(np.uint8)
  return video

def plot_depth_prism(depthmaps):
  depthmaps[depthmaps == 0] = np.nan
  vmin, vmax = np.nanpercentile(depthmaps, [5, 95])
  colored_d = np.stack([colored_depthmap(1/d, 1/vmax, 1/vmin, np.nan, 'turbo') for d in depthmaps], axis=0)
  return colored_d

def flow_to_rgb(vec, flow_mag_range=None, white_bg=False):
  height, width = vec.shape[:2]
  scaling = 50. / (height**2 + width**2)**0.5
  direction = (np.arctan2(vec[..., 0], vec[..., 1]) + np.pi) / (2 * np.pi)
  norm = np.linalg.norm(vec, axis=-1)
  if flow_mag_range is None:
    flow_mag_range = norm.min(), norm.max()
  magnitude = np.clip((norm - flow_mag_range[0]) * scaling, 0., 1.)
  if white_bg == True:
    saturation = np.ones_like(direction)
    hsv = np.stack([direction, magnitude, saturation], axis=-1)
  else:
    saturation = np.ones_like(direction)
    hsv = np.stack([direction, saturation, magnitude], axis=-1)
  rgb = matplotlib.colors.hsv_to_rgb(hsv)
  return rgb

def get_colormap(height, width):
  """Generates rainbow colormap for visualizing points."""
  color_map = matplotlib.colormaps.get_cmap('hsv')
  cmap_norm = matplotlib.colors.Normalize(vmin=0, vmax=height - 1)
  # Same as the for loop below
  # colormap = np.zeros((height, width, 3))
  # for i in range(height):
  #   for j in range(width):
  #     colormap[i, j] = np.array(color_map(cmap_norm(i)))[:3] * 255
  indices = np.arange(height)
  colors = color_map(cmap_norm(indices))[:, :3] * 255
  colormap = np.tile(colors[:, None, :], (1, width, 1))
  return colormap

def render_pose(image,
                object_pose,
                camera_matrix,  # (3, 4)
                arrow_length=0.05,
                arrow_thickness=4,
                arrow_tip_length=0.1):
  """Renders 3D poses onto the image plane, overriding the given image."""
  rotation, position = object_pose[:3, :3], object_pose[:3, 3]
  height, width = image.shape[:2]
  image_point = camera_matrix @ np.r_[position, 1.0]
  image_point /= image_point[-1]
  image_point *= np.array([width, height, 1])

  colors = {'r': (255, 0, 0, 255), 'g': (0, 255, 0, 255), 'b': (0, 0, 255, 255)}
  for i, c in enumerate('rgb'):
    image_arrow_point = camera_matrix @ np.r_[position + rotation[:, i] * arrow_length, 1.0]
    image_arrow_point /= image_arrow_point[-1]
    image_arrow_point *= np.array([width, height, 1])
    pt1 = (round(image_point[0]), round(image_point[1]))
    pt2 = (round(image_arrow_point[0]), round(image_arrow_point[1]))
    # black arrow background (to add a black border to the colored arrows)
    cv2.arrowedLine(image, pt1, pt2, color=(0, 0, 0, 255), thickness=arrow_thickness + 1, line_type=cv2.LINE_AA)
    # red/greeb/blue arrow
    cv2.arrowedLine(image, pt1, pt2, color=colors[c], thickness=arrow_thickness, line_type=cv2.LINE_AA, tipLength=arrow_tip_length)
  return image

def poses_to_video(frames, object_poses, camera_poses, intrinsic):
  intrinsic = np.concatenate([intrinsic, np.zeros((3, 1))], axis=1)

  video = []
  for t in range(object_poses.shape[0]):
    camera_matrix = intrinsic @ camera_poses[t]

    image_with_pose = frames[t].copy()
    for i in range(object_poses.shape[1]):
      image_with_pose = render_pose(image_with_pose, object_poses[t, i], camera_matrix, arrow_length=1)
    video.append(image_with_pose)
  video = np.stack(video)
  return video

def draw_projected_3d_bbox(image, proj_corners, proj_centers=None):
  """Draw a projected 3d bbox on a 2d image."""
  height, width = image.shape[0:2]
  proj_corners = proj_corners[:, :, :2]  # we only need (x, y)
  if proj_centers is not None:
    proj_centers = proj_centers[:, :2]
  corner_pairs = (
      (0, 1),
      (0, 2),
      (2, 3),
      (1, 3),
      (4, 5),
      (4, 6),
      (6, 7),
      (5, 7),
      (0, 4),
      (1, 5),
      (2, 6),
      (3, 7),
  )
  for i, proj_pt in enumerate(proj_corners):  # [8, 2]
    for corner_pair in corner_pairs:
      pt1, pt2 = proj_pt[corner_pair[0]], proj_pt[corner_pair[1]]
      pt1 = (round(pt1[0] * width), round(pt1[1] * height))
      pt2 = (round(pt2[0] * width), round(pt2[1] * height))
      cv2.line(image, pt1, pt2, color=(0, 255, 0), thickness=1)
    if proj_centers is not None:
      pt = proj_centers[i]
      pt = (round(pt[0] * width), round(pt[1] * height))
      cv2.circle(image, pt, height // 100, color=(255, 0, 0), thickness=-1)
  return image

def bboxes_3d_to_video(frames, bboxes_3d, camera_poses, intrinsic):
  """Project 3D bboxes to 2D, then show them on an image."""
  # frames: [T, H, W, C]
  # bboxes_3d: [T, N, 8, 3]
  # cameras: a dict with camera metadata
  bboxes_center_3d = bboxes_3d.mean(-2)
  camera_rotations, camera_positions = camera_poses[..., :3, :3], camera_poses[..., :3, 3]

  images = frames.copy()
  for t in range(frames.shape[0]):
    proj_camera_coords = world2camera(camera_rotations[t], camera_positions[t], bboxes_3d[t])
    proj_corners, _ = camera2image(proj_camera_coords, intrinsic)
    proj_camera_coords = world2camera(camera_rotations[t], camera_positions[t], bboxes_center_3d[t])
    proj_centers, _ = camera2image(proj_camera_coords, intrinsic)
    images[t] = draw_projected_3d_bbox(image=images[t], proj_corners=proj_corners, proj_centers=proj_centers)
  return images

media.show_videos({"rgb": frames,
                   "segmentation": segmentations_to_video(masks),
                   "bboxes_3d": bboxes_3d_to_video(frames, bboxes_3d, camera_poses, intrinsic),
                   "poses": poses_to_video(frames, object_poses, camera_poses, intrinsic),
                   "object_coordinates": sample["object_coordinates"],
                   "depth": depths_to_video(depths),
                   "z_buffer": plot_depth_prism(z_buffers),
                   "forward_flow": flow_to_rgb(forward_flows, white_bg=False),
                   "backward_flow": flow_to_rgb(backward_flows, white_bg=False),
                   "surface_normal": sample["normal"],
                   },
                fps=10,
                columns=5,
                codec='gif',
)

In [None]:
# @title Get forward point tracks in 2D and 3D {form-width: "25%"}

def forward_point_tracks_to_video(frames, point_tracks, visibles, show_occ=False):
  """Converts a sequence of points to color code video.

  Args:
    frames: [num_frames, height, width, 3], np.uint8, [0, 255]
    point_tracks: [num_frames, height, width, 2], np.float32, [0, width / height]
    visibles: [num_frames, height, width], bool

  Returns:
    video: [num_frames, height, width, 3], np.uint8, [0, 255]
  """
  num_frames, height, width = frames.shape[0:3]
  colormap = get_colormap(height, width)

  video = frames.copy()
  for t in range(num_frames):
    for i in range(height):
      for j in range(width):
        x, y = np.round(point_tracks[t, i, j, :2]).astype(np.int32)
        if visibles[t, i, j]:
          cv2.circle(video[t], (x, y), radius=1, color=colormap[i, j], thickness=-1)
        elif show_occ:
          cv2.circle(video[t], (x, y), radius=1, color=colormap[i, j], thickness=1)
  return video

def visualize_pca(feature_maps):
  _, height, width, channels = feature_maps.shape

  pca = PCA(n_components=3)
  pca = pca.fit(feature_maps.reshape(-1, channels))  # PCA on first frame

  pca_video = []
  for i in range(feature_maps.shape[0]):
    feature_map_pca = pca.transform(feature_maps[i].reshape(-1, channels))
    feature_map_pca = feature_map_pca.reshape(height, width, 3)
    min_value = feature_map_pca.min(axis=(0, 1))
    max_value = feature_map_pca.max(axis=(0, 1))
    feature_map_pca = (feature_map_pca - min_value) / (max_value - min_value)
    pca_video.append(feature_map_pca)
  pca_video = np.stack(pca_video, axis=0)
  return pca_video


num_frames, height, width = frames.shape[0:3]
query_points = sample_grid_points(height, width)  # Shape: (height, width, 2)

poses = camera_poses[:, None] @ object_poses  # Shape: (num_frames, num_objects, 4, 4)
relative_poses = np.einsum('tkcd, kde -> tkce', poses, np.linalg.inv(poses[0]))
num_objects = relative_poses.shape[1]
one_hot_masks = (masks[0][None, None, ...] == np.arange(num_objects)[None, :, None, None])
dense_poses = np.einsum('tkhw, tkce -> thwce', one_hot_masks, relative_poses)  # Shape: (num_frames, height, width, 4, 4)

camera_coords = image2camera(query_points, z_buffers[0], intrinsic, width, height)  # Shape: (height, width, 3)
points4d = np.concatenate([camera_coords, np.ones_like(camera_coords[..., :1])], axis=-1)  # Homogeneous
proj_camera_coords = np.einsum('thwcd, hwd -> thwc', dense_poses, points4d)  # Shape: (num_frames, height, width, 4)

image_coords_xy, image_coords_z = camera2image(proj_camera_coords[..., :3], intrinsic)  # Shape: (num_frames, height, width, 2)
image_coords_xy *= np.array([width, height])  # Scale to pixel dimensions

# Visibility check
interpolate_z_buffers = batch_bilinear_interpolate(z_buffers, image_coords_xy[..., 0], image_coords_xy[..., 1])
visible = (image_coords_z <= interpolate_z_buffers * 1.01 ) & \
    (image_coords_xy[..., 0] >= 0) & (image_coords_xy[..., 0] < width) & \
    (image_coords_xy[..., 1] >= 0) & (image_coords_xy[..., 1] < height)

gt_tracks_forward = image_coords_xy
gt_visibles_forward = visible
gt_tracks_xyz_forward = proj_camera_coords[..., :3]

media.show_videos({"forward_point_tracks": forward_point_tracks_to_video(frames, gt_tracks_forward, gt_visibles_forward),
                   "dense_relative_poses": visualize_pca(dense_poses.reshape(num_frames, height, width, -1)),
                   },
                  fps=10,
                  codec='gif',
)

In [None]:
# @title Visualize groundtruth 2D tracks {form-width: "25%"}

def plot_2d_tracks(video, points, visibles, infront_cameras=None, tracks_leave_trace=16, show_occ=False):
  """Visualize 2D point trajectories."""
  num_frames, num_points = points.shape[:2]

  # Precompute colormap for points
  color_map = matplotlib.colormaps.get_cmap('hsv')
  cmap_norm = matplotlib.colors.Normalize(vmin=0, vmax=num_points - 1)
  point_colors = np.zeros((num_points, 3))
  for i in range(num_points):
    point_colors[i] = np.array(color_map(cmap_norm(i)))[:3] * 255

  if infront_cameras is None:
    infront_cameras = np.ones_like(visibles).astype(bool)

  frames = []
  for t in range(num_frames):
    frame = video[t].copy()

    # Draw tracks on the frame
    line_tracks = points[max(0, t - tracks_leave_trace) : t + 1]
    line_visibles = visibles[max(0, t - tracks_leave_trace) : t + 1]
    line_infront_cameras = infront_cameras[max(0, t - tracks_leave_trace) : t + 1]
    for s in range(line_tracks.shape[0] - 1):
      img = frame.copy()

      for i in range(num_points):
        if line_visibles[s, i] and line_visibles[s + 1, i]:  # visible
          x1, y1 = int(round(line_tracks[s, i, 0])), int(round(line_tracks[s, i, 1]))
          x2, y2 = int(round(line_tracks[s + 1, i, 0])), int(round(line_tracks[s + 1, i, 1]))
          cv2.line(frame, (x1, y1), (x2, y2), point_colors[i], 1, cv2.LINE_AA)
        elif show_occ and line_infront_cameras[s, i] and line_infront_cameras[s + 1, i]:  # occluded
          x1, y1 = int(round(line_tracks[s, i, 0])), int(round(line_tracks[s, i, 1]))
          x2, y2 = int(round(line_tracks[s + 1, i, 0])), int(round(line_tracks[s + 1, i, 1]))
          cv2.line(frame, (x1, y1), (x2, y2), point_colors[i], 1, cv2.LINE_AA)

      alpha = (s + 1) / (line_tracks.shape[0] - 1)
      frame = cv2.addWeighted(frame, alpha, img, 1 - alpha, 0)

    # Draw end points on the frame
    for i in range(num_points):
      if visibles[t, i]:  # visible
        x, y = int(round(points[t, i, 0])), int(round(points[t, i, 1]))
        cv2.circle(frame, (x, y), 2, point_colors[i], -1, cv2.LINE_AA)
      elif show_occ and infront_cameras[t, i]:  # occluded
        x, y = int(round(points[t, i, 0])), int(round(points[t, i, 1]))
        cv2.circle(frame, (x, y), 2, point_colors[i], 1, cv2.LINE_AA)

    frames.append(frame)
  frames = np.stack(frames)
  return frames

num_frames, height, width = frames.shape[:3]

grid = sample_grid_points(height, width, 8)
grid = grid.reshape(-1, 2)

tracks = gt_tracks_forward[:, grid[:, 1], grid[:, 0]]
visibles = gt_visibles_forward[:, grid[:, 1], grid[:, 0]]
tracks = tracks.reshape(num_frames, -1, 2)
visibles = visibles.reshape(num_frames, -1)

video2d_viz = plot_2d_tracks(frames, tracks, visibles)
media.show_video(video2d_viz, fps=10)

In [None]:
# @title Visualize camera coordinate 3D tracks {form-width: "25%"}

def plot_3d_tracks(points, visibles, infront_cameras=None, tracks_leave_trace=16, show_occ=False):
  """Visualize 3D point trajectories."""
  num_frames, num_points = points.shape[0:2]

  color_map = matplotlib.colormaps.get_cmap('hsv')
  cmap_norm = matplotlib.colors.Normalize(vmin=0, vmax=num_points - 1)

  if infront_cameras is None:
    infront_cameras = np.ones_like(visibles).astype(bool)

  if show_occ:
    x_min, x_max = np.min(points[infront_cameras, 0]), np.max(points[infront_cameras, 0])
    y_min, y_max = np.min(points[infront_cameras, 2]), np.max(points[infront_cameras, 2])
    z_min, z_max = np.min(points[infront_cameras, 1]), np.max(points[infront_cameras, 1])
  else:
    x_min, x_max = np.min(points[visibles, 0]), np.max(points[visibles, 0])
    y_min, y_max = np.min(points[visibles, 2]), np.max(points[visibles, 2])
    z_min, z_max = np.min(points[visibles, 1]), np.max(points[visibles, 1])

  interval = np.max([x_max - x_min, y_max - y_min, z_max - z_min])
  x_min = (x_min + x_max) / 2 - interval / 2
  x_max = x_min + interval
  y_min = (y_min + y_max) / 2 - interval / 2
  y_max = y_min + interval
  z_min = (z_min + z_max) / 2 - interval / 2
  z_max = z_min + interval

  frames = []
  for t in range(num_frames):
    fig = Figure(figsize=(5.12, 5.12))
    canvas = FigureCanvasAgg(fig)
    ax = fig.add_subplot(111, projection='3d', computed_zorder=False)

    ax.set_xlim([x_min, x_max])
    ax.set_ylim([y_min, y_max])
    ax.set_zlim([z_min, z_max])

    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_zticklabels([])

    ax.invert_zaxis()
    ax.view_init()

    for i in range(num_points):
      if visibles[t, i] or (show_occ and infront_cameras[t, i]):
        color = color_map(cmap_norm(i))
        line = points[max(0, t - tracks_leave_trace) : t + 1, i]
        ax.plot(xs=line[:, 0], ys=line[:, 2], zs=line[:, 1], color=color, linewidth=1)
        end_point = points[t, i]
        ax.scatter(xs=end_point[0], ys=end_point[2], zs=end_point[1], color=color, s=3)

    fig.subplots_adjust(left=-0.05, right=1.05, top=1.05, bottom=-0.05)
    fig.canvas.draw()
    frames.append(canvas.buffer_rgba())
    plt.close(fig)
  return np.array(frames)[..., :3]

num_frames, height, width = frames.shape[:3]

grid = sample_grid_points(height, width, 8)
grid = grid.reshape(-1, 2)

tracks_xyz = gt_tracks_xyz_forward[:, grid[:, 1], grid[:, 0]]
tracks_xyz = tracks_xyz.reshape(num_frames, -1, 3)
visibles =  np.ones(tracks_xyz.shape[0:2]).astype(bool)

video3d_viz = plot_3d_tracks(tracks_xyz, visibles, show_occ=True)
media.show_video(video3d_viz, fps=10)

In [None]:
# @title Get forward point tracks in world coordinates with first frame as world frame {form-width: "25%"}

camera_poses_adjusted = np.einsum('tcd, de -> tce', camera_poses, np.linalg.inv(camera_poses[0]))

num_frames, height, width = frames.shape[0:3]
query_points = sample_grid_points(height, width)  # Shape: (height, width, 2)

relative_object_poses = np.einsum('tkcd, kde -> tkce', object_poses, np.linalg.inv(object_poses[0]))
relative_object_poses = np.einsum('tkcd, de -> tkce', relative_object_poses, np.linalg.inv(camera_poses[0]))
relative_object_poses = np.einsum('cd, tkde -> tkce', camera_poses[0], relative_object_poses)
num_objects = relative_object_poses.shape[1]
one_hot_masks = (masks[0][None, None, ...] == np.arange(num_objects)[None, :, None, None])
dense_object_poses = np.einsum('tkhw, tkce -> thwce', one_hot_masks, relative_object_poses)  # Shape: (num_frames, height, width, 4, 4)

camera_coords = image2camera(query_points, z_buffers[0], intrinsic, width, height)  # Shape: (height, width, 3)
points4d = np.concatenate([camera_coords, np.ones_like(camera_coords[..., :1])], axis=-1)  # Homogeneous
points4d = np.einsum('cd, hwd -> hwc', np.linalg.inv(camera_poses_adjusted[0]), points4d)  # World coordinates
proj_world_coords = np.einsum('thwcd, hwd -> thwc', dense_object_poses, points4d)  # Shape: (num_frames, height, width, 4)
proj_camera_coords = np.einsum('tcd, thwd -> thwc', camera_poses_adjusted, proj_world_coords)  # World coordinates

image_coords_xy, image_coords_z = camera2image(proj_camera_coords[..., :3], intrinsic)  # Shape: (num_frames, height, width, 2)
image_coords_xy *= np.array([width, height])  # Scale to pixel dimensions

# Visibility check
interpolate_z_buffers = batch_bilinear_interpolate(z_buffers, image_coords_xy[..., 0], image_coords_xy[..., 1])
visible = (image_coords_z <= interpolate_z_buffers * 1.01 ) & \
    (image_coords_xy[..., 0] >= 0) & (image_coords_xy[..., 0] < width) & \
    (image_coords_xy[..., 1] >= 0) & (image_coords_xy[..., 1] < height)

gt_tracks_forward = image_coords_xy
gt_visibles_forward = visible
gt_tracks_xyz_forward = proj_camera_coords[..., :3]
gt_tracks_xyz_world_forward = proj_world_coords[..., :3]

media.show_videos({"forward_point_tracks": forward_point_tracks_to_video(frames, gt_tracks_forward, gt_visibles_forward),
                   "dense_relative_poses": visualize_pca(dense_object_poses.reshape(num_frames, height, width, -1)),
                   },
                  fps=10,
                  codec='gif',
)

In [None]:
# @title visualize world coordinate 3D tracks with first frame as world frame

def draw_camera_pose(ax, M, fov=60, aspect=1.0, far=0.1):
    camera_pos, rotation = M[:3, 3], M[:3, :3]
    fov_rad = np.deg2rad(fov / 2)
    far_height, far_width = 2 * np.tan(fov_rad) * far, 2 * np.tan(fov_rad) * far * aspect
    # Define far corners
    far_corners = np.array([[far_width/2, far_height/2, far], [-far_width/2, far_height/2, far],
                            [-far_width/2, -far_height/2, far], [far_width/2, -far_height/2, far]])
    # Transform far corners into world coordinates
    frustum_corners_world = (rotation @ far_corners.T).T + camera_pos
    # Define edges of the far plane
    far_lines = [[camera_pos, frustum_corners_world[i]] for i in range(4)]
    far_edges = [[frustum_corners_world[i], frustum_corners_world[(i + 1) % 4]] for i in range(4)]
    # Create the frustum lines
    frustum_lines = np.array(far_lines + far_edges)[:, :, [0, 2, 1]]  # Swap y and z axes
    ax.add_collection3d(Line3DCollection(frustum_lines, colors=[0.0, 0.0, 0.7], linewidths=1))

def plot_3d_world_tracks(points, visibles, cameras, tracks_leave_trace=16):
  """Visualize 3D point trajectories with camera trajectory."""
  num_frames, num_points = points.shape[:2]
  points = points[..., [0,2,1]]
  point_color_map = matplotlib.colormaps.get_cmap('hsv')
  x_min, x_max = np.nanpercentile(points[visibles, 0], 5), np.nanpercentile(points[visibles, 0], 95)
  y_min, y_max = np.nanpercentile(points[visibles, 1], 5), np.nanpercentile(points[visibles, 1], 95)
  z_min, z_max = np.nanpercentile(points[visibles, 2], 5), np.nanpercentile(points[visibles, 2], 95)
  interval = max(x_max - x_min, y_max - y_min, z_max - z_min) * 1.11
  x_mid, y_mid, z_mid = (x_min + x_max) / 2, (y_min + y_max) / 2, (z_min + z_max) / 2
  x_min, x_max, y_min, y_max, z_min, z_max = [x_mid - interval / 2, x_mid + interval / 2, y_mid - interval / 2, y_mid + interval / 2, z_mid - interval / 2, z_mid + interval / 2]

  M_list = np.linalg.inv(cameras)  # camera to world extrinsics
  camera_positions = M_list[:, :3, 3][:, [0, 2, 1]]

  frames = []
  for t in range(num_frames):
    fig = Figure(figsize=(5.12, 5.12), dpi=100)
    canvas = FigureCanvasAgg(fig)
    ax = fig.add_subplot(111, projection='3d', computed_zorder=False)
    ax.set_xlim([x_min, x_max]), ax.set_ylim([y_min, y_max]), ax.set_zlim([z_min, z_max])
    ax.set_xticklabels([]), ax.set_yticklabels([]), ax.set_zticklabels([])
    ax.invert_zaxis(), ax.view_init()

    path_pos = camera_positions[0:np.min((t+20, num_frames))]
    ax.plot(path_pos[..., 0], path_pos[..., 1], path_pos[..., 2], color=[0.0, 0.0, 0.7], linestyle='dashed')
    draw_camera_pose(ax, M_list[t], far=interval/10)

    indices = np.where(visibles[t, :])[0]
    if indices.size > 0:
      lines = np.transpose(points[max(0, t - tracks_leave_trace):t+1, indices], (1, 0, 2))
      line_collection = Line3DCollection(lines, colors=point_color_map(matplotlib.colors.Normalize(vmin=0, vmax=num_points - 1)(indices)), linewidths=1)
      ax.add_collection3d(line_collection)
      ax.scatter(points[t, indices, 0], points[t, indices, 1], points[t, indices, 2], c=point_color_map(matplotlib.colors.Normalize(vmin=0, vmax=num_points - 1)(indices)), s=3)
    fig.subplots_adjust(left=-0.05, right=1.05, top=1.05, bottom=-0.05)
    fig.canvas.draw()
    frames.append(np.array(canvas.buffer_rgba(), dtype=np.float32) / 255.)
    plt.close(fig)
  return np.array(frames)[..., :3]

num_frames, height, width = frames.shape[:3]

grid = sample_grid_points(height, width, 8)
grid = grid.reshape(-1, 2)

tracks_xyz = gt_tracks_xyz_world_forward[:, grid[:, 1], grid[:, 0]]
tracks_xyz = tracks_xyz.reshape(num_frames, -1, 3)
visibles =  np.ones(tracks_xyz.shape[0:2]).astype(bool)

video3d_viz = plot_3d_world_tracks(tracks_xyz, visibles, camera_poses_adjusted)
media.show_video(video3d_viz, fps=10)