In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path

import polars as pl
import seaborn as sns

from ethos.constants import PROJECT_ROOT
from ethos.inference.constants import Task
from ethos.metrics import compute_drg_results, compute_metrics, compute_sofa_results
from ethos.task_processing import TASK_RESULTS_PROCESSING_FUNC, join_metadata


def our_join_metadata(df: pl.DataFrame, input_dir: Path) -> pl.DataFrame:
    return (
        join_metadata(df, input_dir)
        .select(
            pl.col("model_fp").str.split("/").list[-4].alias("dataset"),
            pl.col("model_fp").str.split("/").list[-2].alias("model"),
            pl.col("temperature").alias("temp"),
            *df.columns,
        )
        .select(
            "dataset",
            pl.col("model")
            .str.split("_")
            .list.slice(5)
            .list.join("_")
            .replace("", "original")
            .alias("model"),
            pl.when(pl.col("model").str.contains("small"))
            .then(pl.lit("small"))
            .when(pl.col("model").str.contains("little"))
            .then(pl.lit("little"))
            .otherwise(pl.lit("big"))
            .alias("variant"),
            "temp",
            *df.columns,
        )
    )


def is_valid_file(fp: Path) -> bool:
    return (
        fp.name.startswith("mimic_synth")
        and any(ss in fp.name for ss in ("small", "big", "little"))
        and "+" not in fp.name
    )


n_bootstraps = 1000

dataset_name = "mimic_synth"
result_dir = PROJECT_ROOT / "results"
all_results = {}

In [None]:
drg_results_dir = result_dir / Task.DRG_PREDICTION
drg_process_func = TASK_RESULTS_PROCESSING_FUNC[Task.DRG_PREDICTION]
all_results[Task.DRG_PREDICTION] = (
    pl.concat(
        drg_process_func(fp, top_k=1)
        .pipe(compute_drg_results, n_bootstraps=n_bootstraps)
        .with_columns(name=pl.lit(fp.name))
        for fp in drg_results_dir.iterdir()
        if is_valid_file(fp)
    )
    .pipe(our_join_metadata, drg_results_dir)
    .sort("acc_top_1", descending=True)
)
all_results[Task.DRG_PREDICTION]

In [None]:
g = sns.catplot(
    kind="bar",
    data=all_results[Task.DRG_PREDICTION].to_pandas(),
    x="acc_top_1",
    y="model",
    hue="variant",
    orient="h",
)
g.figure.suptitle("DRG Prediction")
g.figure.tight_layout()

In [None]:
sofa_results_dir = result_dir / Task.SOFA_PREDICTION
sofa_process_func = TASK_RESULTS_PROCESSING_FUNC[Task.SOFA_PREDICTION]
all_results[Task.SOFA_PREDICTION] = (
    pl.concat(
        pl.from_dict(
            sofa_process_func(fp).pipe(
                lambda tdf: compute_sofa_results(
                    *tdf["true_sofa", "pred_sofa"], n_bootstraps=n_bootstraps
                )
            )
        ).with_columns(name=pl.lit(fp.name))
        for fp in sofa_results_dir.iterdir()
        if is_valid_file(fp)
    )
    .with_columns(
        r2=pl.col("r2").struct["score"],
        r2_ci=pl.concat_list(pl.col("r2").struct["ci_low"], pl.col("r2").struct["ci_high"]),
        mae=pl.col("mae").struct["score"],
        mae_ci=pl.concat_list(pl.col("mae").struct["ci_low"], pl.col("mae").struct["ci_high"]),
    )
    .pipe(our_join_metadata, sofa_results_dir)
    .sort("r2", descending=True)
)
all_results[Task.SOFA_PREDICTION]

In [None]:
g = sns.catplot(
    kind="bar",
    data=all_results[Task.SOFA_PREDICTION].to_pandas(),
    x="r2",
    y="model",
    hue="variant",
    orient="h",
)
g.figure.suptitle("SOFA Prediction")

