In [1]:
import numpy as np
import pandas as pd
import psycopg2
import os 
import random
import datetime
from sqlalchemy import create_engine
import mimic_utils as mimic
import magec_utils as mg
import matplotlib.pyplot as plt
from adjustText import adjust_text

pd.set_option('display.max_columns', None)

random.seed(22891)

%matplotlib inline

Using TensorFlow backend.


In [2]:
def get_magecs(p=None, models=('lr', 'mlp', 'svm', 'lstm')):
    prefix = 'magec_'
    postfix = '_0.csv' if p is None else '_p' + str(p) + '.csv'
    magecs = []
    for m in models:
        filename = prefix + m + postfix
        magecs.append(pd.read_csv(filename))
    return magecs

In [3]:
vitals = ['heartrate_mean', 'sysbp_mean', 'diasbp_mean', 'meanbp_mean',
          'resprate_mean', 'tempc_mean', 'spo2_mean', 'glucose_mean']

labs = ['aniongap', 'albumin', 'bicarbonate', 'bilirubin', 'creatinine', 
        'chloride', 'glucose', 'hemoglobin', 'lactate', 
        'magnesium', 'phosphate', 'platelet', 'potassium', 'ptt', 'inr', 
        'pt', 'sodium', 'bun', 'wbc']  # -hematocrit

comobs = ['congestive_heart_failure', 'chronic_pulmonary', 'pulmonary_circulation']

others = ['age', 'gender']

features = vitals+labs

df_cohort = mimic.get_mimic_data()

df_ml = mimic.get_ml_data(df_cohort)

df_time = mimic.get_ml_series_data(df_cohort)

_, _, _, _, _, _, Y_validation = mimic.train_valid_ml(df_ml)

stsc2, series_means, _, df_series_valid, _, _, xt_valid, Yt_valid = mimic.train_valid_series(df_time, Y_validation)

In [4]:
df_notes = mimic.get_cohort_notes()

In [5]:
magecs = get_magecs()

In [6]:
x_magec_cols = list(set(df_series_valid.columns) - {'label'})
x_magec = df_series_valid[x_magec_cols]
y_magec = df_series_valid['label']

In [7]:
joined = mg.magec_models(*magecs, Xdata=x_magec, Ydata=y_magec, features=features)

In [8]:
def most_abnormal(x, features):
    res = None
    feat = None
    for f in features:
        if res is None or abs(x[f]) > res:
            res = abs(x[f])
            feat = f
    return feat

In [9]:
prob_cols = [c for c in joined.columns if c.startswith('perturb')]
joined['orig_prob_ensemble'] = joined[['orig_prob_mlp', 'orig_prob_lr', 
                                       'orig_prob_svm', 'orig_prob_lstm']].apply(np.mean, 1)
joined[['best_feat', 'new_risk', 'rank_feat', 'rank_val']] = joined.apply(
    lambda x: mimic.best_feature(x, prob_cols), axis=1)
joined['most_abnormal'] = joined.apply(lambda x: most_abnormal(x, features), axis=1)

KeyboardInterrupt: 

In [None]:
drivers = ['heartrate_mean', 'sysbp_mean', 'diasbp_mean', 
           'meanbp_mean', 'resprate_mean', 'spo2_mean']

In [None]:
joined[drivers+['most_abnormal','best_feat', 'rank_feat', 'rank_val']].head()

In [None]:
np.sum(joined['best_feat'] != joined['rank_feat']) / len(joined)

