In [None]:
import pickle
import pandas as pd
import numpy as np
import shap
from copy import deepcopy
import matplotlib.pyplot as plt
from sklearn.metrics import balanced_accuracy_score, roc_auc_score, make_scorer, confusion_matrix, plot_confusion_matrix
from sklearn.metrics import classification_report, auc, roc_curve, precision_recall_curve

In [None]:
# load the model and get the data for the app

fpath = "./models/xgb_classifier_train_test_without_specialty_5y.pkl"
with open(fpath,"rb") as open_file:
    vars = pickle.load(open_file)
# x_train and y_train should just be x and y
classify_xgb,X_train,y_train,X_train_ids,train_ids = vars

y_train_pred = classify_xgb.predict(X_train)
y_train_pred_prob = classify_xgb.predict_proba(X_train)

# Figure 2

In [None]:
explainer = shap.TreeExplainer(classify_xgb)
# explainer = shap.TreeExplainer(classify_xgb, feature_perturbation='tree_path_dependent', model_output="raw")
shap_values = explainer(X_train)
shap_ival = explainer.shap_interaction_values(X_train)

In [None]:
feature_name_dict = {
    "tenure":"Tenure",
    "age_group": "Age group",
    "EWA_avg_risk_avg":"Exp. weighted panel complexity",
    "EWA_avg_note_quality_manual_value": "Exp. weighted note quality",
    "EWA_avg_note_quality_contribution_value":"Exp. weighted note quality contribution",
    "note_quality_manual_value": "Note quality",
    "panel_cnt": "Panel count",
    "risk_avg": "Panel complexity",
    "EWA_avg_teamwork_on_inbox_value":"Exp. weighted teamwork on inbox - value",
    "r_slope_panel_cnt":'Roll. slope panel count',
    "teamwork_on_inbox_value": "Teamwork on inbox - value",
    "gender": "Gender",
    "calendar_month": "Calendar month",
    "covid_wave": "Covid wave",
    "patient_volume": "Patient volume",
    "physician_demand": "Physician demand",
    'EWA_avg_order_time_8':'Exp. weighted order time',
    'EWA_avg_wow_time_8':'Exp. weighted work outside of work time',
    'EWA_avg_physician_demand':'Exp. weighted physician demand',
    'EWA_avg_ib_time_8':'Exp. weighted inbox time',
    'EWA_avg_note_time_8':'Exp. weighted note time',
    'EWA_avg_ehr_time_8':'Exp. weighted EHR time',
    'r_slope_wow_time_8':'Roll. slope work outside of work time',
    'EWA_avg_patient_volume': 'Exp. weighted patient volume',
    'physician_demand': 'physician demand',
    
    
    
}

In [None]:
plt.rcParams["figure.figsize"] = (20,20)

In [None]:
# all features summarized - figure 2
dc_shap_obj = deepcopy(shap_values)
dc_shap_obj.feature_names = [feature_name_dict[x] if x in feature_name_dict else x for x in dc_shap_obj.feature_names]
shap.plots.beeswarm(dc_shap_obj, max_display=10, plot_size=(20,20), order=shap_values.abs.mean(0))


In [None]:
dc_shap_obj = deepcopy(shap_values)
dc_shap_obj.feature_names = [feature_name_dict[x] if x in feature_name_dict else x for x in dc_shap_obj.feature_names]

In [None]:
# Top 4 features explored - figure 3
ax = plt.gca()
shap.plots.scatter(dc_shap_obj[:,feature_name_dict['tenure']], color = dc_shap_obj, x_jitter=0.5, ax=ax, show=False) # male is 0
plt.show()
ax = plt.gca()
shap.plots.scatter(dc_shap_obj[:,feature_name_dict['EWA_avg_risk_avg']], color = dc_shap_obj, ax=ax, show=False) # male is 0
plt.show()
ax = plt.gca()
shap.plots.scatter(dc_shap_obj[:,feature_name_dict['age_group']], color = dc_shap_obj, x_jitter=0.25, ax=ax, show=False) # male is 0
plt.xticks(ticks=[0,1,2,3,4], labels=['25-34', '35-44', '45-54', '55-64', '65+'])
plt.show()
ax = plt.gca()
shap.plots.scatter(dc_shap_obj[:,feature_name_dict['EWA_avg_physician_demand']], color = dc_shap_obj, ax=ax, show=False) # male is 0
plt.show()

In [None]:
X_train.columns = [feature_name_dict[x] if x in feature_name_dict else x for x in X_train.columns]

In [None]:
# Interactions with top feature. Potentially new figure?
# fig 4, need to justify what interactions we show
# is there a good way to rank them in the shap documentation?
# here i've just done some key EHR use metrics because that makes sense

