In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path

import matplotlib.pyplot as plt
import polars as pl

from ethos.constants import PROJECT_ROOT
from ethos.inference.constants import Task
from ethos.metrics import compute_drg_results, compute_metrics, compute_sofa_results
from ethos.task_processing import TASK_RESULTS_PROCESSING_FUNC, join_metadata

color = "#00A1D9"


def our_join_metadata(df: pl.DataFrame, input_dir: Path) -> pl.DataFrame:
    return join_metadata(df, input_dir).select(
        pl.col("model_fp").str.split("/").list[-4].alias("dataset"),
        pl.col("model_fp").str.split("/").list[-2].alias("model").str.slice(len("layer_3_do_0.3_")),
        pl.col("temperature").alias("temp"),
        *df.columns,
    )


def is_valid_file(fp: Path) -> bool:
    folds = ["little", "small", "big"]
    folds.extend([f + "_synth" for f in folds])
    return (
        fp.name.startswith("mimic_synth")
        and any(s in fp.name for s in ("+", *[fold + "_best" for fold in folds]))
        and "val" not in fp.name
    )


def plot_forest(
    df: pl.DataFrame,
    x: str = "fitted_auc",
    x_ci: str = "fitted_auc_ci",
    y: str = "model",
    title: str = "",
    lw=3,
    color=color,
    sort_expr: str | pl.Expr = None,
):
    if sort_expr is None:
        sort_expr = x

    df = df.sort(sort_expr, nulls_last=True)
    for i, (m, (lo, hi)) in enumerate(df[x, x_ci].rows()):
        plt.plot([lo, hi], [i, i], color=color, lw=lw)
        plt.plot([lo, lo], [i - 0.3, i + 0.3], color=color, lw=lw)
        plt.plot([hi, hi], [i - 0.3, i + 0.3], color=color, lw=lw)
        plt.plot(m, i, marker="D", color=color, markersize=10)

    plt.yticks(list(range(len(df))), df[y])
    plt.grid(True)
    plt.title(title)


n_bootstraps = 1000

dataset_name = "mimic_synth"
result_dir = PROJECT_ROOT / "results"
all_results = {}

In [None]:
drg_results_dir = result_dir / Task.DRG_PREDICTION
drg_process_func = TASK_RESULTS_PROCESSING_FUNC[Task.DRG_PREDICTION]
all_results[Task.DRG_PREDICTION] = (
    pl.concat(
        drg_process_func(fp, top_k=1)
        .pipe(compute_drg_results, n_bootstraps=n_bootstraps)
        .with_columns(name=pl.lit(fp.name))
        for fp in drg_results_dir.iterdir()
        if is_valid_file(fp)
    )
    .pipe(our_join_metadata, drg_results_dir)
    .drop("name")
    .sort("acc_top_1", descending=True)
)

In [None]:
plot_forest(
    all_results[Task.DRG_PREDICTION],
    x="acc_top_1",
    x_ci="acc_top_1_ci",
    y="model",
    title="DRG Prediction",
)

In [None]:
sofa_results_dir = result_dir / Task.SOFA_PREDICTION
sofa_process_func = TASK_RESULTS_PROCESSING_FUNC[Task.SOFA_PREDICTION]
all_results[Task.SOFA_PREDICTION] = (
    pl.concat(
        pl.from_dict(
            sofa_process_func(fp).pipe(
                lambda tdf: compute_sofa_results(
                    *tdf["true_sofa", "pred_sofa"], n_bootstraps=n_bootstraps
                )
            )
        ).with_columns(name=pl.lit(fp.name))
        for fp in sofa_results_dir.iterdir()
        if is_valid_file(fp)
    )
    .with_columns(
        r2=pl.col("r2").struct["score"],
        r2_ci=pl.concat_list(pl.col("r2").struct["ci_low"], pl.col("r2").struct["ci_high"]),
        mae=pl.col("mae").struct["score"],
        mae_ci=pl.concat_list(pl.col("mae").struct["ci_low"], pl.col("mae").struct["ci_high"]),
    )
    .pipe(our_join_metadata, sofa_results_dir)
    .drop("name")
    .sort("r2", descending=True)
)

In [None]:
plot_forest(
    all_results[Task.SOFA_PREDICTION],
    x="r2",
    x_ci="r2_ci",
    y="model",
    title="SOFA Prediction",
)

In [None]:
def compute_results(task: Task, n_bootstraps: int, **kwargs) -> pl.DataFrame:
    proc_func = TASK_RESULTS_PROCESSING_FUNC[task]

    def compute_metrics_for_single_case(fp):
        df = proc_func(fp, **kwargs)
        res = compute_metrics(*df["expected", "actual"], n_bootstraps=n_bootstraps)
        return {"name": fp.name, **res, "rep_num": df["counts"].mean()}

    task_results_dir = result_dir / task
    return (
        pl.DataFrame(
            compute_metrics_for_single_case(fp)
            for fp in task_results_dir.iterdir()
            if is_valid_file(fp)
        )
        .pipe(our_join_metadata, task_results_dir)
        .drop("name")
        .sort("fitted_auc", descending=True)
    )

In [None]:
all_results[Task.READMISSION] = compute_results(Task.READMISSION, n_bootstraps)

In [None]:
plot_forest(all_results[Task.READMISSION], title="30-day Hospital Readmission")

In [None]:
all_results[Task.ICU_ADMISSION] = compute_results(
    Task.ICU_ADMISSION, n_bootstraps, warn_on_dropped=False
)

In [None]:
plot_forest(all_results[Task.ICU_ADMISSION], title="ICU Admission")

In [None]:
all_results[Task.HOSPITAL_MORTALITY] = compute_results(
    Task.HOSPITAL_MORTALITY, n_bootstraps, warn_on_dropped=False
)

In [None]:
plot_forest(all_results[Task.HOSPITAL_MORTALITY], title="Hospital Mortality")

In [None]:
from ethos.metrics.paper_fts import col_to_title, eval_overall_results, join_results, score_to_str

tasks = [
    Task.DRG_PREDICTION,
    Task.SOFA_PREDICTION,
    Task.READMISSION,
    Task.ICU_ADMISSION,
    Task.HOSPITAL_MORTALITY,
]

on = "model"

df = join_results(all_results, on=on).select("model", *tasks)
df = df.join(eval_overall_results(df), on=on).sort(
    pl.col("overall_score").struct["score"], descending=True
)

In [None]:
from ethos.metrics.paper_fts import print_overall_score

ax = print_overall_score(df, figsize=(4, 3))
ax.set_ylabel("Training Dataset")
ax.set_xlabel("Overall Score (95% CI)")

figure_dir = PROJECT_ROOT / "figures"
figure_dir.mkdir(exist_ok=True, parents=True)
plt.savefig(
    figure_dir / "4_final_results.pdf",
    dpi=300,
    bbox_inches="tight",
)

In [None]:
print(
    df.rename(col_to_title, strict=False)
    .with_columns(
        pl.exclude(pl.Utf8).map_elements(score_to_str, return_dtype=pl.Utf8),
    )
    .to_pandas()
    .to_latex(
        index=False,
        column_format="l" + "c" * (len(df.columns) - 1),
        escape=True,
        label="tab:stage4-final-results",
    )
)