In [None]:
import logging
import os
import sys
from itertools import product
from pathlib import Path
from typing import Dict, List, Optional

import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.patches import Rectangle
from seaborn._statistics import LetterValues
from sklearn.metrics import r2_score, mean_absolute_error
from sklearn.preprocessing import RobustScaler

NOTEBOOK_PATH: Path = Path(IPython.extract_module_locals()[1]["__vsc_ipynb_file__"])
PROJECT_DIR: Path = NOTEBOOK_PATH.parent.parent
sys.path.append(str(PROJECT_DIR))
import src.utils.custom_log as custom_log
import src.utils.json_util as json_util
from src._StandardNames import StandardNames
from src.utils.PathChecker import PathChecker
from src.utils.set_rcparams import set_rcparams
from src.utils.Csv import Csv
from src.data.fe_processing.InjuryCalculator import InjuryCalculator
from src.data.fe_processing.IsoMme import IsoMme

os.chdir(PROJECT_DIR)
set_rcparams()

LOG: logging.Logger = logging.getLogger(__name__)
custom_log.init_logger(log_lvl=logging.INFO)
LOG.info("Log start, project directory is %s (exist: %s)", PROJECT_DIR, PROJECT_DIR.is_dir())

CHECK: PathChecker = PathChecker()
STR: StandardNames = StandardNames()

FULL_WIDTH: float = 448.13095 / 72 - 0.2

FIG_DIR: Path = CHECK.check_directory(PROJECT_DIR / "reports" / "figures", exit=False)
FIG_DIR /= NOTEBOOK_PATH.stem
FIG_DIR.mkdir(parents=True, exist_ok=True)
LOG.info("Figure directory is %s (exist: %s)", FIG_DIR, FIG_DIR.is_dir())
print(FULL_WIDTH)

COLORS: sns.palettes._ColorPalette = sns.color_palette()
COLORS

In [None]:
EXP_DIR: Path = Path("experiments")
D_DIRS: Dict[int, Path] = {
    5: CHECK.check_directory(EXP_DIR / "2024-12-16-08-32-37_test_set_predictions_5HIII"),
    95: CHECK.check_directory(EXP_DIR / "2024-12-16-12-49-35_test_set_predictions_95HIII"),
}
DOE_PATH: Path = CHECK.check_file(Path("data") / "doe" / "doe_sobol_test_20240829_135200" / "doe.parquet")
DOE_REF_PATH:Path = CHECK.check_file(Path("data")/"doe"/"doe_sobol_20240705_194200"/"doe_combined.parquet")

In [3]:
PRED_NAME:str = "y_pred_test.parquet"
TRUE_NAME:str = "y_true_test.parquet"

In [4]:
RENAMER: Dict[str, str] = {
    "Chest_Deflection": "CDC",
    "Chest_VC": "CVC",
    "Chest_a3ms": "CAC$_3$",
    "Femur_Fz_Max_Compression": "FCC",
    "Head_HIC15": "HIC$_{15}$",
    "Head_a3ms": "HAC$_3$",
    "Neck_Fx_Shear_Max": "NSC",
    "Neck_Fz_Max_Tension": "NTC",
    "Neck_My_Extension": "NEC",
}
INJ_CRIT: List[str] = [
    "Head_HIC15",
    "Head_a3ms",
    "Neck_My_Extension",
    "Neck_Fz_Max_Tension",
    "Neck_Fx_Shear_Max",
    "Chest_Deflection",
    "Chest_VC",
    "Chest_a3ms",
    "Femur_Fz_Max_Compression",
]

In [None]:
pd.read_parquet(D_DIRS[5] / TRUE_NAME)

