In [None]:
from __future__ import annotations

import logging
import os
import sys
from itertools import product
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple, Union

import IPython
import matplotlib
import matplotlib.colors
import matplotlib.pyplot as plt
import numpy as np
import optuna
import pandas as pd
import seaborn as sns
from matplotlib.patches import Rectangle
from matplotlib.scale import ScaleBase
from seaborn._statistics import LetterValues
from sklearn.metrics import r2_score
from sklearn.preprocessing import RobustScaler

NOTEBOOK_PATH: Path = Path(IPython.extract_module_locals()[1]["__vsc_ipynb_file__"])
PROJECT_DIR: Path = NOTEBOOK_PATH.parent.parent
sys.path.append(str(PROJECT_DIR))
import src.utils.custom_log as custom_log
import src.utils.json_util as json_util
from src._StandardNames import StandardNames
from src.utils.PathChecker import PathChecker
from src.utils.set_rcparams import set_rcparams

os.chdir(PROJECT_DIR)
set_rcparams()

LOG: logging.Logger = logging.getLogger(__name__)
custom_log.init_logger(log_lvl=logging.INFO)
LOG.info("Log start, project directory is %s (exist: %s)", PROJECT_DIR, PROJECT_DIR.is_dir())

CHECK: PathChecker = PathChecker()
STR: StandardNames = StandardNames()

FULL_WIDTH: float = 448.13095 / 72 - 0.2

FIG_DIR: Path = CHECK.check_directory(PROJECT_DIR / "reports" / "figures", exit=False)
FIG_DIR /= NOTEBOOK_PATH.stem
FIG_DIR.mkdir(parents=True, exist_ok=True)
LOG.info("Figure directory is %s (exist: %s)", FIG_DIR, FIG_DIR.is_dir())

In [None]:
EXP_DIR: Path = CHECK.check_directory(
    PROJECT_DIR / "experiments" / "2024-11-10-21-13-00_pure_cnn_optuna_doe_sobol_20240705_194200", exit=False
)

In [None]:
COLORS: sns.palettes._ColorPalette = sns.color_palette()
COLORS

In [None]:
def read_study():
    study: pd.DataFrame = optuna.load_study(
        study_name="AnnUniversalPureKeras", storage=f"sqlite:///{EXP_DIR.absolute() / 'study.sqlite3'}"
    ).trials_dataframe()
    study.set_index("number", inplace=True)
    study = study[study["state"].eq("COMPLETE")]
    rel_cols = sorted([col for col in study.columns if "params" in col or col == "value"])
    study.fillna(False, inplace=True)

    return study[rel_cols].copy()


TRIALS: pd.DataFrame = read_study()
TRIALS

In [None]:
TRIALS["value"].describe()

In [None]:
TRIALS.nlargest(20, "value")[
    [
        "params_kernel_size_0_0",
        "params_kernel_size_1_0",
        "params_kernel_size_2_0",
    ]
]

In [None]:
TRIALS[TRIALS["value"].ge(TRIALS["value"].max() - 4e-3)]

In [None]:
fig, ax = plt.subplots(figsize=(FULL_WIDTH, 4))
sns.stripplot(TRIALS["value"], ax=ax)
ax.axhline(TRIALS["value"].max() - 4e-3, color="red", linestyle="--")

In [None]:
def make_hist(thres=0.9):
    fig, ax = plt.subplots(figsize=(0.7 * FULL_WIDTH, 0.4 * FULL_WIDTH))
    sns.histplot(TRIALS["value"][TRIALS["value"].ge(thres)], ax=ax, bins=40)
    ax.set_title("Distribution of R2 scores")
    ax.set_xlabel("R2 score")
    ax.set_ylabel("Frequency")
    ax.grid()
    fig.savefig(FIG_DIR / "r2_hist.pdf")


make_hist()

In [None]:
def get_uniques(thres=0.9045):
    idx = TRIALS[TRIALS["value"].ge(thres)].index
    print(len(idx))
    for col in TRIALS.columns:
        if "params" in col:
            print(col, f"{len(TRIALS.loc[idx, col].unique())}/{len(TRIALS[col].unique())}", sorted(TRIALS.loc[idx, col].unique()))


get_uniques()

