In [1]:
%%capture
%cd ../../
%load_ext autoreload
%autoreload 2

In [2]:
from datetime import datetime
import logging

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score, roc_auc_score
from tqdm import tqdm
from xgboost import XGBClassifier
import pandas as pd

from ml_common.constants import SYMP_COLS, SYMP_CHANGE_COLS
from ml_common.filter import (
    drop_highly_missing_features, 
    drop_samples_with_no_targets, 
    drop_unused_drug_features,
    drop_samples_outside_study_date, 
    keep_only_one_per_week
)
from ml_common.engineer import collapse_rare_categories, get_change_since_prev_session, get_missingness_features
from ml_common.summary import get_label_distribution
from ml_common.util import get_nunique_categories, get_nmissing

from preduce.symp.label import convert_to_binary_symptom_labels, get_symptom_labels
from preduce.symp.pipeline import PrepSympData
from preduce.prepare.prep import fill_missing_data
from preduce.summarize import feature_summary, get_patient_characteristics
from preduce.util import initialize_folders

from sklearn.exceptions import ConvergenceWarning
import warnings

pd.set_option('display.max_rows', 150)

warnings.filterwarnings("ignore", category=ConvergenceWarning)

initialize_folders()

logging.basicConfig(
    filename=f"./logs/{datetime.now().strftime('%Y-%m-%d %H.%M.%S')}_symptom_target.log",
    level=logging.INFO, 
    format='%(asctime)s %(levelname)s:%(message)s', 
    datefmt='%Y-%m-%d %H:%M:%S'
)

In [3]:
# Load data
df = pd.read_parquet('data/processed/treatment_centered_clinical_dataset.parquet.gzip')
df['assessment_date'] = df['treatment_date']

In [4]:
# scoring increase thresholds for determining symptom deterioration
target_pt_increases = [1, 3]

# Prep Data

In [5]:
# filter out dates before 2014 and after 2020
df = drop_samples_outside_study_date(df)

# keep only the first treatment session of a given week
df = keep_only_one_per_week(df)

# get the change in measurement since previous assessment
df = get_change_since_prev_session(df)

# extract labels
symp = pd.read_parquet('./data/interim/symptom.parquet.gzip')
df = get_symptom_labels(df, symp)
for pt_increase in target_pt_increases:
    scoring_map = {symp: pt_increase for symp in SYMP_COLS}
    df = convert_to_binary_symptom_labels(df, scoring_map=scoring_map)

# filter out sessions without any labels
target_cols = 'target_' + pd.Index(SYMP_CHANGE_COLS)
df = drop_samples_with_no_targets(df, target_cols)

# drop drug features that were never used
df = drop_unused_drug_features(df)

Getting the first sessions of a given week...: 100%|██████████| 4388/4388 [00:00<00:00, 24466.56it/s]
Getting change since last session...: 100%|██████████| 4388/4388 [00:01<00:00, 3399.30it/s]
Getting symptom labels...: 100%|██████████| 1097/1097 [00:05<00:00, 189.21it/s]
Getting symptom labels...: 100%|██████████| 1097/1097 [00:06<00:00, 172.92it/s]
Getting symptom labels...: 100%|██████████| 1097/1097 [00:06<00:00, 170.36it/s]
Getting symptom labels...: 100%|██████████| 1097/1097 [00:06<00:00, 163.31it/s]


In [6]:
get_nunique_categories(df)

Unnamed: 0,regimen,intent
Number of Unique Categories,107,4


In [7]:
nmissing = get_nmissing(df)
nmissing[~nmissing.index.str.endswith('_date')].tail(20)

Unnamed: 0,Missing (N),Missing (%)
monocyte_change,14311,62.737
albumin_change,14450,63.347
basophil_change,14752,64.671
lactate_dehydrogenase_change,14927,65.438
esas_nausea_change,15220,66.722
alanine_aminotransferase_change,16140,70.755
bicarbonate,17600,77.156
basophil,18230,79.918
esas_appetite_change,19729,86.489
esas_drowsiness_change,20886,91.561


