# MPS Tutorial
This sample will show you how to use the Aria MPS data via the MPS apis.
Please refer to the MPS wiki for more information about data formats and schemas

### Notebook stuck?
Note that because of Jupyter and Plotly issues, sometimes the code may stuck at visualization. We recommend **restart the kernels** and try again to see if the issue is resolved.


## Download the MPS sample dataset locally
> The sample dataset will get downloaded to a **tmp** folder by default. Please modify the path if necessary

In [None]:
import os

from tqdm import tqdm
from urllib.request import urlretrieve
from zipfile import ZipFile

google_colab_env = 'google.colab' in str(get_ipython())
if google_colab_env:
    print("Running from Google Colab, installing projectaria_tools and getting sample data")
    !pip install projectaria-tools
    mps_sample_path = "./mps_sample_data/"
else:
    mps_sample_path = "/tmp/mps_sample_data/"

base_url = "https://www.projectaria.com/async/sample/download/?bucket=mps&filename="
os.makedirs(mps_sample_path, exist_ok=True)

filenames = [
    "sample.vrs",
    "trajectory.zip",
    "eye_gaze_v3.zip",
    "hand_tracking.zip"]

print("Downloading sample data")
for filename in tqdm(filenames):
    print(f"Processing: {filename}")
    full_path: str = os.path.join(mps_sample_path, filename)
    urlretrieve(f"{base_url}{filename}", full_path)
    if filename.endswith(".zip"):
        with ZipFile(full_path, 'r') as zip_ref:
            zip_ref.extractall(path=mps_sample_path)
            if "eye_gaze" in filename:
                eye_gaze_path = os.path.join(mps_sample_path, "eye_gaze")
                os.makedirs(eye_gaze_path, exist_ok=True)
                os.rename(os.path.join(mps_sample_path, "general_eye_gaze.csv"), os.path.join(eye_gaze_path, "general_eye_gaze.csv"))
                os.rename(os.path.join(mps_sample_path, "personalized_eye_gaze.csv"), os.path.join(eye_gaze_path, "personalized_eye_gaze.csv"))
                

## Load the trajectory, point cloud and eye gaze using the MPS apis

In [None]:
from projectaria_tools.core import data_provider, mps
from projectaria_tools.core.mps.utils import (
    filter_points_from_confidence,
    get_gaze_vector_reprojection,
    get_nearest_eye_gaze,
    get_nearest_pose,
)
from projectaria_tools.core.stream_id import StreamId
import numpy as np

# Load the VRS file
vrsfile = os.path.join(mps_sample_path, "sample.vrs")

# Trajectory and global points
closed_loop_trajectory = os.path.join(
    mps_sample_path, "trajectory", "closed_loop_trajectory.csv"
)
global_points = os.path.join(mps_sample_path, "trajectory", "global_points.csv.gz")

# Eye gaze
generalized_eye_gaze_path = os.path.join(
    mps_sample_path, "eye_gaze", "general_eye_gaze.csv"
)
calibrated_eye_gaze_path = os.path.join(
    mps_sample_path, "eye_gaze", "personalized_eye_gaze.csv"
)

# Hand tracking
wrist_and_palm_poses_path = os.path.join(
    mps_sample_path, "hand_tracking", "wrist_and_palm_poses.csv"
)

# Create data provider and get T_device_rgb
provider = data_provider.create_vrs_data_provider(vrsfile)
# Since we want to display the position of the RGB camera, we are querying its relative location
# from the device and will apply it to the device trajectory.
T_device_RGB = provider.get_device_calibration().get_transform_device_sensor(
    "camera-rgb"
)

## Load trajectory and global points
mps_trajectory = mps.read_closed_loop_trajectory(closed_loop_trajectory)
points = mps.read_global_point_cloud(global_points)

## Load eyegaze
generalized_eye_gazes = mps.read_eyegaze(generalized_eye_gaze_path)
calibrated_eye_gazes = mps.read_eyegaze(calibrated_eye_gaze_path)

