In [1]:
# Use a different environment
# Instead of opencv-python, use opencv-python-headless (requiring a different environment)

# https://github.com/edavalosanaya/plot3d
import plot3d

# In a separate terminal, run to start the server:
# plot3d

# Imports
import cv2
import time
import pathlib
import os
from tqdm import tqdm
import numpy as np
import imutils
import trimesh # install pyembree for a ray tracing speedup of 50x
from scipy.spatial.transform import Rotation as R
import pandas as pd
from dataclasses import dataclass

# Constants 
CWD = pathlib.Path(os.path.abspath(""))
GIT_ROOT = CWD.parent.parent
DATA_DIR = GIT_ROOT / "data" / 'AIED2024'

# Append ZoeDepth to path
import sys
sys.path.append('ZoeDepth')

In [2]:
def get_intrinsics(H,W):
    """
    Intrinsics for a pinhole camera model.
    Assume fov of 55 degrees and central principal point.
    """
    f = 0.5 * W / np.tan(0.5 * 55 * np.pi / 180.0)
    cx = 0.5 * W
    cy = 0.5 * H
    return np.array([[f, 0, cx],
                     [0, f, cy],
                     [0, 0, 1]])

def depth_to_points(depth, R=None, t=None):

    K = get_intrinsics(depth.shape[1], depth.shape[2])
    Kinv = np.linalg.inv(K)
    if R is None:
        R = np.eye(3)
    if t is None:
        t = np.zeros(3)

    # M converts from your coordinate to PyTorch3D's coordinate system
    M = np.eye(3)
    M[0, 0] = -1.0
    M[1, 1] = -1.0

    height, width = depth.shape[1:3]

    x = np.arange(width)
    y = np.arange(height)
    coord = np.stack(np.meshgrid(x, y), -1)
    coord = np.concatenate((coord, np.ones_like(coord)[:, :, [0]]), -1)  # z=1
    coord = coord.astype(np.float32)
    # coord = torch.as_tensor(coord, dtype=torch.float32, device=device)
    coord = coord[None]  # bs, h, w, 3

    D = depth[:, :, :, None, None]
    # print(D.shape, Kinv[None, None, None, ...].shape, coord[:, :, :, :, None].shape )
    pts3D_1 = D * Kinv[None, None, None, ...] @ coord[:, :, :, :, None]
    # pts3D_1 live in your coordinate system. Convert them to Py3D's
    pts3D_1 = M[None, None, None, ...] @ pts3D_1
    # from reference to targe tviewpoint
    pts3D_2 = R[None, None, None, ...] @ pts3D_1 + t[None, None, None, :, None]
    # pts3D_2 = pts3D_1
    # depth_2 = pts3D_2[:, :, :, 2, :]  # b,1,h,w
    return pts3D_2[:, :, :, :3, 0][0]

def depth_edges_mask(depth):
    """Returns a mask of edges in the depth map.
    Args:
    depth: 2D numpy array of shape (H, W) with dtype float32.
    Returns:
    mask: 2D numpy array of shape (H, W) with dtype bool.
    """
    # Compute the x and y gradients of the depth map.
    depth_dx, depth_dy = np.gradient(depth)
    # Compute the gradient magnitude.
    depth_grad = np.sqrt(depth_dx ** 2 + depth_dy ** 2)
    # Compute the edge mask.
    mask = depth_grad > 0.05
    return mask

def create_triangles(h, w, mask=None):
    """Creates mesh triangle indices from a given pixel grid size.
        This function is not and need not be differentiable as triangle indices are
        fixed.
    Args:
    h: (int) denoting the height of the image.
    w: (int) denoting the width of the image.
    Returns:
    triangles: 2D numpy array of indices (int) with shape (2(W-1)(H-1) x 3)
    """
    x, y = np.meshgrid(range(w - 1), range(h - 1))
    tl = y * w + x
    tr = y * w + x + 1
    bl = (y + 1) * w + x
    br = (y + 1) * w + x + 1
    triangles = np.array([tl, bl, tr, br, tr, bl])
    triangles = np.transpose(triangles, (1, 2, 0)).reshape(
        ((w - 1) * (h - 1) * 2, 3))
    if mask is not None:
        mask = mask.reshape(-1)
        triangles = triangles[mask[triangles].all(1)]
    return triangles