In [5]:
IGNORE: List[str] = [
    "params_03CHST0000OCCUACXD",
    "params_03CHST0000OCCUACYD",
    "params_03CHST0000OCCUACZD",
    "params_03CHST0000OCCUDSXD",
    "params_03CHSTLOC0OCCUDSXD",
    "params_03CHSTLOC0OCCUDSYD",
    "params_03CHSTLOC0OCCUDSZD",
    "params_03FEMRLE00OCCUFOZD",
    "params_03FEMRRI00OCCUFOZD",
    "params_03HEAD0000OCCUACXD",
    "params_03HEAD0000OCCUACYD",
    "params_03HEAD0000OCCUACZD",
    "params_03HEADLOC0OCCUDSXD",
    "params_03HEADLOC0OCCUDSYD",
    "params_03HEADLOC0OCCUDSZD",
    "params_03NECKUP00OCCUFOXD",
    "params_03NECKUP00OCCUFOZD",
    "params_03NECKUP00OCCUMOYD",
    "params_03PELV0000OCCUACXD",
    "params_03PELV0000OCCUACYD",
    "params_03PELV0000OCCUACZD",
    "params_03PELVLOC0OCCUDSXD",
    "params_03PELVLOC0OCCUDSYD",
    "params_03PELVLOC0OCCUDSZD",
    "params_Chest_Deflection",
    "params_Chest_VC",
    "params_Chest_a3ms",
    "params_Femur_Fz_Max_Compression",
    "params_Femur_Fz_Max_Tension",
    "params_Head_HIC15",
    "params_Head_a3ms",
    "params_Neck_Fx_Shear_Max",
    "params_Neck_Fz_Max_Compression",
    "params_Neck_Fz_Max_Tension",
    "params_Neck_My_Extension",
    "params_Neck_My_Flexion",
    "params_Neck_Nij",
    "params_kernel_size_0_0",
    "params_kernel_size_0_1",
    "params_kernel_size_0_2",
    "params_kernel_size_0_3",
    "params_kernel_size_1_0",
    "params_kernel_size_1_1",
    "params_kernel_size_1_2",
    "params_kernel_size_1_3",
    "params_kernel_size_2_0",
    "params_kernel_size_2_1",
    "params_kernel_size_2_2",
    "params_kernel_size_2_3",
    "params_kernel_size_3_0",
    "params_kernel_size_3_1",
    "params_kernel_size_3_2",
    "params_kernel_size_3_3",
    "params_n_filters_0_0",
    "params_n_filters_0_1",
    "params_n_filters_0_2",
    "params_n_filters_0_3",
    "params_n_filters_1_0",
    "params_n_filters_1_1",
    "params_n_filters_1_2",
    "params_n_filters_1_3",
    "params_n_filters_2_0",
    "params_n_filters_2_1",
    "params_n_filters_2_2",
    "params_n_filters_2_3",
    "params_n_filters_3_0",
    "params_n_filters_3_1",
    "params_n_filters_3_2",
    "params_n_filters_3_3",
]