In [None]:
def expected(x, drivers, sigmas=0.5, threshold=0.5):
    orig_prob_ensemble = x['orig_prob_ensemble']
    best_feat = x['best_feat']
    label = x['label']
    # some predicates
    cond1 = orig_prob_ensemble > threshold  # models predict MV (ventilated)
    cond2 = np.all(abs(x[drivers]) <= sigmas)  # all drivers are 'normal'
    cond3 = np.isin(best_feat, drivers)  # MAgEC 'best feature' is a driver
    cond4 = x[best_feat] > sigmas # MAgEC 'best feature' is 'abnormal'
    cond5 = label == 1  # patient was ventilated
    # Unexpected (ventilated):
    # 1. ensemble_probability greater than 0.5, 
    # 2. all drivers are normal
    # 3. best feature is not a driver 
    # 4. patient was ventillated
    if cond1 and cond2 and (not cond3) and cond5:
        return 'unexpected_ventilated_nondriver'
    elif cond1 and cond2 and cond3 and cond5:
        return 'unexpected_ventilated_driver'
    # Missed Unexpected (ventilated)
    elif (not cond1) and cond2 and cond5:
        return 'missed_unexpected_ventilated'
    # Expected (ventilated): 
    # 1. one or more drivers were abnormal
    # 2. patient was ventillated
    elif cond1 and (not cond2) and (not cond3) and cond5:
        return 'expected_ventilated_nondriver'
    elif cond1 and (not cond2) and cond3 and cond5:
        return 'expected_ventilated_driver'
    elif (not cond1) and (not cond2) and cond5:
        return 'missed_expected_ventilated'
    # Other (ventilated)
    elif cond5:
        return 'other_ventilated'
    # Unexpected (not ventilated)
    # 1. ensemble_probability less than 0.5
    # 2. one or more drivers are abnormal
    # 3. patient was not ventilated
    elif (not cond1) and (not cond2) and (not cond5):
        return 'unexpected_notventilated'
    # Expected (not ventilated)
    # 1. ensemble_probability less than 0.5
    # 2. all drivers are normal
    # 3. patient was not ventilated
    elif (not cond1) and cond2 and (not cond5):
        return 'expected_notventilated'
    elif (not cond5):
        return 'other_notventilated'
    else:
        return 'other'

In [None]:
joined['stats'] = joined.apply(lambda x: expected(x, drivers), axis=1)

In [15]:
joined['stats'].value_counts()

unexpected_notventilated           32871
missed_expected_ventilated          4064
other_notventilated                 2130
expected_ventilated_driver           880
expected_notventilated               485
expected_ventilated_nondriver        304
unexpected_ventilated_nondriver      232
missed_unexpected_ventilated          62
unexpected_ventilated_driver           1
Name: stats, dtype: int64

In [16]:
excluded = set(df_cohort[np.all(np.isnan(df_cohort[drivers]), axis=1)].subject_id.unique())
filtered = joined[~np.isin(joined.case, list(excluded))]
len(joined), len(filtered)

(41029, 32944)

In [17]:
joined.case.nunique(), filtered.case.nunique()

(2083, 1557)

In [None]:
filtered['stats'].value_counts()

### Unexpected Ventilated Classes (normal drivers w/ MV=True)

In [None]:
missed_unexpected_ventilated = set(filtered[filtered['stats'] == 'missed_unexpected_ventilated'].case.unique())
unexpected_ventilated_nondriver = set(filtered[filtered['stats'] == 'unexpected_ventilated_nondriver'].case.unique())
unexpected_ventilated_driver = set(filtered[filtered['stats'] == 'unexpected_ventilated_driver'].case.unique())

### There were 8 patients that MAgEC correctly identified as ventilated that had all drivers 'normal' and 18 such patients that MAgEC missed

In [None]:
len(missed_unexpected_ventilated), len(unexpected_ventilated_nondriver), len(unexpected_ventilated_driver)

### There are 2 cases missed that were correctly identified at different time points in their trajectories

In [None]:
missed_unexpected_ventilated.intersection(unexpected_ventilated_nondriver.union(unexpected_ventilated_driver))

### Expected Ventilated Classes (one or more abnormal driver w/ MV=True)

In [None]:
missed_expected_ventilated = set(filtered[filtered['stats'] == 'missed_expected_ventilated'].case.unique())
expected_ventilated_nondriver = set(filtered[filtered['stats'] == 'expected_ventilated_nondriver'].case.unique())
expected_ventilated_driver = set(filtered[filtered['stats'] == 'expected_ventilated_driver'].case.unique())

### There were 186 patients that MAgEC missed, but 93 of them were identitied at some point in their trajectories

In [None]:
len(missed_expected_ventilated), len(expected_ventilated_nondriver), len(expected_ventilated_driver)

In [None]:
len(missed_expected_ventilated.intersection(expected_ventilated_nondriver.union(expected_ventilated_driver)))

In [None]:
len(missed_expected_ventilated.union(expected_ventilated_nondriver).\
    union(expected_ventilated_driver).union(missed_unexpected_ventilated).\
    union(unexpected_ventilated_nondriver).union(unexpected_ventilated_driver))

In [None]:
filtered[filtered.label==1].case.nunique()

### Cases

In [None]:
unexpected_ventilated_nondriver.union(unexpected_ventilated_driver)

In [None]:
missed_unexpected_ventilated - unexpected_ventilated_nondriver.union(unexpected_ventilated_driver)