## Load hand tracking
wrist_and_palm_poses = mps.hand_tracking.read_wrist_and_palm_poses(
    wrist_and_palm_poses_path
)

# Loaded data must be not empty
assert(
    len(mps_trajectory) != 0 and
    len(points) != 0 and
    len(generalized_eye_gazes) != 0 and
    len(calibrated_eye_gazes) != 0 and
    len(wrist_and_palm_poses) != 0)

## Helper functions

In [None]:
import plotly.graph_objs as go
from matplotlib import pyplot as plt

# Helper function to build the frustum
def build_cam_frustum(transform_world_device):
    points = (
        np.array(
            [[0, 0, 0], [0.5, 0.5, 1], [-0.5, 0.5, 1], [-0.5, -0.5, 1], [0.5, -0.5, 1]]
        )
        * 0.6
    )
    transform_world_rgb = transform_world_device @ T_device_RGB
    points_transformed = transform_world_rgb @ points.transpose()
    return go.Mesh3d(
        x=points_transformed[0, :],
        y=points_transformed[1, :],
        z=points_transformed[2, :],
        i=[0, 0, 0, 0, 1, 1],
        j=[1, 2, 3, 4, 2, 3],
        k=[2, 3, 4, 1, 3, 4],
        showscale=False,
        visible=False,
        colorscale="jet",
        intensity=points[:, 2],
        opacity=1.0,
        hoverinfo="none",
    )

## Visualize the trajectory and point cloud in a 3D interactive plot
* Load trajectory
* Load global point cloud
* Render dense trajectory (1Khz) as points.
* Render subsampled 6DOF poses via camera frustum. Use calibration to transform RGB camera pose to world frame
* Render subsampled point cloud

_Please wait a minute for all the data to load. Zoom in to the point cloud and adjust your view. Then use the time slider to move the camera_

In [None]:
# Load all world positions from the trajectory
traj = np.empty([len(mps_trajectory), 3])
for i in range(len(mps_trajectory)):
    traj[i, :] = mps_trajectory[i].transform_world_device.translation()

# Subsample trajectory for quick display
skip = 1000
mps_trajectory_subset = mps_trajectory[::skip]
steps = [None]*len(mps_trajectory_subset)

# Load each pose as a camera frustum trace
cam_frustums = [None]*len(mps_trajectory_subset)

for i in range(len(mps_trajectory_subset)):
    pose = mps_trajectory_subset[i]
    cam_frustums[i] = build_cam_frustum(pose.transform_world_device)
    timestamp = pose.tracking_timestamp.total_seconds()
    step = dict(method="update", args=[{"visible": [False] * len(cam_frustums) + [True] * 2}, {"title": "Trajectory and Point Cloud"},], label=timestamp,)
    step["args"][0]["visible"][i] = True  # Toggle i'th trace to "visible"
    steps[i] = step
cam_frustums[0].visible = True
    
# Filter the point cloud by inv depth and depth and load
points = filter_points_from_confidence(points)
# Retrieve point position
point_cloud = np.stack([it.position_world for it in points])

# Create slider to allow scrubbing and set the layout
sliders = [dict(currentvalue={"suffix": " s", "prefix": "Time :"}, pad={"t": 5}, steps=steps,)]
layout = go.Layout(sliders=sliders, scene=dict(bgcolor='lightgray', dragmode='orbit', aspectmode='data', xaxis_visible=False, yaxis_visible=False,zaxis_visible=False))

# Plot trajectory and point cloud
# We color the points by their z coordinate
trajectory = go.Scatter3d(x=traj[:, 0], y=traj[:, 1], z=traj[:, 2], mode="markers", marker={"size": 2, "opacity": 0.8, "color": "red"}, name="Trajectory", hoverinfo='none')
global_points = go.Scatter3d(x=point_cloud[:, 0], y=point_cloud[:, 1], z=point_cloud[:, 2], mode="markers",
    marker={"size" : 1.5, "color": point_cloud[:, 2], "cmin": -1.5, "cmax": 2, "colorscale": "viridis",},
    name="Global Points", hoverinfo='none')