In [11]:
def plot_parcoords(
    values: Union[Sequence, np.ndarray],
    labels: Union[Sequence[str], np.ndarray[str]] = None,
    title: str = None,
    color_field: Union[str, int] = None,
    colors=["red", "grey", "green", "black"],
    c_thres=[0.85, 0.905, 1],
    scale: Union[
        Sequence[Tuple[Union[str, int], Union[str, ScaleBase]]],
        Sequence[Union[str, ScaleBase]],
        str,
    ] = None,
    figsize: Tuple[int, int] = None,
    y_limits: Union[Sequence, np.ndarray] = None,
    axs: np.ndarray[plt.Axes] = None,
) -> Tuple[plt.Figure, np.ndarray[plt.Axes]]:
    """Plotting function for parallel coordinate plots.
    Adapted from https://github.com/VoigtPeter/parcoords.git

    :param values: 2-dimensional sequence or numpy-array containing
        row-vectors of the data to display. (required)
    :param labels: Sequence containing the column labels. (optional)
    :param title: Title of the figure. (optional)
    :param color_field: Either the label of the column (`labels` must be provided)
        or the column index used as basis for the coloring. If not
        provided, the `color` attribute will be used. (optional)
    :param color: Color of the edges when `color_field` attribute is not provided.
        (default: grey)
    :param scale: Sequence of scale types. Must be in one of the forms:
        [({field label/index}, {"linear", "log", ...}), ...] or
        [{"linear", "log", ...}, ...] or
        {"linear", "log", ...}.
        (optional, default: linear)
    :param figsize: Size of the figure. (optional)
    :param y_limits: The min- & max-limits for the axes. Must be in the form of:
        [(`min`, `max`), ...] for all axes. (optional)
    :param axs: An existing axes array, to be used when adding more datapoints.
        (optional)
    :return: The figure object and the axes (as ndarray).
    """
    # transpose row-vector to column-vector
    if not isinstance(values, np.ndarray):
        values = np.array(values, dtype="object")
    values = values.T

    no_of_cols = len(values)
    if axs is None:
        # initialize figure and axes
        fig, axs = plt.subplots(
            1,
            no_of_cols - 1,
            sharey="none",
            gridspec_kw=dict(wspace=0),
        )
        axs = np.append(axs, axs[-1].twinx())

        # calculate limits from data & transform nominal columns
        ylims = []
        for i, (column, ax) in enumerate(zip(values, axs)):
            if not all(isinstance(item, (int, float)) for item in column):
                mappings, column = np.unique(column, return_inverse=True)
                values[i] = column
                ax.set_yticks(
                    range(len(mappings)),
                    labels=mappings,
                )
            if y_limits is None:
                smallest = column.min(axis=0)
                largest = column.max(axis=0)
                ylims.append([smallest, largest])

        if y_limits is None:
            y_limits = ylims

    else:
        for i, (column, ax) in enumerate(zip(values, axs)):
            if not all(isinstance(item, (int, float)) for item in column):
                for x, value in enumerate(column):
                    if isinstance(value, str):
                        column[x] = {text.get_text(): i for i, text in enumerate(ax.get_yticklabels())}[value]
                values[i] = column
        if y_limits is not None:
            print("Warning: setting `y_limits` when using existing axes has no effect.")
            y_limits = None
        if scale is not None:
            print("Warning: setting `scale` when using existing axes has no effect.")
            scale = None
        fig = axs[0].get_figure()

    if scale is not None:
        if isinstance(scale, str):
            [ax.set_yscale(scale) for ax in axs]
        elif isinstance(scale, Sequence) and len(scale) > 0 and all([isinstance(s, str) for s in scale]):
            [ax.set_yscale(s) for ax, s in zip(axs, scale)]
        elif (
            isinstance(scale, Sequence)
            and len(scale) > 0
            and all(
                [
                    isinstance(s, Sequence) and len(s) == 2 and isinstance(s[0], (int, str)) and isinstance(s[1], str)
                    for s in scale
                ]
            )
        ):
            for column, scale_type in scale:
                if isinstance(column, str):
                    column = labels.index(column)
                axs[column].set_yscale(scale_type)
        else:
            print(f"Warning: invalid value '{scale}' passed to `scale`. Attribute is ignored.")

    for i, ax in enumerate(axs):
        if i < len(axs) - 1:
            ax.set_xlim([i, i + 1])
        if y_limits is not None:
            ax.set_ylim(y_limits[i])
        ax.spines[["bottom", "top"]].set_visible(False)
        ax.get_xaxis().set_visible(False)
        if labels is not None:
            ax.text(
                0 if i < len(axs) - 1 else 1,
                -0.02,
                labels[i],
                horizontalalignment="center",
                verticalalignment="top",
                rotation=90,
                transform=ax.transAxes,
                weight="bold" if i == len(values) - 1 else "normal",
            )

    for i in range(len(values) - 1):
        for y1, y2, c in zip(values[i], values[i + 1], values[color_field]):
            ls, lw = "-", 1
            if c == c_thres[-1]:
                # best
                colo = colors[-1]
                alpha = 1
                ls, lw = ":", 3
            elif c <= c_thres[0]:
                # bottom x
                colo = colors[0]
                alpha = 0.2
            elif c >= c_thres[1]:
                # top x
                colo = colors[2]
                alpha = 1
            else:
                # rest
                colo = colors[1]
                alpha = 0.1

            axs[i].axline(
                [
                    0,
                    axs[i].transLimits.transform(axs[i].transScale.transform([0, y1]))[1],
                ],
                [
                    1,
                    axs[i + 1].transLimits.transform(axs[i + 1].transScale.transform([1, y2]))[1],
                ],
                c=colo,
                alpha=alpha,
                transform=axs[i].transAxes,
                ls=ls,
                lw=lw,
            )
    for i in range(len(values)):
        for tick in axs[i].get_yticklabels():
            tick.set_rotation("vertical")
            if i == len(values) - 1:
                tick.set_weight("bold")
                tick.set_verticalalignment("center")
            else:
                tick.set_verticalalignment("bottom")

    if fig:
        fig.subplots_adjust(wspace=0)
        fig.subplots_adjust(top=0.85)
        if title is not None:
            fig.suptitle(title)

        fig.set_figheight(figsize[0])
        fig.set_figwidth(figsize[1])

    return fig, axs

