In [None]:
import json
import numpy as np
import optuna
import os
from torchpack.utils.config import configs
from mmdet3d.utils import recursive_eval
from mmcv import Config
import joblib
from argparse import Namespace
from tools.train import train

In [None]:
%cd ..

In [None]:
CLASSES = (
    "CAR",
    "TRAILER",
    "TRUCK",
    "VAN",
    "PEDESTRIAN",
    "BUS",
    "MOTORCYCLE",
    "BICYCLE",
    "EMERGENCY_VEHICLE",
    # "OTHER",
)

AP_DISTS = [
    "ap_dist_0.5",
    "ap_dist_1.0",
    "ap_dist_2.0",
    "ap_dist_4.0",
]

CLS_SEARCH_SPACE = {
    "CAR": [8, 15],
    "TRAILER": [0, 2],
    "TRUCK": [0, 4],
    "VAN": [0, 5],
    "PEDESTRIAN": [0, 8],
    "BUS": [0, 2],
    "MOTORCYCLE": [0, 4],
    "BICYCLE": [0, 4],
    "EMERGENCY_VEHICLE": [0, 2],
}

In [None]:
def read_map(checkpoint_folder_path: str, classes: list, ap_dists: list) -> float:
    try:
        json_file = [f for f in os.listdir(checkpoint_folder_path) if f.endswith(".json")][0]
        json_file = os.path.join(checkpoint_folder_path, json_file)

        data = []
        with open(json_file, "r") as f:
            for line in f:
                data.append(json.loads(line))
        data = data[-1]

        class_ap = {x: 0 for x in classes}
        for cls in classes:
            all_ap_dists = []
            for ap_dist in ap_dists:
                all_ap_dists.append(data[f"object/{cls}_{ap_dist}"])
            class_ap[cls] = np.mean(all_ap_dists)

        map = data["object/map"]
        epoch = data["epoch"]
    except:
        map = 0
        class_ap = {x: 0 for x in classes}
        epoch = 0
    return map, class_ap, epoch


def find_gtp_in_pipeline(pipeline: list) -> int:
    for i, p in enumerate(pipeline):
        if p["type"] == "ObjectPaste":
            return i
    return -1


def objective(
    trial,
    source_config_path: str,
    tune_target_folder_path: str,
    max_epochs: int,
    n_gpus: int,
    target_classes: list,
    ap_dists: list,
    cls_search_space: dict,
) -> float:
    params = {c: trial.suggest_int(c, v[0], v[1]) for c, v in cls_search_space.items()}
    print(f"Trial {trial.number} - Params: {params}")

    run_dir = os.path.join(tune_target_folder_path, f"trial_{trial.number}")

    configs.load(source_config_path, recursive=True)

    cfg = Config(recursive_eval(configs), filename=source_config_path)

    gtp_idx = find_gtp_in_pipeline(cfg.data.train.dataset.pipeline)
    cfg.data.train.dataset.pipeline[gtp_idx].db_sampler.sample_groups = params

    tmp_config_path = os.path.join(tune_target_folder_path, "tmp_config.yaml")
    if os.path.exists(tmp_config_path):
        os.remove(tmp_config_path)

    cfg.run_dir = run_dir
    cfg.checkpoint_config.max_keep_ckpts = 0
    cfg.runner.max_epochs = max_epochs
    cfg.optimizer.lr = 2.0e-4
    cfg.dump(tmp_config_path)

    command = f"torchpack dist-run -np {n_gpus} python tools/train.py {tmp_config_path} --run-dir {run_dir}"
    command += " > /dev/null 2>&1"
    os.system(command)

    map, _, epoch = read_map(run_dir, target_classes, ap_dists)

    trial.report(map, epoch)

    return map


def save_results(df, study, save_path: str):
    df.to_csv(os.path.join(save_path, "results.csv"))
    fig = optuna.visualization.plot_param_importances(study)
    fig.write_image(os.path.join(save_path, "param_importances.png"))
    fig = optuna.visualization.plot_optimization_history(study)
    fig.write_image(os.path.join(save_path, "optimization_history.png"))
    fig = optuna.visualization.plot_parallel_coordinate(study)
    fig.write_image(os.path.join(save_path, "parallel_coordinate.png"))


def tune(
    source_config_path: str,
    tune_target_folder_path: str,
    n_trials: int,
    n_gpus: int,
    max_epochs: int,
    classes: list,
    ap_dists: list,
    cls_search_space: dict,
):
    os.makedirs(tune_target_folder_path, exist_ok=True)
    run_id = os.path.basename(tune_target_folder_path)

    study = optuna.create_study(
        study_name=run_id,
        direction="maximize",
        sampler=optuna.samplers.TPESampler(),
    )

    study.optimize(
        lambda trial: objective(
            trial,
            source_config_path,
            tune_target_folder_path,
            max_epochs,
            n_gpus,
            classes,
            ap_dists,
            cls_search_space,
        ),
        n_trials=n_trials,
    )

    best_trial = study.best_trial

    for key, value in best_trial.params.items():
        print(f"{key}: {value}")

    joblib.dump(study, os.path.join(tune_target_folder_path, "study.pkl"))

    df = study.trials_dataframe()
    save_results(df, study, tune_target_folder_path)

In [None]:
source_config_path = (
    "configs/tumtraf-i/baseline/transfusion/lidar/voxelnet-1600g-0xy1-0z20-gtp15.yaml"
)
tune_target_folder_path = "checkpoints/tune/tumtraf-i"
max_epochs = 5
n_trials = 20
n_gpus = 2
tune(
    source_config_path,
    tune_target_folder_path,
    n_trials,
    n_gpus,
    max_epochs,
    CLASSES,
    AP_DISTS,
    CLS_SEARCH_SPACE,
)