In [6]:
# fill missing data that can be filled heuristically
df = fill_missing_data(df)

# drop features with high missingness
keep_cols = df.columns[df.columns.str.contains('target_')]
df = drop_highly_missing_features(df, missing_thresh=75, keep_cols=keep_cols)

# create missingness features
df = get_missingness_features(df)

# collapse rare morphology and cancer sites into 'Other' category
df = collapse_rare_categories(df, catcols=['cancer_site', 'morphology'])

In [7]:
prep = PrepSympData()
X, Y, metainfo = prep.run_pipeline(df, split_date='2017-09-30', target_pt_increases=target_pt_increases)
df = df.loc[X.index]
# clean up Y
Y = Y[[col for col in Y.columns if col.endswith('pt_change')]]
for substr in ['target_', 'esas_']: Y.columns = Y.columns.str.replace(substr, '')

In [8]:
train_mask, valid_mask, test_mask = metainfo['split'] == 'Train', metainfo['split'] == 'Valid', metainfo['split'] == 'Test'
X_train, X_valid, X_test = X[train_mask], X[valid_mask], X[test_mask]
Y_train, Y_valid, Y_test = Y[train_mask], Y[valid_mask], Y[test_mask]

# Describe Data

In [16]:
count = pd.DataFrame({
    'Number of sessions': metainfo.groupby('split').apply(len, include_groups=False), 
    'Number of patients': metainfo.groupby('split')['mrn'].nunique()}
).T
count['Total'] = count.sum(axis=1)
print(f'\n{count.to_string()}')


split               Test  Train  Valid  Total
Number of sessions  5570  11723   2974  20267
Number of patients   960   1815    454   3229


In [17]:
get_label_distribution(Y, metainfo, with_respect_to='sessions').sort_index()

Unnamed: 0_level_0,Test,Test,Test,Train,Train,Train,Valid,Valid,Valid,Total,Total,Total
Unnamed: 0_level_1,0,1,-1,0,1,-1,0,1,-1,0,1,-1
anxiety_1pt_change,3854,1662,54,8137,3392,194,2089,832,53,14080,5886,301
anxiety_3pt_change,4947,417,206,10281,919,523,2628,221,125,17856,1557,854
appetite_1pt_change,3663,1817,90,7464,4066,193,1938,988,48,13065,6871,331
appetite_3pt_change,4525,752,293,9531,1580,612,2431,389,154,16487,2721,1059
depression_1pt_change,4009,1519,42,8283,3270,170,2100,839,35,14392,5628,247
depression_3pt_change,5032,391,147,10438,889,396,2673,224,77,18143,1504,620
drowsiness_1pt_change,3438,2099,33,7198,4374,151,1855,1068,51,12491,7541,235
drowsiness_3pt_change,4599,756,215,9714,1559,450,2470,389,115,16783,2704,780
nausea_1pt_change,4031,1496,43,8202,3367,154,2113,839,22,14346,5702,219
nausea_3pt_change,4974,502,94,10305,1161,257,2654,276,44,17933,1939,395


In [18]:
get_label_distribution(Y, metainfo, with_respect_to='patients').sort_index()

Unnamed: 0_level_0,Test,Test,Train,Train,Valid,Valid,Total,Total
Unnamed: 0_level_1,1,0,1,0,1,0,1,0
anxiety_1pt_change,611,349,1187,628,300,154,2098,1131
anxiety_3pt_change,234,726,510,1305,125,329,869,2360
appetite_1pt_change,636,324,1333,482,328,126,2297,932
appetite_3pt_change,386,574,777,1038,187,267,1350,1879
depression_1pt_change,569,391,1160,655,291,163,2020,1209
depression_3pt_change,226,734,497,1318,126,328,849,2380
drowsiness_1pt_change,723,237,1408,407,349,105,2480,749
drowsiness_3pt_change,403,557,785,1030,200,254,1388,1841
nausea_1pt_change,554,406,1173,642,283,171,2010,1219
nausea_3pt_change,276,684,609,1206,141,313,1026,2203


