In [None]:
import sys
sys.path.append("../evaluation")
sys.path.append("../utils")


import pandas as pd
import os
from tqdm import tqdm
import numpy as np
import copy
import importlib
from sklearn.metrics import mean_squared_error
from LUT import LUT_1D
from matplotlib import pyplot as plt
import plot_helpers
from warnings import simplefilter
from sklearn.exceptions import ConvergenceWarning

simplefilter("ignore", category=ConvergenceWarning)
simplefilter(action="ignore", category=pd.errors.PerformanceWarning)

## Helpter Functions

In [None]:
def get_filtered_dataset(df, max_d2cl=0.8, min_d2cl=0, filter_rural=False):
    _df = copy.deepcopy(df)
    _df = _df[_df.Dist_To_Center_Lane.abs() < max_d2cl]
    _df = _df[_df.Dist_To_Center_Lane.abs() > min_d2cl]

    if filter_rural:
        _df = _df[_df.road_type == "rural"]

    return _df

## Params

In [None]:
MAX_D2CL_TRAIN = 1.5
MAX_D2CL_EVAL = 0.8 # 0.8
MIN_D2CL_EVAL = 0

SEGMENT_KEY = "frame" # "frame" or "segment"
# N_SPLITS = 10

VERBOSE = False

FILTER_RURAL = False
DOMAIN = "all" if not FILTER_RURAL else "rural_only"

## Prepare the Data

### Load

In [None]:
run = "val-pretrain"
step_count = 10

iterative_training_splits_df = pd.read_pickle(
    f"./dataset_val_train_iterative_training_splits_{DOMAIN}_{step_count}_splits.pkl"
)

dataset_dir = "/root/sadc/data/datasets"
data_lake = "/root/sadc/results"

config = {
    "val-pretrain":{
        "DSC": "sadc_clustering_resnet_18_val_all_13-01-2024_15-45-37",
        "DSC_run_name": "sadc_clustering_resnet_18_val_all_13-01-2024_15-45-37",
        "best_c_id": 500,
        "NN-IT": {
            10: "resnet18_all_r1/heads/mlp_stepwise_10/predictions_stepwise/val_val/driving_data_predictions.pkl",
        }
    },
}


val_train_df = pd.read_pickle(os.path.join(dataset_dir, "dataset_val_train.pkl"))
val_val_df = pd.read_pickle(os.path.join(dataset_dir, "dataset_val_val.pkl"))

sadc_clustering_val_train_df = pd.read_pickle(os.path.join(data_lake,config[run]["DSC"],f"{config[run]['DSC_run_name']}_val_train.pkl"))
sadc_clustering_val_val_df = pd.read_pickle(os.path.join(data_lake,config[run]["DSC"],f"{config[run]['DSC_run_name']}_val_val.pkl"))
mlp_stepwise_predictions_df = pd.read_pickle(os.path.join(data_lake,config[run]["NN-IT"][step_count]))

### Filter

In [None]:
val_train_df = get_filtered_dataset(df=val_train_df,max_d2cl=MAX_D2CL_TRAIN,filter_rural=FILTER_RURAL)
val_val_df = get_filtered_dataset(df=val_val_df,max_d2cl=MAX_D2CL_EVAL,min_d2cl=MIN_D2CL_EVAL,filter_rural=FILTER_RURAL)

### Merge with Cluster IDs and Training Iterations

In [None]:
cid = f"fkm_{config[run]['best_c_id']}_cluster_id"

In [None]:
sadc_clustering_val_train_df = sadc_clustering_val_train_df[["alias","frame",cid]]
sadc_clustering_val_val_df = sadc_clustering_val_val_df[["alias","frame",cid]]

val_train_df = val_train_df.merge(iterative_training_splits_df,how="left",on=["alias","frame"])
val_train_df = val_train_df.merge(sadc_clustering_val_train_df,how="left",on=["alias","frame"])

val_val_df = val_val_df.merge(sadc_clustering_val_val_df,how="left",on=["alias","frame"])

mlp_pred_cols = [c for c in mlp_stepwise_predictions_df.columns if "predictions" in c]
val_val_df = val_val_df.merge(mlp_stepwise_predictions_df[["alias","frame",*mlp_pred_cols]],how="left",on=["alias","frame"])

## Train

In [None]:
trained_luts = {}
DEFAULT_VALUE = 0.0
VERBOSE = False
LUT_STORE_ALL_DATA = True


