# Data loading example

In [12]:
from atek.data_loaders.atek_wds_dataloader import load_atek_wds_dataset 
import yaml
import os
import webdataset as wds
import numpy as np
TRAJECTORY_COLOR = [30, 100, 30]
GT_COLOR = [30, 200, 30]
PRED_COLOR = [200, 30, 30]

tar_dir = "/home/louy/Calibration_data_link/Atek/2024_05_28_CubeRcnnTest/wds_output/test_cubercnn_adt/"

tars = [
    os.path.join(tar_dir, f"shards-000{i}.tar") for i in range(7)
]

## Default native loading

In [13]:
wds_dataset = load_atek_wds_dataset(urls = tars)

# Print dict keys
for obj in wds_dataset:
    for key, value in obj.items():
        print(f"{key} has type of {type(value)}")
    break

__key__ has type of <class 'str'>
__url__ has type of <class 'str'>
gtdata has type of <class 'dict'>
mfcd#camera-rgb+t_device_camera has type of <class 'torch.Tensor'>
mfcd#camera-rgb+camera_label has type of <class 'str'>
mfcd#camera-rgb+camera_model_name has type of <class 'str'>
mfcd#camera-rgb+capture_timestamps_ns has type of <class 'torch.Tensor'>
mfcd#camera-rgb+exposure_durations_s has type of <class 'torch.Tensor'>
mfcd#camera-rgb+frame_ids has type of <class 'torch.Tensor'>
mfcd#camera-rgb+gains has type of <class 'torch.Tensor'>
mfcd#camera-rgb+origin_camera_label has type of <class 'str'>
mfcd#camera-rgb+projection_params has type of <class 'torch.Tensor'>
mtd#ts_world_device has type of <class 'torch.Tensor'>
mtd#capture_timestamps_ns has type of <class 'torch.Tensor'>
mtd#gravity_in_world has type of <class 'torch.Tensor'>
mfcd#camera-rgb+images has type of <class 'torch.Tensor'>


In [9]:
print("RGB image tensor shape: ", obj['mfcd#camera-rgb+images'].shape)
print(obj['mfcd#camera-rgb+projection_params'])

RGB image tensor shape:  torch.Size([1, 3, 1408, 1408])
tensor([610.9410, 610.9410, 703.5000, 703.5000])


## Visualization for native data

In [15]:
import rerun as rr
from projectaria_tools.core.sophus import SE3
from projectaria_tools.utils.rerun_helpers import ToTransform3D

# Data visualization
rr.init("ATEK Data Loader Viewer", spawn=True)
rr.serve(web_port=8888, ws_port=8877)

