In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path

import polars as pl
from matplotlib import pyplot as plt

from ethos.constants import PROJECT_ROOT
from ethos.inference.constants import Task
from ethos.metrics import compute_drg_results, compute_metrics, compute_sofa_results
from ethos.task_processing import TASK_RESULTS_PROCESSING_FUNC, join_metadata

color = "#00A1D9"


def our_join_metadata(df: pl.DataFrame, input_dir: Path) -> pl.DataFrame:
    return join_metadata(df, input_dir).select(
        pl.col("model_fp").str.split("/").list[-4].alias("dataset"),
        pl.col("model_fp").str.split("/").list[-2].alias("model"),
        pl.col("input_dir").str.split("/").list.last().alias("fold"),
        pl.col("temperature").alias("temp"),
        *df.columns,
    )


def plot_forest(
    ax: plt.Axes,
    df: pl.DataFrame,
    x: str = "model",
    y: str = "fitted_auc",
    y_ci: str = "fitted_auc_ci",
    title: str = "",
    lw=3,
    color=color,
    sort_expr: str | pl.Expr = None,
):
    if sort_expr is None:
        sort_expr = y

    df = df.sort(sort_expr, nulls_last=True)
    for i, (m, (lo, hi)) in enumerate(df[y, y_ci].rows()):
        ax.plot([i, i], [lo, hi], color=color, lw=lw)
        ax.plot([i - 0.3, i + 0.3], [lo, lo], color=color, lw=lw)
        ax.plot([i - 0.3, i + 0.3], [hi, hi], color=color, lw=lw)
        ax.plot(i, m, marker="D", color=color, markersize=10)

    ax.set_xticks(list(range(len(df))))
    ax.set_xticklabels(df[x])
    ax.grid(True)
    ax.set_title(title)


def plot_two_forests(data: pl.DataFrame, **kwargs):
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))

    title = kwargs.pop("title", "")
    data1 = data.filter(fold="val1").with_columns(
        model=(
            pl.col("model").str.split("_").list.last().cast(pl.Int64, strict=False).cast(pl.Utf8)
            + pl.lit("/20")
        ).fill_null("full")
    )
    plot_forest(
        axes[0],
        data1,
        x="model",
        title=f"{title} varying size",
        sort_expr=pl.col("model").str.split("/").list.first().cast(pl.Int64, strict=False),
        **kwargs,
    )

    data2 = data.filter(fold="val2")
    plot_forest(
        axes[1], data2, x="temp", title=f"{title} varying temperatures", sort_expr="temp", **kwargs
    )


n_bootstraps = 1000


def is_file_valid(fp: Path) -> bool:
    return fp.name.startswith(dataset_name) and "val" in fp.name


dataset_name = "mimic_synth"
result_dir = PROJECT_ROOT / "results"
all_results = {}

In [None]:
drg_results_dir = result_dir / Task.DRG_PREDICTION
drg_process_func = TASK_RESULTS_PROCESSING_FUNC[Task.DRG_PREDICTION]
all_results[Task.DRG_PREDICTION] = (
    pl.concat(
        drg_process_func(fp, top_k=1)
        .pipe(compute_drg_results, n_bootstraps=n_bootstraps)
        .with_columns(name=pl.lit(fp.name))
        for fp in drg_results_dir.iterdir()
        if is_file_valid(fp)
    )
    .pipe(our_join_metadata, drg_results_dir)
    .sort("acc_top_1", descending=True)
)

In [None]:
plot_two_forests(all_results[Task.DRG_PREDICTION], y="acc_top_1", y_ci="acc_top_1_ci", title="DRG")

