# Interactive Examples on Project Aria Tools

### Notebook stuck?
Note that because of Jupyter issues, sometimes the code may stuck at visualization. We recommend **restart the kernels** and try again to see if the issue is resolved.

In [None]:
import sys
import os

# Specifics for Google Colab
google_colab_env = 'google.colab' in str(get_ipython())
if google_colab_env:
    print("Running from Google Colab, installing projectaria_tools and getting sample data")
    !pip install projectaria-tools
    # TODO: Update the data path here.
    !curl -O -J -L  "https://github.com/facebookresearch/projectaria_tools/raw/main/data/gen1/mps_sample/sample.vrs"
    vrsfile = "sample.vrs"
else:
    print("Using a pre-existing projectaria_tool github repository")
    # Define the paths to check
    possible_path_1 = "../../../data/mps_sample/sample.vrs"
    possible_path_2 = "../../../data/gen1/mps_sample/sample.vrs"
    # Check which path contains the actual data file
    if os.path.exists(possible_path_1):
        vrsfile = possible_path_1
        print(f"Using data from: {vrsfile}")
    elif os.path.exists(possible_path_2):
        vrsfile = possible_path_2
        print(f"Using data from: {vrsfile}")
    else:
        # Exit with an error message if no data file is found
        sys.exit("Error: No data file found in the specified paths.")

In [None]:
import sys
import os

# Add the current repository path to sys.path
repo_path = os.path.abspath(os.path.join(os.getcwd(), '../../'))
sys.path.insert(0, repo_path)
print(repo_path)

In [None]:
from projectaria_tools.core import data_provider, calibration
from projectaria_tools.core.image import InterpolationMethod
from projectaria_tools.core.sensor_data import TimeDomain, TimeQueryOptions
from projectaria_tools.core.stream_id import RecordableTypeId, StreamId
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image

## Create data provider

In [None]:
print(f"Creating data provider from {vrsfile}")
provider = data_provider.create_vrs_data_provider(vrsfile)
if not provider:
    print("Invalid vrs data provider")


## Check device version
Create device-version specific variables. 

In [None]:
from typing import Optional
from projectaria_tools.core.calibration import DeviceVersion
# Print out the device version of the recording
device_version = provider.get_device_version()
print(f"Device version is {calibration.get_name(device_version)}")

# Example variables used in this notebook
rgb_stream_id = StreamId('214-1')

# Some example variables are different for Gen1 and Gen2,
# because they have different HW configs, sensor label names, etc.
if device_version == DeviceVersion.Gen1:
    example_stream_mappings = {
    "camera-slam-left": StreamId("1201-1"),
    "camera-slam-right":StreamId("1201-2"),
    "camera-rgb":StreamId("214-1"),
    "camera-eyetracking":StreamId("211-1"),
    }
    example_slam_stream_label = "camera-slam-left"

    # Gen1 images are rotated 90 degrees for better visualization
    ROTATE_90_FLAG = True

    # A linear camera model used in undistortion example: [width, height, focal]
    example_linear_rgb_camera_model_params = [512, 512, 150]
elif device_version == DeviceVersion.Gen2:
    example_stream_mappings = {
    "slam-front-left": StreamId("1201-1"),
    "slam-front-right":StreamId("1201-2"),
    "slam-side-left": StreamId("1201-3"),
    "slam-side-right": StreamId("1201-4"),
    "camera-rgb":StreamId("214-1"),
    "camera-et-left":StreamId("211-1"),
    "camera-et-right":StreamId("211-2"),
    }
    example_slam_stream_label = "slam-front-left"
    # Gen2 images are already in up-right orientation
    ROTATE_90_FLAG = False

    # A linear camera model used in undistortion example: [width, height, focal]
    example_linear_rgb_camera_model_params = [4032, 3024, 1600]

example_slam_stream_id = provider.get_stream_id_from_label(example_slam_stream_label)

# A helper function to auto rotate Aria image, if necessary
def auto_image_rotation(img: np.array, stream_label: Optional[str] = None):
    if stream_label != "camera-eyetracking" and ROTATE_90_FLAG:
        return np.rot90(img, -1)
    else:
        return img

# Retrieving image data

Goals:
- Learn how to retrieve Image data for a given Image stream

Key learnings:
- VRS contains data streams are identified with a Unique Identifier: stream_id
- Learn what are the Stream Ids used by Aria data (Slam, Rgb, EyeTracking)
- Learn that image data can be retrieved by using a record Index or a timestamp
- For each stream_id, index ranges from [0, get_num_data(stream_id)], and the same index for different streams could have different timestamps
- Query data from different sensors of the same timestamp can be done through `get_image_data_by_time_ns`, `get_imu_data_by_time_ns`, etc

In [None]:
axes = []
fig, axes = plt.subplots(1, len(example_stream_mappings), figsize=(12, 4))
fig.suptitle('Retrieving image data using Record Index')

# Query data with index
frame_index = 1
for idx, [stream_name, stream_id] in enumerate(list(example_stream_mappings.items())):
    image = provider.get_image_data_by_index(stream_id, frame_index)
    image_to_show = auto_image_rotation(image[0].to_numpy_array(), stream_name)
    axes[idx].imshow(image_to_show, cmap="gray", vmin=0, vmax=255)
    axes[idx].title.set_text(stream_name)
    axes[idx].tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)
