# Demo 1: ATEK data preprocessing

* 3D object detection ML task.
* An example model: CubeRCNN. 
* An example data sequence with Aria VRS recording: [Aria Digtial Twin (ADT)](https://www.projectaria.com/datasets/adt/)


In [None]:
import faulthandler

import logging
import os
from logging import StreamHandler
import numpy as np
from typing import Dict, List, Optional
import torch
import sys
import subprocess
from tqdm import tqdm

from atek.data_preprocess.genera_atek_preprocessor_factory import (
    create_general_atek_preprocessor_from_conf,
)
from atek.viz.atek_visualizer import NativeAtekSampleVisualizer
from atek.data_preprocess.general_atek_preprocessor import GeneralAtekPreprocessor
from atek.data_loaders.atek_wds_dataloader import (
    create_native_atek_dataloader
)
from atek.data_preprocess.atek_data_sample import (
    create_atek_data_sample_from_flatten_dict,
)
from omegaconf import OmegaConf

faulthandler.enable()

# Configure logging to display the log messages in the notebook
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout)
    ]
)

logger = logging.getLogger()

# Prettier colors
COLOR_GREEN = [42,157,143]
COLOR_RED = [231, 111, 81]

# -------------------- Helper functions --------------------#
def print_data_sample_dict_content(data_sample, if_pretty: bool = False):
    """
    A helper function to print the content of data sample dict
    """
    logger.info("Printing the content in a ATEK data sample dict: ")
    for key, val in data_sample.items():
        if if_pretty and "#" in key:
            key = key.split("#", 1)[1]

        msg = f"\t {key}: is a {type(val)}, "
        if isinstance(val, torch.Tensor):
            msg += f"\n \t\t\t\t with tensor dtype of {val.dtype}, and shape of : {val.shape}"
        elif isinstance(val, list):
            msg += f"with len of : {len(val)}"
        elif isinstance(val, str):
            msg += f"value is {val}"
        else:
            pass
        logger.info(msg)

def run_command_and_display_output(command):
    # Start the process
    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)

    # Poll process.stdout to show stdout live
    while True:
        output = process.stdout.readline()
        if output == '' and process.poll() is not None:
            break
        if output:
            print(output.strip())
    rc = process.poll()
    return rc


###  Set up data and code paths

In [None]:
# Set up local data paths
data_dir = os.path.join(os.path.expanduser("~"), "Documents", "atek_data")
sequence_name = "Apartment_release_golden_skeleton_seq100_10s_sample_M1292"
adt_download_dir = os.path.join(data_dir, "adt_data")
example_adt_data_dir = os.path.join(data_dir, "adt_data", sequence_name)
output_wds_path = os.path.join(data_dir, "wds_output")

# Set up ATEK paths
# This is the path that you cloned ATEK into
atek_src_path = os.path.join(data_dir, "ATEK")
atek_preprocess_config_path = os.path.join(atek_src_path, "examples", "data", "adt_cubercnn_preprocess_config.yaml")
category_mapping_file = os.path.join(atek_src_path, "data", "adt_prototype_to_atek.csv")
preprocess_conf = OmegaConf.load(atek_preprocess_config_path)

# Download ADT data
if not os.path.exists(example_adt_data_dir):
    adt_data_sequence_url = "https://www.projectaria.com/async/sample/download/?bucket=adt&filename=aria_digital_twin_test_data_v2.zip"
    download_command_list = [
        f"mkdir -p {adt_download_dir}",
        # Download sample data
        f'curl -o {adt_download_dir}/adt_sample_data.zip -C - -O -L "{adt_data_sequence_url}"',
        # Unzip the sample data
        f"unzip -o {adt_download_dir}/adt_sample_data.zip -d {adt_download_dir}"
    ]
    for command in download_command_list:
        subprocess.run(command, shell=True, check=True)

###  Data preprocessing requirements
ADT sequence has: 
1. Aria recording (VRS).
2. MPS trajectory file (CSV).
3. Object detection annotation files (3 csv files + 1 json file).

CubeRCNN model needs synchronized data frame containing: 
1. Upright RGB camera image.
2. Linear camera calibration matrix.
3. Object bounding box annotations in 2D + 3D.
4. Camera-to-object poses.

**Before ATEK**, users need to implement all the followings to prepare ADT sequence into CubeRCNN model: 
1. Parse in ADT sequence data using `projectaria_tools` lib.   
2. Properly synchronize sensor + annotation data into training samples.
3. Perform additional image & data processing:
    1. Undistort image + camera calibration.
    2. Rescale camera resolution. 
    3. Rotate image + camera calibration.
    4. Undistort + rescale + rotate object 2D bounding boxes accordingly.

## Step 1: Set up and run ATEK data preprocessor
**With ATEK**, all these above preprocessing can be handled by a simple  [configurable yaml file](https://www.internalfb.com/phabricator/paste/view/P1581100261). 

In [None]:
# Create ATEK preprocessor from conf. It will automatically choose which type of sample to build.
atek_preprocessor = create_general_atek_preprocessor_from_conf(
    # [required]
    conf=preprocess_conf,
    raw_data_folder = example_adt_data_dir,
    sequence_name = sequence_name,
    # [optional]
    output_wds_folder=output_wds_path,
    category_mapping_file=category_mapping_file,
)

# Loop over all samples, and write valid ones to local tar files.
atek_preprocessor.process_all_samples(write_to_wds_flag=True, viz_flag=True)

## Step 2: Preprocessed ATEK data sample content
* Preprocessing input: VRS + csv + jsons
* Preprocessing output (in memory): ATEK data samples: `Dict[torch.Tensor, str, or Dict]`
* Preprocessing output (on local disk): WebDataset (WDS) tar files. 

In [None]:
# print the content of a ATEK data sample (in memory)
atek_data_sample = atek_preprocessor[0]
print_data_sample_dict_content(atek_data_sample.to_flatten_dict())

# print the preprocessed files that are saved to disk
print("\nPrinting the preprocessed results that are saved to disk as WebDataset files (.tar)")
listing_command = ["ls", f"{output_wds_path}"]
return_code = run_command_and_display_output(listing_command)