# ATEK Demo 1: ATEK data preprocessing -> visualization -> loading -> inference

This demo will walk through the steps of preparing an Aria data sequence with annotations ([AriaDigitalTwin (ADT)](https://www.projectaria.com/datasets/adt/)), for use in a 3D object detection ML task. 

We include the following 3 examples: 
* **Example 1**: How to pre-process Aria VRS + MPS + annotation data for training / inference, and save as WebDataset (WDS) format. 
* **Example 2**: How to load ATEK preprocessed WDS data into model-compatible DataLoader. 
* **Example 3**: How to run model inference with ATEK preprocessed data. 

In [None]:
import faulthandler

import logging
import os
from logging import StreamHandler
import numpy as np
from typing import Dict, List, Optional
import torch
import sys
from tqdm import tqdm

from atek.data_preprocess.genera_atek_preprocessor_factory import (
    create_general_atek_preprocessor_from_conf,
)
from atek.viz.atek_visualizer_base import NativeAtekSampleVisualizer
from atek.data_preprocess.general_atek_preprocessor import GeneralAtekPreprocessor
from atek.data_loaders.atek_wds_dataloader import (
    create_native_atek_dataloader
)
from atek.data_loaders.cubercnn_model_adaptor import (
    cubercnn_collation_fn,
    create_atek_dataloader_as_cubercnn
)
from atek.data_preprocess.atek_data_sample import (
    create_atek_data_sample_from_flatten_dict,
)
from cubercnn.config import get_cfg_defaults
from cubercnn.modeling.backbone import build_dla_from_vision_fpn_backbone  # noqa
from cubercnn.modeling.meta_arch import build_model  # noqa
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from omegaconf import OmegaConf

faulthandler.enable()

# Configure logging to display the log messages in the notebook
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout)
    ]
)

logger = logging.getLogger()

# Prettier colors
COLOR_GREEN = [42,157,143]
COLOR_RED = [231, 111, 81]

# -------------------- Helper functions --------------------#
def print_data_sample_dict_content(data_sample, if_pretty: bool = False):
    """
    A helper function to print the content of data sample dict
    """
    logger.info("Printing the content in a ATEK data sample dict: ")
    for key, val in data_sample.items():
        if if_pretty and "#" in key:
            key = key.split("#", 1)[1]
        
        msg = f"\t {key}: is a {type(val)}, "
        if isinstance(val, torch.Tensor):
            msg += f"with shape of : {val.shape}"
        elif isinstance(val, list):
            msg += f"with len of : {len(val)}"
        elif isinstance(val, str):
            msg += f"value is {val}"
        else:
            pass
        logger.info(msg)

def create_inference_model(config_file, ckpt_dir, use_cpu_only=False):
    """
    Create the model for inference pipeline, with the model config.
    """
    # Create default model configuration
    model_config = get_cfg()
    get_cfg_defaults(model_config)

    # add extra configs for data
    model_config.MAX_TRAINING_ATTEMPTS = 3
    model_config.TRAIN_LIST = ""
    model_config.TEST_LIST = ""
    model_config.TRAIN_WDS_DIR = ""
    model_config.TEST_WDS_DIR = ""
    model_config.ID_MAP_JSON = ""
    model_config.OBJ_PROP_JSON = ""
    model_config.CATEGORY_JSON = ""
    model_config.DATASETS.OBJECT_DETECTION_MODE = ""
    model_config.SOLVER.VAL_MAX_ITER = 0
    model_config.SOLVER.MAX_EPOCH = 0

    model_config.merge_from_file(config_file)
    if use_cpu_only:
        model_config.MODEL.DEVICE = "cpu"
    model_config.freeze()

    model = build_model(model_config, priors=None)

    _ = DetectionCheckpointer(model, save_dir=ckpt_dir).resume_or_load(
        model_config.MODEL.WEIGHTS, resume=True
    )
    model.eval()

    return model_config, model


## Set up data and code paths

In [None]:
# Follow the following guide to download example ADT sequence to a local path `~/Documents/projectaria_tools_adt_data`
# https://facebookresearch.github.io/projectaria_tools/docs/open_datasets/aria_digital_twin_dataset/dataset_download.

# Set up local data paths
data_dir = os.path.join(os.path.expanduser("~"), "Documents", "projectaria_tools_adt_data")
sequence_name = "Apartment_release_golden_skeleton_seq100_10s_sample_M1292"
example_adt_data_dir = os.path.join(data_dir, sequence_name)
output_wds_path = os.path.join(data_dir, "wds_output")

