In [None]:
%load_ext autoreload
%autoreload 2

import polars as pl  # type: ignore
from matplotlib import pyplot as plt
import seaborn_polars as snl
import seaborn as sns
import metric_aggregation_utils
from typing import Optional
import yaml

import pathlib
import os
from omegaconf import OmegaConf
  
from rich import print

from eval.aggregation import eval_aggregation


# Notice
This Notebook is for aggregating `KPI` and `AVMF` metrics. For `Eval` metrics,
see `eval_aggregation.py`. 

Over time, `KPI` should be phased out and `AVMF` should be aggregated similar to
`Eval`.


In [None]:
# Cell tagged with `parameters` is used by papermill library to set input args.

run_name: str = "2025_03_11_test1"
avmf_sim_execution_table_name: Optional[str] = None
avmf_metrics_table_name: Optional[str] = None
kpi_sim_execution_table_name: Optional[str] = None
kpi_metrics_table_name: Optional[str] = None

route_deviation_th: float = 16.0

In [None]:
# Load configs from /wizard/configs/base_config.yaml if not specified
config_path = "wizard/configs/base_config.yaml"
current_dir = pathlib.Path(os.getcwd()).parent

while not os.path.exists(current_dir / config_path):
    current_dir = current_dir.parent

if not os.path.exists(current_dir / config_path):
    raise FileNotFoundError(f"Could not find {config_path} in any parent directory")

config = OmegaConf.load(current_dir / config_path)

avmf_sim_execution_table_name = (
    avmf_sim_execution_table_name or config.avmf.database.sim_execution_table_name
)
avmf_metrics_table_name = (
    avmf_metrics_table_name or config.avmf.database.metrics_table_name
)
kpi_sim_execution_table_name = (
    kpi_sim_execution_table_name or config.kpi.database.sim_execution_table_name
)
kpi_metrics_table_name = (
    kpi_metrics_table_name or config.kpi.database.metrics_table_name
)

print(
    f"Using for AVMF: \n\t{avmf_sim_execution_table_name=}, \n\t{avmf_metrics_table_name=}"
)
print(
    f"Using for KPI: \n\t{kpi_sim_execution_table_name=}, \n\t{kpi_metrics_table_name=}"
)

In [4]:
# metrics = [
#     "collision",
#     # "comfort_lon_accel",
#     # "comfort_lat_accel",
#     # "comfort_lon_jerk",
#     # "comfort_jerk",
#     # "comfort_yaw_rate",
#     # "comfort_yaw_accel",
#     # "offroad",
#     # "wrong_lane",
#     "stop_sign",
#     # Combined metrics:
#     # "comfort",
#     "offroad_or_wrong_lane",
# ]

collision_type_map = {
    -1: "No Collision",
    0: "Stopped Ego Collision",
    1: "Active Front Collision",
    2: "Active Rear Collision",
    3: "Active Lateral Collision",
}

collisions_to_filter = [0, 2]

# Regular expressions are allowed but must start with ^ and end with $.
# Non-regular expressions are automatically excluded from regex metrics, e.g.
# 'kpi_route_deviation' is not included in the regex metrics for min.
metric_averaging_across_timesteps = {
    "max": ["kpi_route_deviation"],
    # KPI metrics are all "lower is worse, with best value = 1
    "min": ["^kpi_.*$"],
    "mean": ["^avmf_.*$"],
}

In [None]:
avmf_metrics_table = metric_aggregation_utils.query_kratos_metrics(
    run_name=run_name,
    sim_execution_table_name=avmf_sim_execution_table_name,
    metrics_table_name=avmf_metrics_table_name,
)

kpi_metrics_table = metric_aggregation_utils.query_kratos_metrics(
    run_name=run_name,
    sim_execution_table_name=kpi_sim_execution_table_name,
    metrics_table_name=kpi_metrics_table_name,
)

In [None]:
kpi_metrics_table_df = metric_aggregation_utils.pick_relevant_kpi_columns(
    kpi_metrics_table
)
events_df, per_timestep_df, per_scene_df = metric_aggregation_utils.parse_avmf_json(
    avmf_metrics_table
)
avmf_per_timestep_df = metric_aggregation_utils.pick_relevant_avmf_columns(
    per_timestep_df
)

kpi_metrics_table_df = kpi_metrics_table_df.with_columns(
    pl.col("variable").map_elements(lambda x: f"kpi_{x}", return_dtype=pl.Utf8)
)

avmf_per_timestep_df = avmf_per_timestep_df.with_columns(
    pl.col("variable").map_elements(lambda x: f"avmf_{x}", return_dtype=pl.Utf8)
)

combined_per_timestep_df = pl.concat(
    [kpi_metrics_table_df, avmf_per_timestep_df], how="vertical"
)
combined_per_timestep_df, trajectory_uid_df = (
    metric_aggregation_utils.add_rollout_and_trajectory_uids(combined_per_timestep_df)
)

# Convert to wide format - easier for some computations. Sorting is important.
df_wide = combined_per_timestep_df.pivot(
    values="value",
    index=["trajectory_uid", "rel_timestamp"],
    on="variable",
).sort(["trajectory_uid", "rel_timestamp"])

df_wide = metric_aggregation_utils.filter_collision_type(df_wide, collisions_to_filter)
df_wide = metric_aggregation_utils.filter_route_deviation(df_wide, route_deviation_th)
df_wide = metric_aggregation_utils.add_aggregate_metrics(df_wide)

df_wide_avg_over_time = metric_aggregation_utils.average_metrics_across_timesteps(
    df_wide, metric_averaging_across_timesteps
)

# Join "run_uuid" and "rollout_uid" back into the dataframe:
df_wide_avg_over_time = df_wide_avg_over_time.join(
    trajectory_uid_df.select(
        pl.col("trajectory_uid"), pl.col("rollout_uid"), pl.col("run_uuid")
    ),
    on="trajectory_uid",
    how="left",
).drop(["trajectory_uid", "trajectory_uid_right"])

# Average over clips, `rollout_uid` is unique for each [run, batch, rollout]
df_wide_avg = df_wide_avg_over_time.group_by("run_uuid", "rollout_uid").agg(
    pl.col("*").mean(),
    pl.col("run_uuid").count().alias("num_clips"),
)

print("Length after averaging over timesteps and clips: ", len(df_wide_avg))

metric_aggregation_utils.save_metrics_results_txt(
    df_wide_avg,
    trajectory_uid_df,
    metric_averaging_across_timesteps,
    "metrics_results.txt",
)

In [None]:
metrics_to_plot = [
    "kpi_comfort",
    "kpi_collision",
    "kpi_offroad_or_wrong_lane",
    "kpi_stop_sign",
    "avmf_accel_x",
]
df_long_avg = df_wide_avg.unpivot(
    index=["run_uuid", "rollout_uid"],
    on=metrics_to_plot,
)

# We do the renamin back to run_name last and only for plotting to not
fig, ax = plt.subplots()

snl.barplot(df_long_avg, x="variable", y="value", hue="run_uuid", errorbar="sd", ax=ax)
metric_aggregation_utils.rename_legend_handles(ax, trajectory_uid_df)