# Predicting emergency department visits anchored on clinic dates
---
## Background
Before, we built a model to predict emergency department (ED) visits anchored on treatment dates.

The problem with that is the primary physicians do not interact with their patients during their treatment sessions. They only meet during their clinic visits. That is the best time for the model to nudge the physician for an intervention. Thus, we now want to build a model to predict patient's risk of ED visits prior to clinic date instead of prior to treatment session.

---

In [6]:
%%capture
%cd ../../
%load_ext autoreload
%autoreload 2

In [7]:
import logging

import pandas as pd
from datetime import datetime

from ml_common.util import load_pickle, save_pickle

from preduce.acu.eval import evaluate_valid, evaluate_test, predict
from preduce.acu.pipeline import PrepACUData
from preduce.acu.train import train_models, tune_params
from preduce.summarize import feature_summary, get_label_distribution
from preduce.util import compute_threshold

pd.set_option('display.max_rows', 150)
pd.set_option('display.max_columns', 100)

logging.basicConfig(
    level=logging.INFO, 
    format='%(levelname)s:%(message)s', 
)

## Load feature data

In [8]:
df = pd.read_parquet('./data/processed/clinic_centered_feature_dataset.parquet.gzip')
df['assessment_date'] = df['clinic_date']
emerg = pd.read_parquet('./data/interim/emergency_room_visit.parquet.gzip')

## Prepare Data

In [9]:
prep = PrepACUData()
df = prep.preprocess(df, emerg)
X, Y, metainfo = prep.prepare(df, event_name='ED_visit')
df = df.loc[X.index]

INFO:Removing 0 patients and 2555 sessions not first of a given week
INFO:Removing 2858 patients and 13029 sessions before 2012-01-01 and after 2019-12-31
INFO:Removing the following features for drugs given less than 10 times: ['%_ideal_dose_given_DURVALUMAB', '%_ideal_dose_given_IPILIMUMAB', '%_ideal_dose_given_CAPECITABINE', '%_ideal_dose_given_ERLOTINIB']
INFO:Removing 2671 patients and 13461 sessions not from GI department
INFO:Dropping the following 13 features for missingness over 80%: ['bicarbonate_change', 'basophil', 'carbohydrate_antigen_19-9', 'basophil_change', 'prothrombin_time_international_normalized_ratio', 'activated_partial_thromboplastin_time', 'carbohydrate_antigen_19-9_change', 'prothrombin_time_international_normalized_ratio_change', 'activated_partial_thromboplastin_time_change', 'carcinoembryonic_antigen', 'esas_diarrhea', 'esas_vomiting', 'esas_constipation']
INFO:Removing 115 patients and 1387 sessions with at least 80 percent of features missing
INFO:Reassig

In [10]:
train_mask, test_mask = metainfo['split'] == 'Train', metainfo['split'] == 'Test'
X_train, X_test = X[train_mask], X[test_mask]
Y_train, Y_test = Y[train_mask], Y[test_mask]
metainfo_train, metainfo_test = metainfo[train_mask], metainfo[test_mask]

In [11]:
# Save the data prep for silent deployment
# So we transform new incoming data using the original data preparer
save_pickle(prep, './result', 'prep_ED_visit_clinic_anchored')

## Describe Data

In [12]:
count = pd.DataFrame({
    'Number of sessions': metainfo.groupby('split').apply(len, include_groups=False), 
    'Number of patients': metainfo.groupby('split')['mrn'].nunique()}
).T
count['Total'] = count.sum(axis=1)
print(f'\n{count.to_string()}')


split               Test  Train  Total
Number of sessions  1988   8017  10005
Number of patients   424   1804   2228


In [13]:
no_trts_prior = df['treatment_date'].isnull()
pd.concat([
    get_label_distribution(Y[no_trts_prior], metainfo[no_trts_prior], with_respect_to='sessions'),
    get_label_distribution(Y[~no_trts_prior], metainfo[~no_trts_prior], with_respect_to='sessions'),
    get_label_distribution(Y, metainfo, with_respect_to='sessions')
], keys=['First Visit', 'Subsequent Visit', 'All'])

Unnamed: 0_level_0,Unnamed: 1_level_0,Total,Total,Test,Test,Train,Train
Unnamed: 0_level_1,ED_visit,False,True,False,True,False,True
First Visit,ED_visit,1563,294,293,56,1270,238
Subsequent Visit,ED_visit,7397,751,1475,164,5922,587
All,ED_visit,8960,1045,1768,220,7192,825


In [14]:
pd.concat([
    get_label_distribution(Y[no_trts_prior], metainfo[no_trts_prior], with_respect_to='patients'),
    get_label_distribution(Y[~no_trts_prior], metainfo[~no_trts_prior], with_respect_to='patients'),
    get_label_distribution(Y, metainfo, with_respect_to='patients')
], keys=['First Visit', 'Subsequent Visit', 'All'])