In [None]:
sofa_results_dir = result_dir / Task.SOFA_PREDICTION
sofa_process_func = TASK_RESULTS_PROCESSING_FUNC[Task.SOFA_PREDICTION]
all_results[Task.SOFA_PREDICTION] = (
    pl.concat(
        pl.from_dict(
            sofa_process_func(fp).pipe(
                lambda tdf: compute_sofa_results(
                    *tdf["true_sofa", "pred_sofa"], n_bootstraps=n_bootstraps
                )
            )
        ).with_columns(name=pl.lit(fp.name))
        for fp in sofa_results_dir.iterdir()
        if is_file_valid(fp)
    )
    .with_columns(
        r2=pl.col("r2").struct["score"],
        r2_ci=pl.concat_list(pl.col("r2").struct["ci_low"], pl.col("r2").struct["ci_high"]),
        mae=pl.col("mae").struct["score"],
        mae_ci=pl.concat_list(pl.col("mae").struct["ci_low"], pl.col("mae").struct["ci_high"]),
    )
    .pipe(our_join_metadata, sofa_results_dir)
    .sort("r2", descending=True)
)

In [None]:
plot_two_forests(all_results[Task.SOFA_PREDICTION], y="r2", y_ci="r2_ci", title="SOFA")

In [None]:
def compute_results(task: Task, n_bootstraps: int, **kwargs) -> pl.DataFrame:
    proc_func = TASK_RESULTS_PROCESSING_FUNC[task]

    def compute_metrics_for_single_case(fp):
        df = proc_func(fp, **kwargs)
        res = compute_metrics(*df["expected", "actual"], n_bootstraps=n_bootstraps)
        return {"name": fp.name, **res, "rep_num": df["counts"].mean()}

    task_results_dir = result_dir / task
    return (
        pl.DataFrame(
            compute_metrics_for_single_case(fp)
            for fp in task_results_dir.iterdir()
            if is_file_valid(fp)
        )
        .pipe(our_join_metadata, task_results_dir)
        .drop("name")
        .sort("fitted_auc", descending=True)
    )

In [None]:
all_results[Task.READMISSION] = compute_results(
    Task.READMISSION, n_bootstraps, warn_on_dropped=False
)

In [None]:
plot_two_forests(all_results[Task.READMISSION], title="30-day Readmission")

In [None]:
all_results[Task.ICU_ADMISSION] = compute_results(
    Task.ICU_ADMISSION, n_bootstraps, warn_on_dropped=False
)

In [None]:
plot_two_forests(all_results[Task.ICU_ADMISSION], title="ICU Admission")

In [None]:
all_results[Task.HOSPITAL_MORTALITY] = compute_results(
    Task.HOSPITAL_MORTALITY, n_bootstraps, warn_on_dropped=False
)

In [None]:
plot_two_forests(all_results[Task.HOSPITAL_MORTALITY], title="Hospital Mortality")

In [None]:
from ethos.metrics.paper_fts import col_to_title, eval_overall_results, join_results, score_to_str

df = (
    join_results({task: df.filter(fold="val1") for task, df in all_results.items()}, on="model")
    .with_columns(
        (
            pl.col("model").str.split("_").list.last().cast(pl.Int64, strict=False).fill_null(20)
            / 20
        ).map_elements(lambda x: f"{x:.0%}", return_dtype=pl.Utf8)
    )
    .reverse()
)

df = df.join(eval_overall_results(df), on="model", maintain_order="left")
df

In [None]:
from ethos.metrics.paper_fts import print_overall_score

ax = print_overall_score(df, figsize=(4, 2))
ax.set_ylabel("Fraction of Data")
ax.set_xlabel("Overall Score (95% CI)")

figure_dir = PROJECT_ROOT / "figures"
figure_dir.mkdir(exist_ok=True, parents=True)
plt.savefig(
    figure_dir / "1_train_division.pdf",
    dpi=300,
    bbox_inches="tight",
)

In [None]:
print(
    df.rename(col_to_title, strict=False)
    .with_columns(pl.exclude(pl.Utf8).map_elements(score_to_str, return_dtype=pl.Utf8))
    .to_pandas()
    .to_latex(
        index=False,
        column_format="l" + "c" * (len(df.columns) - 1),
        escape=True,
        label="tab:stage1-train-division",
    )
)

In [None]:
df = join_results(
    {task: df.filter(fold="val2") for task, df in all_results.items()},
    on="temp",
    sort=pl.col("temp").cast(pl.Float64),
)

