In [None]:
import os
import csv
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats
%matplotlib widget

In [None]:
sns.set_style("whitegrid")
sns.set_context("talk")
fig_height = 5

In [None]:
# plots mean AUC +/- std for each model
# Inputs: model names as a list, Pandas DataFrame of exported results CSV from Google Spreadsheets

def plot_performance(
    model_names: list,
    df: pd.DataFrame,
    y_min: float = 0.7,
    y_max: float = 0.95,
    x_label_strs: list = [],
    plot_title: str = None,
    fpath: str = None,
    palette = None):

    if palette is None:
        palette = sns.color_palette("pastel")
    
    y_delta = 0.1
    
    stats = {model: {'aucs': [], 'mean': 0, 'std': 0} for model in model_names}

    for model in model_names:
        aucs = []
        for bootstrap in range(10):
            auc_str = df[df.index == str(bootstrap)][model].values[0]
            aucs.append(float(auc_str))

        stats[model]['aucs'] = aucs    
        stats[model]['mean'] = np.nanmean(aucs)
        stats[model]['std'] = np.nanstd(aucs)
        
    means = [stats[model]['mean'] for model in model_names]
    stds = [stats[model]['std'] for model in model_names]
        
    fig_width = len(model_names) * 1.25
    fig, ax = plt.subplots(figsize=(fig_width, fig_height))
    x_pos = np.arange(len(model_names))
    
    bar = sns.barplot(x=x_pos, y=means, yerr=stds, palette=palette)
    ax.set_ylabel('AUC')
    ax.set_ylim([y_min, y_max])
    ax.set_yticks(np.arange(y_min, y_max+0.01, y_delta))
    
    ax.set_xticks(range(len(model_names)))
    ax.set_xticklabels(model_names if len(x_label_strs) == 0 else x_label_strs)

    if plot_title is not None:
        ax.set_title(plot_title)
    
    offset = -0.25
    for i, v in enumerate(means):
        plt.text(x_pos[i]+offset, y_min + 0.01, f'{v:0.2f}')
    plt.xticks(rotation=90)
    plt.tight_layout()
    if fpath is not None:
        plt.savefig(fpath, dpi=300)

In [None]:
def get_aucs_from_column(df: pd.DataFrame, col_name: str, bootstraps: int = 10) -> list:
    """Given a dataframe and model name (col_name),
    extracts the AUCs and casts to list of floats"""
    aucs = []
    if col_name not in df:
        return aucs
    for bootstrap in range(bootstraps):
        try:
            auc = df[col_name].loc[str(bootstrap)]
            aucs.append(float(auc))
        except:
            aucs.append(np.nan)
            print(f"no valid auc found at bootstrap {bootstrap}")
    return aucs

## Load CSVs and concatenate horizontally into one wide dataframe

In [None]:
csv_filenames = ['stsnet', 'sts-shallow', 'sts-cabg', 'sts-valve', 'sts-cabg-valve']
rootdir = os.path.expanduser("~/dropbox/sts-ecg/figures-and-tables")
df = pd.DataFrame()
for csv_filename in csv_filenames:
    fpath = os.path.join(rootdir, f"{csv_filename}.csv")
    df_ = pd.read_csv(fpath, low_memory=False, index_col=0)
    if np.any(["sts" not in col for col in df_.columns]):
        cols = [f"{csv_filename}_{col}" for col in df_.columns]
        df_.columns = cols
    df = pd.concat([df, df_], axis=1)

df.rename(columns={'Unnamed: 0':'parameter'}, inplace=True )

## Get AUCs from dataframe as a dict (keyed by model name) of lists of floats

In [None]:
aucs = {}
for model in df.keys():
    aucs[model] = get_aucs_from_column(df=df, col_name=model, bootstraps=10)

In [None]:
root = os.path.expanduser("~/dropbox/sts-ecg/figures-and-tables")

outcomes = [
    "death",
    "stroke",
    "renal",
    "vent",
    "reop",
    "stay",
    "dsw",
]

cohorts = ["all", "cabg", "valve", "cabg-valve", "others"]

for outcome in outcomes:
    for cohort in cohorts:
        if (cohort == "all") or (cohort == "others"):
            models_to_plot = [
                f'sts-shallow-v001 {outcome} infer {cohort}', 
                f'stsnet-v046 {outcome} infer {cohort}',
            ]
            
            model_1 = f'sts-shallow-v001 {outcome} infer {cohort}'
            model_2 = f'stsnet-v046 {outcome} infer {cohort}'
            statistic, p_val = stats.ttest_rel(aucs[model_1], aucs[model_2])
            sig_star = "*" if p_val < 0.05 else ""
            
            # Remove the first color from the palette because we are not plotting baseline STS AUC
            palette = sns.color_palette("pastel")[1:]
            
            plot_performance(
                model_names=models_to_plot,
                df=df,
                y_min=0.5,
                y_max=0.9,
                x_label_strs=[
                    "LogReg (single task)",
                    f"STSNet (multitask){sig_star}",
                ],
                plot_title=f"{outcome} {cohort}",
                fpath=os.path.join(root, f"{outcome}-{cohort}.png"),
                palette=palette,
            )
        else:
            models_to_plot = [
                f'sts-{cohort}_{outcome}',
                f'sts-shallow-v001 {outcome} infer {cohort}', 
                f'stsnet-v046 {outcome} infer {cohort}',
            ]
            
            model_1 = f'sts-shallow-v001 {outcome} infer {cohort}'
            model_2 = f'stsnet-v046 {outcome} infer {cohort}'
            statistic, p_val = stats.ttest_rel(aucs[model_1], aucs[model_2])
            sig_star = "*" if p_val < 0.05 else ""
            
            plot_performance(
                model_names=models_to_plot,
                df=df,
                y_min=0.5,
                y_max=0.9,
                x_label_strs=[
                    "STS",
                    "LogReg (single task)",
                    f"STSNet (multitask){sig_star}",
                ],
                plot_title=f"{outcome} {cohort}",
                fpath=os.path.join(root, f"{outcome}-{cohort}.png")
            )