In [1]:
import pickle
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score

matplotlib.rcParams['mathtext.fontset'] = 'cm'
matplotlib.rcParams.update({'font.size': 15})

### Load data and results

In [3]:
result_file = "../results/mimic_4bis_mlp_rho=0.05_p1=6_p2=0.35_p3=0.002.pkl"
dataset = "../data/triage_scenario_{}.csv".format(result_file[result_file.index('_')+1:result_file.index('mlp')-1])

In [4]:
data = pd.read_csv(dataset, index_col = [0, 1])
covariates, target = data.drop(columns = ['D', 'Y1', 'Y2', 'YC', 'nurse']), data[['D', 'Y1', 'YC']]

In [7]:
results = pickle.load(open(result_file, 'rb'))

### Evaluate 

In [16]:
def evaluate(results, target, p):
    """
        Evaluate dictionary of results

    Args:
        results (_type_): _description_
        target ()
        p (_type_): _description_

    Returns:
        _type_: Dictionary of results
    """
    evaluation = {}
    
    for i, result in enumerate(results): 
        eval = {
            '$f_Y$': compute_metrics(result['Observed'], target, p),
            '$f_h$': compute_metrics(result['Human'], target, p),
            '$f_\mathcal{A}$': compute_metrics(result['Amalgamation'], target, p),
            '$f_{hyb}$': compute_metrics(result['Hybrid'], target, p),
            '$f_{def}$': compute_metrics(result['Defer'], target, p),
        }
        try: eval['Observed Negative Outcome'] = {(tar, 'TNR'): (1 - target[tar].loc[results[0].index][covariates.loc[results[0].index].anchor_age == 1]).mean() for tar in target.columns}
        except: pass
        evaluation[i] = pd.DataFrame.from_dict(eval)
    
    evaluation = pd.concat(evaluation)
    evaluation.index.rename(['Fold', 'Outcome', 'Metric'], inplace = True)
    return evaluation 

def compute_metrics(predictions, target, p):
    metrics = {}
    tar_test = target.loc[predictions.index]
    for tar in target.columns:
        metrics[(tar, 'AUC-ROC')] = roc_auc_score(tar_test[tar], predictions)
    try:
        bot = predictions.nsmallest(n = int(p * len(predictions)), keep = 'all').index
        female = covariates.loc[predictions.index].Group == 1
        bot_female = bot.intersection(female[female].index)
        bot_male = bot.intersection(female[~female].index)
        for tar in target.columns:
            metrics[(tar, 'Female TNR')] = 1 - tar_test[tar].loc[bot_female].mean()
            metrics[(tar, 'Female PNR')] = len(bot_female) / female.sum()
            metrics[(tar, 'Male TNR')] = 1 - tar_test[tar].loc[bot_male].mean()
            metrics[(tar, 'Male PNR')] = len(bot_male) / (~female).sum()
    except Exception as e: pass 

    return metrics

In [17]:
evaluation = evaluate(results, target, 0.3)

### Display

In [13]:
metric = 'AUC-ROC'

In [None]:
colors = ['tab:green', 'tab:red', 'tab:blue', 'tab:orange', 'tab:brown', 'tab:grey']
patterns = ['/', '-', '\\', '.', '|', '']

In [20]:
mean = evaluation.groupby(['Metric', 'Outcome']).mean()
std = evaluation.groupby(['Metric', 'Outcome']).std()

ax = mean.loc[metric].dropna(axis = 1, how = 'all').plot.bar(edgecolor = 'white', width = 0.8, figsize = (10, 5), yerr = std.loc[metric].dropna(axis = 1, how = 'all'),
                            color = colors)

# Add hatch
hue = mean.loc[metric]
hatches = [p for p in patterns for _ in range(len(hue))]
for i, (bar, hatch) in enumerate(zip(ax.patches, hatches)):
    bar.set_hatch(hatch)

# Add separation lines
lines = np.array([bar.get_x() for bar in ax.patches])
for line in lines[-len(hue):-1] + ((lines[1:len(hue)] - lines[-len(hue):-1] + bar.get_width()) / 2):
    plt.axvline(line, ls = ':', color='grey', linestyle='--')

plt.ylabel(metric)
plt.xticks(rotation = 0)
plt.ylim(0., 1.)
plt.grid(alpha = 0.5)
plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')
plt.tight_layout()

In [None]:
pd.DataFrame.from_dict({m: ["{:.3f} ({:.3f})".format(mean.loc[m].loc['YC'].loc[i], std.loc[m].loc['YC'].loc[i]) for i in mean.columns] for m in mean.index.get_level_values(0)}, columns = mean.columns, orient = 'index')