In [None]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
from sklearn.metrics import auc, roc_curve, brier_score_loss
from sklearn.calibration import calibration_curve
from scipy import stats
%matplotlib widget

bootstraps = 10
quantiles = 10
prediction_dir = os.path.expanduser("~/dropbox/sts-ecg/predictions/v14/")

In [None]:
def get_predictions(prediction_dir: str, label: str):
    """Get predictions as a list of tuples"""
    data = []
    for bootstrap in range(bootstraps):
        fpath = os.path.join(prediction_dir, str(bootstrap), "predictions_test.csv")
        df = pd.read_csv(fpath)
        y = df[f'{label}_{label}_actual']
        y_hat = df[f'{label}_{label}_predicted']
        data.append((y, y_hat))
    return data
data = get_predictions(prediction_dir=prediction_dir, label="sts_death")

In [None]:
def plot_calibrations_across_bootstraps(data, plot_title):
    sns.set(style="white", palette="muted", color_codes=True)
    sns.set_context("talk")

    # find average bins across bootstraps and brier score
    bins = np.zeros((bootstraps, quantiles))
    brier_scores = np.zeros((bootstraps,))
    for i, (y, y_hat) in enumerate(data):
        bins[i] = stats.mstats.mquantiles(y_hat, np.arange(0.0, 1.0, 1.0 / quantiles))
        brier_scores[i] = brier_score_loss(y, y_hat, pos_label=1)

    mean_bins = bins.mean(axis=0)
    mean_bins[0] = bins[:, 0].min()
    mean_brier_score = brier_scores.mean()
    std_brier_score = brier_scores.std()

    # find average true/predicted probability across bootstraps
    prob_pred = np.zeros((bootstraps, quantiles))
    prob_true = np.zeros((bootstraps, quantiles))
    for i, (y, y_hat) in enumerate(data):
        binids = np.digitize(y_hat, bins[i]) - 1

        bin_sums = np.bincount(binids, weights=y_hat, minlength=len(bins[i]))
        bin_true = np.bincount(binids, weights=y, minlength=len(bins[i]))
        bin_total = np.bincount(binids, minlength=len(bins[i]))

        nonzero = bin_total != 0
        prob_pred[i] = bin_sums[nonzero] / bin_total[nonzero]
        prob_true[i] = bin_true[nonzero] / bin_total[nonzero]

    mean_prob_pred = prob_pred.mean(axis=0)
    mean_prob_true = prob_true.mean(axis=0)
    std_prob_true = prob_true.std(axis=0)

    # plotting
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', alpha=.8)

    ax.errorbar(
        mean_prob_pred,
        mean_prob_true,
        std_prob_true,
        fmt=".",
        ecolor="cornflowerblue",
        elinewidth=2.5,
        capsize=2,
    )
    x_max = int(mean_prob_pred.max() * 10 + 1) / 10
    y_max = int(mean_prob_true.max() * 10 + 1) / 10
    lim = [0, max(x_max, y_max)]
    ticks = np.arange(0, lim[1] * 6/5, lim[1] / 5)
    ax.set(
        xticks=ticks,
        yticks=ticks,
        xlim=lim,
        ylim=lim,
        title=f"{plot_title}: Brier score = {mean_brier_score:0.3f} $\pm$ {std_brier_score:0.3f}",
        xlabel="Predicted",
        ylabel="Actual",
    )
    ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    ax.grid()
    plt.tight_layout()
    plt.show()

In [None]:
plot_calibrations_across_bootstraps(data, "ECGNet")

In [None]:
def plot_rocs_across_bootstraps(data, plot_title):
    sns.set(style="white", palette="muted", color_codes=True)
    sns.set_context("talk")

    tprs = []
    aucs = []
    mean_fpr = np.linspace(0, 1, 100)

    for i, (y, y_hat) in enumerate(data):
        fpr, tpr, _ = roc_curve(y, y_hat)
        roc_auc = auc(fpr, tpr)
        interp_tpr = np.interp(mean_fpr, fpr, tpr)
        interp_tpr[0] = 0.0
        tprs.append(interp_tpr)
        aucs.append(roc_auc)

    fig, ax = plt.subplots(figsize=(6, 6))
    ax.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', alpha=.8)

    mean_tpr = np.mean(tprs, axis=0)
    mean_tpr[-1] = 1.0
    mean_auc = auc(mean_fpr, mean_tpr)
    std_auc = np.std(aucs)
    ax.plot(
        mean_fpr,
        mean_tpr,
        color="cornflowerblue",
        alpha=0.8,
        lw=3,
    )
    std_tpr = np.std(tprs, axis=0)
    tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
    tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
    ax.fill_between(
        mean_fpr,
        tprs_lower,
        tprs_upper,
        color='lightgrey',
        alpha=.3,
    )
    ax.set(
        xlim=[0, 1],
        ylim=[0, 1],
        title=f"{plot_title}: AUC = {mean_auc:0.2f} $\pm$ {std_auc:0.2f}",
        xlabel="False Positive Rate",
        ylabel="True Positive Rate",
    )
    ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    ax.grid()
    plt.tight_layout()
    plt.show()

In [None]:
plot_rocs_across_bootstraps(data=data, plot_title="STSNet")