In [19]:
# Feature Characteristics
x = prep.ohe.encode(df.loc[X_train.index].copy(), verbose=False) # get original (non-normalized, non-imputed) data one-hot encoded
x = x[[col for col in x.columns if not (col in metainfo.columns or col.startswith('target'))]]
feature_summary(x, save_path='result/tables/feature_summary.csv').sample(10, random_state=42)

Unnamed: 0,Features,Group,Mean (SD),Missingness (%)
141,Phosphate Change,Laboratory,0.059 (1.543),17.2
211,Regimen TRIAL,Treatment,0.084 (0.277),0.0
217,Intent of Systemic Treatment OTHER,Treatment,0.000 (0.000),0.0
175,Regimen GI-PANITUMUMAB,Treatment,0.011 (0.104),0.0
48,"Morphology ICD-0-3 848, Cystic, mucinous, and ...",Cancer,0.041 (0.198),0.0
105,Percentage of Ideal Dose Given RAMUCIRUMAB,Treatment,1.022 (0.038),0.0
167,Regimen GI-GEM 7-WEEKLY,Treatment,0.004 (0.066),0.0
72,Eosinophil (x10e9/L),Laboratory,0.175 (0.215),55.7
204,Regimen LU-PACLICARBO,Treatment,0.004 (0.060),0.0
215,Intent of Systemic Treatment Palliative,Treatment,0.732 (0.443),0.0


In [14]:
# Patient Characteristics
cancer_cols = [col for col in df.columns if col.startswith('cancer_site')]
top_cancers = df.loc[train_mask, cancer_cols].sum().sort_values(ascending=False)[:5].index
top_regimens = df.loc[train_mask, 'regimen'].value_counts()[:5].index
result = dict()
result[('Dev', 'Treatments')] = get_patient_characteristics(df[train_mask | valid_mask], top_regimens, top_cancers)
result[('Dev', 'Patients')] = get_patient_characteristics(df[train_mask | valid_mask].groupby('mrn').last(), top_regimens, top_cancers)
result[('Test', 'Treatments')] = get_patient_characteristics(df[test_mask], top_regimens, top_cancers)
result[('Test', 'Patients')] = get_patient_characteristics(df[test_mask].groupby('mrn').last(), top_regimens, top_cancers)
pd.DataFrame(result)

Unnamed: 0_level_0,Dev,Dev,Test,Test
Unnamed: 0_level_1,Treatments,Patients,Treatments,Patients
"Number of Treatments, Median (IQR)",4 (2-8),1 (1-1),4 (2-7),1 (1-1)
"Age (years), Median (IQR)",64 (56-70),64 (56-70),64 (57-71),65 (57-71)
"Height (cm), Median (IQR)",168.0 (161.0-175.0),168.0 (161.0-175.0),169.0 (162.0-177.0),169.0 (162.0-176.0)
"Weight (kg), Median (IQR)",69.3 (59.1-81.6),69.2 (58.8-81.5),71.6 (60.7-83.4),71.0 (60.0-82.9)
"Female, No. (%)",6304 (42.9),948 (41.8),2183 (39.2),372 (38.8)
"Regimen GI-FOLFIRI+BEVACIZUMAB, No. (%)",1607 (10.9),124 (5.5),198 (3.6),11 (1.1)
"Regimen GI-FOLFIRINOX, No. (%)",1239 (8.4),138 (6.1),41 (0.7),5 (0.5)
"Regimen TRIAL, No. (%)",1306 (8.9),167 (7.4),865 (15.5),102 (10.6)
"Regimen GI-GEM D1,8,15, No. (%)",982 (6.7),124 (5.5),163 (2.9),18 (1.9)
"Regimen GI-FOLFOX-6 MOD, No. (%)",954 (6.5),152 (6.7),256 (4.6),49 (5.1)