plt.show()

# Same example using Time
plt.figure()
fig, axes = plt.subplots(1, len(example_stream_mappings), figsize=(12, 4))
fig.suptitle('Retrieving image data using Time')

time_domain = TimeDomain.DEVICE_TIME  # query data based on host time
option = TimeQueryOptions.CLOSEST # get data whose time [in TimeDomain] is CLOSEST to query time
start_time = provider.get_first_time_ns(rgb_stream_id, time_domain)

for idx, [stream_name, stream_id] in enumerate(list(example_stream_mappings.items())):
    image = provider.get_image_data_by_time_ns(stream_id, start_time, time_domain, option)
    image_to_show = auto_image_rotation(image[0].to_numpy_array(), stream_name)
    axes[idx].imshow(image_to_show, cmap="gray", vmin=0, vmax=255)
    axes[idx].title.set_text(stream_name)
    axes[idx].tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)
plt.show()


# Summarize a VRS using thumbnails

Goals:
- Summarize a VRS using 10 image side by side

Key learnings:
- Image streams are identified with a Unique Identifier: stream_id
- PIL images can be created from Numpy array

In [None]:
from PIL import Image, ImageOps
from tqdm import tqdm

# Retrieve Start and End time for the given Sensor Stream Id
start_time = provider.get_first_time_ns(rgb_stream_id, time_domain)
end_time = provider.get_last_time_ns(rgb_stream_id, time_domain)

# Retrieve image size for the RGB stream
time_domain = TimeDomain.DEVICE_TIME  # query data based on host time
option = TimeQueryOptions.CLOSEST # get data whose time [in TimeDomain] is CLOSEST to query time

image_config = provider.get_image_configuration(rgb_stream_id)
width = image_config.image_width
height = image_config.image_height

sample_count = 10
resize_ratio = 10
thumbnail = newImage = Image.new(
    "RGB", (int(width * sample_count / resize_ratio), int(height / resize_ratio))
)
current_width = 0


# Samples 10 timestamps
sample_timestamps = np.linspace(start_time, end_time, sample_count)
for sample in tqdm(sample_timestamps):
    image_tuple = provider.get_image_data_by_time_ns(rgb_stream_id, int(sample), time_domain, option)
    image_array = auto_image_rotation(image_tuple[0].to_numpy_array())
    image = Image.fromarray(image_array)
    new_size = (
        int(image.size[0] / resize_ratio),
        int(image.size[1] / resize_ratio),
    )
    image = image.resize(new_size)
    thumbnail.paste(image, (current_width, 0))
    current_width = int(current_width + width / resize_ratio)

from IPython.display import Image
display(thumbnail)

# Obtain mapping between stream_id and sensor label
Goals:
- In a vrs file, each sensor data is identified through stream_id
- Learn mapping between stream_id and label for each sensor

Key learnings:
- VRS is using Unique Identifier for each stream called stream_id. 
- For each sensor data, it is attached with a stream_id, which contains two parts [RecordableTypeId, InstanceId]. 
- To get the actual readable name of each sensor,
we can use `get_label_from_stream_id` vise versa `get_stream_id_from_label`

In [None]:
streams = provider.get_all_streams()
for stream_id in streams:
    label = provider.get_label_from_stream_id(stream_id)
    print(
        f"stream_id: [{stream_id}] convert to label: [{label}] and back: [{provider.get_stream_id_from_label(label)}]"
    )

# Get sensor data in a sequence based on data capture time
Goal:
- Obtain sensor data sequentially based on timestamp

Key learnings
- Default option activates all sensors and playback the entire dataset from vrs
- Setup option to only activate certain streams, truncate start/end time, and sample rate
- Obtain data from different sensor types
- `TimeDomain` are separated into four categories: `RECORD_TIME`, `DEVICE_TIME`, `HOST_TIME`, `TIME_CODE`

### Step 1: obtain default options that provides the whole dataset from VRS
* activates all sensor streams
* No truncation for first/last timestamp
* Subsample rate = 1 (do not skip any data per sensor)

In [None]:
options = (
    provider.get_default_deliver_queued_options()
)  # default options activates all streams

### Step 2: set preferred deliver options
* truncate first/last time: `set_truncate_first_device_time_ns/set_truncate_last_device_time_ns()`
* subselect sensor streams to play: `activate_stream(stream_id)`
* skip sensor data : `set_subsample_rate(stream_id, rate)`

In [None]:
options.set_truncate_first_device_time_ns(int(1e8))  # 0.1 secs after vrs first timestamp
options.set_truncate_last_device_time_ns(int(2e8))  # 0.2 sec before vrs last timestamp

# deactivate all sensors
options.deactivate_stream_all()

# activate only a subset of sensors
slam_stream_ids = options.get_stream_ids(RecordableTypeId.SLAM_CAMERA_DATA)
imu_stream_ids = options.get_stream_ids(RecordableTypeId.SLAM_IMU_DATA)

for stream_id in slam_stream_ids:
    options.activate_stream(stream_id)  # activate slam cameras
    options.set_subsample_rate(stream_id, 1)  # sample every data for each slam camera

for stream_id in imu_stream_ids:
    options.activate_stream(stream_id)  # activate imus
    options.set_subsample_rate(stream_id, 10)  # sample every 10th data for each imu

