# ATEK Demo 1: ATEK data preprocessing + model inference

This demo will walk through the steps of preparing an Aria data sequence with annotations ([AriaDigitalTwin (ADT)](https://www.projectaria.com/datasets/adt/)), for use in a 3D object detection model CubeRCNN, run model inference on the preprocessed data, and evaluate the model performance. 


In [None]:
import faulthandler

import logging
import os
from logging import StreamHandler
import numpy as np
from typing import Dict, List, Optional
import torch
import sys
import subprocess
from tqdm import tqdm

from atek.data_preprocess.genera_atek_preprocessor_factory import (
    create_general_atek_preprocessor_from_conf,
)
from atek.viz.atek_visualizer import NativeAtekSampleVisualizer
from atek.data_preprocess.general_atek_preprocessor import GeneralAtekPreprocessor
from atek.data_loaders.atek_wds_dataloader import (
    create_native_atek_dataloader
)
from atek.data_loaders.cubercnn_model_adaptor import (
    cubercnn_collation_fn,
    create_atek_dataloader_as_cubercnn
)
from atek.data_preprocess.atek_data_sample import (
    create_atek_data_sample_from_flatten_dict,
)
from cubercnn.config import get_cfg_defaults
from cubercnn.modeling.backbone import build_dla_from_vision_fpn_backbone  # noqa
from cubercnn.modeling.meta_arch import build_model  # noqa
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from omegaconf import OmegaConf

faulthandler.enable()

# Configure logging to display the log messages in the notebook
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout)
    ]
)

logger = logging.getLogger()

# Prettier colors
COLOR_GREEN = [42,157,143]
COLOR_RED = [231, 111, 81]

# -------------------- Helper functions --------------------#
def print_data_sample_dict_content(data_sample, if_pretty: bool = False):
    """
    A helper function to print the content of data sample dict
    """
    logger.info("Printing the content in a ATEK data sample dict: ")
    for key, val in data_sample.items():
        if if_pretty and "#" in key:
            key = key.split("#", 1)[1]

        msg = f"\t {key}: is a {type(val)}, "
        if isinstance(val, torch.Tensor):
            msg += f"with shape of : {val.shape}"
        elif isinstance(val, list):
            msg += f"with len of : {len(val)}"
        elif isinstance(val, str):
            msg += f"value is {val}"
        else:
            pass
        logger.info(msg)

def create_inference_model(config_file, ckpt_dir, use_cpu_only=False):
    """
    Create the model for inference pipeline, with the model config.
    """
    # Create default model configuration
    model_config = get_cfg()
    get_cfg_defaults(model_config)

    # add extra configs for data
    model_config.MAX_TRAINING_ATTEMPTS = 3
    model_config.TRAIN_LIST = ""
    model_config.TEST_LIST = ""
    model_config.TRAIN_WDS_DIR = ""
    model_config.TEST_WDS_DIR = ""
    model_config.ID_MAP_JSON = ""
    model_config.OBJ_PROP_JSON = ""
    model_config.CATEGORY_JSON = ""
    model_config.DATASETS.OBJECT_DETECTION_MODE = ""
    model_config.SOLVER.VAL_MAX_ITER = 0
    model_config.SOLVER.MAX_EPOCH = 0

    model_config.merge_from_file(config_file)
    if use_cpu_only:
        model_config.MODEL.DEVICE = "cpu"
    model_config.freeze()

    model = build_model(model_config, priors=None)

    _ = DetectionCheckpointer(model, save_dir=ckpt_dir).resume_or_load(
        model_config.MODEL.WEIGHTS, resume=True
    )
    model.eval()

    return model_config, model

def run_command_and_display_output(command):
    # Start the process
    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)

    # Poll process.stdout to show stdout live
    while True:
        output = process.stdout.readline()
        if output == '' and process.poll() is not None:
            break
        if output:
            print(output.strip())
    rc = process.poll()
    return rc


## Set up data and code paths

In [None]:
# Follow the following guide to download example ADT sequence to a local path `~/Documents/projectaria_tools_adt_data`
# https://facebookresearch.github.io/projectaria_tools/docs/open_datasets/aria_digital_twin_dataset/dataset_download.