def get_mesh(image, depth, keep_edges=False):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    pts3d = depth_to_points(depth[None])
    pts3d = pts3d.reshape(-1, 3)

    # Create a trimesh mesh from the points
    # Each pixel is connected to its 4 neighbors
    # colors are the RGB values of the image

    verts = pts3d.reshape(-1, 3)
    image = np.array(image)
    if keep_edges:
        triangles = create_triangles(image.shape[0], image.shape[1])
    else:
        triangles = create_triangles(image.shape[0], image.shape[1], mask=~depth_edges_mask(depth))
    colors = image.reshape(-1, 3)
    mesh = trimesh.Trimesh(vertices=verts, faces=triangles, vertex_colors=colors)

    # Save as glb
    return mesh

def compute_3D_point(x, y, Z, H, W):
    """
    Compute the 3D point in the camera coordinate system from an image coordinate and depth.

    Parameters:
    - x, y: The image coordinates (pixels)
    - Z: The depth value (distance along the camera's viewing axis)
    - f_x, f_y: The camera's focal lengths along the X and Y axes (pixels)
    - c_x, c_y: The optical center of the camera (pixels)

    Returns:
    A tuple (X, Y, Z) representing the 3D point in the camera coordinate system.
    """
    # 
    fy = 0.5 * W / np.tan(0.5 * 55 * np.pi / 180.0)
    fx = 0.5 * W / np.tan(0.5 * 55 * np.pi / 180.0)
    cx = 0.5 * W
    cy = 0.5 * H

    # Normalize the 2D coordinates
    x_prime = (x - cx) / fx
    y_prime = (y - cy) / fy

    # Apply the depth to get the 3D point
    X = x_prime * Z
    Y = y_prime * Z

    return np.array([X, Y, Z])


def draw_gaze(x, y, length, img, pitchyaw, thickness=2, color=(255, 255, 0),sclae=2.0):
    """Draw gaze angle on given image with a given eye positions."""
    pos = (int(x), int(y))
    if len(img.shape) == 2 or img.shape[2] == 1:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
    dx = -length * np.sin(pitchyaw[0]) * np.cos(pitchyaw[1])
    dy = -length * np.sin(pitchyaw[1])
    cv2.arrowedLine(img, tuple(np.round(pos).astype(np.int32)),
                   tuple(np.round([pos[0] + dx, pos[1] + dy]).astype(int)), color,
                   thickness, cv2.LINE_AA, tipLength=0.18)
    return img

def create_arrow(shaft_length=1.0, shaft_radius=0.01, head_length=0.2, head_radius=0.1, thickness=1):
    """
    Create an arrow mesh.
    
    Parameters:
    - shaft_length: Length of the shaft.
    - shaft_radius: Radius of the shaft.
    - head_length: Length of the head.
    - head_radius: Radius of the head.
    
    Returns:
    - A trimesh object representing the arrow.
    """
    # Create the shaft of the arrow (cylinder)
    shaft = trimesh.creation.cylinder(radius=shaft_radius*thickness, height=shaft_length, sections=32)
    shaft.apply_translation((0, 0, shaft_length/2))
    
    # Create the head of the arrow (cone)
    head = trimesh.creation.cone(radius=head_radius*thickness, height=head_length*thickness, sections=32)
    head.apply_translation((0, 0, shaft_length))
    
    # Combine the shaft and the head
    arrow = trimesh.util.concatenate([shaft, head])

    arrow.visual.face_colors = [1, 1, 1, 0.5]
    arrow.visual.vertex_colors = [1, 1, 1, 0.5]
    
    return arrow


def create_oriented_arrow(origin, pitch, yaw, length=1.0, thickness=1):

    # Convert pitch, yaw, and roll to rotvec
    rotation = R.from_euler('xyz', [0, pitch, yaw])
    initial_vector = np.array([0,0,1])

    # Compute the endpoint based on origin, pitch, yaw, and length
    endpoint = rotation.apply(initial_vector)*length
    
    # Create an arrow mesh
    # arrow = trimesh.creation.arrow(radius=0.05, height=length)
    arrow = create_arrow(length, thickness=5)
    
    # Compute the direction vector for the arrow
    direction = endpoint - origin
    direction /= np.linalg.norm(direction) # Normalize the direction vector
    
    # Compute the rotation needed to align the arrow with the direction vector
    # Default arrow direction is along the z-axis (0, 0, 1)
    # default_direction = np.array([0, 0, 1])
    # rotation_vector = np.cross(default_direction, direction)
    # rotation_angle = np.arccos(np.dot(default_direction, direction))
    # rotation = R.from_rotvec(rotation_vector * rotation_angle)
    
    # Apply rotation to the arrow
    # arrow.apply_transform(rotation.as_matrix())
    rt = np.eye(4)
    rt[:3,:3] = rotation.as_matrix()
    rt[:3,-1] = origin
    arrow.apply_transform(rt)
    
    return arrow, direction