### Step 3: create iterator to deliver data
`TimeDomain` contains the following
* `RECORD_TIME`: timestamp stored in vrs index, fast to access, but not guaranteed which time domain
* `DEVICE_TIME`: capture time in device's timedomain, accurate
* `HOST_TIME`: arrival time in host computer's timedomain, may not be accurate
* `TIME_CODE`: capture in TimeSync server's timedomain


In [None]:
iterator = provider.deliver_queued_sensor_data(options)
for sensor_data in iterator:
    label = provider.get_label_from_stream_id(sensor_data.stream_id())
    sensor_type = sensor_data.sensor_data_type()
    device_timestamp = sensor_data.get_time_ns(TimeDomain.DEVICE_TIME)
    host_timestamp = sensor_data.get_time_ns(TimeDomain.HOST_TIME)
    timecode_timestamp = sensor_data.get_time_ns(TimeDomain.TIME_CODE)
    print(
        f"""obtain data from {label} of type {sensor_type} with
        DEVICE_TIME: {device_timestamp} nanoseconds
        HOST_TIME: {host_timestamp} nanoseconds
        """
    )

# Random access data
Goal
- Access data from a stream randomly using a data index or a timestamp

Key learnings
- Sensor data can be obtained through index within the range of [0, number of data for this stream_id)

  - `get_sensor_data_by_index(stream_id, index)`
  - `get_image_data_by_index(stream_id, index)`
  - Access other sensor data by index interface is available in core/python/VrsDataProviderPyBind.h
  
- `TimeQueryOptions` has three options: `TimeQueryOptions.BEFORE`, `TimeQueryOptions.AFTER`, `TimeQueryOptions.CLOSEST`
- Query through index will provide the exact data vs query through a timestamp that is not exact, data nearby will be omitted base on `TimeQueryOptions`

In [None]:
# get all image data by index, skip every 20 frames
num_data = provider.get_num_data(example_slam_stream_id)

for index in range(0, num_data, 20):
    image_data = provider.get_image_data_by_index(example_slam_stream_id, index)
    print(
        f"Get image: {index} with timestamp {image_data[1].capture_timestamp_ns}"
    )

### Sensor data can be obtained by timestamp (nanoseconds)
* Get stream time range `get_first_time_ns` and `get_last_time_ns`
* Specify timedomain: `TimeDomain.DEVICE_TIME` (default)
* Query data by queryTime
  * `TimeQueryOptions.BEFORE` (default): sensor_dataTime <= queryTime
  * `TimeQueryOptions.AFTER` : sensor_dataTime >= queryTime
  * `TimeQueryOptions.CLOSEST` : sensor_dataTime closest to queryTime

In [None]:
time_domain = TimeDomain.DEVICE_TIME  # query data based on DEVICE_TIME
option = TimeQueryOptions.CLOSEST # get data whose time [in TimeDomain] is CLOSEST to query time

start_time = provider.get_first_time_ns(example_slam_stream_id, time_domain)
end_time = provider.get_last_time_ns(example_slam_stream_id, time_domain)

# Fetch every 1 second (1e9 ns)
for time in range(start_time, end_time, int(1e9)):
    image_data = provider.get_image_data_by_time_ns(
        example_slam_stream_id, time, time_domain, option
    )
    print(
        f"query time {time} and get capture image time {image_data[1].capture_timestamp_ns} within range {start_time} {end_time}"
    )

### Get sensor data configuration

In [None]:
def image_config_example(config):
    print(f"device_type {config.device_type}")
    print(f"device_version {config.device_version}")
    print(f"device_serial {config.device_serial}")
    print(f"sensor_serial {config.sensor_serial}")
    print(f"nominal_rate_hz {config.nominal_rate_hz}")
    print(f"image_width {config.image_width}")
    print(f"image_height {config.image_height}")
    print(f"pixel_format {config.pixel_format}")
    print(f"gamma_factor {config.gamma_factor}")

In [None]:
config = provider.get_image_configuration(example_slam_stream_id)
image_config_example(config)

# Calibration examples
Goal:
- Obtain camera extrinsics and intrinsics
- Learn to project a 3D point to camera frame

Key learnings
- Get calibration for different sensors using sensor labels
- Learn how to use extrinsics/intrinsics to project a 3D points to a given camera
- Reference frame convention

In [None]:
device_calib = provider.get_device_calibration()
all_sensor_labels = device_calib.get_all_labels()
print(f"device calibration contains calibrations for the following sensors \n {all_sensor_labels}")

### Project a 3D point to camera frame

In this section we will learn how to retrieve calibration data and how to use it.
Aria calibration is defined by two objects: one defining the intrinsics (`rgb_calib.project` and `rgb_calib.unproject`) and one defining the extrinsics as a SE3 pose (`device_calib.get_transform_device_sensor(sensor_label`).

Intrinsics can be used to project a 3d point to the image plane or un-project a 2d point as a bearing vector. Extrinsics are used to set the camera in world coordinates at a given rotation and position in space.

### Reference frame convention

> `transform_sensor1_sensor3` = `transform_sensor1_sensor2` * `transform_sensor2_sensor3` \
> `point_in_sensor`: 3D point measured from sensor's reference frame \
> `point_in_sensor` = `transform_sensor1_sensor` * `point_in_sensor`

Device Frame: `device_calib.get_origin_label() = camera-slam-left`\
Sensor extrinsics: `device_calib.get_transform_device_sensor(sensor_label)`

