In [None]:
import logging
import os
from itertools import combinations
from pathlib import Path
from typing import List, Optional, Literal

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import polars as pl
import seaborn as sns
from IPython.display import display


os.chdir("/root/py_projects/aihiii")

import src.utils.json_util as json_util
from src._StandardNames import StandardNames
from src.utils.custom_log import init_logger
from src.utils.set_rcparams import set_rcparams

set_rcparams()

LOG: logging.Logger = logging.getLogger(__name__)
STR: StandardNames = StandardNames()

init_logger(log_lvl=logging.INFO)
LOG.info("Working directory: %s", os.getcwd())

WIDTH: float = 448.13095 / 72 -0.2

In [None]:
FIG_DIR: Path = Path() / "reports" / "figures"
FIG_DIR /= "eval_rocket"
FIG_DIR.mkdir(parents=True, exist_ok=True)
LOG.info("Figures in %s, exist - %s", FIG_DIR, FIG_DIR.is_dir())

In [None]:
DATA_DIR:Path = Path("experiments")
LOG.info("Data in %s, exist - %s", DATA_DIR, DATA_DIR.is_dir())

In [None]:
DIRS: List[Path] = sorted(DATA_DIR.glob("2024-12-2*-*-*-*_rocket_ann_95HIII_injury_criteria_from_doe_sobol_20240705_194200_ft_channels"))
LOG.info("Rocket dirs (n=%s):\n%s", len(DIRS), DIRS)

In [None]:
def get_data() -> pd.DataFrame:
    results = []
    for res_dir in DIRS:
        LOG.info("Processing %s", res_dir)

        # get results
        results.append(pd.read_csv(res_dir / STR.fname_results_csv, index_col=[0,1]).loc[(-1, slice(None)), :].droplevel(STR.fold))

        # get para
        para = json_util.load(f_path=res_dir / STR.fname_para)
        results[-1][STR.perc] = para[STR.perc][STR.target][0]
        k = para[STR.pipeline]["n_kernels"]
        results[-1]["Kernels"] = "None" if k is None else k
        results[-1].set_index("Kernels", append=True, inplace=True)
        results[-1]["Median"] = results[-1].median(axis=1)

    results = pd.concat(results).sort_index().drop(columns=STR.perc)

    return results

RESULTS:pd.DataFrame = get_data()
RESULTS

In [None]:
def plotter():
    
    db = RESULTS.reset_index()
    display(db)

    for col in RESULTS.columns:
        fig, ax = plt.subplots()
        sns.scatterplot(data=db, x="Kernels", y=col, hue="Data", ax=ax)

plotter()

In [None]:
pd.DataFrame(RESULTS["Median"]).reset_index().pivot(index="Kernels", columns="Data").round(2)