In [None]:
def alt_cat():
    trials = TRIALS.copy()

    dtypes = {
        "params_conv_depth": float,
        "params_conv_width": float,
        "params_dense_regularizer": str,
        "params_dropout_conv": bool,
        "params_file_names_ai_in": str,
        "params_fst_dense_layer_shape": int,
        "params_last_dense_layer_shape": int,
        "params_learning_rate": float,
        "params_n_dense_layers": int,
        "params_pooling_size": int,
        "params_pooling_strategy": str,
        "params_share_dense": bool,
        "params_spatial_dropout_rate": float,
        "params_temporal_feature_n_tsps": int,
    }
    for col, dtype in dtypes.items():
        trials[col] = trials[col].astype(dtype)
        if dtype == bool:
            trials[col] = trials[col].astype(str)
    trials.rename(columns={col: "_".join(col.split("_")[1:]) for col in trials.columns if col.startswith("params")}, inplace=True)
    trials.rename(columns={"value": "R2-score"}, inplace=True)

    col_ord = [
        "learning_rate",
        "share_dense",
        "dense_regularizer",
        "n_dense_layers",
        "fst_dense_layer_shape",
        "last_dense_layer_shape",
        "file_names_ai_in",
        "temporal_feature_n_tsps",
        "conv_depth",
        "conv_width",
        "pooling_strategy",
        "pooling_size",
        "dropout_conv",
        "spatial_dropout_rate",
        "R2-score",
    ]

    fig, ax = plot_parcoords(
        trials[col_ord].to_numpy(),
        labels=col_ord,
        figsize=(FULL_WIDTH, 1.3 * FULL_WIDTH),
        color_field=-1,
        # c_thres=[TRIALS["value"].nsmallest(20).to_list()[-1], TRIALS["value"].nlargest(20).to_list()[-1], TRIALS["value"].max()],
        c_thres=[TRIALS["value"].nsmallest(20).to_list()[-1], TRIALS["value"].max() - 2e-3, TRIALS["value"].max()],
        scale=["log"],
    )

    fig.savefig(FIG_DIR / "parcoords_cnn.pdf")


alt_cat()

In [None]:
def plot_history(param="params_file_names_ai_in"):
    fig, ax = plt.subplots(figsize=(FULL_WIDTH, 0.4 * FULL_WIDTH))
    for conf in TRIALS[param].unique():
        ax.scatter(TRIALS[TRIALS[param].eq(conf)].index, TRIALS[TRIALS[param].eq(conf)]["value"], label=conf, s=4)
    ax.set_xlabel("Trial Number")
    ax.set_ylabel("R2-score")
    ax.grid()
    ax.legend(title=f"{'_'.join(param.split('_')[1:])}")
    fig.savefig(FIG_DIR / "r2_over_trials_{param}.pdf")


plot_history()

In [None]:
_ = [plot_history(param) for param in TRIALS.columns if param.startswith("params") and param not in IGNORE]