for alias in tqdm(val_train_df.alias.unique(),desc="Training"):
    trained_luts[alias] = {}
    _df_alias = val_train_df[val_train_df.alias == alias]

    _lut = LUT_1D(default_value=DEFAULT_VALUE, verbose=VERBOSE, store_all_data = LUT_STORE_ALL_DATA)

    for it in _df_alias["train_iter"].unique():
        _df_alias_it = _df_alias[_df_alias["train_iter"] == it]
        for i, row in tqdm(
            _df_alias_it.iterrows(),
            total=len(_df_alias_it.index),
            desc=f"Training iteration {it} for alias {alias}",
            disable=True
        ):
            _lut.train_sample(index=row[cid], key="D2CL", value=row.Dist_To_Center_Lane)
        trained_luts[alias][it] = copy.deepcopy(_lut)

## Predict

In [None]:
def get_predictions(alias, index, key, training_iteration, trained_luts):
    lut = trained_luts[alias][training_iteration]
    mean, std = lut.get_mean_std(index=index, key=key)
    return mean, std


v_get_predictions = np.vectorize(
    get_predictions, excluded=["key", "training_iteration", "trained_luts"]
)

In [None]:
for training_iteration in tqdm(iterative_training_splits_df.train_iter.unique()):
    (
        val_val_df[f"lut_D2CL_it_{training_iteration}_mean"],
        val_val_df[f"lut_D2CL_it_{training_iteration}_std"],
    ) = v_get_predictions(
        alias=val_val_df.alias,
        index=val_val_df[cid],
        key="D2CL",
        training_iteration=training_iteration,
        trained_luts=trained_luts,
    )

## Eval

In [None]:
def get_predictions(df, key, zero_shape):
    if key in df.columns:
        return df[key]
    else:
        return np.zeros_like(zero_shape)

In [None]:
results = []

for training_iteration in tqdm(iterative_training_splits_df.train_iter.unique()):
    for alias in val_val_df.alias.unique():
        _d = val_val_df[val_val_df.alias == alias]
        _d_y_true = _d["Dist_To_Center_Lane"]
        _d_y_predicted_lut = _d[f"lut_D2CL_it_{training_iteration}_mean"]
        _d_y_predicted_mlp = get_predictions(_d,f"predictions_{training_iteration}",_d_y_true)

        results.append(
            {
                "training_iteration": training_iteration,
                "alias": alias,
                "rmse_lut": mean_squared_error(
                    _d_y_true, _d_y_predicted_lut, squared=False
                ),
                "rmse_mlp": mean_squared_error(
                    _d_y_true, _d_y_predicted_mlp, squared=False
                ),
                "mse_lut": mean_squared_error(_d_y_true, _d_y_predicted_lut),
                "mse_mlp": mean_squared_error(_d_y_true, _d_y_predicted_mlp),
            }
        )

results_df = pd.DataFrame.from_dict(results)

# Plots

## Iterative Training

In [None]:
importlib.reload(plot_helpers)
plt.close()
t = None
t = f"./final_iterative_plots/{run}_{step_count}_steps.pdf"
plot_helpers.plot_iterative_training(
    results_df=results_df,
    marker_size=1,
    figsize=(2, 2),
    y_lim=(0.05, 0.95),
    legend_y_pad=-0.32,
    target=t,
    x_ticks=[int(i * step_count / 5) for i in range(6)],
)

## Cluster wise Histograms

In [None]:
RANGE = (-0.8, 0.8)
N_BINS = 40
FIGSIZE = (2, 2)
KED_LINEWIDTH = 1.5
context = ["science", "ieee", "no-latex"]
legend_y_pad = -0.3