# shap.plots.scatter(shap_values[:,'tenure']) # male is 0
ax = plt.gca()
shap.dependence_plot((feature_name_dict["tenure"], feature_name_dict["EWA_avg_ehr_time_8"]),shap_ival, X_train,x_jitter = 0.5, ax=ax)
plt.show()
ax = plt.gca()
shap.dependence_plot((feature_name_dict["tenure"], feature_name_dict["EWA_avg_ib_time_8"]),shap_ival, X_train,x_jitter = 0.5, ax=ax)
plt.show()
ax = plt.gca()
shap.dependence_plot((feature_name_dict["tenure"], feature_name_dict["EWA_avg_order_time_8"]),shap_ival, X_train,x_jitter = 0.5, ax=ax)
plt.show()
ax = plt.gca()
shap.dependence_plot((feature_name_dict["tenure"], feature_name_dict["EWA_avg_note_time_8"]),shap_ival, X_train,x_jitter = 0.5, ax=ax)
plt.show()

In [None]:
# tricky bit, going to remove some special words from the columns

In [None]:
'Teamwork on inbox - value'.replace(' - value', '')

In [None]:
X_train.columns = [feature_name_dict[x].replace(' - value', '') if x in feature_name_dict else x.replace(' - value', '') for x in X_train.columns]

In [None]:
def compile_physician_data(X,y,y_pred,y_prob,X_ids,ids):
    P = X.copy()
    P['id'] = X_ids
    P['prob'] = y_prob[:,1]
    P['pred'] = y_pred
    P['depart'] = y
    P['phys_depart'] = P.id.isin(ids[ids['depart']]['id'])
    # P['month_sync'] = P.groupby('id')['study_day'].transform(lambda x: round((x-max(x))/30))
    P['month_sync'] = P.groupby('id').cumcount()
    P['prob_rm'] = P.groupby('id')['prob'].rolling(3).mean().to_list()
    return(P)



In [None]:
P_train = compile_physician_data(X_train,y_train,y_train_pred,y_train_pred_prob,X_train_ids,train_ids)


In [None]:
shap_values_2 = explainer.shap_values(X_train)
shap_ival2 = explainer.shap_interaction_values(X_train)

In [None]:
shap_ival_flat = []
for i in range(len(shap_ival2)):
    shap_ival_flat.append(shap_ival2[i][np.triu_indices(76, k=0)])
shap_ival_flat = np.array(shap_ival_flat)
triu_cols = X_train.columns[np.triu_indices(76, k=0)[0]] + '<>' + X_train.columns[np.triu_indices(76, k=0)[1]]
shap_ival_flat

In [None]:
## subset the shap data set to only physicians for whom we have predictions of both stay and leave

# create a id'd shap matrix, include prediction for later grouping
shap_values_pd = pd.DataFrame(shap_ival_flat, columns = triu_cols)
shap_values_pd['id'] = X_train_ids.values
shap_values_pd['pred'] = y_train_pred

# find ids with both depart and non-depart
depart_ids = pd.unique(X_train_ids[y_train_pred])
stay_ids = pd.unique(X_train_ids[~y_train_pred])

select_ids = set(depart_ids).intersection(set(stay_ids))
select_ids = pd.Series(list(select_ids))

# # do a check - all looks good
# select_rows = X_train_ids.isin(select_ids)
# phys_check = pd.DataFrame({
#     'phys': X_train_ids[select_rows],
#     'pred': y_train_pred[select_rows]
# })

# subset the shap matrix
shap_values_pd_sub = shap_values_pd.loc[shap_values_pd['id'].isin(select_ids)]

In [None]:
## find the mean per physician of shap scores for quit and nonquit months

shap_paired = shap_values_pd_sub.groupby(['id','pred']).mean()
shap_diff = shap_paired.groupby('id').diff()
shap_diff = shap_diff.groupby('id').nth(1)

In [None]:
# sort columns by magnitude and plot
shap_diff_np = shap_diff.to_numpy()
#feat_vals = np.mean(np.abs(shap_diff_np),axis=0)
feat_vals = np.mean(shap_diff_np,axis=0)
feat_sort = np.argsort(feat_vals)

# sort the overall means
feat_vals_sort = feat_vals[feat_sort]

# sort the individual physicians
shap_diff_np_sort = shap_diff_np[:,feat_sort]


In [None]:
# individual contribution version

num_feats = 15
y_vec = np.arange(num_feats)
x_vec = feat_vals_sort
plt.barh(y_vec,x_vec[-num_feats:],color='#4286DE',lw=2)
plt.yticks(y_vec,triu_cols[feat_sort][-num_feats:])
plt.rcParams['figure.figsize'] = [20, 20]
plt.xlabel('SHAP value change')