In [None]:
def plots_analyse_5(exp_path: Path, inj_crits: List[str], lims: List[List[int]] = [[5, 35], [0, 2]]):
    fields = [["letter_true_0", "empty_0", "letter_true_1", "empty_1"], ["data_0", "letter_pred_0", "data_1", "letter_pred_1"]]
    fig, ax = plt.subplot_mosaic(mosaic=fields, height_ratios=[0.2, 1], width_ratios=[*[1, 0.2] * 2], layout="constrained")
    renamer = {"Chest_Deflection": "CDC", "Femur_Fz_Max_Compression": "FCC", "Head_a3ms": "HAC$_3$", "Chest_a3ms": "CAC$_3$"}

    for i, inj_crit in enumerate(inj_crits):
        y_true = pd.read_parquet(exp_path / "y_true_Test_Fold-1.parquet", columns=[inj_crit])
        y_pred = pd.read_parquet(exp_path / "y_pred_Test_Fold-1.parquet", columns=[inj_crit])
        s_score = pd.read_parquet(exp_path / "sample_score.parquet", columns=[inj_crit])
        y_true_str, y_pred_str = f"True {renamer[inj_crit]}", f"Predicted {renamer[inj_crit]}"
        db = pd.DataFrame(
            {
                y_true_str: y_true.loc[y_true.index, inj_crit].values,
                y_pred_str: y_pred.loc[y_true.index, inj_crit].values,
                "Sample Score": s_score.loc[y_true.index, inj_crit].values,
            }
        )

        ax[f"data_{i}"].scatter(db[y_true_str], db[y_pred_str], alpha=0.5, s=2, c=COLORS[4], marker=".")
        sns.kdeplot(
            data=db,
            x=y_true_str,
            y=y_pred_str,
            ax=ax[f"data_{i}"],
            color=COLORS[-3],
            linewidths=0.5,
        )
        ax[f"data_{i}"].plot(
            lims[i],
            lims[i],
            color="black",
            linestyle="--",
            alpha=0.2,
        )
        # ax.set_title(f"Maximum {REL_COLS[inj_crit].replace(' [mm]', '')}")
        ax[f"data_{i}"].set_xlim(lims[i])
        ax[f"data_{i}"].set_ylim(lims[i])
        ax[f"data_{i}"].grid()
        ax[f"data_{i}"].text(
            0.85,
            0.05,
            f"R2 {r2_score(db[y_true_str], db[y_pred_str]):.3f}",
            horizontalalignment="center",
            verticalalignment="center",
            transform=ax[f"data_{i}"].transAxes,
            fontweight="bold",
        )

        sns.boxenplot(x=y_true[inj_crit], ax=ax[f"letter_true_{i}"], color=COLORS[0], flier_kws={"s": 0.5})
        ax[f"letter_true_{i}"].set_xlim(lims[i])
        ax[f"letter_true_{i}"].set_xlabel("")
        ax[f"letter_true_{i}"].set_ylabel("True")
        ax[f"letter_true_{i}"].set_xticklabels([])

        sns.boxenplot(y=y_pred[inj_crit], ax=ax[f"letter_pred_{i}"], color=COLORS[1], flier_kws={"s": 0.5})
        ax[f"letter_pred_{i}"].set_ylim(lims[i])
        ax[f"letter_pred_{i}"].set_ylabel("")
        ax[f"letter_pred_{i}"].set_xlabel("Pred")
        ax[f"letter_pred_{i}"].set_yticklabels([])

        ax[f"empty_{i}"].axis("off")

    fig.set_figwidth(FULL_WIDTH)
    fig.set_figheight(0.5 * FULL_WIDTH)
    fig.savefig(FIG_DIR / f"inj_crits_{inj_crits[0]}_{inj_crits[1]}.pdf")

plots_analyse_5(
    exp_path=Path("experiments")
    / "2024-12-20-13-40-40_pure_cnn_single_full_output_05HIII_injury_criteria_from_doe_sobol_20240705_194200",
    inj_crits=["Chest_Deflection", "Femur_Fz_Max_Compression"],
)

In [None]:
plots_analyse_5(
    exp_path=Path("experiments")
    / "2024-12-20-13-40-40_pure_cnn_single_full_output_05HIII_injury_criteria_from_doe_sobol_20240705_194200",
    inj_crits=["Head_a3ms", "Chest_a3ms"],
    lims=[[25, 150], [35, 70]],
)