# draw
plot_figure = go.Figure(data=cam_frustums + [trajectory, global_points], layout=layout)
plot_figure.show()

## Visualize generalized and calibrated eye gaze projection on an rgb image.
* Load Eyegaze MPS output
* Select a random RGB frame
* Find the closest eye gaze data for the RGB frame
* Project the eye gaze for the RGB frame by **using a fixed depth of 1m** or existing depth if available.
* Show the gaze cross on the RGB image

In [None]:
rgb_stream_id = StreamId("214-1")
rgb_stream_label = provider.get_label_from_stream_id(rgb_stream_id)
num_rgb_frames = provider.get_num_data(rgb_stream_id)
rgb_frame = provider.get_image_data_by_index(rgb_stream_id, (int)(num_rgb_frames-5))
assert rgb_frame[0] is not None, "no rgb frame"

image = rgb_frame[0].to_numpy_array()
capture_timestamp_ns = rgb_frame[1].capture_timestamp_ns
generalized_eye_gaze = get_nearest_eye_gaze(generalized_eye_gazes, capture_timestamp_ns)
calibrated_eye_gaze = get_nearest_eye_gaze(calibrated_eye_gazes, capture_timestamp_ns)
# get projection function
device_calibration = provider.get_device_calibration()
cam_calibration = device_calibration.get_camera_calib(rgb_stream_label)
assert cam_calibration is not None, "no camera calibration"

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 10))

# Draw a cross at the projected gaze center location on the RGB image at available depth or if unavailable a 1m proxy
depth_m = generalized_eye_gaze.depth or 1.0
generalized_gaze_center_in_pixels = get_gaze_vector_reprojection(generalized_eye_gaze, rgb_stream_label, device_calibration, cam_calibration, depth_m)
if generalized_gaze_center_in_pixels is not None:
    ax1.imshow(image)
    ax1.plot(generalized_gaze_center_in_pixels[0], generalized_gaze_center_in_pixels[1], '+', c="red", mew=1, ms=20)
    ax1.grid(False)
    ax1.axis(False)
    ax1.set_title("Generalized Eye Gaze")
else:
    print(f"Eye gaze center projected to {generalized_gaze_center_in_pixels}, which is out of camera sensor plane.")
    
depth_m = calibrated_eye_gaze.depth or 1.0
calibrated_gaze_center_in_pixels = get_gaze_vector_reprojection(calibrated_eye_gaze, rgb_stream_label, device_calibration, cam_calibration, depth_m = 1.0)
if calibrated_gaze_center_in_pixels is not None:
    ax2.imshow(image)
    ax2.plot(calibrated_gaze_center_in_pixels[0], calibrated_gaze_center_in_pixels[1], '+', c="red", mew=1, ms=20)
    ax2.grid(False)
    ax2.axis(False)
    ax2.set_title("Calibrated Eye Gaze")
else:
    print(f"Eye gaze center projected to {calibrated_gaze_center_in_pixels}, which is out of camera sensor plane.")

plt.show()


## Visualize wrist and palm pose projection on RGB and SLAM images

In [None]:
from typing import Dict, List, Optional

from projectaria_tools.core.calibration import CameraCalibration, DeviceCalibration
from projectaria_tools.core.sensor_data import TimeDomain, TimeQueryOptions

time_domain: TimeDomain = TimeDomain.DEVICE_TIME
time_query_closest: TimeQueryOptions = TimeQueryOptions.CLOSEST

# Get stream ids, stream labels, stream timestamps, and camera calibrations for RGB and SLAM cameras
stream_ids: Dict[str, StreamId] = {
    "rgb": StreamId("214-1"),
    "slam-left": StreamId("1201-1"),
    "slam-right": StreamId("1201-2"),
}
stream_labels: Dict[str, str] = {
    key: provider.get_label_from_stream_id(stream_id)
    for key, stream_id in stream_ids.items()
}
stream_timestamps_ns: Dict[str, List[int]] = {
    key: provider.get_timestamps_ns(stream_id, time_domain)
    for key, stream_id in stream_ids.items()
}
cam_calibrations = {
    key: device_calibration.get_camera_calib(stream_label)
    for key, stream_label in stream_labels.items()
}
for key, cam_calibration in cam_calibrations.items():
    assert cam_calibration is not None, f"no camera calibration for {key}"