def find_closest_intersected_mesh(scene, origin, direction):
    """
    Find the closest mesh in the scene that a ray intersects with, excluding a specific mesh by name.

    Args:
    - scene: The trimesh.Scene containing all meshes.
    - origin: The starting point of the ray.
    - direction: The direction vector of the ray

    Returns:
    - The name of the closest mesh intersected by the ray, or None if no intersection is found.
    """
    closest_mesh_name = None
    closest_distance = np.inf

    for mesh_name, mesh in scene.geometry.items():

        # Check for intersections with this mesh
        locations, _, _ = mesh.ray.intersects_location(
            ray_origins=[origin],
            ray_directions=[direction]
        )

        # Find the closest intersection point (if any)
        for location in locations:
            distance = np.linalg.norm(location - origin)
            if distance < closest_distance:
                closest_mesh_name = mesh_name
                closest_distance = distance

    return closest_mesh_name

def find_closest_intersected_mesh_via_intersection(scene, arrow, origin):
    closest_mesh_name = None
    closest_distance = np.inf

    for mesh_name, mesh in scene.geometry.items():

        # Obtain intersection with arrow
        int_mesh = arrow.intersection(mesh)

        if int_mesh.is_empty:
            continue

        # Find the closest intersection point (if any)
        distance = np.linalg.norm(int_mesh.centroid - origin)
        if distance < closest_distance:
            closest_mesh_name = mesh_name
            closest_distance = distance

    return closest_mesh_name

In [27]:
VISUALIZE = True

if VISUALIZE:
    # Create a plot
    plot = plot3d.Plot(port=9001)

    # Reset the 3D Plot
    plot.reset()

sphere = trimesh.creation.uv_sphere(radius=0.25)
sphere.visual.face_colors = [0, 0, 1, 0.5]
sphere.visual.vertex_colors = [0, 0, 1, 0.5]

HUMAN_BOX_RATIO = 1.2
bbox = trimesh.creation.box(extents=np.array([4, 10, 4])*HUMAN_BOX_RATIO)
r = R.from_euler('xyz', np.radians(np.array([-10,0,0])))
t = np.array([0.2, 1.3, -0.2])*4
rt = np.eye(4)
rt[:3, :3] = r.as_matrix()
rt[:3, 3] = t
bbox.apply_transform(rt)


@dataclass
class PersonGaze:
    id: int
    tracked_id: int
    bbox: trimesh.Trimesh
    arrow: trimesh.Trimesh
    origin: np.ndarray
    direction: np.ndarray
    pitch: float
    yaw: float


arrow_color_map = {
    "display": [0, 1, 0, 0.5],
    "floor": [0, 0, 1, 0.5],
}