In [None]:
def plots2_analyse_5(exp_path: Path, inj_crits: List[str], lims: List[List[int]] = [[5, 35], [0, 2], [25, 150], [35, 70]]):
    fields = [["letter_true_0", "empty_0", "letter_true_1", "empty_1"], ["data_0", "letter_pred_0", "data_1", "letter_pred_1"], ["letter_true_2", "empty_2", "letter_true_3", "empty_3"], ["data_2", "letter_pred_2", "data_3", "letter_pred_3"]]
    fig, ax = plt.subplot_mosaic(mosaic=fields, height_ratios=[*[0.2, 1]*2], width_ratios=[*[1, 0.2] * 2], layout="constrained")
    renamer = {"Chest_Deflection": "CDC", "Femur_Fz_Max_Compression": "FCC", "Head_a3ms": "HAC$_3$", "Chest_a3ms": "CAC$_3$"}

    for i, inj_crit in enumerate(inj_crits):
        y_true = pd.read_parquet(exp_path / "y_true_Test_Fold-1.parquet", columns=[inj_crit])
        y_pred = pd.read_parquet(exp_path / "y_pred_Test_Fold-1.parquet", columns=[inj_crit])
        s_score = pd.read_parquet(exp_path / "sample_score.parquet", columns=[inj_crit])
        y_true_str, y_pred_str = f"True {renamer[inj_crit]}", f"Predicted {renamer[inj_crit]}"
        db = pd.DataFrame(
            {
                y_true_str: y_true.loc[y_true.index, inj_crit].values,
                y_pred_str: y_pred.loc[y_true.index, inj_crit].values,
                "Sample Score": s_score.loc[y_true.index, inj_crit].values,
            }
        )

        ax[f"data_{i}"].scatter(db[y_true_str], db[y_pred_str], alpha=0.5, s=2, c=COLORS[4], marker=".")
        sns.kdeplot(
            data=db,
            x=y_true_str,
            y=y_pred_str,
            ax=ax[f"data_{i}"],
            color=COLORS[-3],
            linewidths=0.5,
            
        )
        ax[f"data_{i}"].plot(
            lims[i],
            lims[i],
            color="black",
            linestyle="--",
            alpha=0.2,
        )
        # ax.set_title(f"Maximum {REL_COLS[inj_crit].replace(' [mm]', '')}")
        ax[f"data_{i}"].set_xlim(lims[i])
        ax[f"data_{i}"].set_ylim(lims[i])
        ax[f"data_{i}"].grid()
        ax[f"data_{i}"].text(
            0.85,
            0.05,
            f"R2 {r2_score(db[y_true_str], db[y_pred_str]):.3f}",
            horizontalalignment="center",
            verticalalignment="center",
            transform=ax[f"data_{i}"].transAxes,
            fontweight="bold",
        )

        sns.boxenplot(x=y_true[inj_crit], ax=ax[f"letter_true_{i}"], color=COLORS[0], flier_kws={"s": 0.5},alpha=0.6,)
        ax[f"letter_true_{i}"].set_xlim(lims[i])
        ax[f"letter_true_{i}"].set_xlabel("")
        ax[f"letter_true_{i}"].set_ylabel("True")
        ax[f"letter_true_{i}"].set_xticklabels([])

        sns.boxenplot(y=y_pred[inj_crit], ax=ax[f"letter_pred_{i}"], color=COLORS[1], flier_kws={"s": 0.5},alpha=0.6,)
        ax[f"letter_pred_{i}"].set_ylim(lims[i])
        ax[f"letter_pred_{i}"].set_ylabel("")
        ax[f"letter_pred_{i}"].set_xlabel("Pred")
        ax[f"letter_pred_{i}"].set_yticklabels([])

        ax[f"empty_{i}"].axis("off")

    fig.align_ylabels(list(ax.values()))
    fig.align_xlabels(list(ax.values()))
    fig.set_figwidth(FULL_WIDTH)
    fig.set_figheight(0.6 * FULL_WIDTH)
    fig.savefig(FIG_DIR / f"inj_crits_{'_'.join(inj_crits)}.pdf")


plots2_analyse_5(
    exp_path=Path("experiments")
    / "2024-12-20-13-40-40_pure_cnn_single_full_output_05HIII_injury_criteria_from_doe_sobol_20240705_194200",
    inj_crits=["Chest_Deflection", "Femur_Fz_Max_Compression", "Head_a3ms", "Chest_a3ms"],
)