# Set up ATEK paths
atek_src_path = os.path.join(os.path.expanduser("~"), "atek_on_fbsource")
atek_preprocess_config_path = "/home/louy/Calibration_data_link/Atek/2024_08_05_DryRun/adt_cubercnn_preprocess_config.yaml"
category_mapping_file = os.path.join(atek_src_path, "data", "adt_prototype_to_atek.csv")
preprocess_conf = OmegaConf.load(atek_preprocess_config_path)

# Set up trained model weight path
model_ckpt_path = "/home/louy/Calibration_data_link/Atek/pre_trained_models/2024_08_28_AdtCubercnnWeights"

# Example 1: ATEK data preprocessing
In this example, we demonstrate how to preprocess Aria data sequences for ML training, and how to customize by simply changing a configuration file.  


The expected output should contain iterable data samples, each containing time-aligned camera images, trajectory, calibration info, along with annotation information. And 


### Step 1: Set up ATEK data preprocessor
First, user will create a `GeneralAtekDataPreprocessor` that provides high level APIs for preprocessing. 

In [None]:
# Create ATEK preprocessor from conf. It will automatically choose which type of sample to build.
atek_preprocessor = create_general_atek_preprocessor_from_conf(
    # [required]
    conf=preprocess_conf,  
    raw_data_folder = example_adt_data_dir,   
    sequence_name = sequence_name, 
    # [optional]
    output_wds_folder=output_wds_path, 
    output_viz_file=os.path.join(example_adt_data_dir, "atek_preprocess_viz.rrd"),
    category_mapping_file=category_mapping_file,
)


## Step 2: Content of a preprocessed ATEK data sample
User can directly get an ATEK data sample from `GeneralAtekDataProcessor`'s `[]` operator. Each data sample contains grouped sensor, MPS, and annotation data, user can inspect its content as a flattened dictionary. `.process_all_samples()` allow user to preprocess the entire sequence, visualize, and save the results as WebDataset (WDS) files to local disk for future use. 

In [None]:
atek_data_sample = atek_preprocessor[0]
atek_data_sample_dict = atek_data_sample.to_flatten_dict()
print_data_sample_dict_content(atek_data_sample_dict)

# Loop over all samples, and write valid ones to local tar files.
atek_preprocessor.process_all_samples(write_to_wds_flag=True, viz_flag=True)

## Step 3: Customization through config
Preprocessing requirements often differ by models. In ATEK, we provide 2 levels of customization: 
1. Change config yaml file, no code change. 
2. Customized preprocessor code.    

Below is an example of config yaml file: 
```
atek_config_name: "cubercnn"
camera_temporal_subsampler:
  main_camera_label: "camera-rgb"
  time_domain: "DEVICE_TIME"
  main_camera_target_freq_hz: 10.0
  sample_length_in_num_frames: 1
  stride_length_in_num_frames: 2
processors:
  rgb:
    selected: true
    sensor_label: "camera-rgb"
    time_domain: "DEVICE_TIME"
    tolerance_ns: 10_000_000
    undistort_to_linear_camera: true  # if set, undistort to a linear camera model
    target_camera_resolution: [] # if set, rescale to [image_width, image_height]
    # rescale_antialias: true[default] controls whether to perform antialiasing during image rescaling.
    rotate_image_cw90deg: true  # if set, rotate image by 90 degrees clockwise
  slam_left:
    selected: true
    sensor_label: "camera-slam-left"
    tolerance_ns: 10_000_000
    time_domain: "DEVICE_TIME"
    rotate_image_cw90deg: true  # if set, rotate image by 90 degrees clockwise
  slam_right:
    selected: true
    sensor_label: "camera-slam-right"
    tolerance_ns: 10_000_000
    time_domain: "DEVICE_TIME"
    rotate_image_cw90deg: true  # if set, rotate image by 90 degrees clockwise
  mps_traj:
    selected: true
    tolerance_ns: 10_000_000
  mps_semidense:
    selected: false
  mps_online_calib:
    tolerance_ns: 10_000_000
  rgb_depth:
    selected: false
    convert_zdepth_to_distance: false
    unit_scale: 0.001
  obb_gt:
    selected: true
    tolerance_ns : 10_000_000
    category_mapping_field_name: prototype_name # {prototype_name, category}
    bbox2d_num_samples_on_edge: 10
wds_writer:
  prefix_string: ""
  max_samples_per_shard: 32
  remove_last_tar_if_not_full: false
```


Below we show an example of how to preprocess the same dataset, but with different preprocessing settings.   