# Train Model - Quick and Dirty

In [22]:
targets = Y.columns
LR_params = {'C': 0.3, 'penalty': 'l2', 'class_weight': 'balanced', 'max_iter': 2000, 'random_state': 42} # 'solver': 'saga', 
LR_model = {target: LogisticRegression(**LR_params) for target in targets}
XGB_params = dict(n_estimators=100, max_depth=6, learning_rate=0.01, min_child_weight=6, random_state=42)
XGB_model = {target: XGBClassifier(**XGB_params) for target in targets}
for target in tqdm(targets):
    mask = Y_train[target] != -1
    LR_model[target].fit(X_train[mask], Y_train.loc[mask, target])
    XGB_model[target].fit(X_train[mask], Y_train.loc[mask, target])

100%|██████████| 18/18 [02:05<00:00,  6.98s/it]


In [23]:
def evaluate(model, X, Y):
    result = {}
    for target, label in Y.items():
        mask = label != -1
        # check model.classes_ to confirm prediction of positive label is at index 1
        pred = model[target].predict_proba(X[mask])[: ,1]
        auprc = average_precision_score(label[mask], pred)
        auroc = roc_auc_score(label[mask], pred)
        result[target] = {'AUPRC': auprc, 'AUROC': auroc}
    return pd.DataFrame(result).T

In [20]:
evaluate(LR_model, X_valid, Y_valid)

Unnamed: 0,AUPRC,AUROC
pain_1pt_change,0.475862,0.655851
tiredness_1pt_change,0.627286,0.678208
nausea_1pt_change,0.453932,0.671093
depression_1pt_change,0.401786,0.651748
anxiety_1pt_change,0.388799,0.617675
drowsiness_1pt_change,0.561692,0.685365
appetite_1pt_change,0.563517,0.700907
well_being_1pt_change,0.563339,0.680049
shortness_of_breath_1pt_change,0.408388,0.653472
pain_3pt_change,0.292664,0.730638


In [21]:
evaluate(LR_model, X_test, Y_test)

Unnamed: 0,AUPRC,AUROC
pain_1pt_change,0.464188,0.637326
tiredness_1pt_change,0.611927,0.673078
nausea_1pt_change,0.443644,0.671883
depression_1pt_change,0.384616,0.626949
anxiety_1pt_change,0.401322,0.610335
drowsiness_1pt_change,0.564664,0.675702
appetite_1pt_change,0.517947,0.66364
well_being_1pt_change,0.570521,0.665429
shortness_of_breath_1pt_change,0.368694,0.600363
pain_3pt_change,0.226392,0.675176


In [22]:
evaluate(XGB_model, X_valid, Y_valid)

Unnamed: 0,AUPRC,AUROC
pain_1pt_change,0.497675,0.683565
tiredness_1pt_change,0.62082,0.689137
nausea_1pt_change,0.498506,0.706147
depression_1pt_change,0.475036,0.707769
anxiety_1pt_change,0.434293,0.67669
drowsiness_1pt_change,0.558526,0.703425
appetite_1pt_change,0.584689,0.728845
well_being_1pt_change,0.550627,0.692945
shortness_of_breath_1pt_change,0.410443,0.650389
pain_3pt_change,0.299741,0.751617


In [19]:
evaluate(XGB_model, X_test, Y_test)

Unnamed: 0,AUPRC,AUROC
pain_1pt_change,0.479161,0.662992
tiredness_1pt_change,0.604046,0.666222
nausea_1pt_change,0.445009,0.675159
depression_1pt_change,0.399748,0.652901
anxiety_1pt_change,0.430365,0.644815
drowsiness_1pt_change,0.5694,0.683269
appetite_1pt_change,0.516838,0.676955
well_being_1pt_change,0.554816,0.666361
shortness_of_breath_1pt_change,0.405333,0.637055
pain_3pt_change,0.244697,0.700301