In [None]:
def compute_results(task: Task, n_bootstraps: int, **kwargs) -> pl.DataFrame:
    proc_func = TASK_RESULTS_PROCESSING_FUNC[task]

    def compute_metrics_for_single_case(fp):
        df = proc_func(fp, **kwargs)
        res = compute_metrics(*df["expected", "actual"], n_bootstraps=n_bootstraps)
        return {"name": fp.name, **res, "rep_num": df["counts"].mean()}

    task_results_dir = result_dir / task
    return (
        pl.DataFrame(
            compute_metrics_for_single_case(fp)
            for fp in task_results_dir.iterdir()
            if is_valid_file(fp)
        )
        .pipe(our_join_metadata, task_results_dir)
        .drop("name")
        .sort("fitted_auc", descending=True)
    )

In [None]:
all_results[Task.READMISSION] = compute_results(
    Task.READMISSION, n_bootstraps, warn_on_dropped=False
)
all_results[Task.READMISSION]

In [None]:
g = sns.catplot(
    kind="bar",
    data=all_results[Task.READMISSION].to_pandas(),
    x="fitted_auc",
    y="model",
    hue="variant",
    orient="h",
)
g.figure.suptitle("30-day Readmission")

In [None]:
all_results[Task.ICU_ADMISSION] = compute_results(
    Task.ICU_ADMISSION, n_bootstraps, warn_on_dropped=False
)
all_results[Task.ICU_ADMISSION]

In [None]:
g = sns.catplot(
    kind="bar",
    data=all_results[Task.ICU_ADMISSION].to_pandas(),
    x="fitted_auc",
    y="model",
    hue="variant",
    hue_order=["big", "small", "little"],
    orient="h",
)
g.figure.suptitle("ICU Admission")

In [None]:
all_results[Task.HOSPITAL_MORTALITY] = compute_results(
    Task.HOSPITAL_MORTALITY, n_bootstraps, warn_on_dropped=False
)
all_results[Task.HOSPITAL_MORTALITY]

In [None]:
g = sns.catplot(
    kind="bar",
    data=all_results[Task.HOSPITAL_MORTALITY].to_pandas(),
    x="fitted_auc",
    y="model",
    hue="variant",
    hue_order=["big", "small", "little"],
    orient="h",
)
g.figure.suptitle("Hospital Mortality")

In [None]:
from ethos.metrics.paper_fts import (
    col_to_title,
    eval_overall_results,
    join_results,
    print_overall_score,
    score_to_str,
)

tasks = [
    Task.DRG_PREDICTION,
    Task.SOFA_PREDICTION,
    Task.READMISSION,
    Task.ICU_ADMISSION,
    Task.HOSPITAL_MORTALITY,
]

on = ["model", "variant"]

df = join_results(all_results, on=on).select(*on, *tasks).filter(pl.col("variant") != "little")

df = (
    df.join(
        eval_overall_results(df.pivot(on="variant", index="model")),
        on="model",
    )
    .sort(pl.col("overall_score").struct["score"], "variant", descending=[True, False])
    .with_columns(
        pl.col("model")
        .str.split_exact("_", 1)
        .struct.rename_fields(["model", "temp"])
        .struct.unnest(),
    )
    .with_columns(
        model=pl.when(pl.col("model") == "synth")
        .then(
            pl.lit("synth_temp_")
            + pl.col("temp").str.slice(4).cast(pl.Float64).fill_null(1).cast(str)
        )
        .otherwise("model")
    )
    .drop("temp")
)
df

In [None]:
from matplotlib import pyplot as plt

ax = print_overall_score(df, figsize=(4, 2))
ax.set_ylabel("Training Dataset")
ax.set_xlabel("Overall Score (95% CI)")

figure_dir = PROJECT_ROOT / "figures"
figure_dir.mkdir(exist_ok=True, parents=True)
plt.savefig(
    figure_dir / "3_synthetic_temperature.pdf",
    dpi=300,
    bbox_inches="tight",
)

