# ATEK Demo 3: Model training in ATEK

In [1]:
import faulthandler

import logging
import os
from logging import StreamHandler
import numpy as np
from typing import Dict, List, Optional
import torch
import sys
import subprocess
from tqdm import tqdm

from atek.viz.atek_visualizer import NativeAtekSampleVisualizer
from atek.data_loaders.atek_wds_dataloader import (
    create_native_atek_dataloader
)
from atek.util.file_io_utils import load_yaml_and_extract_tar_list
from omegaconf import OmegaConf

faulthandler.enable()

# Configure logging to display the log messages in the notebook
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout)
    ]
)

logger = logging.getLogger()

def run_command_and_display_output(command):
    # Start the process
    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)

    # Poll process.stdout to show stdout live
    while True:
        output = process.stdout.readline()
        if output == '' and process.poll() is not None:
            break
        if output:
            print(output.strip())
    rc = process.poll()
    return rc

### Set up data and code paths

In [2]:
data_dir = "/home/louy/Calibration_data_link/Atek/2024_08_05_DryRun"
atek_src_path = os.path.join(os.path.expanduser("~"), "atek_on_fbsource")
viz_conf = OmegaConf.load(os.path.join(atek_src_path, "atek", "configs", "obb_viz.yaml"))
atek_json_path = os.path.join(data_dir, "AriaDigitalTwin_ATEK_download_urls.json")

### Step 1: Download preprocessed data from ATEK Data Store

In [3]:
# Download from ATEK Data Store
download_data_command = [
    "python3", f"{atek_src_path}/tools/dataverse_url_parser.py",
    "--config-name","cubercnn",
    "--input-json-path",f"{atek_json_path}",
    "--output-folder-path",f"{data_dir}/downloaded_local_wds/",
    "--max-num-sequences", "5",
    "--train-val-split-ratio", "0.8"
]

# Uncomment this line to download data
# return_code = run_command_and_display_output(download_data_command)

## ATEK Training example with CubeRCNN
User can call our `tools/train_cubercnn.py` script to perform training on ATEK WDS data downloaded from Data Store. We will run this on my local machine for a mini demonstration.  

Core code snippets in the script (check out the script for full details): 
```
model.train()
tar_file_urls = load_yaml_and_extract_tar_list(train_list_yaml)
data_loader = create_atek_dataloader_as_cubercnn(urls = tar_file_urls, ...)

# Loop over cubercnn-format data samples
for sample_data in data_loader:
    # Training step
    loss_dict = model(data)
    losses = sum(loss_dict.values())
    optimizer.zero_grad()
    losses.backward()
    optimizer.step()
    ...
```

In [4]:
# Example training command
mini_training_command = [
  f"python",f"{atek_src_path}/tools/train_cubercnn.py",
  "--config-file",f"{data_dir}/cubercnn_train_config_mini_example.yaml",
  "--num-gpus", "1",
  "OUTPUT_DIR", f"{data_dir}/mini_test_1",
  "TRAIN_LIST",f"{data_dir}/downloaded_local_wds/local_train_tars.yaml",
  "TEST_LIST", f"{data_dir}/downloaded_local_wds/local_validation_tars.yaml",
  # "TRAIN_LIST",f"{data_dir}/streamable_yamls/streamable_train_tars.yaml",
  # "TEST_LIST", f"{data_dir}/streamable_yamls/streamable_validation_tars.yaml",
  "CATEGORY_JSON", f"{atek_src_path}/data/atek_id_to_name.json",
  "ID_MAP_JSON", f"{atek_src_path}/data/atek_name_to_id.json",
  "MODEL.WEIGHTS_PRETRAIN", f"/home/louy/Calibration_data_link/Atek/cubercnn_DLA34_FPN.pth"
]
return_code = run_command_and_display_output(mini_training_command)

Command Line Args: Namespace(config_file='/home/louy/Calibration_data_link/Atek/2024_08_05_DryRun/cubercnn_train_config_mini_example.yaml', resume=False, eval_only=False, num_gpus=1, num_machines=1, machine_rank=0, dist_url='tcp://127.0.0.1:50152', opts=['OUTPUT_DIR', '/home/louy/Calibration_data_link/Atek/2024_08_05_DryRun/mini_test_1', 'TRAIN_LIST', '/home/louy/Calibration_data_link/Atek/2024_08_05_DryRun/downloaded_local_wds/local_train_tars.yaml', 'TEST_LIST', '/home/louy/Calibration_data_link/Atek/2024_08_05_DryRun/downloaded_local_wds/local_validation_tars.yaml', 'CATEGORY_JSON', '/home/louy/atek_on_fbsource/data/atek_id_to_name.json', 'ID_MAP_JSON', '/home/louy/atek_on_fbsource/data/atek_name_to_id.json', 'MODEL.WEIGHTS_PRETRAIN', '/home/louy/Calibration_data_link/Atek/cubercnn_DLA34_FPN.pth'])
[32m[09/11 21:57:26 detectron2]: [0mRank of current process: 0. World size: 1
[32m[09/11 21:57:27 detectron2]: [0mEnvironment info:
-------------------------------  ------------------

In [None]:
# [Optional] Inspect training progress using tensorboard
tensorboard_command = ["tensorboard", f"--logdir={data_dir}/full_train_tensorboard", "--port", "6007", "--samples_per_plugin=images=1000"]
return_code = run_command_and_display_output(tensorboard_command)

TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
"--load_fast=false" and report issues on GitHub. More details:
https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.17.1 at http://localhost:6007/ (Press CTRL+C to quit)