def process(
        tracking_file: pathlib.Path, 
        vid_file: pathlib.Path, 
        depth_file: pathlib.Path,
        gaze_file: pathlib.Path,
        display_r: np.ndarray,
        display_t: np.ndarray
    ):
    assert tracking_file.exists()
    assert vid_file.exists()
    assert depth_file.exists()
    assert gaze_file.exists()

    DISPLAY_RATIO = 80
    display = trimesh.creation.box(extents=np.array([0.8, 0.25, 0.01])*DISPLAY_RATIO)
    display.visual.face_colors = [0, 1, 0, 0.5]
    display.visual.vertex_colors = [0, 1, 0, 0.5]

    # Add the monitor rectangle
    # X -> blue, Y -> green, Z -> yellow
    r = R.from_euler('xyz', np.radians(display_r))
    t = display_t
    rt = np.eye(4)
    rt[:3, :3] = r.as_matrix()
    rt[:3, 3] = t
    display.apply_transform(rt)
    if VISUALIZE: plot.add_mesh(f'display', display)

    FLOOR_RATIO = 70
    floor = trimesh.creation.box(extents=np.array([2, 2, 0.01])*FLOOR_RATIO)
    floor.visual.face_colors = [0, 0, 1, 0.5]
    floor.visual.vertex_colors = [0, 0, 1, 0.5]

    # Add the floor rectangle
    r = R.from_euler('xyz', np.radians(np.array([82,0,0])))
    t = np.array([0, 0, -22])*4
    rt = np.eye(4)
    rt[:3, :3] = r.as_matrix()
    rt[:3, 3] = t
    floor.apply_transform(rt)
    # if VISUALIZE: plot.add_mesh("floor", floor)

    # Output file
    output_file = gaze_file.parent / f"{gaze_file.stem}_raytraced.csv"
    output_container = {'frame': [], 'src_tracked_id': [], 'dst_tracked_id': []}

    # Load the gaze vectors and the corresponding CSV with BBox info
    faces_df = pd.read_csv(tracking_file)
    gaze_df = pd.read_csv(gaze_file)

    # Load the RGB and depth videos
    cap = cv2.VideoCapture(str(vid_file))
    LENGTH = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    depth_cap = cv2.VideoCapture(str(depth_file))

    # Starting point
    video_index = 3150 # 0 # 9450 
    cap.set(cv2.CAP_PROP_POS_FRAMES, video_index)
    depth_cap.set(cv2.CAP_PROP_POS_FRAMES, video_index)

    try:

        for i in tqdm(range(video_index, LENGTH), total=LENGTH-video_index):

            print(i)

            # Load frame
            r_ret, rgb = cap.read()
            d_ret, depth = depth_cap.read()

            # cv2.imwrite("test.png", rgb)
            # cv2.imwrite("test_depth.png", depth)

            if not r_ret or not d_ret:
                break

            depth = cv2.cvtColor(depth, cv2.COLOR_BGR2GRAY)

            # Get the gaze vector
            faces = faces_df[faces_df['Frame'] == i]
            gaze_vectors = gaze_df[gaze_df['frame'] == i]

            # Mesh containers
            persons = []
            objs = {}

            j = 0
            for (_, face) in faces.iterrows():
                # Get the gaze vector
                # print(gaze_vectors)
                gaze_vector = gaze_vectors[gaze_vectors['tracked_id'] == face["Student_ID"]].iloc[0]
                pitch = gaze_vector['pitch']
                yaw = gaze_vector['yaw']

                # Compute the centroid of the face
                centroid = (face.X + face.Width/2, face.Y + face.Height/2)
                centroid_depth = depth[int(centroid[1]), int(centroid[0])]
                face_t = compute_3D_point(centroid[0], centroid[1], centroid_depth, depth.shape[0], depth.shape[1])
                face_t[-1] = face_t[-1]*-1
                face_t[0] = face_t[0]*-1
                line_start = face_t.copy()

                # 2D Data
                # Draw in the 2D image
                cv2.circle(rgb, (int(centroid[0]), int(centroid[1])), 5, (0, 255, 0), -1)

                # Draw the id in 2D
                # cv2.putText(rgb, f"{j}", (int(centroid[0]), int(centroid[1])), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 5)

                # Draw the gaze vector
                # draw_gaze(centroid[0], centroid[1], 100, rgb, [pitch, yaw], color=(0,255,0))

                # 3D Data
                # Make copy of spher and apply transform
                hbbox = bbox.copy()
                hbbox.apply_translation(face_t)
                arrow, direction = create_oriented_arrow(line_start, pitch, yaw, length=30, thickness=15)

                persons.append(
                    PersonGaze(
                        id=j,
                        tracked_id=int(face["Student_ID"]),
                        bbox=hbbox, 
                        arrow=arrow, 
                        origin=face_t, 
                        direction=direction,
                        pitch=pitch,
                        yaw=yaw
                    )
                )
                objs[int(face['Student_ID'])] = hbbox
                j += 1

            # Perform raytracing
            for person in persons:
                objs_exclude_person = {str(k):v for k,v in objs.items() if k != person.tracked_id}
                objs_exclude_person.update({'display': display})
                scene = trimesh.Scene()
                for k,v in objs_exclude_person.items():
                    scene.add_geometry(v, geom_name=k)

                mesh_name = find_closest_intersected_mesh_via_intersection(
                    scene,
                    arrow=person.arrow,
                    origin=person.origin
                )
                # scene.add_geometry(person.arrow, geom_name=f"arrow-{person.id}")
                # scene.show(viewer="gl", axis=True)
                # print(f"{person.id} -> {mesh_name}")

                assert mesh_name != person.tracked_id, "The person is intersecting with itself"

                # Save to output
                output_container['frame'].append(i)
                output_container['src_tracked_id'].append(person.tracked_id)
                output_container['dst_tracked_id'].append(mesh_name)

                # Reduce the lenght of the arrow
                person.arrow = create_oriented_arrow(person.origin, person.pitch, person.yaw, length=5, thickness=15)[0]

                if mesh_name and VISUALIZE:
                    if mesh_name in arrow_color_map:
                        person.arrow.visual.face_colors = arrow_color_map[mesh_name]
                        person.arrow.visual.vertex_colors = arrow_color_map[mesh_name]
                    else:
                        person.arrow.visual.face_colors = [1, 0, 0, 0.5]
                        person.arrow.visual.vertex_colors = [1, 0, 0, 0.5]

            # Show the scene
            # if VISUALIZE:
            #     debug_scene = trimesh.Scene()
            #     for person in persons:
            #         debug_scene.add_geometry(person.bbox, geom_name=f"bbox-{person.id}")
            #         debug_scene.add_geometry(person.arrow, geom_name=f"arrow-{person.id}")
            #         # Display and floor
            #         debug_scene.add_geometry(display, geom_name="display")
            #         debug_scene.add_geometry(floor, geom_name="floor")
            #     # plot.add_mesh('debug_scene', debug_scene)
            #     debug_scene.show(viewer="gl", axis=True)


            # Draw the arrows and bboxes later
            if VISUALIZE:
                for person in persons:
                    # Draw spheres in the 3D plot
                    if f"bbox-{person.id}" in plot.client.visuals:
                        plot.update_mesh(f"bbox-{person.id}", person.bbox, drawFaces=False, drawEdges=True)
                    else:
                        plot.add_mesh(f"bbox-{person.id}", person.bbox, drawFaces=False, drawEdges=True)

                    if f"arrow-{person.id}" in plot.client.visuals:
                        plot.update_mesh(f"arrow-{person.id}", person.arrow)
                    else:
                        plot.add_mesh(f"arrow-{person.id}", person.arrow)
                    
                # Resize
                sm_rgb = imutils.resize(rgb, width=500)
                sm_depth = imutils.resize(depth, width=500)

                mesh = get_mesh(sm_rgb, sm_depth, keep_edges=True)
                mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi, [1,0,0]))

                # Plot the frame
                plot.plot_image(sm_rgb)
                if i == video_index:
                    plot.add_mesh('mesh', mesh)
                else:
                    plot.update_mesh('mesh', mesh)

            # Save to PLY
            # 400 camera distance
            # filepath = DATA_DIR / 'mesh' / 'g1d1.ply'
            # mesh.apply_translation(-mesh.centroid)
            # mesh.export(str(filepath), file_type='ply')
            # break
            # time.sleep(1)

            break
                    
    except KeyboardInterrupt:
        pass
                
    # Save the output
    # output_df = pd.DataFrame(output_container)
    # output_df = output_df.sort_values(by=['frame', 'src_tracked_id'])
    # output_df.to_csv(output_file, index=False)