In [None]:
camera_name = "camera-rgb"
transform_device_camera = device_calib.get_transform_device_sensor(camera_name).to_matrix()
transform_camera_device = np.linalg.inv(transform_device_camera)
print(f"Device calibration origin label {device_calib.get_origin_label()}")
print(f"{camera_name} has extrinsics of \n {transform_device_camera}")

rgb_calib = device_calib.get_camera_calib("camera-rgb")
if rgb_calib is not None:
    # project a 3D point in device frame [camera-slam-left] to rgb camera
    point_in_device = np.array([0, 0, 10])
    point_in_camera = (
        np.matmul(transform_camera_device[0:3,0:3], point_in_device.transpose())
        + transform_camera_device[0:3,3]
    )

    maybe_pixel = rgb_calib.project(point_in_camera)
    if maybe_pixel is not None:
        print(
            f"Get pixel {maybe_pixel} within image of size {rgb_calib.get_image_size()}"
        )

### Get calibration data for other sensors
Aria is a multimodal capture device, each sensors calibration can be retrieved using the same interface. 

For Aria Gen1, EyeTracking (`get_aria_et_camera_calib()`) and Audio calibration (`get_aria_microphone_calib()`) is a bit different since we have multiple sensors that share the same stream_id.

In [None]:
et_calib = device_calib.get_aria_et_camera_calib()
if et_calib is not None:
    print(f"Camera {et_calib[0].get_label()} has image size {et_calib[0].get_image_size()}")
    print(f"Camera {et_calib[1].get_label()} has image size {et_calib[1].get_image_size()}"),

imu_calib = device_calib.get_imu_calib("imu-left")
if imu_calib is not None:
    print(f"{imu_calib.get_label()} has extrinsics transform_Device_Imu:\n {imu_calib.get_transform_device_imu().to_matrix3x4()}")

### Undistort an image
You can remove distortions in an image in three steps. 

First, use the provider to access the image and the camera calibration of the stream. Then create a "linear" spherical camera model with `get_spherical_camera_calibration`. The function allows you to specify the image size as well as focal length of the model, assuming principal point is at the image center. Finally, apply `distort_by_calibration` function to distort the image.

In [None]:
# input: retrieve image as a numpy array
sensor_name = "camera-rgb"
sensor_stream_id = provider.get_stream_id_from_label(sensor_name)
image_data = provider.get_image_data_by_index(sensor_stream_id, 0)
image_array = image_data[0].to_numpy_array()
# input: retrieve image distortion
device_calib = provider.get_device_calibration()
src_calib = device_calib.get_camera_calib(sensor_name)

# create output calibration: a linear model of image example_linear_rgb_camera_model_params.
# Invisible pixels are shown as black.
dst_calib = calibration.get_linear_camera_calibration(example_linear_rgb_camera_model_params[0], example_linear_rgb_camera_model_params[1], example_linear_rgb_camera_model_params[2], camera_name)

# distort image
rectified_array = calibration.distort_by_calibration(image_array, dst_calib, src_calib, InterpolationMethod.BILINEAR)

# visualize input and results
plt.figure()
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
fig.suptitle(f"Image undistortion (focal length = {dst_calib.get_focal_lengths()})")

axes[0].imshow(image_array, cmap="gray", vmin=0, vmax=255)
axes[0].title.set_text(f"sensor image ({sensor_name})")
axes[0].tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)
axes[1].imshow(rectified_array, cmap="gray", vmin=0, vmax=255)
axes[1].title.set_text(f"undistorted image ({sensor_name})")
axes[1].tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)
plt.show()

Note the rectified image shows a circular area of visible pixels. If you want the entire rectified image to be covered by pixels, you can increase the magnification.

# Retrieve on-device machine perception data: EyeGaze + HandTracking (Aria Gen2 only)

Goals:
- Learn how to retrieve on-device machine perception data from VRS

Key learnings:
- Learn what on-device MP data streams are available in Aria Gen2. 
- Learn how to query such data either by timestamp, or by index. 
- Learn how to match the on-device MP data with camera images.

In [None]:
# Helper functions for on device MP plotting
from typing import List
from projectaria_tools.core.mps import hand_tracking
from matplotlib.collections import LineCollection

def create_hand_skeleton_segments_from_landmarks(
    all_landmark_locations, segment_landmark_names):
    skeleton_segments = []

    # insert pairs into outline segments
    for i in range(len(segment_landmark_names) - 1):
        start_index = segment_landmark_names[i]
        end_index = segment_landmark_names[i + 1]
        skeleton_segments.append(
            [all_landmark_locations[start_index], all_landmark_locations[end_index]]
        )
    return skeleton_segments