for cluster_id in tqdm(trained_luts["001"][9]._lut.keys()):

    try:
        data_001 = trained_luts["001"][9]._lut[cluster_id]["D2CL"]["data"]
        data_002 = trained_luts["002"][9]._lut[cluster_id]["D2CL"]["data"]
        data_003 = trained_luts["003"][9]._lut[cluster_id]["D2CL"]["data"]
        data_004 = trained_luts["004"][9]._lut[cluster_id]["D2CL"]["data"]
        data_005 = trained_luts["005"][9]._lut[cluster_id]["D2CL"]["data"]
    except:
        continue

    with plt.style.context(context):
        fig, ax = plt.subplots(figsize=FIGSIZE)

        plot_helpers.plot_hist(
            data_001,
            ax=ax,
            range=RANGE,
            n_bins=N_BINS,
            bar_alpha=0.25,
            color="k",
            fill_color="k",
            edgecolor="k",
            kde_linestyle="-",
            kde_linewidth=KED_LINEWIDTH,
            show_bar=False,
            label="001",
        )
        plot_helpers.plot_hist(
            data_002,
            ax=ax,
            range=RANGE,
            n_bins=N_BINS,
            bar_alpha=0.25,
            color="r",
            fill_color="r",
            edgecolor="r",
            kde_linestyle="-",
            kde_linewidth=KED_LINEWIDTH,
            show_bar=False,
            label="002",
        )
        plot_helpers.plot_hist(
            data_003,
            ax=ax,
            range=RANGE,
            n_bins=N_BINS,
            bar_alpha=0.25,
            color="k",
            fill_color="k",
            edgecolor="k",
            kde_linestyle="--",
            kde_linewidth=KED_LINEWIDTH,
            show_bar=False,
            label="003",
        )
        plot_helpers.plot_hist(
            data_004,
            ax=ax,
            range=RANGE,
            n_bins=N_BINS,
            bar_alpha=0.25,
            color="r",
            fill_color="r",
            edgecolor="r",
            kde_linestyle="--",
            kde_linewidth=KED_LINEWIDTH,
            show_bar=False,
            label="004",
        )
        plot_helpers.plot_hist(
            data_005,
            ax=ax,
            range=RANGE,
            n_bins=N_BINS,
            bar_alpha=0.25,
            color="k",
            fill_color="w",
            edgecolor="k",
            kde_linestyle=":",
            kde_linewidth=KED_LINEWIDTH,
            show_bar=False,
            label="005",
        )

        plt.xlabel(r"$d_{\mathrm{CL}}$ in $m$")

        plt.legend(ncols=5, loc="center", bbox_to_anchor=(0.5, legend_y_pad))
        plt.xlim(RANGE[1], RANGE[0])
        plt.savefig(f"./hist_plots_resnext50/hist_plot_{cluster_id}.pdf")
        plt.savefig(f"./hist_plots_resnext50/hist_plot_{cluster_id}.png")
        plt.close()

## Plot predicted Trajectories

In [None]:
IT = 199
CID = cid
n_frames = 6

v_df = copy.deepcopy(val_val_df)
v_df[f"{CID}_d2cl_mean"] = v_df[f"lut_D2CL_it_{IT}_mean"]
v_df[f"{CID}_d2cl_std"] = v_df[f"lut_D2CL_it_{IT}_std"]


for alias in v_df.alias.unique():
    _v_a = v_df[v_df.alias == alias]

    for s in tqdm(_v_a.segment.unique()):
        _v_a_s = _v_a[_v_a.segment == s]

        step_size = (_v_a_s.frame.max() - _v_a_s.frame.min()) / (n_frames - 1)
        frame_markers = [int(_v_a_s.frame.min() + i * step_size) for i in range(n_frames)]
        frame_markers_closest = [min(_v_a_s.frame.tolist(), key=lambda x:abs(x-f)) for f in frame_markers]
        frame_cluster_annotations = _v_a_s[_v_a_s.frame.isin(frame_markers_closest)][CID].tolist()

        if len(frame_markers_closest) != len(frame_cluster_annotations):
            print(f"Skipping {alias}-->{s}")
            continue

        plot_helpers.plotSituationPredictions(
            df=_v_a_s,
            cID=CID,
            target=f"./sit_plots_with_nr_nc/{alias}_{s}.pdf",
            mlp_predictions=_v_a_s.predictions.to_numpy(),
            std_alpha=0.1,
            cluster_marker_size=10,
            figsize=(1.8, 1.5),
            legend_below_plot=True,
            fill_cluster_marker=False,
            cluster_marker_color="r",
            annotate_cluster_markers = False,
            legend_y_pad = -0.35,
            legend_n_cols = 4,
            frame_markers=frame_markers_closest,
            frame_cluster_annotations= frame_cluster_annotations,
            y_lim = 0.8
        )

## Save Situation Images

In [None]:
n_frames = 6
collageSize = (600, 400)
imagesRoot = "/root/sadc/data/images/val/"
targetFolder = "./sit_set_cards_with_nr/"

for alias in v_df.alias.unique():
    _v_a = v_df[v_df.alias == alias]

    for s in tqdm(_v_a.segment.unique()):
        _v_a_s = _v_a[_v_a.segment == s]

        step_size = (_v_a_s.frame.max() - _v_a_s.frame.min()) / (n_frames - 1)
        frames = [
            plot_helpers.get_frame(
                imagesRoot=imagesRoot,
                alias=alias,
                frame=int(_v_a_s.frame.min() + i * step_size),
            )
            for i in range(n_frames)
        ]

        c = plot_helpers.createCollage(frames, collageSize)
        c.save(f"{targetFolder}/{alias}_{s}_samples.jpg", optimize=True, quality=95)