In [None]:
def plot(perc: int = 95):
    # calculate loss
    y_true = pd.read_parquet(D_DIRS[perc] / TRUE_NAME)
    y_pred = pd.read_parquet(D_DIRS[perc] / PRED_NAME)
    if all(y_true.index == y_pred.index):
        LOG.info("Index of Y_TRUE and Index of Y_PRED are identical")
    else:
        LOG.error("Index of Y_TRUE and Index of Y_PRED are NOT identical")
    loss = y_true - y_pred

    # get DOEs
    doe = pd.read_parquet(DOE_PATH, filters=[(STR.perc, "==", 5), (STR.id, "in", loss.index)]).drop(columns=STR.perc)
    doe_ref = pd.read_parquet(DOE_REF_PATH, filters=[(STR.perc, "==", 5)]).drop(columns=STR.perc)
    doe_ref_min, doe_ref_max = doe_ref.min(), doe_ref.max()
    del doe_ref

    scores = {}
    for val in doe.columns:
        idx = y_true[doe[val].lt(doe_ref_min[val])].index
        if len(idx):
            scores[val] = pd.Series(
                r2_score(y_true=y_true.loc[idx], y_pred=y_pred.loc[idx], multioutput="raw_values"), index=y_true.columns
            )
            scores[val][scores[val] < 0] = 0
    scores_min = pd.DataFrame(scores)
    scores = {}
    for val in doe.columns:
        idx = y_true[doe[val].gt(doe_ref_max[val])].index
        if len(idx):
            scores[val] = pd.Series(
                r2_score(y_true=y_true.loc[idx], y_pred=y_pred.loc[idx], multioutput="raw_values"), index=y_true.columns
            )
            scores[val][scores[val] < 0] = 0
    scores_max = pd.DataFrame(scores)
    scores = {}
    for val in doe.columns:
        idx = y_true[doe[val].between(doe_ref_min[val], doe_ref_max[val])].index
        if len(idx):
            scores[val] = pd.Series(
                r2_score(y_true=y_true.loc[idx], y_pred=y_pred.loc[idx], multioutput="raw_values"), index=y_true.columns
            )
            scores[val][scores[val] < 0] = 0
    scores_mid = pd.DataFrame(scores)

    fig, ax = plt.subplots(
        nrows=loss.shape[1] + 1,
        ncols=doe.shape[1],
        sharey="row",
        sharex="col",
        layout="constrained",
        height_ratios=[0.1, *[1] * loss.shape[1]],
    )
    gs = ax[0, 0].get_gridspec()
    for ax_ in ax[0, :]:
        ax_.remove()
    axbig = fig.add_subplot(gs[0, :])
    axbig.axis("off")

    sc_col, r2_col = COLORS[0], COLORS[1]
    ax2 = []
    for j, crit in enumerate(INJ_CRIT, 1):
        for i, val in enumerate(doe.columns):
            ax[j, i].scatter(doe.loc[loss.index, val], loss[crit], s=1, c=sc_col, label="Loss")
            if i == 0:
                ax[j, i].set_ylabel(RENAMER[crit], c=sc_col)
            if j == loss.shape[1]:
                ax[j, i].set_xlabel(val)
            ax[j, i].grid()
            ax[j, i].set_axisbelow(True)
            ax[j, i].axvline(doe_ref_min[val], c="black", alpha=0.5, ls="--", label="Factor Range from Development-set")
            ax[j, i].axvline(doe_ref_max[val], c="black", alpha=0.5, ls="--")
            ax[j, i].axhline(0, c=sc_col, alpha=0.8, label="Loss = 0")
            ax[j, i].set_xlim([doe.loc[loss.index, val].min(), doe.loc[loss.index, val].max()])

            ax2 = ax[j, i].twinx()
            ax2.set_ylim([0, 1])
            ax2.set_yticks(np.linspace(0, 1, 6))
            if val in scores_mid.columns:
                ax2.plot([doe_ref_min[val], doe_ref_max[val]], [scores_mid.loc[crit, val]] * 2, c=r2_col, label="R2-score @ Segment")
                ax2.add_patch(
                    Rectangle(
                        xy=(doe_ref_min[val], 0),
                        width=doe_ref_max[val] - doe_ref_min[val],
                        height=scores_mid.loc[crit, val],
                        facecolor=r2_col,
                        alpha=0.2,
                    )
                )
            if val in scores_min.columns:
                ax2.plot([doe.loc[loss.index, val].min(), doe_ref_min[val]], [scores_min.loc[crit, val]] * 2, c=r2_col)
                ax2.add_patch(
                    Rectangle(
                        xy=(doe.loc[loss.index, val].min(), 0),
                        width=doe_ref_min[val] - doe.loc[loss.index, val].min(),
                        height=scores_min.loc[crit, val],
                        facecolor=r2_col,
                        alpha=0.2,
                    )
                )
            if val in scores_max.columns:
                ax2.plot([doe_ref_max[val], doe.loc[loss.index, val].max()], [scores_max.loc[crit, val]] * 2, c=r2_col)
                ax2.add_patch(
                    Rectangle(
                        xy=(doe_ref_max[val], 0),
                        width=doe.loc[loss.index, val].max() - doe_ref_max[val],
                        height=scores_max.loc[crit, val],
                        facecolor=r2_col,
                        alpha=0.2,
                    )
                )
            if i != doe.shape[1] - 1:
                ax2.set_yticklabels([])
            else:
                ax2.set_ylabel("R² (Capped)", c=r2_col)
                ax2.set_yticklabels(ax2.get_yticklabels(), c=r2_col)

    axbig.legend(
        ax[1, 0].get_legend_handles_labels()[0] + ax2.get_legend_handles_labels()[0],
        ax[1, 0].get_legend_handles_labels()[1] + ax2.get_legend_handles_labels()[1],
        loc="center",
        ncol=4,
        #title=f"HIII{perc:02d}{'F' if perc==5 else 'M'}",
    )
    fig.align_ylabels(ax)
    fig.set_figheight(FULL_WIDTH * 1.3)
    fig.set_figwidth(FULL_WIDTH)

    fig.savefig(FIG_DIR / f"loss_r2_{perc:02d}.pdf")


