# Predicting emergency department visits anchored on treatment dates
---

## Background
We would like to build a model to predict whether patients are at risk of visiting ER within the next 30 days after receiving chemotherapy.
If we can identify patients who are at high risk, we could apply an intervention (i.e. nursing call, reduced dosages, pain medication, etc) to mitigate the outcome. 

---

In [1]:
%%capture
%cd ../../
%load_ext autoreload
%autoreload 2

In [2]:
import logging

import numpy as np
import pandas as pd
from datetime import datetime

from ml_common.util import load_pickle, save_pickle

from preduce.acu.eval import evaluate_valid, evaluate_test, predict
from preduce.acu.pipeline import PrepACUData
from preduce.acu.train import train_models, tune_params
from preduce.summarize import feature_summary, get_label_distribution
from preduce.util import initialize_folders, compute_threshold

pd.set_option('display.max_rows', 150)
pd.set_option('display.max_columns', 100)

initialize_folders()

logging.basicConfig(
    filename=f"./logs/{datetime.now().strftime('%Y-%m-%d %H.%M.%S')}_ED_target.log",
    level=logging.INFO, 
    format='%(asctime)s %(levelname)s:%(message)s', 
    datefmt='%Y-%m-%d %H:%M:%S'
)

## Load feature data

In [14]:
df = pd.read_parquet('./data/processed/treatment_centered_clinical_dataset.parquet.gzip')
df['assessment_date'] = df['treatment_date']
emerg = pd.read_parquet('./data/interim/emergency_room_visit.parquet.gzip')

## Prepare Data

In [15]:
prep = PrepACUData()
df = prep.preprocess(df, emerg)
X, Y, metainfo = prep.prepare(df)
df = df.loc[X.index]

In [12]:
train_mask, test_mask = metainfo['split'] == 'Train', metainfo['split'] == 'Test'
X_train, X_test = X[train_mask], X[test_mask]
Y_train, Y_test = Y[train_mask], Y[test_mask]
metainfo_train, metainfo_test = metainfo[train_mask], metainfo[test_mask]

In [6]:
# Save the data prep for silent deployment
# So we transform new incoming data using the original data preparer
save_pickle(prep, './result', 'prep_ED_visit_trt_anchored')

# X.to_csv('./data/debug/to_muammar/X.csv', index=False)
# Y.to_csv('./data/debug/to_muammar/Y.csv', index=False)
# metainfo.to_csv('./data/debug/to_muammar/metainfo.csv', index=False)
# df.loc[X.index].to_csv('./data/debug/to_muammar/orig.csv', index=False)

# Describe Data

In [9]:
count = pd.DataFrame({
    'Number of sessions': metainfo.groupby('split').apply(len, include_groups=False), 
    'Number of patients': metainfo.groupby('split')['mrn'].nunique()}
).T
count['Total'] = count.sum(axis=1)
print(f'\n{count.to_string()}')


split               Test  Train  Total
Number of sessions  3600  18406  22006
Number of patients   440   1961   2401


In [10]:
get_label_distribution(Y, metainfo, with_respect_to='sessions')

Unnamed: 0_level_0,Total,Total,Test,Test,Train,Train
ED_visit,False,True,False,True,False,True
ED_visit,19903,2103,3206,394,16697,1709


In [11]:
get_label_distribution(Y, metainfo, with_respect_to='patients')

Unnamed: 0_level_0,Total,Total,Test,Test,Train,Train
Unnamed: 0_level_1,1,0,1,0,1,0
ED_visit,780,1621,155,285,625,1336


In [12]:
# Feature Characteristics
x = prep.ohe.encode(df.loc[X_train.index].copy(), verbose=False) # get original (non-normalized, non-imputed) data one-hot encoded
x = x[[col for col in x.columns if not (col in metainfo.columns or col.startswith('target'))]]
feature_summary(x, save_path='result/tables/feature_summary_ED_trt_anchored.csv').sample(10, random_state=42)

Unnamed: 0,Features,Group,Mean (SD),Missingness (%)
114,Regimen GI-FOLFOX+BEVACIZUMAB,Treatment,0.008 (0.088),0.0
0,Height (cm),Demographic,167.582 (9.657),0.0
24,ESAS Drowsiness Score,Symptoms,2.079 (2.352),14.7
62,Days Since Last Treatment,Treatment,218.101 (944.035),0.0
47,Monocyte (x10e9/L),Laboratory,0.560 (0.339),34.7
15,"Topography ICD-0-3 C24, Other and unspecified ...",Cancer,0.079 (0.270),0.0
129,Regimen GI-IRINO Q3W,Treatment,0.002 (0.045),0.0
84,Potassium Change,Laboratory,-0.008 (0.425),66.2
52,Red Blood Cell (x10e12/L),Laboratory,3.704 (0.570),34.1
5,Female (yes/no),Demographic,0.430 (0.495),0.0


## Train Models

In [12]:
# LGBM does not like non alphanumeric characters (except for _)
for char in ['(', ')', '+', '-', '/', ',']: 
    X_train.columns = X_train.columns.str.replace(char, '_')
    X_test.columns = X_test.columns.str.replace(char, '_')

In [15]:
%%capture
# Hyperparameter tuning
# TODO: try greater kappa for greater exploration
algs = ['LASSO', 'RF', 'Ridge', 'XGB', 'LGBM']
best_params = {}
for alg in algs:
    best_params[alg] = tune_params(alg, X_train, Y_train['ED_visit'], metainfo_train)
save_pickle(best_params, './models', 'best_params_trt_anchored')
save_pickle(best_params, './models', f'best_params_trt_anchored-{datetime.now()}')

In [13]:
best_params = load_pickle('./models', 'best_params_trt_anchored')
models = train_models(X_train, Y_train, metainfo_train, best_params)

## Model Selection
Select final model based on the average performance across the validation folds

In [14]:
evaluate_valid(models, X_train, Y_train, metainfo_train)

Unnamed: 0_level_0,Ridge,Ridge,LASSO,LASSO,XGB,XGB,LGBM,LGBM,RF,RF
Unnamed: 0_level_1,AUPRC,AUROC,AUPRC,AUROC,AUPRC,AUROC,AUPRC,AUROC,AUPRC,AUROC
ED_visit,0.24557,0.765009,0.24949,0.765682,0.308746,0.7931,0.258847,0.775553,0.34843,0.796766


## Evaluate Model

In [16]:
pd.concat([evaluate_test(model, X_test, Y_test) for alg, model in models.items()], keys=models.keys()).T

Unnamed: 0_level_0,Ridge,Ridge,LASSO,LASSO,XGB,XGB,LGBM,LGBM,RF,RF
Unnamed: 0_level_1,AUPRC,AUROC,AUPRC,AUROC,AUPRC,AUROC,AUPRC,AUROC,AUPRC,AUROC
ED_visit,0.198605,0.669618,0.198633,0.668859,0.21095,0.689123,0.18891,0.67496,0.215938,0.704659


In [19]:
# compute threshold that achieves 10% and 20% alarm rate
model = models['RF']
pred = predict(model['ED_visit'], X_test)
res = [compute_threshold(pred, desired_alarm_rate) for desired_alarm_rate in [0.1, 0.2]]
pd.DataFrame(res, columns=['Prediction Threshold', 'Alarm Rate'])

Unnamed: 0,Prediction Threshold,Alarm Rate
0,0.22,0.104722
1,0.138,0.204167


In [21]:
save_pickle(model['ED_visit'], './models', 'RF_ED_visit_trt_anchored')