In [None]:
# run with a different configuration: Fisheye camera + lower resolution on RGB, higher temporal subsample rate. 
new_preprocess_conf = OmegaConf.load("/home/louy/Calibration_data_link/Atek/2024_08_05_DryRun/adt_cubercnn_preprocess_config_2.yaml")
new_atek_preprocessor = create_general_atek_preprocessor_from_conf(
    conf=new_preprocess_conf,  
    raw_data_folder = example_adt_data_dir,   
    sequence_name = sequence_name, 
    category_mapping_file=category_mapping_file,
    output_wds_folder = "", # empty folder
)

new_atek_preprocessor.process_all_samples(write_to_wds_flag = False, viz_flag = True)


# Example 2: load ATEK WDS files into model-compatible format
In this example, we demonstrate how to load ATEK preprocessed data into data formats that are compatible with specific ML models via a light-weight ModelAdaptor class. Here we use CubeRCNN as an example. 

In [None]:
logger.info(
    "-------------------- ATEK WDS data can loaded into Model-specific format --------------- "
)

# Loading preprocessed WDS files that we just created
tar_file_urls = [os.path.join(output_wds_path, f"shards-000{i}.tar") for i in range(2)]

# The CubeRCNN ModelAdaptor class is wrapped in this function
cubercnn_dataloader = create_atek_dataloader_as_cubercnn(urls = tar_file_urls, batch_size = None, repeat_flag = False)
first_cubercnn_sample = next(iter(cubercnn_dataloader)) 
logger.info(f"Loading WDS into CubeRCNN format, each sample contains the following keys: {first_cubercnn_sample.keys()}")


# Example 3: Run Object detection inference using pre-trained CubeRCNN model
In Example 2, we show how users can load ATEK preprocessed data into a Pytorch DataLoader, and can load as model-specific format with the help of ModelAdaptors. In this example, we further demonstrate how to run model inference with this Pytorch DataLoader. 

## Step 1: load trained CubeRCNN model weights, and create a PyTorch DataLoader from ATEK WDS files
Use the same API in Example 2 (`create_atek_dataloader_as_cubercnn`) to a create CubeRCNN-format PyTorch DataLoader. Here we also created the ATEK-format PyTorch DataLoader for visualization purpose. 

In [None]:
from atek.viz.cubercnn_visualizer import CubercnnVisualizer
from tqdm import tqdm

# parse in config file
model_config_file = os.path.join(model_ckpt_path, "config.yaml")
conf = OmegaConf.load(model_config_file)

# setup config and model
model_config, model = create_inference_model(
    model_config_file, model_ckpt_path, False
)

# create ATEK dataloader for CubeRCNN model. The native DataLoader is only for visualization purpose.
tar_file_urls = [os.path.join(output_wds_path, f"shards-000{i}.tar") for i in range(2)]
batch_size = 6
cubercnn_dataloader = create_atek_dataloader_as_cubercnn(urls = tar_file_urls, batch_size = batch_size, num_workers = 1)

## Step 2: Run model inference over the dataset, and visualize results
Iterate through the Pytorch DataLoaders, perform inference, and visualize prediction vs GT results. 

In [None]:
# Cache for visualization
input_output_data_pairs = []

with torch.no_grad():
    for cubercnn_input_data in tqdm(
       cubercnn_dataloader,
        desc="Inference progress: ",
    ):
        cubercnn_model_output = model(cubercnn_input_data)

        # cache inference results for visualization
        input_output_data_pairs.append((cubercnn_input_data, cubercnn_model_output))

logger.info("Inference completed.")

In [None]:
# Visualize cached inference results
logger.info("Visualizing inference results.")
viz_conf = preprocess_conf.visualizer
cubercnn_visualizer = CubercnnVisualizer(viz_prefix = "inference_visualizer", conf = viz_conf)
for input_data_as_list, output_data_as_list in input_output_data_pairs:
    for single_cubercnn_input, single_cubercnn_output in zip(input_data_as_list, output_data_as_list):
        timestamp_ns = single_cubercnn_input["timestamp_ns"]
        # Plot RGB image
        cubercnn_visualizer.plot_cubercnn_img(single_cubercnn_input["image"], timestamp_ns = timestamp_ns)

        # Plot GT and prediction in different colors
        single_cubercnn_output["T_world_camera"] = single_cubercnn_input["T_world_camera"] # This patch is needed for visualization
        cubercnn_visualizer.plot_cubercnn_dict(cubercnn_dict = single_cubercnn_input, timestamp_ns = timestamp_ns, plot_color = cubercnn_visualizer.COLOR_GREEN, suffix = "_model_input")
        cubercnn_visualizer.plot_cubercnn_dict(cubercnn_dict = single_cubercnn_output, timestamp_ns = timestamp_ns, plot_color = cubercnn_visualizer.COLOR_RED, suffix = "_model_output")