plot()

In [None]:
def get_errors():
    scores = []
    for perc in (5, 95):
        y_true = pd.read_parquet(D_DIRS[perc] / TRUE_NAME)
        y_pred = pd.read_parquet(D_DIRS[perc] / PRED_NAME)
        if all(y_true.index == y_pred.index):
            LOG.info("Index of Y_TRUE and Index of Y_PRED are identical")
        else:
            LOG.error("Index of Y_TRUE and Index of Y_PRED are NOT identical")

        sc = pd.DataFrame(
            {
                "MAE": mean_absolute_error(y_true, y_pred, multioutput="raw_values"),
                "R2": r2_score(y_true, y_pred, multioutput="raw_values"),
            },
        )
        sc["PERC"] = perc
        sc["Crit"] = y_true.columns
        scores.append(sc)

    scores = pd.concat(scores, ignore_index=True)
    display(scores)
    scores = scores.set_index(["PERC", "Crit"])

    return scores.unstack("PERC").loc[INJ_CRIT]


get_errors()

In [None]:
FE_B_PATH = CHECK.check_directory(Path("/mnt") / "q" / "Val_Chain_Sims" / "AB_Testing")
MODEL_PATHS = {
    "TEST": CHECK.check_directory(FE_B_PATH / "400_HIII"),
    "DEV": CHECK.check_directory(FE_B_PATH / "990_Carpet_Rigid"),
}

In [None]:
def read_score(rel_col: Optional[str] = None):
    dbs = {}
    for model_name, model_path in MODEL_PATHS.items():
        LOG.info("Reading %s", model_name)
        f_paths = model_path.rglob("injury_criteria.parquet")
        model_db = []
        for f_path in f_paths:
            LOG.info("Reading %s", f_path)
            model_db.append(pd.read_parquet(f_path, columns=rel_col).loc[["D"]])
            model_db[-1]["Case"] = f_path.parent.stem
            model_db[-1]["Assembly"] = f_path.parent.parent.stem
        LOG.info("Concatenating %s", model_name)
        dbs[model_name] = pd.concat(model_db, ignore_index=True).set_index(["Assembly", "Case"])

    absolute_error = (dbs["TEST"] - dbs["DEV"]).abs()

    display(absolute_error)

    display(absolute_error.groupby("Assembly").mean())
    display(absolute_error.groupby("Case").mean().T)
    display(absolute_error.mean(axis=0))

read_score(rel_col=INJ_CRIT)