Unnamed: 0_level_0,Unnamed: 1_level_0,Total,Total,Test,Test,Train,Train
Unnamed: 0_level_1,Unnamed: 1_level_1,1,0,1,0,1,0
First Visit,ED_visit,278,1167,53,221,225,946
Subsequent Visit,ED_visit,476,1523,102,285,374,1238
All,ED_visit,627,1601,132,292,495,1309


In [15]:
# Feature Characteristics
x = prep.ohe.encode(df.loc[X_train.index].copy(), verbose=False) # get original (non-normalized, non-imputed) data one-hot encoded
x = x[[col for col in x.columns if not (col in metainfo.columns or col.startswith('target'))]]
feature_summary(x, save_path='result/tables/feature_summary_ED_clinic_anchored.csv').sample(10, random_state=42)

Unnamed: 0,Features,Group,Mean (SD),Missingness (%)
20,ESAS Pain Score,Symptoms,2.056 (2.366),9.6
103,Regimen GI-CISPFU + TRAS(MAIN),Treatment,0.003 (0.053),0.0
17,"Topography ICD-0-3 C25, Pancreas",Cancer,0.319 (0.466),0.0
33,Aspartate Aminotransferase (U/L),Laboratory,29.593 (24.500),45.7
84,Potassium Change,Laboratory,-0.008 (0.432),61.9
54,Red Cell Distribution Width (%CV),Laboratory,16.997 (3.382),43.5
53,Red Blood Cell (x10e12/L),Laboratory,3.773 (0.580),43.5
115,Regimen GI-FUFA C2 (GASTRIC),Treatment,0.001 (0.039),0.0
106,Regimen GI-ECX,Treatment,0.051 (0.220),0.0
57,White Blood Cell (x10e9/L),Laboratory,5.952 (3.331),43.5


## Train Models

In [16]:
# LGBM does not like non alphanumeric characters (except for _)
for char in ['(', ')', '+', '-', '/', ',']: 
    X_train.columns = X_train.columns.str.replace(char, '_')
    X_test.columns = X_test.columns.str.replace(char, '_')

In [81]:
%%capture
# Hyperparameter tuning
# TODO: try greater kappa for greater exploration
algs = ['LASSO', 'RF', 'Ridge', 'XGB', 'LGBM']
best_params = {}
for alg in algs:
    best_params[alg] = tune_params(alg, X_train, Y_train['ED_visit'], metainfo_train)
save_pickle(best_params, './models', 'best_params_clinic_anchored')
save_pickle(best_params, './models', f'best_params_clinic_anchored-{datetime.now()}')

In [17]:
best_params = load_pickle('./models', 'best_params_clinic_anchored')
models = train_models(X_train, Y_train, metainfo_train, best_params)

## Model Selection
Select final model based on the average performance across the validation folds

In [18]:
evaluate_valid(models, X_train, Y_train, metainfo_train)

Unnamed: 0_level_0,Ridge,Ridge,LASSO,LASSO,XGB,XGB,LGBM,LGBM,RF,RF
Unnamed: 0_level_1,AUPRC,AUROC,AUPRC,AUROC,AUPRC,AUROC,AUPRC,AUROC,AUPRC,AUROC
ED_visit,0.294047,0.798646,0.28843,0.797692,0.605714,0.936005,0.285457,0.775399,0.369115,0.819199


## Evaluate Model

In [19]:
pd.concat([evaluate_test(model, X_test, Y_test) for alg, model in models.items()], keys=models.keys()).T

Unnamed: 0_level_0,Ridge,Ridge,LASSO,LASSO,XGB,XGB,LGBM,LGBM,RF,RF
Unnamed: 0_level_1,AUPRC,AUROC,AUPRC,AUROC,AUPRC,AUROC,AUPRC,AUROC,AUPRC,AUROC
ED_visit,0.226223,0.694168,0.21776,0.689989,0.199263,0.695302,0.186826,0.675679,0.219374,0.694317


In [20]:
model = models['XGB']
mask = metainfo_test['treatment_date'].isnull()
pd.concat([
    evaluate_test(model, X_test[mask], Y_test[mask]),
    evaluate_test(model, X_test[~mask], Y_test[~mask])
], keys=['First Visit', 'Subsequent Visits']).T

Unnamed: 0_level_0,First Visit,First Visit,Subsequent Visits,Subsequent Visits
Unnamed: 0_level_1,AUPRC,AUROC,AUPRC,AUROC
ED_visit,0.237066,0.627529,0.1872,0.704117


In [21]:
# compute threshold that achieves 10% and 20% alarm rate
pred = predict(model['ED_visit'], X_test)
res = [compute_threshold(pred, desired_alarm_rate) for desired_alarm_rate in [0.1, 0.2]]
pd.DataFrame(res, columns=['Prediction Threshold', 'Alarm Rate'])

Unnamed: 0,Prediction Threshold,Alarm Rate
0,0.202,0.104628
1,0.16,0.204225


In [22]:
save_pickle(model['ED_visit'], './models', 'XGB_ED_visit_clinic_anchored')