def create_hand_skeleton_from_landmarks(landmark_locations):
    HandLandmark = hand_tracking.HandLandmark
    hand_skeleton = []
    # Palm shape
    hand_skeleton.extend(
        create_hand_skeleton_segments_from_landmarks(
            landmark_locations,
            [
                HandLandmark.WRIST,
                HandLandmark.THUMB_INTERMEDIATE,
                HandLandmark.INDEX_PROXIMAL,
                HandLandmark.MIDDLE_PROXIMAL,
                HandLandmark.RING_PROXIMAL,
                HandLandmark.PINKY_PROXIMAL,
                HandLandmark.WRIST,
                HandLandmark.PALM_CENTER,
            ],
        )
    )

    # Thumb line
    hand_skeleton.extend(
        create_hand_skeleton_segments_from_landmarks(
            landmark_locations,
            [
                HandLandmark.WRIST,
                HandLandmark.THUMB_INTERMEDIATE,
                HandLandmark.THUMB_DISTAL,
                HandLandmark.THUMB_FINGERTIP,
            ],
        )
    )

    # Index line
    hand_skeleton.extend(
        create_hand_skeleton_segments_from_landmarks(
            landmark_locations,
            [
                HandLandmark.WRIST,
                HandLandmark.INDEX_PROXIMAL,
                HandLandmark.INDEX_INTERMEDIATE,
                HandLandmark.INDEX_DISTAL,
                HandLandmark.INDEX_FINGERTIP,
            ],
        )
    )

    # Middle line
    hand_skeleton.extend(
        create_hand_skeleton_segments_from_landmarks(
            landmark_locations,
            [
                HandLandmark.WRIST,
                HandLandmark.MIDDLE_PROXIMAL,
                HandLandmark.MIDDLE_INTERMEDIATE,
                HandLandmark.MIDDLE_DISTAL,
                HandLandmark.MIDDLE_FINGERTIP,
            ],
        )
    )

    # Ring line
    hand_skeleton.extend(
        create_hand_skeleton_segments_from_landmarks(
            landmark_locations,
            [
                HandLandmark.WRIST,
                HandLandmark.RING_PROXIMAL,
                HandLandmark.RING_INTERMEDIATE,
                HandLandmark.RING_DISTAL,
                HandLandmark.RING_FINGERTIP,
            ],
        )
    )

    # Pinky line
    hand_skeleton.extend(
        create_hand_skeleton_segments_from_landmarks(
            landmark_locations,
            [
                HandLandmark.WRIST,
                HandLandmark.PINKY_PROXIMAL,
                HandLandmark.PINKY_INTERMEDIATE,
                HandLandmark.PINKY_DISTAL,
                HandLandmark.PINKY_FINGERTIP,
            ],
        )
    )

    # Remove segments that may contain empty pixels
    hand_skeleton = list(
        filter(lambda x: x[0] is not None and x[1] is not None, hand_skeleton)
    )

    return hand_skeleton


def plot_single_hand(axes, hand_markers_in_device, rgb_calib, hand_label):
    hand_markers_in_rgb = []
    # Project markers into RGB camera frame
    for marker_in_device in hand_markers_in_device:
        marker_in_rgb = rgb_calib.project(rgb_calib.get_transform_device_camera().inverse() @ marker_in_device)
        hand_markers_in_rgb.append(marker_in_rgb)

    # Create hand skeleton
    hand_skeleton = create_hand_skeleton_from_landmarks(hand_markers_in_rgb)
    hand_skeleton_line_collection = LineCollection(hand_skeleton, linewidths=2, colors='g')

    # Remove "None" markers from hand joints in camera. This is intentionally done AFTER the hand skeleton creation
    hand_markers_in_rgb = list(
        filter(lambda x: x is not None, hand_markers_in_rgb)
    )
    if len(hand_markers_in_rgb) == 0:
        return

    hand_markers_x = [x[0] for x in hand_markers_in_rgb]
    hand_markers_y = [x[1] for x in hand_markers_in_rgb]

    # Plot hand markers
    if hand_label == "left":
        color = "orangered"
    else:
        color = "yellow"
    axes.plot(hand_markers_x, hand_markers_y, 'o', markersize=5, color=color)  # 'o' is for circle markers

    axes.add_collection(hand_skeleton_line_collection)

def plot_hand_pose_data(axes, provider, timestamp, time_domain, time_tolerance, rgb_calib):
    hand_stream_id = provider.get_stream_id_from_label("handtracking")
    if hand_stream_id is None:
        print("Hand tracking stream not found in current VRS, skipping.")
    else:
        # Query hand pose data
        hand_pose_data = provider.get_hand_pose_data_by_time_ns(hand_stream_id, timestamp, time_domain)

        if abs(hand_pose_data.tracking_timestamp.total_seconds() * 1e9 - timestamp) <= time_tolerance and (hand_pose_data.left_hand is not None or hand_pose_data.right_hand is not None):
            print("Hand data valid at this timestamp")

            if hand_pose_data.left_hand is not None:
                plot_single_hand(axes, hand_pose_data.left_hand.landmark_positions_device, rgb_calib, "left")
            if hand_pose_data.right_hand is not None:
                plot_single_hand(axes, hand_pose_data.right_hand.landmark_positions_device, rgb_calib, "right")
            plt.show()

        else:
            print("Hand data invalid at this timestamp")