[2024-06-28T19:34:48Z INFO  re_ws_comms::server] Shutting down Rerun server on ws://localhost:8877
[2024-06-28T19:34:48Z INFO  re_web_viewer_server] Shutting down web server on http://localhost:38111
[2024-06-28T19:34:48Z INFO  winit::platform_impl::platform::x11::window] Guessed window scale factor: 1.1041666666666667
[2024-06-28T19:34:48Z INFO  tracing::span] perform;
[2024-06-28T19:34:48Z INFO  zbus::handshake] write_command; command=Auth(Some(External), Some([49, 48, 48, 48]))
[2024-06-28T19:34:48Z INFO  tracing::span] read_command;
[2024-06-28T19:34:48Z INFO  zbus::handshake] write_command; command=NegotiateUnixFD
[2024-06-28T19:34:48Z INFO  tracing::span] read_command;
[2024-06-28T19:34:48Z INFO  zbus::handshake] write_command; command=Begin
[2024-06-28T19:34:48Z INFO  tracing::span] socket reader;
[2024-06-28T19:34:48Z INFO  tracing::span] perform;
[2024-06-28T19:34:48Z INFO  zbus::handshake] write_command; command=Auth(Some(External), Some([49, 48, 48, 48]))
[2024-06-28T19:34:4

In [17]:
from projectaria_tools.utils.rerun_helpers import ToTransform3D

def log_pred_3d_2d_bbox(atek_wds_dict_all):
    i_frame = 0
    for atek_wds_dict in atek_wds_dict_all:
        T_world_device = SE3.from_matrix3x4(atek_wds_dict["mtd#ts_world_device"][i_frame, :, :])
        T_device_cam = SE3.from_matrix3x4(atek_wds_dict["mfcd#camera-rgb+t_device_camera"])
        # HWC -> CWH
        image = atek_wds_dict["mfcd#camera-rgb+images"][i_frame].detach().cpu().permute(1, 2, 0).numpy()

        # log device and camera locations
        rr.log(
            f"world", 
            ToTransform3D(SE3(), False),
        )
        
        rr.log(
            f"world/device", 
            ToTransform3D(T_world_device, False),
        )
        
        rr.log(
            f"world/camera-rgb",
            ToTransform3D(T_world_device @ T_device_cam, False),
        )
        
        

        # log images
        rr.log(
            f"image",
            rr.Image(image),
        )

        # For testing only
        img_timestamp = atek_wds_dict["mfcd#camera-rgb+capture_timestamps_ns"][i_frame].item()
        rr.set_time_seconds("frame_time_ns", img_timestamp)
        pose_timestamp = atek_wds_dict["mtd#capture_timestamps_ns"][i_frame].item()
        # gt_timestamp = int(list(atek_wds_dict["gtdata"].keys())[i_frame])
        print(f"img_time: {img_timestamp}, pose_time: {pose_timestamp}, difference in us: {(img_timestamp - pose_timestamp)/1e3}")

        # Log 3d bbox
        bb3ds_centers_infer = []
        bb3ds_quats_xyzw_infer = []
        bb3ds_sizes_infer = []
        labels_infer = []
        objs = list(atek_wds_dict["gtdata"]["obb3_gt"]["bbox3d_all_instances"].values())
        for obj_gt_dict in objs:
            # Only plot chair
            if obj_gt_dict["category_id"] not in [1, 4]:
                continue
            T_world_obj = SE3.from_matrix3x4(obj_gt_dict["T_World_Object"])
            bb3ds_centers_infer.append(T_world_obj.translation()[0])
            wxyz = T_world_obj.rotation().to_quat()[0]
            bb3ds_quats_xyzw_infer.append([wxyz[3], wxyz[0], wxyz[1], wxyz[2]])
            bb3ds_sizes_infer.append(np.array(obj_gt_dict["object_dimensions"]))
            labels_infer.append(obj_gt_dict["category_name"])            
        
        # log 3D bounding boxes
        rr.log(
            f"world/bb3d_infer",
            rr.Boxes3D(
                sizes=bb3ds_sizes_infer,
                centers=bb3ds_centers_infer,
                rotations=bb3ds_quats_xyzw_infer,
                radii=0.01,
                colors=PRED_COLOR,
                labels=labels_infer,
            ),
        )

        # Log 2d bbox
        bb2ds_all = []
        for obj_2d_dict in atek_wds_dict["gtdata"]["obb2_gt"]["camera-rgb"].values():
            # Only plot coffee table
            if obj_2d_dict["category_id"] not in [1, 4]:
                continue
            bb2d = obj_2d_dict["box_range"]
            bb2ds_XYXY = np.array([bb2d[0], bb2d[2], bb2d[1], bb2d[3]])
            bb2ds_all.append(bb2ds_XYXY)
        
        if len(bb2ds_all) == 0:
            print(f" ---- -- debug: no 2d bboxes found for frame {i_frame}")
        
        rr.log(
            f"image/bb2d_gt",
            rr.Boxes2D(
                array=bb2ds_all,
                array_format=rr.Box2DFormat.XYXY,
                radii=1,
                colors=GT_COLOR,
                # labels=labels_infer,
            ),
        )


    i_frame += 1
    

for data_dict in wds_dataset:
    log_pred_3d_2d_bbox(data_dict)

TypeError: string indices must be integers

## Load as CubeRCNN data example

In [None]:
import torch

from detectron2.data import detection_utils
from detectron2.structures import Boxes, BoxMode, Instances

from atek.data_loaders.cubercnn_model_adaptor import load_atek_wds_dataset_as_cubercnn
from tqdm import tqdm

dataset = load_atek_wds_dataset_as_cubercnn(tars)

sample = next(iter(dataset))

print(sample.keys())
print("Image shape: ", sample['image'].shape)
print("K: ", sample['K'])