In [None]:
import logging
import os
from itertools import combinations
from pathlib import Path
from typing import List, Optional, Literal

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import polars as pl
import seaborn as sns
from IPython.display import display


os.chdir("/root/py_projects/aihiii")

import src.utils.json_util as json_util
from src._StandardNames import StandardNames
from src.utils.custom_log import init_logger
from src.utils.set_rcparams import set_rcparams

set_rcparams()

LOG: logging.Logger = logging.getLogger(__name__)
STR: StandardNames = StandardNames()

init_logger(log_lvl=logging.INFO)
LOG.info("Working directory: %s", os.getcwd())

WIDTH: float = 448.13095 / 72

In [None]:
FIG_DIR: Path = Path() / "reports" / "figures"
FIG_DIR /= "characterize_50th"
FIG_DIR.mkdir(parents=True, exist_ok=True)
LOG.info("Figures in %s, exist - %s", FIG_DIR, FIG_DIR.is_dir())

In [None]:
DATA_DIR:Path = Path("experiments")
LOG.info("Data in %s, exist - %s", DATA_DIR, DATA_DIR.is_dir())

In [None]:
PCA_DIRS: List[Path] = sorted(DATA_DIR.glob("2024-12-1*-*-*-*_pca_ann_*HIII_injury_criteria_from_doe_sobol_20240705_194200_ft_channels"))
LOG.info("PCA dirs (n=%s):\n%s", len(PCA_DIRS), PCA_DIRS)

In [None]:
def get_data() -> pd.DataFrame:
    results = []
    for res_dir in PCA_DIRS:
        LOG.info("Processing %s", res_dir)

        # get results
        results.append(pd.read_csv(res_dir / STR.fname_results_csv, index_col=[0,1]).loc[(-1, slice(None)), :].droplevel(STR.fold))

        # get para
        para = json_util.load(f_path=res_dir / STR.fname_para)
        results[-1][STR.perc] = para[STR.perc][STR.target][0]
        k = para[STR.pipeline]["pca_kernel"]
        results[-1]["Kernel"] = "None" if k is None else k
        results[-1]["N_COMPONENTS"] = para[STR.pipeline]["n_pca_components"]
        results[-1].set_index(["Kernel", "N_COMPONENTS", STR.perc], append=True, inplace=True)
        results[-1]["Median"] = results[-1].median(axis=1)

    results = pd.concat(results).sort_index()
    results.columns.name = "Injury_Criterion"

    return results

RESULTS:pd.DataFrame = get_data()
RESULTS

In [None]:
RESULTS_L:pd.DataFrame = pd.DataFrame({"R2":RESULTS.stack()}).reset_index()
RESULTS_L

In [None]:
RESULTS.loc[(*[slice(None)]*3, 5), :].droplevel(STR.perc)

In [None]:
RESULTS.loc[("Test", *[slice(None)]*2, 95), :].droplevel(STR.perc)

In [None]:
def plot(perc: int):
    g=sns.catplot(
        data=RESULTS_L[RESULTS_L[STR.perc].eq(perc)],
        y="Injury_Criterion",
        x="R2",
        hue="Data",
        col="Kernel",
        row="N_COMPONENTS",
        kind="bar",
        hue_order=["Train", "Test"],
        orient="h",
    )
    for ax in g.axes.flat:
        ax.grid()
        ax.set_xlim(0, 1)
        ax.set_xticks(np.linspace(0, 1, 21))
        ax.axvline(RESULTS.loc[("Test", *[slice(None)]*2, perc), "Median"].max(), c="black", ls="--")
    print(perc, RESULTS.loc[("Test", *[slice(None)]*2, perc), "Median"].max())

plot(perc=5)

In [None]:
plot(perc=95)

