In [None]:
# Downgrade jedi so tab autocomplete works
! pip install jedi==0.17.2

In [None]:
import os
import csv
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats
%matplotlib inline

In [None]:
# Plot and global settings
sns.set_style("whitegrid")
sns.set_context("talk")
fig_height = 5
num_bootstraps = 10
font_size_txt_on_plot = 16
font_size_asterisk_on_plot = 20
image_ext = ".eps"
palette = sns.color_palette("pastel")

In [None]:
# Get today's date
from datetime import date
today = date.today().strftime("%Y-%m-%d")
print(today)

In [None]:
# Plots mean AUC +/- std for each model
def bar_plot_aucs(
    model_names: list,
    aucs: dict,
    pvals: list,
    y_min: float = 0.7,
    y_max: float = 0.95,
    x_tick_labels: list = [],
    x_label: str = None,
    plot_title: str = None,
    fpath: str = None,
    palette = None):
    
    if palette is None:
        palette = sns.color_palette("pastel")
    
    y_delta = 0.1
    
    stats = {model: {'aucs': [], 'mean': 0, 'std': 0} for model in model_names}

    means = [np.nanmean(aucs[model]) for model in model_names]
    stds = [np.nanstd(aucs[model]) for model in model_names]
        
    fig_width = 1 + len(model_names) * 1
    fig, ax = plt.subplots(figsize=(fig_width, fig_height))
    x_pos = np.arange(len(model_names))
    
    bar = sns.barplot(x=x_pos, y=means, yerr=stds, palette=palette)
    ax.set_ylabel('AUC')
    ax.set_ylim([y_min, y_max])
    ax.set_yticks(np.arange(y_min, y_max+0.01, y_delta))
    
    ax.set_xticks(range(len(model_names)))
    ax.set_xticklabels(model_names if len(x_tick_labels) == 0 else x_tick_labels)
   
    if plot_title is not None:
        ax.set_title(plot_title)
        
    if x_label is not None:
        ax.set_xlabel(x_label)
    
    # Indicate significance for each column
    for i, patch in enumerate(ax.patches):
        y_height = means[i] + stds[i]
        sig_txt = "*" if pvals[i] < 0.05 else None
        ax.text(patch.get_x() + patch.get_width() / 2,
                means[i] + stds[i],
                s=sig_txt,
                ha='center',
                fontsize=font_size_asterisk_on_plot,
        )
    
    offset = -0.25
    for i, v in enumerate(means):
        plt.text(
            x_pos[i]+offset,
            y_min + 0.01,
            f'{v:0.2f}',
            fontsize=font_size_txt_on_plot,
        )
    plt.xticks(rotation=90)
    plt.tight_layout()
    
    if fpath is not None:
        plt.savefig(fpath, dpi=300)
        print(f"Saved {fpath}")

    return ax

In [None]:
def get_aucs_from_column(df: pd.DataFrame, col_name: str, bootstraps: int = 10) -> list:
    """Given a dataframe and model name (col_name),
    extracts the AUCs and casts to list of floats"""
    aucs = []
    if col_name not in df:
        return aucs
    for bootstrap in range(bootstraps):
        try:
            auc = df[col_name].loc[str(bootstrap)]
            aucs.append(float(auc))
        except:
            aucs.append(np.nan)
            print(f"no valid auc found at bootstrap {bootstrap}")
    return aucs

In [None]:
def get_same_number_of_aucs(aucs1: list, aucs2: list) -> tuple:
    # Remove NaNs
    aucs1 = [auc for auc in aucs1 if auc == auc]
    aucs2 = [auc for auc in aucs2 if auc == auc]
    
    # Determine length of shorter list
    max_len = min(len(aucs1), len(aucs2))
    max_idx = max_len - 1
    
    # Return the first max_len values of each list
    return aucs1[:max_idx], aucs2[:max_idx]

In [None]:
print("All functions initialized")

## Set paths and create dirs

In [None]:
root = os.path.expanduser("~/dropbox/ecgnet-sts")
print(f"Set root path to: {root}")

dirpath_figures = os.path.join(root, "figures-and-tables")
if not os.path.isdir(dirpath_figures):
    os.mkdir(dirpath_figures)
print(f"Set figures path to: {dirpath_figures}")
    
dirpath_auc_spreadsheets = os.path.join(root, "auc-spreadsheets")
print(f"Set directory path to AUC CSVs to: {dirpath_auc_spreadsheets}")

## Load CSVs and concatenate horizontally into one wide dataframe

In [None]:
csv_filenames = [
    "aucs"
]
df = pd.DataFrame()
for csv_filename in csv_filenames:
    fpath = os.path.join(dirpath_auc_spreadsheets, f"{csv_filename}.csv")
    df_ = pd.read_csv(fpath, low_memory=False, index_col=0)
    df = pd.concat([df, df_], axis=1)

df.rename(columns={'Unnamed: 0':'parameter'}, inplace=True)

print(f"DataFrame generated with shape {df.shape}")

## Get AUCs from dataframe as a dict (keyed by model name) of lists of floats

In [None]:
aucs = {}
for model in df.keys():
    aucs[model] = get_aucs_from_column(df=df, col_name=model, bootstraps=10)

## Plot AUCs

In [None]:
models = [
    "use-all-ecgs-v001",
    "use-all-ecgs-v001-bn",
    "use-all-ecgs-v002",
    "use-all-ecgs-v002-bn",
    "use-all-ecgs-v003",
    "use-all-ecgs-v003-bn",
    "use-all-ecgs-v004",
    "use-all-ecgs-v004-bn",
    "use-all-ecgs-v005",
    "use-all-ecgs-v005-bn",
]

pvals = []

for model in models:   
    aucs1, aucs2 = get_same_number_of_aucs(
        aucs1=aucs[models[0]],
        aucs2=aucs[model],
    )
    _, pval = stats.ttest_rel(aucs1, aucs2)
    pvals.append(pval)
    
bar_plot_aucs(
    model_names=models,
    aucs=aucs,
    pvals=pvals,
    y_min=0.5,
    y_max=1.0,
    plot_title=f"STS (all ECGs)",
    x_tick_labels=[
        "v1",
        "v1-bn",
        "v2",
        "v2-bn",
        "v3",
        "v3-bn",
        "v4",
        "v4-bn",
        "v5",
        "v5-bn",
    ],
    palette=palette,
    fpath=os.path.join(dirpath_figures, f"{today}-all-ecgs{image_ext}"),
)