def plot_eye_gaze_data(axes, provider, timestamp, time_domain, time_tolerance, rgb_calib, T_device_cpf):
    eyegaze_stream_id = provider.get_stream_id_from_label("eyegaze")
    if eyegaze_stream_id is None:
        print("eyegaze stream not found in current VRS, skipping.")
    else:
        # Query eyegaze data
        eyegaze_data = provider.get_eye_gaze_data_by_time_ns(eyegaze_stream_id, timestamp, time_domain)

        if eyegaze_data.spatial_gaze_point_valid and abs(eyegaze_data.tracking_timestamp.total_seconds() * 1e9 - timestamp) <= time_tolerance:
            print("spatial gaze point is valid at this timestamp")
            spatial_gaze_point_in_cpf = eyegaze_data.spatial_gaze_point_in_cpf
            spatial_gaze_point_in_device = T_device_cpf @ spatial_gaze_point_in_cpf

            # Project spatial gaze point into RGB frame
            point = rgb_calib.get_transform_device_camera().inverse() @ spatial_gaze_point_in_device
            projected_gaze_point = rgb_calib.project(point)
            if projected_gaze_point is not None:
                # Plot a red cross as gaze point
                axes.plot(projected_gaze_point[0], projected_gaze_point[1], 'ro', linewidth = 3, markersize=8)
                axes.text(projected_gaze_point[0]+50, projected_gaze_point[1], 'EyeGazePoint', color='red', fontsize=10,
                    bbox=dict(facecolor='red', alpha=0, boxstyle='round,pad=0.5'))
            else:
                print("eyegaze point projection out of camera frame")
        else:
            print("spatial gaze point is not valid at this timestamp")

In [None]:
if device_version == DeviceVersion.Gen2:
    # Use a slider to get a certain RGB frame, and try to plot the corresponding EyeGaze and HandPose data in RGB image.
    rgb_stream_id = StreamId("214-1")
    time_domain = TimeDomain.DEVICE_TIME
    num_rgb_frames = provider.get_num_data(rgb_stream_id)

    # Get RGB calibration
    device_calib = provider.get_device_calibration()
    rgb_calib = device_calib.get_camera_calib("camera-rgb")
    T_device_cpf = device_calib.get_transform_device_cpf()

    # Create a widget with slider to choose an RGB frame to plot
    import ipywidgets as widgets
    from IPython.display import display
    from functools import partial

    # Get the very first frame (frame=0) so we can initialize the image.
    initial_rgb_record = provider.get_image_data_by_index(rgb_stream_id, 0)
    initial_rgb_array  = initial_rgb_record[0].to_numpy_array()
    # Normalize [0,255] → [0,1]
    initial_norm = (initial_rgb_array - 0) / 255.0
    initial_norm = np.clip(initial_norm, 0, 1)

    # Create figure & axes just once:
    fig, axes = plt.subplots(figsize=(6, 6))
    img_handle = axes.imshow(initial_norm, cmap="gray", vmin=0, vmax=1)
    axes.axis("off")  # hide ticks
    plt.close(fig)

    output = widgets.Output()
    with output:
        display(fig)

    slider = widgets.IntSlider(value=0, min=0, max=num_rgb_frames-1, continuous_update = False)

    def on_slider_change(change, output, provider, time_domain, rgb_stream_id,  rgb_calib, T_device_cpf):
        with output: # you need this for Bento Next
            output.clear_output(wait=True)
            rgb_frame_index = change['new']
            print(f"Selecting RGB frame {rgb_frame_index}")

            # Plot RGB image
            rgb_image_and_record = provider.get_image_data_by_index(
                rgb_stream_id, rgb_frame_index)
            rgb_image_array = rgb_image_and_record[0].to_numpy_array()
            rgb_timestamp = rgb_image_and_record[1].capture_timestamp_ns
            min_val, max_val = 0, 255  # Set your desired min and max values
            normalized_rgb_image = (rgb_image_array - min_val) / (max_val - min_val)
            normalized_rgb_image = np.clip(normalized_rgb_image, 0, 1)  # Ensure values are within [0, 1]

            img_handle.set_data(normalized_rgb_image)

            # Remove any old overlays (eye gaze / hand pose) from previous call
            for artist in axes.artists + axes.lines + axes.collections:
                artist.remove()
            for txt in axes.texts:
                txt.remove()

            # tolerance time to ensure the MP data is close to the query time.
            time_tolerance = 500e6

            # Plot Eye gaze data
            plot_eye_gaze_data(axes, provider, rgb_timestamp, time_domain, time_tolerance, rgb_calib, T_device_cpf)

            # Plot hand pose data
            plot_hand_pose_data(axes, provider, rgb_timestamp, time_domain, time_tolerance, rgb_calib)

            display(fig)


    # Attach the function to the slider
    print("Please select a RGB Frame ID, note that plotting may be slow in Bento notebook")
    wrapped_function = partial(on_slider_change, output = output, provider = provider, time_domain=time_domain, rgb_stream_id=rgb_stream_id, rgb_calib=rgb_calib, T_device_cpf=T_device_cpf)
    slider.observe(wrapped_function, names='value')

    display(slider, output)
else:
    print("On-device machine perception data is only available in Aria Gen2. ")

## Retrieve on-device machine perception data (VIO high frequency and VIO)

Goals:
- Learn how to retrieve on-device machine perception data (VIO, VIO high frequency) from VRS

Key learnings:
- Learn how to query VIO pose information from VRS.

In [None]:
import plotly.graph_objs as go
from matplotlib import pyplot as plt
from projectaria_tools.core.sophus import SE3f
from projectaria_tools.core.sensor_data import TrackingQuality

# Tune this parameter to control the plotted camera frustum size
CAMERA_FRUSTUM_SIZE = 0.1