# Get device calibration and transform from device to sensor
device_calibration = provider.get_device_calibration()


def get_T_device_sensor(key: str):
    return device_calibration.get_transform_device_sensor(stream_labels[key])


# Get a sample frame for each of the RGB, SLAM left, and SLAM right streams
sample_timestamp_ns: int = stream_timestamps_ns["rgb"][120]
sample_frames = {
    key: provider.get_image_data_by_time_ns(
        stream_id, sample_timestamp_ns, time_domain, time_query_closest
    )[0]
    for key, stream_id in stream_ids.items()
}

# Get the wrist and palm pose
mps_data_paths_provider = mps.MpsDataPathsProvider(mps_sample_path)
mps_data_paths = mps_data_paths_provider.get_data_paths()
mps_data_provider = mps.MpsDataProvider(mps_data_paths)
wrist_and_palm_pose = mps_data_provider.get_wrist_and_palm_pose(
    sample_timestamp_ns, time_query_closest
)

# Helper functions for reprojection and plotting
def get_point_reprojection(
    point_position_device: np.array, key: str
) -> Optional[np.array]:
    point_position_camera = get_T_device_sensor(key).inverse() @ point_position_device
    point_position_pixel = cam_calibrations[key].project(point_position_camera)
    return point_position_pixel


def get_wrist_and_palm_pixels(key: str) -> np.array:
    left_wrist = get_point_reprojection(
        wrist_and_palm_pose.left_hand.wrist_position_device, key
    )
    left_palm = get_point_reprojection(
        wrist_and_palm_pose.left_hand.palm_position_device, key
    )
    right_wrist = get_point_reprojection(
        wrist_and_palm_pose.right_hand.wrist_position_device, key
    )
    right_palm = get_point_reprojection(
        wrist_and_palm_pose.right_hand.palm_position_device, key
    )
    return left_wrist, left_palm, right_wrist, right_palm


def plot_wrists_and_palms(plt, left_wrist, left_palm, right_wrist, right_palm):
    def plot_point(point, color):
        plt.plot(*point, ".", c=color, mew=1, ms=20)

    if left_wrist is not None:
        plot_point(left_wrist, "blue")
    if left_palm is not None:
        plot_point(left_palm, "blue")
    if right_wrist is not None:
        plot_point(right_wrist, "red")
    if right_palm is not None:
        plot_point(right_palm, "red")


# Display wrist and palm positions on RGB, SLAM left, and SLAM right images
plt.figure()
rgb_image = sample_frames["rgb"].to_numpy_array()
plt.grid(False)
plt.axis("off")
plt.imshow(rgb_image)
left_wrist, left_palm, right_wrist, right_palm = get_wrist_and_palm_pixels("rgb")
plot_wrists_and_palms(plt, left_wrist, left_palm, right_wrist, right_palm)

plt.figure()
plt.subplot(1, 2, 1)
slam_left_image = sample_frames["slam-left"].to_numpy_array()
plt.grid(False)
plt.axis("off")
plt.imshow(slam_left_image, cmap="gray", vmin=0, vmax=255)
left_wrist, left_palm, right_wrist, right_palm = get_wrist_and_palm_pixels("slam-left")
plot_wrists_and_palms(plt, left_wrist, left_palm, right_wrist, right_palm)


plt.subplot(1, 2, 2)
slam_right_image = sample_frames["slam-right"].to_numpy_array()
plt.grid(False)
plt.axis("off")
plt.imshow(slam_right_image, interpolation="nearest", cmap="gray")
left_wrist, left_palm, right_wrist, right_palm = get_wrist_and_palm_pixels("slam-right")
plot_wrists_and_palms(plt, left_wrist, left_palm, right_wrist, right_palm)
plt.show()