In [None]:
print(
    df.rename(col_to_title, strict=False)
    .with_columns(
        pl.exclude(pl.Utf8).map_elements(score_to_str, return_dtype=pl.Utf8),
    )
    .to_pandas()
    .set_index(on)
    .to_latex(
        column_format="l" + "c" * (len(df.columns) - 1),
        escape=True,
        label="tab:stage3-synthetic-temperature",
        multirow=True,
    )
)

In [None]:
import numpy as np
from matplotlib.offsetbox import AnchoredText
from sklearn.calibration import calibration_curve
from sklearn.metrics import brier_score_loss

n_bins = 10
color = "#00A1D9"

tasks = [
    Task.READMISSION,
    Task.ICU_ADMISSION,
    Task.HOSPITAL_MORTALITY,
]


fns = [
    ("mimic_synth_layer_3_do_0.3_big_synth_temp1.1_best_temp0.9_ftnd8qsy", 1.1),
    ("mimic_synth_layer_3_do_0.3_big_synth_best_temp0.9_ctrt46s2", 1),
    ("mimic_synth_layer_3_do_0.3_big_synth_temp0.9_best_temp0.9_078crrg1", 0.9),
    ("mimic_synth_layer_3_do_0.3_big_synth_temp0.7_best_temp0.9_1vt6pmk4", 0.7),
]

size = 8
n_rows, n_cols = len(fns), len(tasks)
fig, axes = plt.subplots(n_rows, n_cols, figsize=(size, size * n_rows / n_cols))
for i, (fn, temperature) in enumerate(fns):
    for j, task in enumerate(tasks):
        df = TASK_RESULTS_PROCESSING_FUNC[task](result_dir / task / fn, warn_on_dropped=False)

        frac_pos, mean_pred = calibration_curve(*df["expected", "actual"], n_bins=n_bins)

        bootstrapped_fracs = np.zeros((n_bootstraps, len(mean_pred)))
        for seed in range(n_bootstraps):
            frac_bs, mean_pred_bs = calibration_curve(
                *df["expected", "actual"].sample(fraction=1, with_replacement=True, seed=seed),
                n_bins=n_bins,
            )
            bootstrapped_fracs[seed] = (
                frac_bs
                if len(frac_bs) == len(frac_pos)
                else np.interp(mean_pred, mean_pred_bs, frac_bs)
            )

        ax = axes[i, j]
        ci_lower, ci_upper = np.percentile(bootstrapped_fracs, [2.5, 97.5], axis=0)
        ax.fill_between(
            mean_pred,
            ci_lower,
            ci_upper,
            color="gray",
            label="95% Confidence Interval",
        )

        ax.plot([0, 1], [0, 1], linestyle="--", color="black", label="Perfect Calibration")
        ax.plot(
            mean_pred,
            frac_pos,
            color=color,
            lw=5,
            label="ETHOS Calibration",
        )
        ax.set_xlim([-0.01, 1.01])
        ax.set_ylim([-0.01, 1.01])
        if i == len(fns) - 1:
            ax.set_xlabel(col_to_title[task])
        else:
            ax.set_xticks([])

        if j == 0:
            ax.set_ylabel(f"{temperature:.1f}")
        else:
            ax.set_yticks([])
        ax.grid(False)
        ax.add_artist(
            AnchoredText(
                f"Brier score: {brier_score_loss(df['expected'], df['actual']):.3f}",
                loc="lower right",
                pad=0,
                borderpad=0.1,
                frameon=False,
            )
        )

    fig.supxlabel("Binary Downstream Tasks")
    fig.supylabel("Temperature Used to Create Synthetic Data")
    fig.subplots_adjust(hspace=0.1, wspace=0.1)
    fig.savefig(
        PROJECT_ROOT / "figures" / "supp_calibration_curves_synth.pdf",
        dpi=300,
        bbox_inches="tight",
    )