In [None]:
# Downgrade jedi so tab autocomplete works
! pip install jedi==0.17.2

In [None]:
import os
import csv
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats
%matplotlib inline

In [None]:
# Plot and global settings
sns.set_style("whitegrid")
sns.set_context("talk")
num_bootstraps = 10
font_size_txt_on_plot = 16
font_size_asterisk_on_plot = 20
image_ext = ".eps"
palette = sns.color_palette("pastel")

In [None]:
# Get today's date
from datetime import date
today = date.today().strftime("%Y-%m-%d")
print(today)

In [None]:
# Calculate mean AUC across bootstraps for each model
def _mean_aucs(aucs: pd.DataFrame):
    means = {}
    for model in aucs["model"].unique():
        means[model] = np.mean(aucs[aucs["model"] == model]["auc"])
    return means

# 
def pad_with_leading_zeros(num: int, leading_zeros: int=3):
    return "{:03d}".format(num)

In [None]:
# Plot mean AUC +/- std for each model
def box_plot_aucs(
    model_names: list,
    aucs: pd.DataFrame,
    y_min: float = 0.7,
    y_max: float = 0.95,
    x_tick_labels: list = [],
    x_label: str = None,
    plot_title: str = None,
    fpath: str = None,
    palette = None):
    
    # Isolate subset of data
    aucs_to_plot = aucs[aucs['model'].isin(models)]
    
    if palette is None:
        palette = sns.color_palette("pastel")
    
    y_delta = 0.1
    
#     stats = {model: {'aucs': [], 'mean': 0, 'std': 0} for model in model_names}
#     means = [np.nanmean(aucs[model]) for model in model_names]
#     stds = [np.nanstd(aucs[model]) for model in model_names]

    max_chars_model_name = np.max([len(m) for m in model_names])
    fig_width = 1 + len(model_names) * 1
    fig_height = max_chars_model_name / 5 + fig_width / 5
    fig, ax = plt.subplots(figsize=(fig_width, fig_height))

    #     bar = sns.barplot(x=x_pos, y=means, yerr=stds, palette=palette)
    
    bar = sns.boxplot(x="model", y="auc", data=aucs_to_plot, palette=palette)
    ax.set_ylabel('AUC')
    ax.set_ylim([y_min, y_max])
    ax.set_yticks(np.arange(y_min, y_max+0.01, y_delta))
    
    ax.set_xticks(range(len(model_names)))
    ax.set_xticklabels(model_names if len(x_tick_labels) == 0 else x_tick_labels)
   
    if plot_title is not None:
        ax.set_title(plot_title)
        
    ax.set_xlabel(None)
    
#     # Indicate significance for each column
#     for i, patch in enumerate(ax.patches):
#         y_height = means[i] + stds[i]
#         sig_txt = "*" if pvals[i] < 0.05 else None
#         ax.text(patch.get_x() + patch.get_width() / 2,
#                 means[i] + stds[i],
#                 s=sig_txt,
#                 ha='center',
#                 fontsize=font_size_asterisk_on_plot,
#         )

    means = _mean_aucs(aucs)
    offset = -0.25
    x_pos = np.arange(len(model_names))
    for i in range(len(model_names)):
        plt.text(
            x_pos[i]+offset,
            0.95,
            f'{means[model_names[i]]:0.2f}',
            fontsize=font_size_txt_on_plot,
        )

    plt.xticks(rotation=90)
    plt.tight_layout()
    
    if fpath is not None:
        plt.savefig(fpath, dpi=300)
        print(f"Saved {fpath}")

    return ax

In [None]:
def get_same_number_of_aucs(aucs1: list, aucs2: list) -> tuple:
    # Remove NaNs
    aucs1 = [auc for auc in aucs1 if auc == auc]
    aucs2 = [auc for auc in aucs2 if auc == auc]
    
    # Determine length of shorter list
    max_len = min(len(aucs1), len(aucs2))
    max_idx = max_len - 1
    
    # Return the first max_len values of each list
    return aucs1[:max_idx], aucs2[:max_idx]

In [None]:
print("All functions initialized")

## Set paths and create dirs

In [None]:
root = os.path.expanduser("~/dropbox/ecgnet-sts")
print(f"Set root path to: {root}")

dirpath_figures = os.path.join(root, "figures-and-tables")
if not os.path.isdir(dirpath_figures):
    os.mkdir(dirpath_figures)
print(f"Set figures path to: {dirpath_figures}")
    
dirpath_auc_spreadsheets = os.path.join(root, "auc-spreadsheets")
print(f"Set directory path to AUC CSVs to: {dirpath_auc_spreadsheets}")

## Load CSVs and concatenate horizontally into one wide dataframe