# Helper function to build the frustum
def build_camera_frustum(T_world_camera):
    points = (
        np.array(
            [[0, 0, 0], [0.5, 0.5, 1], [-0.5, 0.5, 1], [-0.5, -0.5, 1], [0.5, -0.5, 1]]
        )
        * CAMERA_FRUSTUM_SIZE
    )
    points_transformed = T_world_camera @ points.transpose()
    return go.Mesh3d(
        x=points_transformed[0, :],
        y=points_transformed[1, :],
        z=points_transformed[2, :],
        i=[0, 0, 0, 0, 1, 1],
        j=[1, 2, 3, 4, 2, 3],
        k=[2, 3, 4, 1, 3, 4],
        showscale=False,
        visible=False,
        colorscale="jet",
        intensity=points[:, 2],
        opacity=1.0,
        hoverinfo="none",
    )

# helper function to cast from double to float
def cast_SE3_to_SE3f(se3_double):
    # Ensure size=1
    if len(se3_double) != 1:
        raise ValueError("Expected SE3 of size 1 for this cast helper")
    mat = se3_double.to_matrix()           # shape (4,4)
    mat_f = mat.astype(np.float32)
    return SE3f.from_matrix(mat_f)         # returns SE3f of size 1

vio_high_freq_stream_id = provider.get_stream_id_from_label("vio_high_frequency")
vio_stream_id = provider.get_stream_id_from_label("vio")
if vio_high_freq_stream_id is not None and vio_stream_id is not None:
    T_device_rgb = device_calib.get_transform_device_sensor("camera-rgb")

    vio_high_freq_data_num = provider.get_num_data(vio_high_freq_stream_id)

    # Record RGB locations in the vio-high-freq trajectory (subsample by 20)
    vio_high_freq_subsample_rate = 20
    vio_high_freq_trajectory = np.empty([vio_high_freq_data_num // vio_high_freq_subsample_rate + 1, 3])
    print(f"--- vio high freq: num of data is {vio_high_freq_data_num}, size of traj is {vio_high_freq_trajectory.shape}")
    all_high_freq_poses = []
    j = 0
    for i in range(0, vio_high_freq_data_num, vio_high_freq_subsample_rate):
        vio_high_freq_pose = provider.get_vio_high_freq_data_by_index(vio_high_freq_stream_id, i)
        T_odometry_rgb = vio_high_freq_pose.transform_odometry_device @ T_device_rgb
        vio_high_freq_trajectory[j, :] = T_odometry_rgb.translation()
        all_high_freq_poses.append(vio_high_freq_pose)
        j = j+1

    # Plot camera frustum trace along high freq trajectory
    cam_frustums = [None]*len(vio_high_freq_trajectory)
    steps = [None] * len(vio_high_freq_trajectory)
    for i in range(len(vio_high_freq_trajectory)):
        pose = all_high_freq_poses[i]
        cam_frustums[i] = build_camera_frustum(pose.transform_odometry_device @ T_device_rgb)
        timestamp = pose.tracking_timestamp.total_seconds()
        step = dict(method="update", args=[{"visible": [False] * len(cam_frustums) + [True] * 2}, {"title": "Trajectory and Point Cloud"},], label=timestamp,)
        step["args"][0]["visible"][i] = True  # Toggle i'th trace to "visible"
        steps[i] = step
    cam_frustums[0].visible = True

    # Record RGB poses in the vio trajectory, check validity
    valid_vio_poses = []
    vio_data_num = provider.get_num_data(vio_stream_id)
    for i in range(vio_data_num):
        vio_data = provider.get_vio_data_by_index(vio_stream_id, i)
        # Check if the pose quality is GOOD
        if vio_data.pose_quality == TrackingQuality.GOOD:
            T_odometry_rgb = (vio_data.transform_odometry_bodyimu @
                              vio_data.transform_bodyimu_device @
                              cast_SE3_to_SE3f(T_device_rgb))
            valid_vio_poses.append(T_odometry_rgb.translation().transpose())
    
    # Convert the list of good poses to a NumPy array
    vio_trajectory = np.array(valid_vio_poses).squeeze()

    # Create slider to allow scrubbing and set the layout
    sliders = [dict(currentvalue={"suffix": " s", "prefix": "Time :"}, pad={"t": 5}, steps=steps,)]
    layout = go.Layout(
        sliders=sliders,
        scene=dict(
            bgcolor='lightgray',
            dragmode='orbit',
            aspectmode='data',
            xaxis_visible=False,
            yaxis_visible=False,
            zaxis_visible=False,
            camera=dict(
            eye=dict(x=0.5, y=0.5, z=0.5),
            center=dict(x=0, y=0, z=0),
            up=dict(x=0, y=0, z=1)
        )),
        width=1100,
        height=1000,
    )

    # Plot trajectory
    plotter_vio_high_freq_trajectory = go.Scatter3d(x=vio_high_freq_trajectory[:, 0], y=vio_high_freq_trajectory[:, 1], z=vio_high_freq_trajectory[:, 2],
                                                    mode="markers", marker={"size": 2, "opacity": 0.8, "color": "red"},
                                                    name="Vio High Freq Trajectory")
    plotter_vio_trajectory = go.Scatter3d(x=vio_trajectory[:, 0], y=vio_trajectory[:, 1], z=vio_trajectory[:, 2],
                                          mode="markers", marker={"size": 4, "opacity": 0.8, "color": "green"},
                                          name="Vio Trajectory")

    # draw
    plot_figure = go.Figure(data=cam_frustums + [plotter_vio_high_freq_trajectory, plotter_vio_trajectory], layout=layout)
    plot_figure.show()

else:
    print("Vio high-freq stream does not exist in the current VRS file. ")

# Image color correction and devignetting examples (Aria Gen1 only)
## Correcting Color Distortion in Older Aria Captures
Videos and images captured with earlier versions of the Aria OS may exhibit color distortion due to inconsistent gamma curves and unconventional color temperatures. This can result in colors appearing inconsistent across images and overly blue.
This issue has been resolved in the new OS update V1.13. For images and videos captured before this update, we offer a Color Correction API to address the distortion. The images will be corrected to a reference color temperature of 5000K. 

Below, we demonstrate how to apply color correction: 
1. set `set_color_correction` with True, default value is False
2. The output from `provider.get_image_data_by_index` would be color corrected. 

In [None]:
if device_version == DeviceVersion.Gen1:
    # save source image for comparison
    stream_id = provider.get_stream_id_from_label("camera-rgb")
    provider.set_color_correction(False)
    provider.set_devignetting(False)
    src_image_array = provider.get_image_data_by_index(stream_id, 0)[0].to_numpy_array()

    provider.set_color_correction(True)
    provider.set_devignetting(False)
    color_corrected_image_array = provider.get_image_data_by_index(stream_id, 0)[0].to_numpy_array()

    # visualize input and results
    plt.figure()
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    fig.suptitle(f"Color Correction")

    axes[0].imshow(src_image_array, vmin=0, vmax=255)
    axes[0].title.set_text(f"before color correction")
    axes[1].imshow(color_corrected_image_array, vmin=0, vmax=255)
    axes[1].title.set_text(f"after color correction")

    plt.show()
else:
    print("Color correction feature is Gen1 only")

#### Devignetting

Devignetting corrects uneven lighting, enhancing image uniformity and clarity. We provide devignetting for camera-rgb full size image [2880, 2880], camera-rgb half size image[1408, 1408] and slam image [640, 480].
1. Aria devignetting masks can be downloaded from [Link](https://www.projectaria.com/async/sample/download/?bucket=core&filename=devignetting_masks_bin.zip). It contains the following files:

```
devignetting_masks_bin
|- new_isp
   |- slam_devignetting_mask.bin
   |- rgb_half_devignetting_mask.bin
   |- rgb_full_devignetting_mask.bin
|- old_isp
   |- slam_devignetting_mask.bin
   |- rgb_half_devignetting_mask.bin
   |- rgb_full_devignetting_mask.bin
```
2. Turn on devignetting. Set devignetting mask folder path with the local aria camera devignetting masks folder path.
   `set_devignetting(True)`
   `mask_folder_path = "devignetting_masks_bin"`
   `set_devignetting_mask_folder_path(mask_folder_path)`
3. The image data from `get_image_data_by_index` will be devignetted. 
4. (Optional) If you don't want to devignetting feature, turn off by calling `set_devignetting(False)`

In [None]:
if device_version == DeviceVersion.Gen1:
    # ==============================================================================
    # Step 1: Download devignetting mask
    # ==============================================================================
    from urllib.request import urlretrieve
    import zipfile
    import ssl
    ssl._create_default_https_context = ssl._create_unverified_context

    # Download from url
    devignetting_mask_folder_path = os.path.join(repo_path, "devignetting_masks")
    downloaded_devignetting_mask_zip = os.path.join(devignetting_mask_folder_path, "aria_camera_devignetting_masks.zip")
    if not os.path.exists(devignetting_mask_folder_path):
        os.mkdir(devignetting_mask_folder_path)
    urlretrieve("https://www.projectaria.com/async/sample/download/?bucket=core&filename=devignetting_masks_bin.zip", downloaded_devignetting_mask_zip)

    # unzip the mask files, with cross-platform compatibility
    with zipfile.ZipFile(downloaded_devignetting_mask_zip, 'r') as zip_ref:
        # Extract all files
        zip_ref.extractall(devignetting_mask_folder_path)

        # Print out the filenames
        print(f"Successfully downloaded and extracted the following files for devignetting:")
        for file_info in zip_ref.infolist():
            print(file_info.filename)

    # ==============================================================================
    # Step 2: Turn on devignetting and set devignetting mask folder path
    # ==============================================================================
    index = 1
    provider.set_devignetting(False)
    provider.set_color_correction(False)
    src_image_array = provider.get_image_data_by_index(stream_id, index)[0].to_numpy_array()
    provider.set_devignetting(True)
    provider.set_devignetting_mask_folder_path(devignetting_mask_folder_path)

    # ==============================================================================
    # Step 3: Retrieve Image from stream
    # ==============================================================================
    devignetted_image_array = provider.get_image_data_by_index(stream_id, index)[0].to_numpy_array()

    # visualize input and results
    plt.figure()
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    fig.suptitle(f"Image devignetting (camera-rgb)")

    axes[0].imshow(src_image_array, vmin=0, vmax=255)
    axes[0].title.set_text(f"before devignetting")
    axes[1].imshow(devignetted_image_array, vmin=0, vmax=255)
    axes[1].title.set_text(f"after devignetting")

    plt.show()
else:
    print("Devignetting is only supported on Gen1.")