# Set up local data paths
data_dir = os.path.join(os.path.expanduser("~"), "Documents", "projectaria_tools_adt_data")
sequence_name = "Apartment_release_golden_skeleton_seq100_10s_sample_M1292"
example_adt_data_dir = os.path.join(data_dir, sequence_name)
output_wds_path = os.path.join(data_dir, "wds_output")

# Set up ATEK paths
atek_src_path = os.path.join(os.path.expanduser("~"), "atek_on_fbsource")
atek_preprocess_config_path = "/home/louy/Calibration_data_link/Atek/2024_08_05_DryRun/adt_cubercnn_preprocess_config.yaml"
category_mapping_file = os.path.join(atek_src_path, "data", "adt_prototype_to_atek.csv")
preprocess_conf = OmegaConf.load(atek_preprocess_config_path)

# Set up trained model weight path
model_ckpt_path = "/home/louy/Calibration_data_link/Atek/pre_trained_models/2024_08_28_AdtCubercnnWeights"

# Step 1: ATEK data preprocessing
ADT sequence has: 
1. Aria recording (VRS).
2. MPS trajectory file (CSV).
3. Object detection annotation files (3 csv files + 1 json file).

CubeRCNN model data samples (all in `torch.Tensor`): 
1. Upright RGB camera image.
2. Linear camera calibration matrix.
3. Object bounding box annotations in 2D + 3D.
4. Camera-to-object poses.

**Before ATEK**, users need to implement all the followings to prepare ADT sequence into CubeRCNN model: 
1. Parse in ADT sequence data using `projectaria_tools` lib.   
2. Linearize image + camera calibration.
3. Rescale camera resolution. 
4. Rotate image + camera calibration.
5. Linearize + rescale + rotate object 2D bounding boxes accordingly.
6. Synchronize sensor + annotation data into training samples.

