In [None]:
import json
import logging
import os
import sys
from collections import defaultdict
from itertools import product
from pathlib import Path
from typing import Dict, List, Optional

import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import pyarrow.parquet as pq
import seaborn as sns
import textalloc as ta
from IPython.display import display
from sklearn.dummy import DummyClassifier
from sklearn.metrics import balanced_accuracy_score, confusion_matrix, f1_score, precision_score, r2_score, recall_score
from tqdm import tqdm

NOTEBOOK_PATH: Path = Path(IPython.extract_module_locals()[1]["__vsc_ipynb_file__"])
PROJECT_DIR: Path = NOTEBOOK_PATH.parent.parent
sys.path.append(str(PROJECT_DIR))
import src.utils.custom_log as custom_log
import src.utils.json_util as json_util
from src._StandardNames import StandardNames
from src.evaluate._Data import Data
from src.load.LoadForClassification import RENAMER, LoadForClassification
from src.utils.PathChecker import PathChecker
from src.utils.set_rcparams import set_rcparams

os.chdir(PROJECT_DIR)
set_rcparams()

LOG: logging.Logger = logging.getLogger(__name__)
custom_log.init_logger(log_lvl=logging.INFO)
LOG.info("Log start, project directory is %s (exist: %s)", PROJECT_DIR, PROJECT_DIR.is_dir())

CHECK: PathChecker = PathChecker()
STR: StandardNames = StandardNames()

FIG_DIR: Path = CHECK.check_directory(PROJECT_DIR / "reports" / "figures", exit=False)
FIG_DIR /= NOTEBOOK_PATH.stem
FIG_DIR.mkdir(parents=True, exist_ok=True)
LOG.info("Figure directory is %s (exist: %s)", FIG_DIR, FIG_DIR.is_dir())

DATA_DIR: Path = CHECK.check_directory(PROJECT_DIR / "data" / "doe" / "doe_sobol_20240705_194200", exit=False)
INJ_FPATH: Path = CHECK.check_file(DATA_DIR / STR.fname_injury_crit, exit=False)
CLASSES_FPATHS: Dict[int, Path] = {
    i: CHECK.check_file(DATA_DIR / f"{INJ_FPATH.stem}_classes_{i}.parquet", exit=False) for i in (2, 3, 5, 7)
}
CHANNELS_FPATH: Path = CHECK.check_file(DATA_DIR / STR.fname_channels, exit=False)
J_PATH: Path = CHECK.check_file(DATA_DIR / STR.fname_dropped_ids, exit=False)

In [None]:
with open(J_PATH) as f:
    DROPPED_IDS = {int(k):set(i) for k,i in json.load(f).items()}
DROPPED_IDS

In [None]:
def eval_inj(perc):
    db = (
        pd.read_parquet(
            INJ_FPATH,
            filters=[("PERC", "==", perc), ("ID", "not in", DROPPED_IDS[perc])],
            columns=[
                "Head_HIC15",
                "Head_a3ms",
                "Neck_My_Extension",
                "Neck_Fz_Max_Tension",
                "Neck_Fx_Shear_Max",
                "Chest_Deflection",
                "Femur_Fz_Max_Compression",
                "Chest_VC",
                "Chest_a3ms",
            ],
        )
        .droplevel("PERC")
        .astype(np.float32)
    )
    db_med_ = db.median(axis=0)
    db_med = ((db.copy() * 0) + 1) * db_med_

    sc = pd.Series(r2_score(db, db_med, multioutput="raw_values"), index=db.columns)
    display(pd.DataFrame({"Value":db_med_, "R2":sc}))


eval_inj(5)

In [None]:
eval_inj(95)

In [None]:
def eval_class(displ=True):
    rel_cols = [
        "Head_HIC15",
        "Head_a3ms",
        "Neck_My_Extension",
        "Neck_Fz_Max_Tension",
        "Neck_Fx_Shear_Max",
        "Chest_Deflection",
        "Chest_VC",
        "Femur_Fz_Max_Compression",
    ]
    n_classess = (2, 3, 5, 7)
    db = defaultdict(dict)
    for n_classes in n_classess:
        for perc in [5, 95]:
            db_ = (
                pd.read_parquet(
                    CLASSES_FPATHS[n_classes],
                    filters=[("PERC", "==", perc), ("ID", "not in", DROPPED_IDS[perc])],
                    columns=rel_cols,
                )
                .droplevel("PERC")
                .astype(np.int32)[rel_cols]
            )

            clf = DummyClassifier(strategy="most_frequent")
            clf.fit(db_[rel_cols], db_[rel_cols])
            y_pred = pd.DataFrame(clf.predict(db_), index=db_.index, columns=rel_cols)

            modes = [clf.classes_[i][np.argmax(clf.class_prior_[i])] for i in range(len(clf.classes_))]

            db[n_classes][perc] = pd.DataFrame(
                {
                    "Mode": modes,
                    "Frequency": [np.max(x) for x in clf.class_prior_],
                    "F1": [
                        f1_score(
                            db_[c],
                            y_pred[c],
                            average="weighted",
                            labels=list(range(n_classes)),
                            zero_division=0.0,
                        )
                        for c in rel_cols
                    ],
                },
                index=rel_cols,
            )

            if displ:
                print(n_classes, perc)
                display(db[n_classes][perc].loc[rel_cols])

    if not displ:
        for idx in rel_cols:
            print("#" * 5, idx, "#" * 5)
            for n_classes in n_classess:
                full_str = f"{n_classes:1d}"
                for perc in [5, 95]:
                    full_str += " & "
                    full_str += " & ".join(
                        [str(int(x)) if i == 0 else f"{np.floor(x*100)/100:.2f}" for i, x in enumerate(db[n_classes][perc].loc[idx].values)]
                    )

                print(full_str)