df = df.join(eval_overall_results(df), on="temp", maintain_order="left")
df

In [None]:
ax = print_overall_score(df, figsize=(4, 2))
ax.set_ylabel("Inference Temperature")
ax.set_xlabel("Overall Score (95% CI)")
plt.savefig(
    figure_dir / "2_inference_temperature.pdf",
    dpi=300,
    bbox_inches="tight",
)

In [None]:
print(
    df.rename(col_to_title, strict=False)
    .with_columns(pl.exclude(pl.Utf8).map_elements(score_to_str, return_dtype=pl.Utf8))
    .to_pandas()
    .to_latex(
        index=False,
        column_format="l" + "c" * (len(df.columns) - 1),
        escape=True,
        label="tab:stage2-inference-temperature",
    )
)

In [None]:
import numpy as np
from matplotlib.offsetbox import AnchoredText
from sklearn.calibration import calibration_curve
from sklearn.metrics import brier_score_loss

n_bins = 10

tasks = [
    Task.READMISSION,
    Task.ICU_ADMISSION,
    Task.HOSPITAL_MORTALITY,
]


fns = [
    ("mimic_synth_layer_3_do_0.3_full_val2_best_egaf72qa", 1),
    ("mimic_synth_layer_3_do_0.3_full_val2_best_temp0.9_egaf72qa", 0.9),
    ("mimic_synth_layer_3_do_0.3_full_val2_best_temp0.8_egaf72qa", 0.8),
]

size = 8
n_rows, n_cols = len(fns), len(tasks)
fig, axes = plt.subplots(n_rows, n_cols, figsize=(size, size * n_rows / n_cols))
for i, (fn, temperature) in enumerate(fns):
    for j, task in enumerate(tasks):
        df = TASK_RESULTS_PROCESSING_FUNC[task](result_dir / task / fn, warn_on_dropped=False)

        frac_pos, mean_pred = calibration_curve(*df["expected", "actual"], n_bins=n_bins)

        bootstrapped_fracs = np.zeros((n_bootstraps, len(mean_pred)))
        for seed in range(n_bootstraps):
            frac_bs, mean_pred_bs = calibration_curve(
                *df["expected", "actual"].sample(fraction=1, with_replacement=True, seed=seed),
                n_bins=n_bins,
            )
            bootstrapped_fracs[seed] = (
                frac_bs
                if len(frac_bs) == len(frac_pos)
                else np.interp(mean_pred, mean_pred_bs, frac_bs)
            )

        ax = axes[i, j]
        ci_lower, ci_upper = np.percentile(bootstrapped_fracs, [2.5, 97.5], axis=0)
        ax.fill_between(
            mean_pred,
            ci_lower,
            ci_upper,
            color="gray",
            label="95% Confidence Interval",
        )

        ax.plot([0, 1], [0, 1], linestyle="--", color="black", label="Perfect Calibration")
        ax.plot(
            mean_pred,
            frac_pos,
            color=color,
            lw=5,
            label="ETHOS Calibration",
        )
        ax.set_xlim([-0.01, 1.01])
        ax.set_ylim([-0.01, 1.01])
        if i == len(fns) - 1:
            ax.set_xlabel(col_to_title[task])
        else:
            ax.set_xticks([])

        if j == 0:
            ax.set_ylabel(f"{temperature:.1f}")
        else:
            ax.set_yticks([])
        ax.grid(False)
        ax.add_artist(
            AnchoredText(
                f"Brier score: {brier_score_loss(df['expected'], df['actual']):.3f}",
                loc="lower right",
                pad=0,
                borderpad=0.1,
                frameon=False,
            )
        )

    fig.supxlabel("Binary Downstream Tasks")
    fig.supylabel("Inference Temperature")
    fig.subplots_adjust(hspace=0.1, wspace=0.1)
    fig.savefig(
        PROJECT_ROOT / "figures" / "supp_calibration_curves.pdf",
        dpi=300,
        bbox_inches="tight",
    )