In [None]:
index = 1000
joined[joined.case == index][drivers+['timepoint','orig_prob_ensemble',
                                      'best_feat','new_risk', 'rank_feat', 'rank_val', 'stats','label']]

In [None]:
x1 = filtered[filtered.label == 1]['orig_prob_lstm']
x2 = filtered[filtered.label == 0]['orig_prob_lstm']
bins = np.linspace(0, 1, 100)
plt.figure(figsize=[10,8])
plt.hist(x1, bins, alpha=0.5, label='ventilated', density=False)
plt.hist(x2, bins, alpha=0.5, label='not ventilated', density=False)
plt.legend(loc='upper right');

In [None]:
mort = df_cohort[['subject_id', 'mort_icu']].drop_duplicates()

In [None]:
len(mort), df_cohort.subject_id.nunique()

In [None]:
stats = joined.merge(mort, left_on='case', right_on='subject_id', how='inner')

In [None]:
stats[['case','stats']].groupby('stats')['case'].nunique()

In [None]:
stats[['case','stats','mort_icu']].groupby(['stats','mort_icu'])['case'].nunique()

In [None]:
stats[(stats['stats'] == 'unexpected_ventilated')&(stats['mort_icu']==1)]['case'].unique()

In [None]:
mimic.print_notes(df_notes, 32505)

In [None]:
unexpected_ventilated_nondriver

In [None]:
joined[joined.case == 44775][drivers+['timepoint','orig_prob_ensemble','stats','label','best_feat','new_risk']]

In [None]:
df_cohort[df_cohort.subject_id==1000][drivers]

In [None]:
excluded = set(df_cohort[np.all(np.isnan(df_cohort[drivers]), axis=1)].subject_id.unique())

In [None]:
joined[~np.isin(joined.case, list(excluded))].case.nunique()

In [None]:
category = 'unexpexted_ventilated'
case = 1000

foo[(foo['stats'] == category)&(foo['mort_icu']==1)&(foo['case']==case)][drivers+
                                                                         ['orig_prob_ensemble',
                                                                          'most_abnormal',
                                                                          'best_feat',
                                                                          'label', 
                                                                          'new_risk']]

In [None]:
foo[foo['case']==1000]['case']

In [None]:
len(foo), len(joined)

In [None]:
foo[foo['case']==1000][drivers+['orig_prob_ensemble','most_abnormal','best_feat','label', 'new_risk','phosphate']]

In [None]:
12/143, 43/1645

In [None]:
df_cohort[df_cohort.subject_id == 1000]

In [None]:
foo[foo['case']==1000][drivers+['orig_prob_ensemble','most_abnormal','best_feat','label', 'new_risk','phosphate']]

In [None]:
mimic.print_notes(df_notes, 1000)

In [None]:
def transitions(joined):
    group = joined[['case','timepoint','stats']].groupby('case')['timepoint','stats']
    transitions = dict()
    for case, x in group:
        times = x['timepoint']
        stats = x['stats']
        sorter = np.argsort(times)[::-1]  # larger value is earliest timepoint
        stats = np.array(stats)[sorter]
        if len(stats) > 0:
            prev = 0
            for i in range(1, len(stats)):
                t0 = stats[prev]
                t1 = stats[i]
                if (t0, t1) in transitions:
                    transitions[(t0,t1)] += 1
                else:
                    transitions[(t0,t1)] = 1
                prev = i
    return transitions

In [None]:
tran_dict = transitions(joined)

In [None]:
tran_df = pd.DataFrame.from_dict(tran_dict, orient='index').reset_index()
tran_df.columns = ['transition', 'counts']

In [None]:
tran_df.sort_values('counts', ascending=False)

In [None]:
joined['case'].nunique()

In [None]:
np.all(abs(joined.iloc[1][drivers]) < 0.5)

In [None]:
joined.iloc[1][drivers]

In [None]:
def anomalies(x, drivers, sigmas=0.5, threshold=0.5):
    orig_prob_ensemble = x['orig_prob_ensemble']
    if orig_prob_ensemble < threshold:
        return False
    else:
        for driver in drivers:
            if abs(x[driver]) > sigmas:
                return False
    return True

In [None]:
joined['anomaly'] = joined.apply(lambda x: anomalies(x, drivers), axis=1)

In [None]:
np.sum(joined['anomaly'] == True), np.sum(joined['anomaly'] == False)

In [None]:
joined[(joined['anomaly'] == False) & 
       (np.isin(joined['best_feat'], drivers))][drivers+['best_feat']].head()

