In [None]:
import os
import sys
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Literal, Optional, Tuple, Union

import IPython
import matplotlib.pyplot as plt
import numpy as np
import logging
import pandas as pd
import seaborn as sns
from IPython.display import display
from matplotlib.axes import Axes as Axes

notebook_path = Path(IPython.extract_module_locals()[1]["__vsc_ipynb_file__"])
project_dir = notebook_path.parent.parent
sys.path.append(str(project_dir))
import src.utils.custom_log as custom_log
from src.utils.Csv import Csv
from src.utils.set_rcparams import set_rcparams

os.chdir(project_dir)
LOG:logging.Logger = logging.getLogger(__name__)
custom_log.init_logger(logging.INFO)
LOG.info("Log initialized")

set_rcparams()

In [2]:
B_PATH: Path = Path("/mnt") / "q"  / "Val_Chain_Sims" / "AB_Testing"
ISO_PATH: Path = B_PATH / "ISO18571"
ISO_PATH.mkdir(exist_ok=True, parents=True)

In [3]:
CASES: Tuple[str, str, str, str, str, str, str, str, str, str, str, str] = (
    "000_Base_Model",
    "100_Guided_BIW",
    "200_PAB_Simplified",
    "300_Seat_Simplified",
    "400_HIII",
    "400_900_NoIntrusion",
    "500_NoCAB",
    "600_NoDoor",
    "700_Simplified_Belt",
    "800_Simplified_BIW",
    "900_NoIntrusion",
    "950_Dash_Rigid",
    "990_Carpet_Rigid",
)

In [4]:
def label(idx:int, setting:str) -> str:

    return f"{CASES[idx]}_{setting}"

In [5]:
LOAD_CASES = ("Full Frontal", "Moderate Overlap Left", "Moderate Overlap Right", "Oblique Right", "Oblique Left")

In [None]:
ISO_DATA = defaultdict(dict)
for iso_file in ISO_PATH.glob("*.csv.zip"):
    LOG.info("ISO file %s", iso_file)
    data = Csv(csv_path=iso_file, compress=True).read()
    LOG.info("ISO data %s", data.shape)
    parts = iso_file.stem.split("_")
    ISO_DATA[" ".join(parts[:-1])][parts[-1].split(".")[0].replace("DM", "??")] = data
ISO_DATA = dict(ISO_DATA)
ISO_DATA

In [None]:
fig, ax = plt.subplots(1,1,figsize=(20,20))
sns.heatmap(ISO_DATA["Full Frontal"]["03CHST0000??50ACZD"], ax=ax, annot=True)

In [None]:
def isolate(names: List[str]) -> Dict[str, List[str]]:
    groups = defaultdict(list)

    for name in names:
        if name.endswith("Report"):
            groups[name].append(name)
        else:
            groups[name[:-6]].append(name)

    return dict(groups)


def get_grouped_isos() -> Dict[str, Dict[str, pd.DataFrame]]:
    all_grouped_isos: Dict[str, Dict[str, pd.DataFrame]] = defaultdict(dict)

    for case in ISO_DATA.keys():
        for channel in ISO_DATA[case].keys():
            LOG.info("Processing %s %s", case, channel)
            db: pd.DataFrame = ISO_DATA[case][channel].copy()
            db.rename(
                columns={x: x.replace("_", " ").replace("CAE ", "") for x in db.columns if "Report" not in x},
                index={x: x.replace("_", " ").replace("CAE ", "") for x in db.index if "Report" not in x},
                inplace=True,
            )

            ref_groups = isolate(db.index)
            chal_groups = isolate(db.columns)

            iso_avgs = defaultdict(dict)
            for ref_group in ref_groups.keys():
                for chal_group in chal_groups.keys():
                    selection = db.loc[ref_groups[ref_group], chal_groups[chal_group]]
                    LOG.debug("Selection of %s to %s has shape %s", ref_group, chal_group, selection.shape)
                    iso_avgs[ref_group][chal_group] = selection.median(axis=None)
            all_grouped_isos[case][channel] = pd.DataFrame(iso_avgs)
            LOG.info("Grouped ISOs for %s %s with shape %s", case, channel, all_grouped_isos[case][channel].shape)

    return dict(all_grouped_isos)

ALL_GROUPED_ISOS = get_grouped_isos()

