In [None]:
import os 
import pickle
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score

matplotlib.rcParams['mathtext.fontset'] = 'cm'
matplotlib.rcParams.update({'font.size': 15})

### Load data and results

In [None]:
result_folder = "../results/mimic_1_mlp_delta=0.05_gamma1=6_gamma2=0.95_gamma3=0.002"
dataset = "../data/triage_scenario_{}.csv".format(result_folder[result_folder.index('_')+1:result_folder.index('mlp')-1])

In [None]:
data = pd.read_csv(dataset, index_col = [0, 1])
covariates, target = data.drop(columns = ['D', 'Y1', 'Y2', 'YC', 'nurse']), data[['D', 'Y1', 'Y2', 'YC']]

### Evaluate 

In [None]:
matching = {
    'f_Y': '$f_Y$',
    'f_D': '$f_D$',
    
    'f_hyb': '$f_{hyb}$',
    'f_ensemble': '$f_{ens}$',
    'f_weak': '$f_{weak}$',
    'f_robust': '$f_{noise}$',

    'f_A': '$f_\mathcal{A}$'
}

In [None]:
def evaluate(target, p):
    """
    Evaluate the performance of the models in the result folder
    """
    evaluation = {}

    # Enumerate through folds
    folds = os.listdir(result_folder)
    for fold in folds:
        eval = {}
        file_path = os.path.join(result_folder, fold)
        for model in os.listdir(file_path):
            if 'f_' in model:
                res = pd.read_csv(os.path.join(file_path, model), index_col = [0, 1])['0']
                eval[matching[model[:model.index('.csv')]]] = compute_metrics(res, target, p)
        evaluation[fold] = pd.DataFrame.from_dict(eval)

    evaluation = pd.concat(evaluation)
    evaluation.index.rename(['Fold', 'Outcome', 'Metric'], inplace = True)
    return evaluation 

def compute_metrics(predictions, target, p):
    metrics = {}
    tar_test = target.loc[predictions.index]
    for tar in target.columns:
        metrics[(tar, 'AUC-ROC')] = roc_auc_score(tar_test[tar], predictions)
    try:
        bot = predictions.nsmallest(n = int(p * len(predictions)), keep = 'all').index
        female = covariates.loc[predictions.index].Group == 1
        bot_female = bot.intersection(female[female].index)
        bot_male = bot.intersection(female[~female].index)
        for tar in target.columns:
            metrics[(tar, 'Female TNR')] = 1 - tar_test[tar].loc[bot_female].mean()
            metrics[(tar, 'Female PNR')] = len(bot_female) / female.sum()
            metrics[(tar, 'Male TNR')] = 1 - tar_test[tar].loc[bot_male].mean()
            metrics[(tar, 'Male PNR')] = len(bot_male) / (~female).sum()
    except Exception as e: pass

    return metrics

In [None]:
evaluation = evaluate(target, 0.3)
evaluation = evaluation[[col for col in matching.values() if col in evaluation.columns]] # Reorder

### Display

In [None]:
metric = 'AUC-ROC'

In [None]:
colors = ['tab:green', 'tab:red', 'tab:blue', 'tab:orange', 'tab:brown', 'tab:grey', 'tab:purple', 'tab:olive']
patterns = ['/', '-', '\\', '.', '|', '', 'x', 'o']

In [None]:
evaluation.columns.inter(matching.values())

In [None]:
matching.values()

In [None]:
mean = evaluation.groupby(['Metric', 'Outcome']).mean()
std = evaluation.groupby(['Metric', 'Outcome']).std()

ax = mean.loc[metric].dropna(axis = 1, how = 'all').plot.bar(edgecolor = 'white', width = 0.8, figsize = (10, 5), yerr = std.loc[metric].dropna(axis = 1, how = 'all'),
                            color = colors)

# Add hatch
hue = mean.loc[metric]
hatches = [p for p in patterns for _ in range(len(hue))]
for i, (bar, hatch) in enumerate(zip(ax.patches, hatches)):
    bar.set_hatch(hatch)

# Add separation lines
lines = np.array([bar.get_x() for bar in ax.patches])
for line in lines[-len(hue):-1] + ((lines[1:len(hue)] - lines[-len(hue):-1] + bar.get_width()) / 2):
    plt.axvline(line, ls = ':', color='grey', linestyle='--')

plt.ylabel(metric)
plt.xticks(rotation = 0)
plt.ylim(0., 1.)
plt.grid(alpha = 0.5)
plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')
plt.tight_layout()

In [None]:
pd.DataFrame.from_dict({m: ["{:.3f} ({:.3f})".format(mean.loc[m].loc['YC'].loc[i], std.loc[m].loc['YC'].loc[i]) for i in mean.columns] for m in mean.index.get_level_values(0)}, columns = mean.columns, orient = 'index')