In [None]:
joined[(joined['label'] == 1) & 
       (np.isin(joined['best_feat'], drivers)) & 
       (abs(joined[joined['best_feat']]) < 0.5)
       ][drivers+['orig_prob_ensemble',
                                         'most_abnormal',
                                         'best_feat', 
                                         'new_risk',
                                         'label']].sort_values(['orig_prob_ensemble'], ascending=False).head(20)

In [None]:
joined[joined['anomaly'] == True][features+['orig_prob_ensemble','most_abnormal','best_feat','label']].head()

In [None]:
joined[(joined['anomaly'] == True) & (joined['label'] == 1)]['case'].nunique()

In [None]:
len(joined[(joined['anomaly'] == True) & (joined['label'] == 1)])

In [None]:
joined[(joined['anomaly'] == True) & 
       (joined['label'] == 1)].best_feat.value_counts()

In [None]:
joined[(joined['anomaly'] == True) & 
       (joined['label'] == 1) & 
       (np.isin(joined['best_feat'], drivers))][drivers+['orig_prob_ensemble',
                                         'most_abnormal',
                                         'best_feat', 
                                         'new_risk',
                                         'label']].sort_values(['orig_prob_ensemble'], ascending=False).head(20)

In [None]:
joined[(joined['anomaly'] == True) & 
       (joined['label'] == 1)][drivers+['orig_prob_ensemble',
                                         'most_abnormal',
                                         'best_feat',
                                        'new_risk',
                                         'label']].sort_values(['orig_prob_ensemble'], ascending=False).head(20)

In [None]:
joined[(joined['anomaly'] == True) & 
       (joined['label'] == 1)][drivers+['orig_prob_ensemble',
                                        'most_abnormal',
                                        'best_feat', 
                                        'new_risk',
                                        'label']].sort_values(['new_risk']).head(20)

In [None]:
index = 37836
case = joined.loc[index].case
timepoint = joined.loc[268].timepoint
best_feat = joined.loc[index].best_feat
abnormal = joined.loc[index].most_abnormal

In [None]:
df_cohort[(df_cohort.subject_id==case) & (df_cohort.timepoint==timepoint)][[best_feat, abnormal]]

In [None]:
mimic.print_notes(df_notes, case)

In [None]:
mimic.print_notes?

In [None]:
joined[(joined['anomaly'] == True) & 
       (joined['label'] == 1)]['abnormal'].value_counts()

In [None]:
joined[(joined['anomaly'] == True) & 
       (joined['label'] == 1)]['best_feat'].value_counts()

In [None]:
df_notes = mimic.get_cohort_notes()

## Scratchpad

In [None]:
rbos = pd.read_json('mimic_rbos_valid.json')

In [None]:
rbos[['lstm_lr', 'mlp_lr', 'mlp_lstm', 'rf_lr', 'rf_lstm', 'rf_mlp']].mean()

In [None]:
rbos[['lstm_lr', 'mlp_lr', 'mlp_lstm', 'rf_lr', 'rf_lstm', 'rf_mlp']].std()

In [None]:
joined = pd.read_csv('time_mimic_magecs.csv')
prob_cols = [c for c in joined.columns if c.startswith('perturb') and 'resprate_mean' not in c]
joined[['best_feat', 'new_risk']] = joined.apply(lambda x: mimic.best_feature(x, prob_cols), axis=1)

In [None]:
joined['orig_prob_ensemble'] = joined[['orig_prob_mlp', 'orig_prob_lr', 
                                       'orig_prob_rf', 'orig_prob_lstm']].apply(np.mean, 1)

In [None]:
vitals = ['heartrate_mean', 'sysbp_mean', 'diasbp_mean', 'meanbp_mean',
          'resprate_mean', 'tempc_mean', 'spo2_mean', 'glucose_mean']

labs = ['aniongap', 'albumin', 'bicarbonate', 'bilirubin', 'creatinine', 
        'chloride', 'glucose', 'hemoglobin', 'lactate', 
        'magnesium', 'phosphate', 'platelet', 'potassium', 'ptt', 'inr', 
        'pt', 'sodium', 'bun', 'wbc']  # -hematocrit

comobs = ['congestive_heart_failure', 'chronic_pulmonary', 'pulmonary_circulation']

others = ['age', 'gender']

features = vitals+labs

In [None]:
def most_abnormal(x, features):
    res = None
    feat = None
    for f in features:
        if f == 'resprate_mean':
            continue
        if res is None or abs(x[f]) > res:
            res = abs(x[f])
            feat = f
    return feat