In [None]:
def show_isos(
    case: Literal["Full Frontal", "Oblique Left", "Oblique Right"],
    channels: Optional[List[str]] = None,
    ax: Optional[Axes] = None,
    cmap: Optional[str] = None,
    norm: Optional[str] = None,
) -> None:
    refs = ["HW TH Report", "CAE TH Report"] + [x.replace("_", " ") for x in CASES]
    chals = refs.copy()

    for channel in ALL_GROUPED_ISOS[case].keys() if channels is None else channels:
        LOG.info("Processing %s %s", case, channel)

        av_refs = set(ALL_GROUPED_ISOS[case][channel].columns)
        av_chals = set(ALL_GROUPED_ISOS[case][channel].index)

        selected = defaultdict(dict)
        for ref in refs:
            for chal in chals:
                if ref in av_refs and chal in av_chals:
                    if ref == chal:
                        selected[ref][chal] = np.nan
                    else:
                        selected[ref][chal] = ALL_GROUPED_ISOS[case][channel].loc[chal, ref]

                else:
                    selected[ref][chal] = np.nan
        to_plot = pd.DataFrame(selected, index=chals)
        to_plot.rename(
            index={idx: " ".join(idx.split()[1:]) for idx in to_plot.index if "Report" not in idx},
            columns={idx: " ".join(idx.split()[1:]) for idx in to_plot.columns if "Report" not in idx},
            inplace=True,
        )

        if ax is None:
            _, ax = plt.subplots(layout="constrained")
        sns.heatmap(
            to_plot[list(to_plot.index[:-1])],
            annot=True,
            mask=np.triu(np.ones_like(to_plot[list(to_plot.index[:-1])], dtype=bool), k=1),
            cbar=False,
            ax=ax,
            vmin=0,
            vmax=1,
            cmap="magma" if cmap is None else cmap,
            linewidth=0.7,
            norm=norm,
            fmt=".2f",
            # square=True,
            annot_kws={"fontsize":4}
        )
        ax.set_title(channel)
        ax.set_xlabel("Reference")
        ax.set_ylabel("Comparison")
        ax.grid()
        ax.set_axisbelow(True)
        # ax.invert_yaxis()

        # highlight upmid (TH Full Full vehicle vs H3 Buck)
        # for t in ax.texts:
        #    if t.get_position() == (1.5, 2.5) or t.get_position() == (0.5, 0.5) or t.get_position() == (0.5, 2.5):
        #        t.set_fontsize("large")
        #        t.set_fontweight("bold")


show_isos(LOAD_CASES[0], channels=["03CHST0000??50ACZD"])

In [None]:
import matplotlib as mpl


def plot_channel_group(
    case: str,
    channels: List[List[str]],
    formats: Optional[List[str]] = None,
    grp_name: Optional[str] = None,
) -> None:
    LOG.info("Processing load case %s with %s channels", case, len(sum(channels, [])))

    fig_width: float = 1 * (448.13095 / 72)
    fig_height: float = 1.1 * fig_width
    fig, ax = plt.subplot_mosaic(
        [["none"] * len(channels[0]), *channels],
        figsize=(fig_width, 0.5 * len(channels) * fig_height),
        layout="constrained",
        gridspec_kw={"height_ratios": (0.025, *([1] * len(channels)))},
    )
    # fig.suptitle(f"{case} - ISO 18571 Rating")
    cmap = mpl.colors.ListedColormap(["indianred", "orange", "yellowgreen", "forestgreen"])
    bounds = [0, 0.58, 0.8, 0.94, 1]
    norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
    fig.colorbar(
        mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
        cax=ax["none"],
        location="top",
    )
    ax["none"].set(frame_on=False)
    ax["none"].set_xticklabels([f"{x:.2f}" for x in bounds])


    for ch in sum(channels, []):
        LOG.info("Processing channel %s", ch)
        show_isos(case=case, channels=[ch], ax=ax[ch], cmap=cmap, norm=norm)

    p_path = ISO_PATH / "Figures"
    p_path.mkdir(exist_ok=True, parents=True)
    if formats is not None:
        for fmt in formats:
            fp_path = p_path / fmt.upper()
            fp_path.mkdir(exist_ok=True)
            plt.savefig(fp_path / f"{case.replace(' ', '_')}_{grp_name}.{fmt}")
        plt.close(fig)


plot_channel_group(
    LOAD_CASES[2],
    [
        ["00COG00000VH00VEXD", "00COG00000VH00VEYD"],
        ["00COG00000VH00ACXD", "00COG00000VH00ACYD"],
    ],
)

In [None]:
# pulse
channel_sets = {
    "Pulses": [
        ["00COG00000VH00VEXD", "00COG00000VH00VEYD"],
        ["00COG00000VH00ACXD", "00COG00000VH00ACYD"],
    ]
}

# RHS
for s in ("03", ):
    # RHS
    channel_sets[f"{s}_RHS"] = [
        [f"{s}FAB00000VH00PRRD", f"{s}BELTBUSLVH00DSRD"],
        [f"{s}BELTB000VH00DSRD", f"{s}BELTB000VH00FORD"],
        [f"{s}BELTB300VH00FORD", f"{s}BELTB400VH00FORD"],
        [f"{s}BELTB500VH00FORD", f"{s}BELTB600VH00FORD"],
    ]

    # body
    for part in ("HEAD", "CHST", "PELV"):
        channel_sets[f"{s}_{part}"] = [
            [f"{s}{part}0000??50ACRD", f"{s}{part}0000??50ACXD"],
            [f"{s}{part}0000??50ACYD", f"{s}{part}0000??50ACZD"],
        ]

    # femur
    channel_sets[f"{s}_FMR"] = [
        [f"{s}FEMRLE00??50FORD", f"{s}FEMRRI00??50FORD"],
    ]

for case in LOAD_CASES:
    for channel_set in channel_sets.keys():
        plot_channel_group(case=case, channels=channel_sets[channel_set], formats=["png", "pdf"], grp_name=channel_set)

In [None]:
ALL_GROUPED_ISOS.keys()