### Set up and run ATEK data preprocessor
**With ATEK**, all these above preprocessing can be handled by a simple  [configurable yaml file](https://www.internalfb.com/phabricator/paste/view/P1581100261). 

In [None]:
# Create ATEK preprocessor from conf. It will automatically choose which type of sample to build.
atek_preprocessor = create_general_atek_preprocessor_from_conf(
    # [required]
    conf=preprocess_conf,
    raw_data_folder = example_adt_data_dir,
    sequence_name = sequence_name,
    # [optional]
    output_wds_folder=output_wds_path,
    category_mapping_file=category_mapping_file,
)

# Loop over all samples, and write valid ones to local tar files.
atek_preprocessor.process_all_samples(write_to_wds_flag=True, viz_flag=True)

## Preprocessed ATEK data sample content
Converted from VRS + csv + jsons -> `Dict[torch.Tensor, str, or Dict]`

In [None]:
# print the content of a ATEK data sample (in memory)
atek_data_sample = atek_preprocessor[0]
print_data_sample_dict_content(atek_data_sample.to_flatten_dict())

# print the preprocessed files that are saved to disk
print("\nPrinting the preprocessed results that are saved to disk")
listing_command = ["ls", f"{output_wds_path}"]
return_code = run_command_and_display_output(listing_command)

# Step 2: Run Object detection inference using pre-trained CubeRCNN model
In this example, we demonstrate how to run model inference with preprocessed ATEK data. 

### Create a PyTorch DataLoader from ATEK WDS files
Load WDS files as ATEK format: 

In [None]:
# create ATEK dataloader with native ATEK format.
tar_file_urls = [os.path.join(output_wds_path, f"shards-000{i}.tar") for i in range(2)]

atek_dataloader = create_native_atek_dataloader(urls = tar_file_urls, batch_size = None, num_workers = 1)
first_atek_sample = next(iter(atek_dataloader))
logger.info(f"Loading WDS into ATEK natvie format, each sample contains the following keys: {first_atek_sample.keys()}")

### Create PyTorch DataLoader, converted to CubeRCNN format
Data converter from ATEK format -> CubeRCNN format: 
1. Dict key renaming.
2. Tensor reshaping & reordering. 
3. Other data transformations.

Example data converter impl for CubeRCNN model: [src code](https://www.internalfb.com/code/fbsource/[a5c3831c045bc718862d1c512e84d4ed6f79d722]/fbcode/surreal/data_services/atek/atek/data_loaders/cubercnn_model_adaptor.py?lines=44-74)

The `create_atek_dataloader_as_cubercnn` API is a thin wrapper on top of a `CubeRCNN` data converter class. 

In [None]:
cubercnn_dataloader = create_atek_dataloader_as_cubercnn(urls = tar_file_urls, batch_size = 6, num_workers = 1)
first_cubercnn_sample = next(iter(cubercnn_dataloader))
logger.info(f"Loading WDS into CubeRCNN format, each sample contains the following keys: {first_cubercnn_sample[0].keys()}")

### Run model inference

In [None]:
from tqdm import tqdm

# load pre-trained CubeRCNN model
model_config_file = os.path.join(model_ckpt_path, "config.yaml")
conf = OmegaConf.load(model_config_file)

# setup config and model
model_config, model = create_inference_model(
    model_config_file, model_ckpt_path, False
)


# Cache inference results for visualization
input_output_data_pairs = []

# Loop over created Pytorch Dataloader
with torch.no_grad():
    for cubercnn_input_data in tqdm(
       cubercnn_dataloader,
        desc="Inference progress: ",
    ):
        cubercnn_model_output = model(cubercnn_input_data)

        # cache inference results for visualization
        input_output_data_pairs.append((cubercnn_input_data, cubercnn_model_output))

logger.info("Inference completed.")

### Visualize inference results

In [None]:
from atek.viz.cubercnn_visualizer import CubercnnVisualizer

# Visualize cached inference results
logger.info("Visualizing inference results.")
viz_conf = preprocess_conf.visualizer
cubercnn_visualizer = CubercnnVisualizer(viz_prefix = "inference_visualizer", conf = viz_conf)
for input_data_as_list, output_data_as_list in input_output_data_pairs:
    for single_cubercnn_input, single_cubercnn_output in zip(input_data_as_list, output_data_as_list):
        timestamp_ns = single_cubercnn_input["timestamp_ns"]
        # Plot RGB image
        cubercnn_visualizer.plot_cubercnn_img(single_cubercnn_input["image"], timestamp_ns = timestamp_ns)

        # Plot GT and prediction in different colors
        single_cubercnn_output["T_world_camera"] = single_cubercnn_input["T_world_camera"] # This patch is needed for visualization
        cubercnn_visualizer.plot_cubercnn_dict(cubercnn_dict = single_cubercnn_input, timestamp_ns = timestamp_ns, plot_color = cubercnn_visualizer.COLOR_GREEN, suffix = "_model_input")
        cubercnn_visualizer.plot_cubercnn_dict(cubercnn_dict = single_cubercnn_output, timestamp_ns = timestamp_ns, plot_color = cubercnn_visualizer.COLOR_RED, suffix = "_model_output")

# Step 3: Evaluate model performance
Hard to compare model performances within Aria community: 
1. Different file formats. 
2. Different metrics impl.

ATEK provides: 
1. Standardized prediction file format for each ML task.
2. Standardized benchmarking scripts. 

### Write inference results into ATEK-format csv files


In [None]:
from atek.evaluation.static_object_detection.obb3_csv_io import AtekObb3CsvWriter

gt_writer = AtekObb3CsvWriter(output_filename = os.path.join(data_dir, "gt_obbs.csv"))
prediction_writer = AtekObb3CsvWriter(output_filename = os.path.join(data_dir, "prediction_obbs.csv"))

for input_data_as_list, output_data_as_list in input_output_data_pairs:
    for single_cubercnn_input, single_cubercnn_output in zip(input_data_as_list, output_data_as_list):
        timestamp_ns = single_cubercnn_input["timestamp_ns"]
        single_cubercnn_output["T_world_camera"] = single_cubercnn_input["T_world_camera"]

        gt_writer.write_from_cubercnn_dict(cubercnn_dict = single_cubercnn_input, timestamp_ns = timestamp_ns)
        prediction_writer.write_from_cubercnn_dict(cubercnn_dict = single_cubercnn_output, timestamp_ns = timestamp_ns)

### Call ATEK's benchmarking script to evaluate the results

In [None]:
benchmarking_command = [
    "python3", f"{atek_src_path}/tools/benchmarking_static_object_detection.py",
    "--pred-csv", f"{data_dir}/prediction_obbs.csv",
    "--gt-csv", f"{data_dir}/gt_obbs.csv",
    "--output-file", f"{data_dir}/atek_metrics.json"
]
return_code = run_command_and_display_output(benchmarking_command)