In [None]:
joined['abnormal'] = joined.apply(lambda x: most_abnormal(x, features), axis=1)

In [None]:
joined[features+['abnormal','best_feat']].head()

In [None]:
np.sum(joined.best_feat=='resprate_mean'), np.sum(joined.abnormal=='resprate_mean')

In [None]:
np.sum(joined['abnormal'] == joined['best_feat']), np.sum(joined['abnormal'] != joined['best_feat'])

In [None]:
7736 / (7736+33293), (7736+33293) == len(joined)

In [None]:
len(joined) == 8310 + 32719

In [None]:
8310 / len(joined)

In [None]:
print(np.sum((joined.new_risk < 0.5) & 
             (joined.best_feat != 'resprate_mean') & 
             (joined.abnormal != 'resprate_mean')
             (joined.orig_prob_ensemble > 0.5)))

In [None]:
print(np.sum((joined.new_risk < 0.5) & 
             (joined.best_feat != 'resprate_mean') & 
             (joined.orig_prob_ensemble > 0.5) & 
             (joined.abnormal != joined.best_feat)))

In [None]:
906 / 1425

In [None]:
print(np.sum((joined.label==1)& (joined.orig_prob_ensemble > 0.5))) 
print(np.sum((joined.label==1) & 
             (joined.new_risk < 0.5) & 
             (joined.best_feat != 'resprate_mean') & 
             (joined.orig_prob_ensemble > 0.5)))

In [None]:
554 / 1540

In [None]:
joined[joined.label == 1]['best_feat'].value_counts()

In [None]:
index = 31760
fig1, ax1 = mimic.best_feat_plot(joined, df_cohort, index, title='for Mean Arterial Pressure');

In [None]:
joined.loc[index][['case','timepoint','abnormal','best_feat','orig_prob_ensemble']]

In [None]:
df_series_valid.loc[joined.loc[index]['case'], joined.loc[index]['timepoint']].sort_values()

In [None]:
index = 10406
fig2, ax2 = mimic.best_feat_plot(joined, df_cohort, index, title='for Mean Arterial Pressure');

In [None]:
joined.loc[index][['case','timepoint','abnormal','best_feat','orig_prob_ensemble']]

In [None]:
joined[(joined.case==15396) & 
       (joined.timepoint==11)][['case','timepoint','abnormal','best_feat','orig_prob_ensemble']]

In [None]:
df_cohort[(df_cohort.subject_id==joined.loc[index]['case']) & 
          (df_cohort.timepoint==11)][features]

In [None]:
series_means['sysbp_mean']

In [None]:
df_series_valid.loc[joined.loc[index]['case'], joined.loc[index]['timepoint']].sort_values()

In [None]:
df_time[(df_time.subject_id==joined.loc[index]['case']) & 
          (df_time.timepoint==joined.loc[index]['timepoint'])]

In [None]:
df_cohort = mimic.get_mimic_data()

df_ml = mimic.get_ml_data(df_cohort)

df_time = mimic.get_ml_series_data(df_cohort)

_, x_validation, stsc, _, xst_validation, _, Y_validation = mimic.train_valid_ml(df_ml)

stsc2, series_means, _, df_series_valid, _, _, xt_valid, Yt_valid = mimic.train_valid_series(df_time, Y_validation)

In [None]:
df_series_valid.loc[joined.loc[index]['case'], joined.loc[index]['timepoint']].sort_values()

In [None]:
len(df_cohort)

In [None]:
vitals = ['heartrate_mean', 'sysbp_mean', 'diasbp_mean', 'meanbp_mean',
          'resprate_mean', 'tempc_mean', 'spo2_mean', 'glucose_mean']

labs = ['aniongap', 'albumin', 'bicarbonate', 'bilirubin', 'creatinine', 
        'chloride', 'glucose', 'hemoglobin', 'lactate', 
        'magnesium', 'phosphate', 'platelet', 'potassium', 'ptt', 'inr', 
        'pt', 'sodium', 'bun', 'wbc']  # -hematocrit

comobs = ['congestive_heart_failure', 'chronic_pulmonary', 'pulmonary_circulation']

others = ['age', 'gender']

features = vitals+labs

In [None]:
def get_magecs(p=None, models=('lr', 'mlp', 'svm', 'lstm')):
    prefix = 'magec_'
    postfix = '_0.csv' if p is None else '_p' + str(p) + '.csv'
    magecs = []
    for m in models:
        filename = prefix + m + postfix
        magecs.append(pd.read_csv(filename))
    return magecs

