# ATEK Data Proprocessing Example

### 1. Load the example adt data

In [None]:
example_adt_data_path = "/source/data/atek/adt_raw/Apartment_release_golden_skeleton_seq100_10s_sample"

from projectaria_tools.projects.adt import (
   AriaDigitalTwinDataProvider,
   AriaDigitalTwinDataPathsProvider
)

paths_provider = AriaDigitalTwinDataPathsProvider(example_adt_data_path)
all_device_serials = paths_provider.get_device_serial_numbers()

selected_device_number = 0
data_paths = paths_provider.get_datapaths_by_device_num(selected_device_number)

### 2. Frame data processor exmaple

In [None]:
from projectaria_tools.core import calibration, data_provider
from atek.data_preprocess.frame_data_processor import FrameDataProcessor
from atek.data_preprocess.pose_data_processor import PoseDataProcessor
from projectaria_tools.core.stream_id import StreamId
from atek.data_preprocess.adt_gt_data_processor import AdtGtDataProcessor
import numpy as np

rgb_adt_gt_data_processor = AdtGtDataProcessor("adt_gt", "ADT", StreamId("214-1"), data_paths)
slaml_adt_gt_data_processor = AdtGtDataProcessor("adt_gt", "ADT", StreamId("1201-1"), data_paths)
slamr_adt_gt_data_processor = AdtGtDataProcessor("adt_gt", "ADT", StreamId("1201-2"), data_paths)

pose_data_processor = PoseDataProcessor(
    name = "pose",
    trajectory_file = data_paths.aria_trajectory_filepath,
)

rotate_image_cw90deg=True

rgb_data_processor = FrameDataProcessor(
    video_vrs=data_paths.aria_vrs_filepath,
    stream_id=StreamId("214-1"),
    rotate_image_cw90deg = rotate_image_cw90deg,
    target_linear_camera_params = np.array([512,512]), # None means use the orignal image
    pose_data_processor = pose_data_processor,
    gt_data_processor = rgb_adt_gt_data_processor,
)

slaml_data_processor = FrameDataProcessor(
    video_vrs=data_paths.aria_vrs_filepath,
    stream_id=StreamId("1201-1"),
    rotate_image_cw90deg = rotate_image_cw90deg,
    target_linear_camera_params = np.array([320,240]),
    pose_data_processor = pose_data_processor,
    gt_data_processor = slaml_adt_gt_data_processor,
)

slamr_data_processor = FrameDataProcessor(
    video_vrs=data_paths.aria_vrs_filepath,
    stream_id=StreamId("1201-2"),
    rotate_image_cw90deg = rotate_image_cw90deg,
    target_linear_camera_params = np.array([320,240]),
    pose_data_processor = pose_data_processor,
    gt_data_processor = slamr_adt_gt_data_processor,
)

# Print the frame dataclass keys
rgb_image_frame = rgb_data_processor.get_frame_by_index(30)
from dataclasses import asdict
print(asdict(rgb_image_frame).keys())


# Visualize one frame of undistorted 2d bounding boxes
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from projectaria_tools.projects.adt import (
    bbox2d_to_image_coordinates,
)


fig, ax = plt.subplots(figsize=(10, 10))
plt.imshow(rgb_image_frame.image)

for bb2d in rgb_image_frame.bb2ds:
    image_coords = bbox2d_to_image_coordinates(bb2d)
    rect = patches.Polygon(image_coords, closed=True, fill=False, edgecolor='r')

    # Add the patch to the axes
    ax.add_patch(rect)

### 3. Frameset Example

In [None]:
from atek.data_preprocess.frameset_aligner import FramesetAligner

frame_data_processors = [rgb_data_processor, slaml_data_processor, slamr_data_processor]

frameset_aligner = FramesetAligner(
    target_hz = 10,
    frame_data_processors = frame_data_processors, 
    pose_data_processor = pose_data_processor,
    require_objects = True)

frameset = frameset_aligner.get_frameset_by_index(10)
print(asdict(frameset).keys())


# Visualize one frameset
fig, ax = plt.subplots(1,3, figsize=(30, 10))

for i, frame in enumerate(frameset.frames):
    if frame.image.ndim == 2:
        ax[i].imshow(frame.image, cmap='gray')
    else:
        ax[i].imshow(frame.image)

    for bb2d in frame.bb2ds:
        image_coords = bbox2d_to_image_coordinates(bb2d)
        rect = patches.Polygon(image_coords, closed=True, fill=False, edgecolor='r')

        # Add the patch to the axes
        ax[i].add_patch(rect)

### 4. Frameset Group

In [None]:
from atek.data_preprocess.frameset_group_generator import FramesetGroupGenerator, FramesetSelectionConfig

frameset_selection_config = FramesetSelectionConfig(
    num_framesets_per_group=3,
    skip_first_n_framesets=0,
    skip_last_n_framesets=0,
    stride=1,
    time_duration_ns_threshold=None,
    translation_m_threshold=None,
    rotation_deg_threshold=None,
    fov_overlapping_ratio_threshold=None,
    far_clipping_distance=4.0,
    local_selection=0,
)

frameset_group_generator = FramesetGroupGenerator(frameset_aligner=frameset_aligner, 
    frameset_selection_config=frameset_selection_config, 
    require_objects=True)

fg = frameset_group_generator.get_frameset_group_by_index(10)
print(asdict(fg).keys())

# Visualize a frameset group
n = len(fg.framesets)
fig, ax = plt.subplots(n,3, figsize=(30, 30))
plt.tight_layout()
for i, frameset in enumerate(fg.framesets):
    for j, frame in enumerate(frameset.frames):
        if frame.image.ndim == 2:
            ax[i][j].imshow(frame.image, cmap='gray')
        else:
            ax[i][j].imshow(frame.image)

        for bb2d in frame.bb2ds:
            image_coords = bbox2d_to_image_coordinates(bb2d)
            rect = patches.Polygon(image_coords, closed=True, fill=False, edgecolor='r')

            # Add the patch to the axes
            ax[i][j].add_patch(rect)

### 5. Write wds files

In [None]:
from atek.data_preprocess.webdataset_writer import (
    DataSelectionSettings,
    convert_frameset_group_to_wds_dict,
    AtekWdsWriter,
)

settings = DataSelectionSettings(
    require_traj_for_frame = True,
    require_obb2d_gt_for_frame = True,
    require_obb3d_gt_for_frame = True,
    require_traj_for_frameset = True,
    require_obb3d_gt_for_frameset = True,
    require_traj_for_frameset_group = True,
    require_obb3d_gt_for_frameset_group = True,
)

wds_dict = convert_frameset_group_to_wds_dict(0, fg, settings)

atek_wds_writer = AtekWdsWriter("/tmp/test_wds", settings, 32, None)

num_samples = 4
for i in range(min(num_samples, frameset_group_generator.frameset_group_number())):
    fg = frameset_group_generator.get_frameset_group_by_index(i)
    atek_wds_writer.add_sample(fg)

atek_wds_writer.close()