process(
    DATA_DIR / 'trackings' / 'Day1Group1Camera2_with_student_IDs.csv',
    DATA_DIR / "videos" / "day1" / "block-a-blue-day1-first-group-cam2.mp4",
    DATA_DIR / 'depths' / 'day1' / "group1-depth-cam2.mp4",
    DATA_DIR / 'gaze_vectors' / "gaze_vector_d1g1.csv",
    np.array([30,110,30]),
    np.array([-3, 0, -9])*4
)

# process(
#     DATA_DIR / 'trackings' / 'Day1Group2Camera2_with_student_IDs.csv',
#     DATA_DIR / "videos" / "day1" / "block-a-blue-day1-second-group-cam2.mp4",
#     DATA_DIR / 'depths' / 'day1' / "group2-depth-cam2.mp4",
#     DATA_DIR / 'gaze_vectors' / "gaze_vector_d1g2.csv",
#     np.array([30,110,30]),
#     np.array([-3, 0, -9])*4
# )

# process(
#     DATA_DIR / 'trackings' / 'Day2Group1Camera2_with_student_IDs.csv',
#     DATA_DIR / "videos" / "day2" / "block-a-blue-day2-first-group-cam2.mp4",
#     DATA_DIR / 'depths' / 'day2' / "group1-depth-cam2.mp4",
#     DATA_DIR / 'gaze_vectors' / "gaze_vector_d2g1.csv",
#     np.array([30,110,30]),
#     np.array([-3, 0, -9])*4
# )

# process(
#     DATA_DIR / 'trackings' / 'Day2Group2Camera2_with_student_IDs.csv',
#     DATA_DIR / "videos" / "day2" / "block-a-blue-day2-second-group-cam2.mp4",
#     DATA_DIR / 'depths' / 'day2' / "group2-depth-cam2.mp4",
#     DATA_DIR / 'gaze_vectors' / "gaze_vector_d2g2.csv",
#     np.array([30,110,30]),
#     np.array([-3, 0, -9])*4
# )

  0%|          | 0/10314 [00:00<?, ?it/s]

3150


  0%|          | 0/10314 [00:00<?, ?it/s]