eval_class(displ=False)

In [None]:
def eval_class_conf(n_classes: int, perc: int, col: str):
    rel_cols = [
        "Head_HIC15",
        "Head_a3ms",
        "Neck_My_Extension",
        "Neck_Fz_Max_Tension",
        "Neck_Fx_Shear_Max",
        "Chest_Deflection",
        "Chest_VC",
        "Femur_Fz_Max_Compression",
    ]

    db_ = (
        pd.read_parquet(
            CLASSES_FPATHS[n_classes],
            filters=[("PERC", "==", perc), ("ID", "not in", DROPPED_IDS[perc])],
            columns=[col],
        )
        .droplevel("PERC")
        .astype(np.int32)
    )

    clf = DummyClassifier(strategy="most_frequent")
    clf.fit(db_, db_)
    y_pred = pd.DataFrame(clf.predict(db_), index=db_.index, columns=[col])
    print(clf.classes_)

    conf = confusion_matrix(db_[col], y_pred[col], labels=list(range(n_classes)))
    conf = pd.DataFrame(
        conf,
        index=[f"True {x}" for x in range(n_classes)],
        columns=[f"Pred {x}" for x in range(n_classes)],
    )
    display(conf)


eval_class_conf(n_classes=3, perc=5, col="Femur_Fz_Max_Compression")

In [None]:
def conv_intervals(chs, perc: int):
    f_width = 448.13095 / 72 - 0.2
    fig, ax = plt.subplot_mosaic(mosaic=[["L"]*len(chs), chs], layout="constrained", height_ratios=[0.1,1])

    for i, ch in enumerate(chs):
        db = pd.read_parquet(
            CHANNELS_FPATH,
            columns=[ch],
            filters=[("PERC", "==", perc), ("ID", "not in", DROPPED_IDS[perc])],
        ).droplevel("PERC")
        print(len(set(db.index.get_level_values(STR.id))))
        db = db.groupby("TIME").quantile([0.25, 0.5, 0.75])
        db.index.set_names(["TIME", "Quantile"], inplace=True)
        db = db[ch].unstack("Quantile")

        ax[ch].plot(db.index, db[0.5], label="Median")
        ax[ch].fill_between(db.index, db[0.25], db[0.75], alpha=0.3, label="IQR")
        ax[ch].grid()
        [ax[ch].axvline(v, c="black", ls="--", label="Zone of Interest" if v==60 else "") for v in (60, 120)]
        ax[ch].set_xlabel("Time [ms]")
        ax[ch].set_ylabel(ch.replace("OCCU", f"H3{perc:02d}"), {"fontname": "CMU Typewriter Text", "fontsize": "large", "fontweight": "bold"})
    ax["L"].legend(*ax[chs[0]].get_legend_handles_labels(), loc="upper center", ncols=3)
    ax["L"].axis("off")

    fig.set_figheight(0.35 * f_width)
    fig.set_figwidth(f_width)
    fig.savefig(FIG_DIR / f"intervals_{perc}.pdf")


conv_intervals(chs=["03HEAD0000OCCUACRD", "03NECKUP00OCCUMOYD"], perc=5)

In [114]:
def eval_channel(perc:int):
        db = pd.read_parquet(
            CHANNELS_FPATH,
            # columns=[ch],
            filters=[("PERC", "==", perc), ("ID", "not in", DROPPED_IDS[perc])],
        ).droplevel("PERC")

In [None]:
def plot_resampler(chs: List[str], perc: int, n_stps: List[int], sid: int):
    f_path = Path("data") / "doe" / "doe_sobol_20240705_194200" / "channels.parquet"
    a = Data()
    a.set_from_files(file_paths=[f_path], percentiles=[perc], idxs=[sid], columns=chs)

    f_width = 448.13095 / 72 - 0.2
    fig, ax = plt.subplot_mosaic(mosaic=([["L"] * len(chs), chs]), layout="constrained", height_ratios=[0.1, 1])
    ls = ["-", "--", "-.", ":"]

    for j, n_tsp in enumerate(reversed(n_stps)):
        db: pd.DataFrame = a.get_temporal_resampled(n_tsp)
        for i, ch in enumerate(chs):
            ax[ch].plot(
                db.index.get_level_values(STR.time),
                db[ch],  # - 10 * j,
                label=f"{'Reference with' if n_tsp == 1400 else 'Resampled to'} {n_tsp/140:.1f} kHz",
                # marker="o",
                # markersize=1,
                # alpha=0.5,
                linewidth=2,
                ls=ls[j]
            )

            if j == 0:
                ax[ch].set_xlabel("Time [ms]")
                ax[ch].set_ylabel(ch.replace("OCCU", f"H3{perc:02d}"), {"fontname": "CMU Typewriter Text", "fontsize": "large", "fontweight": "bold"})
                ax[ch].grid()
    ax["L"].legend(*ax[chs[0]].get_legend_handles_labels(), ncol=2, loc="upper center")
    ax["L"].axis("off")

    fig.set_figheight(0.4 * f_width)
    fig.set_figwidth(f_width)
    fig.savefig(FIG_DIR / f"resampling_{perc}.pdf")


plot_resampler(chs=["03HEAD0000OCCUACRD", "03NECKUP00OCCUMOYD"], perc=95, n_stps=[10, 28, 70, 1400], sid=66)