In [None]:
def plot2():
    fig, ax = plt.subplot_mosaic(
        mosaic=[["L", "L"], ["5_20", "5_40"], ["95_20", "95_40"]],
        layout="constrained",
        height_ratios=[0.1, 1, 1],
        sharex=True,
        sharey=True,
    )
    db = pd.DataFrame({"R2": RESULTS["Median"]}).reset_index().replace({"Train": "Training-set", "Test": "Validation-set"})
    for perc in [5, 95]:
        for n_comp in [20, 40]:
            sns.barplot(
                data=db[db[STR.perc].eq(perc) & db["N_COMPONENTS"].eq(n_comp)],
                x="Kernel",
                y="R2",
                hue="Data",                
                ax=ax[f"{perc}_{n_comp}"],
                alpha=0.5,
                hue_order=["Training-set", "Validation-set"],
            )
            ax[f"{perc}_{n_comp}"].bar_label(ax[f"{perc}_{n_comp}"].containers[0], fmt="%.2f", padding=-9)
            ax[f"{perc}_{n_comp}"].bar_label(ax[f"{perc}_{n_comp}"].containers[1], fmt="%.2f", padding=-9)

            ax[f"{perc}_{n_comp}"].grid()
            ax[f"{perc}_{n_comp}"].set_ylabel("Median of R2-score")
            ax[f"{perc}_{n_comp}"].set_ylim(0, 1)
            ax[f"{perc}_{n_comp}"].set_yticks(np.linspace(0, 1, 11))
            ax[f"{perc}_{n_comp}"].set_axisbelow(True)
            ax[f"{perc}_{n_comp}"].set_title(f"HIII-{perc:02d}{'F' if perc==5 else 'M'} with n_components={n_comp}")
            ax["L"].legend(*ax[f"{perc}_{n_comp}"].get_legend_handles_labels(), ncols=2, loc="upper center")
            ax[f"{perc}_{n_comp}"].legend().remove()
    ax["L"].axis("off")

    fig.set_figwidth(WIDTH - 0.2)
    fig.set_figheight(0.5 * WIDTH)
    fig.savefig(FIG_DIR / "pca_results.pdf")


plot2()

In [None]:
pd.DataFrame({"R2":RESULTS["Median"]}).reset_index()

In [None]:
import pickle
from sklearn.decomposition import PCA
with open("experiments/2024-11-08-15-50-55_pca_ann_05HIII_injury_criteria_from_doe_sobol_20240705_194200_ft_channels/feature_extractor.pkl", "rb") as f:
    fe:PCA = pickle.load(f)

print(len(fe.explained_variance_ratio_), np.sum(fe.explained_variance_ratio_))

In [None]:
with open("experiments/2024-11-08-22-14-27_pca_ann_05HIII_injury_criteria_from_doe_sobol_20240705_194200_ft_channels/feature_extractor.pkl", "rb") as f:
    fe:PCA = pickle.load(f)

print(len(fe.explained_variance_ratio_), np.sum(fe.explained_variance_ratio_))

In [None]:
with open("experiments/2024-11-08-15-50-55_pca_ann_05HIII_injury_criteria_from_doe_sobol_20240705_194200_ft_channels/feature_extractor.pkl", "rb") as f:
    fe:PCA = pickle.load(f)

print(len(fe.explained_variance_ratio_), np.sum(fe.explained_variance_ratio_))

In [None]:
len([
      "03CHST0000OCCUACXD",
      "03CHST0000OCCUACYD",
      "03CHST0000OCCUACZD",
      "03CHST0000OCCUDSXD",
      "03CHSTLOC0OCCUDSXD",
      "03CHSTLOC0OCCUDSYD",
      "03CHSTLOC0OCCUDSZD",
      "03FEMRLE00OCCUFOZD",
      "03FEMRRI00OCCUFOZD",
      "03HEAD0000OCCUACXD",
      "03HEAD0000OCCUACYD",
      "03HEAD0000OCCUACZD",
      "03HEADLOC0OCCUDSXD",
      "03HEADLOC0OCCUDSYD",
      "03HEADLOC0OCCUDSZD",
      "03NECKUP00OCCUFOXD",
      "03NECKUP00OCCUFOYD",
      "03NECKUP00OCCUFOZD",
      "03NECKUP00OCCUMOYD",
      "03PELV0000OCCUACXD",
      "03PELV0000OCCUACYD",
      "03PELV0000OCCUACZD",
      "03PELVLOC0OCCUDSXD",
      "03PELVLOC0OCCUDSYD",
      "03PELVLOC0OCCUDSZD"
    ])