In [None]:
import os
from pathlib import Path

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import r2_score

os.chdir(path="/root/py_projects/aihiii")

from src._StandardNames import StandardNames
from src.build.AnnUniversalImportableFTExtractor import AnnUniversal
from src.evaluate._Data import Data
import src.utils.json_util as json_util

STR: StandardNames = StandardNames()
PERC: int = 95

In [None]:
b_path = Path("models") / "CNN" / "Reference" / f"{PERC:02d}HIII_tg_injury_criteria"
d_path = Path("data") / "doe" / "doe_sobol_20240705_194200"
b_path.is_dir(), d_path.is_dir()

In [3]:
x, y = Data(), Data()
x.set_from_files(file_paths=[d_path / STR.fname_channels, d_path / STR.fname_injury_crit], percentiles=[50])
y.set_from_files(
    file_paths=[d_path / STR.fname_injury_crit],
    percentiles=[PERC],
    columns=[
        "Head_HIC15",
        "Head_a3ms",
        "Neck_My_Max",
        "Neck_Fz_Max_Tension",
        "Neck_Fx_Shear_Max",
        "Chest_Deflection",
        "Femur_Fz_Max_Compression",
    ],
)
idxs = x.get_tabular().index
drops = json_util.load(d_path / STR.fname_dropped_ids)
idxs = idxs.drop(drops[str(PERC)] + drops["50"])
x, y = Data(), Data()
x.set_from_files(file_paths=[d_path / STR.fname_channels, d_path / STR.fname_injury_crit], percentiles=[50], idxs=idxs)
y.set_from_files(
    file_paths=[d_path / STR.fname_injury_crit],
    percentiles=[PERC],
    columns=[
        "Head_HIC15",
        "Head_a3ms",
        "Neck_My_Max",
        "Neck_Fz_Max_Tension",
        "Neck_Fx_Shear_Max",
        "Chest_Deflection",
        "Femur_Fz_Max_Compression",
    ],
    idxs=idxs,
)

In [None]:
ann = AnnUniversal()
ann.load(model_dir=b_path, is_regression=True, is_multi_channel_regression=False, label_names=y.get_tabular().columns)

In [None]:
y_pred = ann.predict(x=x)

In [None]:
y_pred.get_tabular().head()

In [None]:
def plot(ch: str):
    fig, ax = plt.subplots()
    y_true = y.get_tabular()[ch]
    y_pred_ = y_pred.get_tabular()[ch]

    sns.kdeplot(x=y_true, y=y_pred_, ax=ax, fill=False, levels=10)
    ax.scatter(y_true, y_pred_, s=1, c="orange", alpha=0.5)

    r_spot = 0.8 * min([y_true.min(), y_pred_.min()])
    l_spot = 1.2 * max([y_true.max(), y_pred_.max()])
    ax.plot([r_spot, l_spot], [r_spot, l_spot], c="black", linestyle="--")
    ax.set_xlim([r_spot, l_spot])
    ax.set_ylim([r_spot, l_spot])
    ax.plot([0, 1], [0, 1], transform=ax.transAxes, c="black", linestyle="--")

    ax.set_title(f"{ch} - R²={r2_score(y_true, y_pred_):.2f}")
    ax.set_xlabel("True")
    ax.set_ylabel("Predicted")


[plot(c) for c in sorted(y.get_tabular().columns)]