In [None]:
csv_filenames = [
    "ecgnet-sts-results"
]
df = pd.DataFrame()
for csv_filename in csv_filenames:
    fpath = os.path.join(dirpath_auc_spreadsheets, f"{csv_filename}.csv")
    df_ = pd.read_csv(fpath, low_memory=False, index_col=0)
    df = pd.concat([df, df_], axis=1)

df.rename(columns={'Unnamed: 0':'parameter'}, inplace=True)

print(f"DataFrame generated with shape {df.shape}")

## Format AUCs from dataframe
1. Isolate just bootstrap AUCs
2. Unpivot  from wide to long format
3. Drop NaNs
4. Cast AUCs to floats

In [None]:
aucs = df.loc[[str(bootstrap) for bootstrap in range(1, num_bootstraps)]]
aucs = aucs.melt(var_name="model", value_name="auc")
aucs.dropna(inplace=True)
aucs["auc"] = aucs["auc"].astype(float)

In [None]:
# pvals = []
# for model in models:   
#     aucs1, aucs2 = get_same_number_of_aucs(
#         aucs1=aucs[models[0]],
#         aucs2=aucs[model],
#     )
#     _, pval = stats.ttest_rel(aucs1, aucs2)
#     pvals.append(pval)

## Generate box plots of AUCs

In [None]:
models = [f"mlp-v{pad_with_leading_zeros(num=i, leading_zeros=3)}-age-sex-metadata" for i in range(1, 9)]
model_names = [f"v{pad_with_leading_zeros(num=i, leading_zeros=3)}" for i in range(1, 9)]
    
box_plot_aucs(
    model_names=models,
    aucs=aucs,
    y_min=0.5,
    y_max=1.0,
    plot_title=f"MLP: age, sex, metadata",
    x_tick_labels=model_names,
    palette=palette,
    fpath=os.path.join(dirpath_figures, f"{today}-mlp-age-sex-metadata{image_ext}"),
)

In [None]:
models = [f"ecgnet-v{pad_with_leading_zeros(num=i, leading_zeros=3)}" for i in range(1, 19)]
model_names = [f"v{pad_with_leading_zeros(num=i, leading_zeros=3)}" for i in range(1, 19)]
    
box_plot_aucs(
    model_names=models,
    aucs=aucs,
    y_min=0.5,
    y_max=1.0,
    plot_title=f"ECGNet: voltage",
    x_tick_labels=model_names,
    palette=palette,
    fpath=os.path.join(dirpath_figures, f"{today}-ecgnet{image_ext}"),
)

In [None]:
models = [f"resnet-v{pad_with_leading_zeros(num=i, leading_zeros=3)}" for i in range(1, 12)]
model_names = [f"v{pad_with_leading_zeros(num=i, leading_zeros=3)}" for i in range(1, 12)]

box_plot_aucs(
    model_names=models,
    aucs=aucs,
    y_min=0.5,
    y_max=1.0,
    plot_title=f"ResNet: voltage",
    x_tick_labels=model_names,
    palette=palette,
    fpath=os.path.join(dirpath_figures, f"{today}-resnet{image_ext}"),
)

In [None]:
# ecgnet with age, sex, and metadata too
import re

models = [m for m in aucs["model"] if (m.endswith("-age-sex-metadata") and m.startswith("ecgnet-"))]
models = list(np.unique(np.array(models)))

regex = "ecgnet-(.*)-age-sex-metadata"
x_tick_labels = [re.search(regex, m).group(1) for m in models]

box_plot_aucs(
    model_names=models,
    aucs=aucs,
    y_min=0.5,
    y_max=1.0,
    plot_title=f"ECGNet: voltage, age, sex, and metadata",
    x_tick_labels=x_tick_labels,
    palette=palette,
    fpath=os.path.join(dirpath_figures, f"{today}-ecgnet-age-sex-metadata{image_ext}"),
)

In [None]:
models = [
    "mlp-v002-age-sex-metadata",
    "mlp-v001-sts-features",
    "ecgnet-v011",
    "ecgnet-v018-age-sex-metadata",
    "ecgnet-v001-sts-features",
    "resnet-v001",
    "resnet-v001-age-sex-metadata",
    "resnet-v001-sts-features",
]
    
box_plot_aucs(
    model_names=models,
    aucs=aucs,
    y_min=0.5,
    y_max=1.0,
    plot_title=f"Model comparison",
    x_tick_labels=[
        "MLP age, sex, metadata",
        "MLP STS features",
        "ECGNet voltage",
        "ECGNet voltage, age, sex, metadata",
        "ECGNet voltage, STS features",
        "ResNet voltage",
        "ResNet voltage, age, sex, metadata",
        "ResNet voltage, STS features",
    ],
    palette=palette,
    fpath=os.path.join(dirpath_figures, f"{today}-compare-models{image_ext}"),
)