In [None]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
from matplotlib.ticker import FormatStrFormatter
from sklearn.metrics import auc, roc_curve, brier_score_loss
from sklearn.calibration import calibration_curve
from scipy import stats
%matplotlib inline

bootstraps = 10
n_bins = 10
plot_dir = os.path.expanduser("~/dropbox/sts-ecg/figures-and-tables")
prediction_dir = os.path.expanduser("~/dropbox/sts-ecg/predictions")
label = "sts_death"
models = ["v30", "deep-sts-preop-v13-swish", "ecgnet-stsnet"]
titles = ["ECGNet", "STSNet", "ECGNet STSNet"]
verbose_titles = ["ECGNet v30", "STSNet v13 swish", "ECGNet STSNet"]
min_max_scale=False

os.makedirs(plot_dir, exist_ok=True)

In [None]:
def get_predictions(models: str, label: str, prediction_dir: str, min_max_scale: bool):
    """Get predictions as a dictionary of list of tuples"""
    data = defaultdict(list)
    y_hat_min = 1
    y_hat_max = 0
    for model in models:
        for bootstrap in range(bootstraps):
            fpath = os.path.join(prediction_dir, model, str(bootstrap), "predictions_test.csv")
            df = pd.read_csv(fpath)
            y = df[f'{label}_{label}_actual']
            y_hat = df[f'{label}_{label}_predicted']
            if min_max_scale:
                y_hat = (y_hat - y_hat.min()) / (y_hat.max() - y_hat.min())
            cur_min = y_hat.min()
            cur_max = y_hat.max()
            if cur_min < y_hat_min:
                y_hat_min = cur_min
            if cur_max > y_hat_max:
                y_hat_max = cur_max
            data[model].append((y, y_hat))
    return data, y_hat_min, y_hat_max
data, x_min, x_max = get_predictions(models=models, label=label, prediction_dir=prediction_dir, min_max_scale=min_max_scale)

In [None]:
def plot_calibrations_across_bootstraps(data, plot_title, file_title, x_min, x_max, n_bins, plot_dir):
    sns.set(style="white", palette="muted", color_codes=True)
    sns.set_context("talk")

    bins = np.linspace(x_min, x_max, n_bins + 1)
    bins[0] -= 0.0001
    bins[-1] += 0.0001

    brier_scores = np.zeros((bootstraps,))
    died_counts = np.zeros((bootstraps, n_bins))
    pred_probs = np.zeros((bootstraps, n_bins))
    true_probs = np.zeros((bootstraps, n_bins))
    bin_counts = np.zeros((bootstraps, n_bins))
    for bootstrap, (y, y_hat) in enumerate(data):
        brier_scores[bootstrap] = brier_score_loss(y, y_hat, pos_label=1)
        
        # bin by predicted probabilities
        bin_mask = pd.cut(y_hat, bins)
        y_hat_bin_sums = y_hat.groupby(bin_mask).sum()
        y_bin_sums = y.groupby(bin_mask).sum()
        bin_count = y.groupby(bin_mask).count()

        died_counts[bootstrap] = y_bin_sums
        pred_probs[bootstrap] = y_hat_bin_sums / bin_count
        true_probs[bootstrap] = y_bin_sums / bin_count
        bin_counts[bootstrap] = bin_count

    mean_brier_score = brier_scores.mean()
    std_brier_score = brier_scores.std()

    mean_pred_prob = np.nanmean(pred_probs, axis=0)
    std_pred_prob = np.nanstd(pred_probs, axis=0)
    sem_pred_prob = stats.sem(pred_probs, axis=0, nan_policy='omit')

    mean_true_prob = np.nanmean(true_probs, axis=0)
    std_true_prob = np.nanstd(true_probs, axis=0)
    sem_true_prob = stats.sem(true_probs, axis=0, nan_policy='omit')

    mean_bin_count = np.nanmean(bin_counts, axis=0)
    O1bs = np.nanmean(died_counts, axis=0)
    O0bs = np.nanmean(bin_counts - died_counts, axis=0)
    E1bs = mean_pred_prob * mean_bin_count
    E0bs = (1 - mean_pred_prob) * mean_bin_count

    # only use bins that have 
    E1bs = E1bs[~np.isnan(E1bs)]
    dof = len(E1bs)

    HL_score = 0
    for O1b, O0b, E1b, E0b in zip(O1bs, O0bs, E1bs, E0bs):
        HL_score += (O1b - E1b) ** 2 / E1b + (O0b - E0b) ** 2 / E0b
    p = 1 - stats.chi2.cdf(HL_score, max(0, dof - 2))

    lim = [-0.1, 1.1]
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.plot(lim, lim, linestyle='--', lw=2, color='r', alpha=.8)

    ax.errorbar(
        x=mean_pred_prob,
        y=mean_true_prob,
        xerr=sem_pred_prob,
        yerr=sem_true_prob,
        fmt=".",
        ecolor="cornflowerblue",
        elinewidth=2.5,
        capsize=2,
    )
    ticks = np.arange(0, 1.1, 0.2)
    ax.set(
        xticks=ticks,
        yticks=ticks,
        xlim=lim,
        ylim=lim,
        # title=f"{plot_title}: Brier score = {mean_brier_score:0.3f} $\pm$ {std_brier_score:0.3f}",
        title=f"{plot_title}: HL score = {HL_score:.1f}, p = {p:.2f}",
        xlabel="Predicted",
        ylabel="Actual",
    )
    ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    ax.grid()
    plt.tight_layout()
    
    fpath = os.path.join(plot_dir, f'calibration-{file_title.replace(" ", "-")}.png')
    plt.savefig(fpath)
    print(f"Saved {fpath}")

In [None]:
for title, verbose_title, (model, _data) in zip(titles, verbose_titles, data.items()):
    plot_calibrations_across_bootstraps(
        data=_data,
        plot_title=title,
        file_title=verbose_title,
        x_min=x_min,
        x_max=x_max,
        n_bins=n_bins,
        plot_dir=plot_dir,
    )