In [None]:
x_magec_cols = list(set(df_series_valid.columns) - {'label'})
x_magec = df_series_valid[x_magec_cols]
y_magec = df_series_valid['label']

In [None]:
magecs = get_magecs()

In [None]:
joined = mg.magec_models(*magecs, Xdata=x_magec, Ydata=y_magec, features=features)

In [None]:
prob_cols = [c for c in joined.columns if c.startswith('perturb')]
joined['orig_prob_ensemble'] = joined[['orig_prob_mlp', 'orig_prob_lr', 
                                       'orig_prob_svm', 'orig_prob_lstm']].apply(np.mean, 1)
joined[['best_feat', 'new_risk']] = joined.apply(lambda x: mimic.best_feature(x, prob_cols), axis=1)

In [None]:
len(joined)

In [None]:
joined.loc[31760]

In [None]:
index = 31760
fig1, ax1 = mimic.best_feat_plot(joined, df_cohort, index, title='for Mean Arterial Pressure');

In [None]:
timepoint = 14
txt = 'Clinical values and ranked MAgECs at time point {}'
fig  = mg.panel_plot(xst_validation.columns, features, stsc2, joined, 
                                                      joined.loc[index].case, timepoint, 
                     models=('lr','svm','mlp','lstm', 'ensemble'), label='label', limit=6, rotate=25, 
                  title=txt.format(timepoint))

In [None]:
index = 10406
fig2, ax2 = mimic.best_feat_plot(joined, df_cohort, index, title='for Mean Arterial Pressure');

In [None]:
from scipy import interpolate
from adjustText import adjust_text

def best_feat_plot(joined, cohort, index, title='', save=False, feat=None):
    data = joined.loc[index]
    case, t_0, label, orig_prob, new_risk = data[['case', 'timepoint', 'label',
                                                  'orig_prob_ensemble', 'new_risk']]
    if feat is None:
        best_feat = joined.loc[index]['best_feat']
    else:
        best_feat = feat

    xy = cohort[cohort['subject_id'] == case][['timepoint', best_feat]].values
    x = [int(x[0]) for x in xy]
    yy = [x[1] for x in xy]

    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 6))

    mimic.plot_feature(ax[0], x, yy, best_feat, case, label, title=title)

    xyzw = joined[joined.case == case][['timepoint', 'orig_prob_ensemble', 'best_feat', 'new_risk']].values
    x = [int(x[0]) for x in xyzw]
    y = [x[1] for x in xyzw]
    z = [x[2] for x in xyzw]
    w = [x[3] for x in xyzw]
    zz = [cohort[(cohort['subject_id'] == case) & (cohort['timepoint'] == x[i])][feat].values[0]
          for i, feat in enumerate(z)]

    plot_risk(ax[1], x, y, z, w, zz, case, label, feat=feat)
    ax[1].invert_xaxis()

    if save:
        plt.savefig('case_{}_series.png'.format(case))

    return fig, ax

def plot_risk(ax, x, y, z, w, yy, case, label, feat=None):
    ax.plot(x, y, 'rx--')
    ax.plot(np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], 10), 0.5 * np.ones(10), '--')
    txt = 'Case {}: Hourly Estimated Ensemble Risk (Outcome: {})'
    ax.set_title(txt.format(case, label))
    ax.set_ylabel('Ensemble Risk')
    ax.set_xlabel('hours to event')
    ax.grid('on')
    # ax.set_ylim([0.2, 0.9])

    texts = []

    for i, txt in enumerate(z):
        if np.isnan(yy[i]):
            continue
        if feat is not None and feat == txt:    
            msg = txt + ' = {:.2f}\n perturb. risk = {:.2g}'.format(yy[i], w[i])
            texts.append(ax.text(x[i], y[i], msg))
        elif feat is None and (w[i] < 0.5 < y[i]):
            msg = txt + ' = {:.2f}\n perturb. risk = {:.2g}'.format(yy[i], w[i])
            texts.append(ax.text(x[i], y[i], msg))

    if feat is None:
        f = interpolate.interp1d(x, y)
        x = np.linspace(min(x), max(x), 140)
        y = f(x)
        adjust_text(texts, x, y, arrowprops=dict(arrowstyle="->", color='b', lw=0.5), autoalign='xy')
    return

In [None]:
index = 10406
fig2, ax2 = best_feat_plot(joined, df_cohort, index, title='for Mean Arterial Pressure', feat='meanbp_mean');