In [12]:
# Use a different environment
# Instead of opencv-python, use opencv-python-headless (requiring a different environment)

# https://github.com/edavalosanaya/plot3d
import plot3d

# In a separate terminal, run to start the server:
# plot3d

# Imports
import cv2
import time
import pathlib
import os
from tqdm import tqdm
import numpy as np
import imutils
import open3d as o3d
import matplotlib
import trimesh
from scipy.spatial.transform import Rotation as R

# Constants 
CWD = pathlib.Path(os.path.abspath(""))
GIT_ROOT = CWD.parent.parent
DATA_DIR = GIT_ROOT / "data" / 'AIED2024'

# Append ZoeDepth to path
import sys
sys.path.append('ZoeDepth')

In [13]:
# Create a plot
plot = plot3d.Plot()

In [14]:
def get_intrinsics(H,W):
    """
    Intrinsics for a pinhole camera model.
    Assume fov of 55 degrees and central principal point.
    """
    f = 0.5 * W / np.tan(0.5 * 55 * np.pi / 180.0)
    cx = 0.5 * W
    cy = 0.5 * H
    return np.array([[f, 0, cx],
                     [0, f, cy],
                     [0, 0, 1]])

def depth_to_points(depth, R=None, t=None):

    K = get_intrinsics(depth.shape[1], depth.shape[2])
    Kinv = np.linalg.inv(K)
    if R is None:
        R = np.eye(3)
    if t is None:
        t = np.zeros(3)

    # M converts from your coordinate to PyTorch3D's coordinate system
    M = np.eye(3)
    M[0, 0] = -1.0
    M[1, 1] = -1.0

    height, width = depth.shape[1:3]

    x = np.arange(width)
    y = np.arange(height)
    coord = np.stack(np.meshgrid(x, y), -1)
    coord = np.concatenate((coord, np.ones_like(coord)[:, :, [0]]), -1)  # z=1
    coord = coord.astype(np.float32)
    # coord = torch.as_tensor(coord, dtype=torch.float32, device=device)
    coord = coord[None]  # bs, h, w, 3

    D = depth[:, :, :, None, None]
    # print(D.shape, Kinv[None, None, None, ...].shape, coord[:, :, :, :, None].shape )
    pts3D_1 = D * Kinv[None, None, None, ...] @ coord[:, :, :, :, None]
    # pts3D_1 live in your coordinate system. Convert them to Py3D's
    pts3D_1 = M[None, None, None, ...] @ pts3D_1
    # from reference to targe tviewpoint
    pts3D_2 = R[None, None, None, ...] @ pts3D_1 + t[None, None, None, :, None]
    # pts3D_2 = pts3D_1
    # depth_2 = pts3D_2[:, :, :, 2, :]  # b,1,h,w
    return pts3D_2[:, :, :, :3, 0][0]

def depth_edges_mask(depth):
    """Returns a mask of edges in the depth map.
    Args:
    depth: 2D numpy array of shape (H, W) with dtype float32.
    Returns:
    mask: 2D numpy array of shape (H, W) with dtype bool.
    """
    # Compute the x and y gradients of the depth map.
    depth_dx, depth_dy = np.gradient(depth)
    # Compute the gradient magnitude.
    depth_grad = np.sqrt(depth_dx ** 2 + depth_dy ** 2)
    # Compute the edge mask.
    mask = depth_grad > 0.05
    return mask

def create_triangles(h, w, mask=None):
    """Creates mesh triangle indices from a given pixel grid size.
        This function is not and need not be differentiable as triangle indices are
        fixed.
    Args:
    h: (int) denoting the height of the image.
    w: (int) denoting the width of the image.
    Returns:
    triangles: 2D numpy array of indices (int) with shape (2(W-1)(H-1) x 3)
    """
    x, y = np.meshgrid(range(w - 1), range(h - 1))
    tl = y * w + x
    tr = y * w + x + 1
    bl = (y + 1) * w + x
    br = (y + 1) * w + x + 1
    triangles = np.array([tl, bl, tr, br, tr, bl])
    triangles = np.transpose(triangles, (1, 2, 0)).reshape(
        ((w - 1) * (h - 1) * 2, 3))
    if mask is not None:
        mask = mask.reshape(-1)
        triangles = triangles[mask[triangles].all(1)]
    return triangles

def get_mesh(image, depth, keep_edges=False):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    pts3d = depth_to_points(depth[None])
    pts3d = pts3d.reshape(-1, 3)

    # Create a trimesh mesh from the points
    # Each pixel is connected to its 4 neighbors
    # colors are the RGB values of the image

    verts = pts3d.reshape(-1, 3)
    image = np.array(image)
    if keep_edges:
        triangles = create_triangles(image.shape[0], image.shape[1])
    else:
        triangles = create_triangles(image.shape[0], image.shape[1], mask=~depth_edges_mask(depth))
    colors = image.reshape(-1, 3)
    mesh = trimesh.Trimesh(vertices=verts, faces=triangles, vertex_colors=colors)

    # Save as glb
    return mesh


In [57]:
# Reset the 3D Plot
plot.reset()

# Add the monitor rectangle
SIZE = 125
r = R.from_rotvec(np.pi/2 * np.array([-0.2, 0.75, -0.1]))
t = np.array([4, 0, -10])*4
rt = np.eye(4)
rt[:3, :3] = r.as_matrix()
rt[:3, 3] = t

rect = trimesh.creation.box(extents=np.array([0.5, 0.2, 0.01])*SIZE)
rect.visual.face_colors = [1, 0, 0, 0.5]
rect.visual.vertex_colors = [1, 0, 0, 0.5]

rect.apply_transform(rt)
plot.add_mesh('monitor', rect)
# plot.update_mesh('monitor', rect)

# Load the RGB and depth videos
vid_file = DATA_DIR / "block-a-blue-day1-first-group-cam1.mp4"
assert vid_file.exists()
cap = cv2.VideoCapture(str(vid_file))
LENGTH = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

depth_file = DATA_DIR / "depth_test.mp4"
assert depth_file.exists()
depth_cap = cv2.VideoCapture(str(depth_file))

for i in tqdm(range(LENGTH), total=LENGTH):

    # Load frame
    r_ret, rgb = cap.read()
    d_ret, depth = depth_cap.read()

    if not r_ret or not d_ret:
        break

    depth = cv2.cvtColor(depth, cv2.COLOR_BGR2GRAY)

    # Resize
    rgb = imutils.resize(rgb, width=500)
    depth = imutils.resize(depth, width=500)

    mesh = get_mesh(rgb, depth, keep_edges=True)
    mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi, [1,0,0]))

    # Plot the frame
    plot.plot_image(depth)
    if i == 0:
        plot.add_mesh('mesh', mesh)
    else:
        plot.update_mesh('mesh', mesh)
    time.sleep(0.1)

  2%|▏         | 242/13464 [00:52<48:00,  4.59it/s]


KeyboardInterrupt: 