In [1]:
%%capture
%cd ../
%load_ext autoreload
%autoreload 2

In [2]:
from tqdm import tqdm
import pandas as pd
pd.set_option('display.max_rows', 150)

from src import logger
from src.constants import symp_cols
from src.label import convert_to_binary_symptom_labels, get_symptom_labels, get_label_distribution
from src.prepare.filter import (
    drop_highly_missing_features, 
    drop_samples_outside_study_date, 
    drop_samples_with_no_targets,
    drop_unused_drug_features
)
from src.prepare.engineer import collapse_rare_categories, get_change_since_prev_session, get_missingness_features
from src.prepare.pipeline import symptom_prep_pipeline
from src.prepare.prep import PrepData, fill_missing_data
from src.summarize import feature_summary
from src.util import get_nunique_categories, get_nmissing

In [3]:
df = pd.read_parquet('data/treatment_centered_clinical_dataset.parquet.gzip')
"""
Note to self: why is morphology and cancer site one-hot encoded but regimen is not?
Because patients can have multiple diagnoses at different dates. 
See make-clinical-dataset/preprocess/cancer_registry for more info.
"""
df.columns.tolist()

['mrn',
 'treatment_date',
 'regimen',
 'height',
 'weight',
 'body_surface_area',
 'cycle_number',
 'first_treatment_date',
 'intent',
 'date_of_birth',
 'female',
 'cancer_site_C00',
 'cancer_site_C01',
 'cancer_site_C02',
 'cancer_site_C03',
 'cancer_site_C04',
 'cancer_site_C05',
 'cancer_site_C06',
 'cancer_site_C07',
 'cancer_site_C08',
 'cancer_site_C09',
 'cancer_site_C10',
 'cancer_site_C11',
 'cancer_site_C12',
 'cancer_site_C13',
 'cancer_site_C14',
 'cancer_site_C15',
 'cancer_site_C16',
 'cancer_site_C17',
 'cancer_site_C18',
 'cancer_site_C19',
 'cancer_site_C20',
 'cancer_site_C21',
 'cancer_site_C22',
 'cancer_site_C23',
 'cancer_site_C24',
 'cancer_site_C25',
 'cancer_site_C26',
 'cancer_site_C30',
 'cancer_site_C31',
 'cancer_site_C32',
 'cancer_site_C34',
 'cancer_site_C37',
 'cancer_site_C38',
 'cancer_site_C48',
 'cancer_site_C62',
 'cancer_site_C76',
 'morphology_800',
 'morphology_801',
 'morphology_802',
 'morphology_803',
 'morphology_804',
 'morphology_805',
 

In [4]:
# scoring increase thresholds for determining symptom deterioration
target_pt_increases = [1, 3]

# Prep Data - Part 1

In [5]:
# get the change in measurement since previous assessment
df = get_change_since_prev_session(df)

100%|██████████| 9297/9297 [00:11<00:00, 836.78it/s]


In [6]:
# extract labels
symp = pd.read_parquet('./data/external/symptom.parquet.gzip')
df = get_symptom_labels(df, symp)
for pt_increase in target_pt_increases:
    scoring_map = {symp: pt_increase for symp in symp_cols if symp != 'patient_ecog'}
    df = convert_to_binary_symptom_labels(df, scoring_map=scoring_map)

In [7]:
# filter out sessions without any labels
target_cols = 'target_' + pd.Index(symp_cols) + '_change'
df = drop_samples_with_no_targets(df, target_cols)

11:59:01 INFO:Removing 5069 patients and 76530 sessions with no targets


In [8]:
# filter out dates before 2014 and after 2020
df = drop_samples_outside_study_date(df)

11:59:01 INFO:Removing 997 patients and 8147 sessions before 2014-01-01 and after 2019-12-31


In [9]:
# drop drug features that were never used
df = drop_unused_drug_features(df)

11:59:01 INFO:Removing the following features for drugs given less than 10 times: ['%_ideal_dose_given_DURVALUMAB', '%_ideal_dose_given_RALTITREXED', '%_ideal_dose_given_IPILIMUMAB', '%_ideal_dose_given_CAPECITABINE', '%_ideal_dose_given_ERLOTINIB']


# Describe Data - Part 1

In [10]:
get_nunique_categories(df)

Unnamed: 0,regimen,intent
Number of Unique Categories,107,4


In [11]:
nmissing = get_nmissing(df)
nmissing[~nmissing.index.str.endswith('_date')]

Unnamed: 0,Missing (N),Missing (%)
esas_pain,29,0.109
esas_tiredness,41,0.155
target_esas_pain,49,0.185
esas_drowsiness,52,0.196
target_esas_tiredness,53,0.2
esas_appetite,54,0.204
esas_depression,57,0.215
esas_anxiety,58,0.219
esas_shortness_of_breath,59,0.222
target_esas_pain_change,76,0.286


# Prep Data - Part 2

In [10]:
# fill missing data that can be filled heuristically
df = fill_missing_data(df)

# drop features with high missingness
keep_cols = df.columns[df.columns.str.contains('target_')]
df = drop_highly_missing_features(df, missing_thresh=75, keep_cols=keep_cols)

# create missingness features
df = get_missingness_features(df)

# collapse rare morphology and cancer sites into 'Other' category
df = collapse_rare_categories(df, catcols=['cancer_site', 'morphology'])

11:59:01 INFO:Dropping the following 11 features for missingness over 75%: ['bicarbonate', 'basophil', 'bicarbonate_change', 'basophil_change', 'carbohydrate_antigen_19-9', 'prothrombin_time_international_normalized_ratio', 'activated_partial_thromboplastin_time', 'carcinoembryonic_antigen', 'esas_constipation', 'esas_vomiting', 'esas_diarrhea']
11:59:01 INFO:Reassigning the following 6 indicators with less than 6 patients as other: ['cancer_site_C00', 'cancer_site_C14', 'cancer_site_C26', 'cancer_site_C48', 'cancer_site_C62', 'cancer_site_C76']
11:59:02 INFO:Reassigning the following 63 indicators with less than 6 patients as other: ['morphology_800', 'morphology_803', 'morphology_805', 'morphology_809', 'morphology_812', 'morphology_815', 'morphology_818', 'morphology_820', 'morphology_822', 'morphology_829', 'morphology_831', 'morphology_832', 'morphology_833', 'morphology_836', 'morphology_840', 'morphology_843', 'morphology_844', 'morphology_845', 'morphology_847', 'morphology_851

In [11]:
X, Y, metainfo = symptom_prep_pipeline(df, split_date='2017-10-01', target_pt_increases=target_pt_increases)
# clean up Y
Y = Y[[col for col in Y.columns if col.endswith('pt_change')]]
for substr in ['target_', 'esas_']: Y.columns = Y.columns.str.replace(substr, '')

11:59:02 INFO:Development Cohort: NSessions=20058. NPatients=2271. Contains all patients whose first visit was on or before 2017-10-01
11:59:02 INFO:Test Cohort: NSessions=6477. NPatients=960. Contains all patients whose first visit was after 2017-10-01
11:59:02 INFO:About 43-84 sessions had a target event (e.g. target_esas_pain_1pt_change) in less than 2 days.
11:59:02 INFO:About 12-20 sessions had a target event (e.g. target_esas_pain_1pt_change) in less than 2 days.
11:59:02 INFO:About 8-27 sessions had a target event (e.g. target_esas_pain_3pt_change) in less than 2 days.
11:59:02 INFO:About 3-10 sessions had a target event (e.g. target_esas_pain_3pt_change) in less than 2 days.
11:59:02 INFO:One-hot encoding training data
11:59:02 INFO:Separated and dropped 0 treatment set indicator columns, and added 0 new treatment indicator columns
11:59:02 INFO:One-hot encoding validation data


Reassigning the following indicators with less than 6 patients as other: ['regimen_GI-CISPFU + TRAS(LOAD)', 'regimen_GI-CISPFU + TRAS(MAIN)', 'regimen_GI-DOCEQ3W', 'regimen_GI-DOXO', 'regimen_GI-EOX', 'regimen_GI-FOLFNALIRI', 'regimen_GI-FOLFNALIRI (COMP)', 'regimen_GI-FUFA WEEKLY', 'regimen_GI-GEM D1,8 + CAPECIT', 'regimen_GI-GEMCAP', 'regimen_GI-GEMFU (BILIARY)', 'regimen_GI-IRINO Q3W', 'regimen_GI-PACLI WEEKLY', 'regimen_GI-PACLITAXEL', 'regimen_HN-DOCETAXEL WEEKLY', 'regimen_HN-ETOPCISP 3 DAY', 'regimen_HN-GEM/CIS + APREP', 'regimen_HN-NIVOLUMAB', 'regimen_LU-DOCECARBO', 'regimen_LU-DURVALUMAB (COMP)', 'regimen_LU-ETOPCARBO-NO RT', 'regimen_LU-GEM D1,8,15', 'regimen_LU-GEMCISP +APREPITANT', 'regimen_LU-IRINOCARBO NO RT', 'regimen_LU-PACLI/CARBO WEEKX6', 'regimen_LU-RALTICARBO', 'regimen_LU-RALTICISP', 'regimen_LU-TOPOTECAN', 'regimen_LU-VINO D1,8']


11:59:02 INFO:Separated and dropped 0 treatment set indicator columns, and added 0 new treatment indicator columns
11:59:02 INFO:Reassigning the following regimen indicator columns that did not exist in train set as other:
regimen_GI-CISPFU + TRAS(LOAD)     3
regimen_GI-CISPFU + TRAS(MAIN)    12
regimen_GI-DOXO                    2
regimen_GI-FLOT (GASTRIC)          3
regimen_GI-FOLFNALIRI (COMP)      14
regimen_GI-FUFA WEEKLY             5
regimen_GI-FUFA-5 DAYS            10
regimen_GI-GEM D1,8 + CAPECIT      2
regimen_GI-PACLI WEEKLY            2
regimen_HN-DOCE/CISP Q3W           4
regimen_HN-DOCETAXEL WEEKLY        3
regimen_HN-ETOPCISP 3 DAY          9
regimen_HN-GEM/CIS + APREP        15
regimen_HN-NIVO Q4WEEKS (CCO)      2
regimen_HN-NIVOLUMAB               8
regimen_LU-DOCECISP                2
regimen_LU-PACLI/CARBO WEEKX5      5
regimen_LU-VINO D1,8               9
dtype: int64
11:59:02 INFO:One-hot encoding testing data
11:59:02 INFO:Separated and dropped 1 treatment set in

In [12]:
train_mask, valid_mask, test_mask = metainfo['split'] == 'Train', metainfo['split'] == 'Valid', metainfo['split'] == 'Test'
X_train, X_valid, X_test = X[train_mask], X[valid_mask], X[test_mask]
Y_train, Y_valid, Y_test = Y[train_mask], Y[valid_mask], Y[test_mask]

# Describe Data - Part 2

In [15]:
count = pd.DataFrame({
    'Number of sessions': metainfo.groupby('split').apply(len), 
    'Number of patients': metainfo.groupby('split')['mrn'].nunique()}
).T
count['Total'] = count.sum(axis=1)
logger.info(f'\n{count.to_string()}')

08:51:12 INFO:
split               Test  Train  Valid  Total
Number of sessions  6477  16145   3913  26535
Number of patients   960   1816    455   3231


In [16]:
# UNIT TESTING
assert not X.isnull().any().any()

In [17]:
get_label_distribution(Y, metainfo, with_respect_to='sessions').sort_index()

  dists = {split: group.apply(pd.value_counts)


Unnamed: 0_level_0,Test,Test,Test,Train,Train,Train,Valid,Valid,Valid,Total,Total,Total
Unnamed: 0_level_1,0,1,-1,0,1,-1,0,1,-1,0,1,-1
anxiety_1pt_change,4420,1999,58,11259,4643,243,2678,1167,68,18357,7809,369
anxiety_3pt_change,5686,539,252,14145,1298,702,3469,289,155,23300,2126,1109
appetite_1pt_change,4092,2286,99,10212,5681,252,2462,1386,65,16766,9353,416
appetite_3pt_change,5162,982,333,13138,2276,731,3133,580,200,21433,3838,1264
depression_1pt_change,4550,1881,46,11372,4568,205,2750,1107,56,18672,7556,307
depression_3pt_change,5809,495,173,14393,1262,490,3519,280,114,23721,2037,777
drowsiness_1pt_change,3822,2620,35,9782,6121,242,2393,1479,41,15997,10220,318
drowsiness_3pt_change,5223,996,258,13258,2233,654,3251,548,114,21732,3777,1026
nausea_1pt_change,4539,1891,47,11117,4803,225,2716,1165,32,18372,7859,304
nausea_3pt_change,5680,688,109,14099,1712,334,3447,413,53,23226,2813,496


In [18]:
get_label_distribution(Y, metainfo, with_respect_to='patients').sort_index()

Unnamed: 0_level_0,Test,Test,Train,Train,Valid,Valid,Total,Total
Unnamed: 0_level_1,1,0,1,0,1,0,1,0
anxiety_1pt_change,615,345,1230,586,314,141,2159,1072
anxiety_3pt_change,235,725,530,1286,142,313,907,2324
appetite_1pt_change,639,321,1379,437,337,118,2355,876
appetite_3pt_change,388,572,816,1000,207,248,1411,1820
depression_1pt_change,571,389,1197,619,302,153,2070,1161
depression_3pt_change,227,733,522,1294,134,321,883,2348
drowsiness_1pt_change,726,234,1440,376,361,94,2527,704
drowsiness_3pt_change,405,555,830,986,211,244,1446,1785
nausea_1pt_change,556,404,1201,615,301,154,2058,1173
nausea_3pt_change,278,682,636,1180,147,308,1061,2170


In [54]:
# Feature Characteristics
prep = PrepData()
x = prep.ohe.encode(df.loc[X_train.index].copy(), verbose=False) # get original (non-normalized, non-imputed) data one-hot encoded
x = x[[col for col in x.columns if not (col in metainfo.columns or col.startswith('target'))]]
feature_summary(x, save_path='result/tables/feature_summary.csv').head(100)

Reassigning the following indicators with less than 6 patients as other: ['regimen_GI-CISPFU + TRAS(LOAD)', 'regimen_GI-CISPFU + TRAS(MAIN)', 'regimen_GI-DOCEQ3W', 'regimen_GI-DOXO', 'regimen_GI-EOX', 'regimen_GI-FOLFNALIRI', 'regimen_GI-FOLFNALIRI (COMP)', 'regimen_GI-FUFA WEEKLY', 'regimen_GI-GEM D1,8 + CAPECIT', 'regimen_GI-GEMCAP', 'regimen_GI-GEMFU (BILIARY)', 'regimen_GI-IRINO Q3W', 'regimen_GI-PACLI WEEKLY', 'regimen_GI-PACLITAXEL', 'regimen_HN-DOCETAXEL WEEKLY', 'regimen_HN-ETOPCISP 3 DAY', 'regimen_HN-GEM/CIS + APREP', 'regimen_HN-NIVOLUMAB', 'regimen_LU-DOCECARBO', 'regimen_LU-DURVALUMAB (COMP)', 'regimen_LU-ETOPCARBO-NO RT', 'regimen_LU-GEM D1,8,15', 'regimen_LU-GEMCISP +APREPITANT', 'regimen_LU-IRINOCARBO NO RT', 'regimen_LU-PACLI/CARBO WEEKX6', 'regimen_LU-RALTICARBO', 'regimen_LU-RALTICISP', 'regimen_LU-TOPOTECAN', 'regimen_LU-VINO D1,8']


Unnamed: 0,Features,Group,Mean (SD),Missingness (%)
94,Days Since Previous ED Visit,Acute care use,1158.937 (800.847),0.0
93,Number of Prior ED Visits Within 5 Years,Acute care use,1.097 (2.187),0.0
35,"Morphology ICD-0-3 801, Epithelial neoplasms, NOS",Cancer,0.026 (0.159),0.0
36,"Morphology ICD-0-3 802, Epithelial neoplasms, NOS",Cancer,0.007 (0.081),0.0
37,"Morphology ICD-0-3 804, Epithelial neoplasms, NOS",Cancer,0.059 (0.236),0.0
38,"Morphology ICD-0-3 807, Squamous cell neoplasms",Cancer,0.151 (0.358),0.0
39,"Morphology ICD-0-3 808, Squamous cell neoplasms",Cancer,0.006 (0.076),0.0
40,"Morphology ICD-0-3 814, Adenomas and adenocarc...",Cancer,0.546 (0.498),0.0
41,"Morphology ICD-0-3 816, Adenomas and adenocarc...",Cancer,0.021 (0.144),0.0
42,"Morphology ICD-0-3 817, Adenomas and adenocarc...",Cancer,0.002 (0.040),0.0


# Train Model

In [19]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score, roc_auc_score
from xgboost import XGBClassifier

from sklearn.exceptions import ConvergenceWarning
import warnings
warnings.filterwarnings("ignore", category=ConvergenceWarning)

In [20]:
# Logistic Regression
targets = Y.columns
LR_params = {'C': 0.3, 'penalty': 'l2', 'class_weight': 'balanced', 'max_iter': 2000, 'random_state': 42} # 'solver': 'saga', 
LR_model = {target: LogisticRegression(**LR_params) for target in targets}
XGB_params = dict(n_estimators=100, max_depth=6, learning_rate=0.01, min_child_weight=6, random_state=42)
XGB_model = {target: XGBClassifier(**XGB_params) for target in targets}
for target in tqdm(targets):
    mask = Y_train[target] != -1
    LR_model[target].fit(X_train[mask], Y_train.loc[mask, target])
    XGB_model[target].fit(X_train[mask], Y_train.loc[mask, target])

100%|██████████| 18/18 [02:56<00:00,  9.79s/it]


In [21]:
def evaluate(model, X, Y):
    result = {}
    for target, label in Y.items():
        mask = label != -1
        # check model.classes_ to confirm prediction of positive label is at index 1
        pred = model[target].predict_proba(X[mask])[: ,1]
        auprc = average_precision_score(label[mask], pred)
        auroc = roc_auc_score(label[mask], pred)
        result[target] = {'AUPRC': auprc, 'AUROC': auroc}
    return pd.DataFrame(result).T

In [22]:
evaluate(LR_model, X_valid, Y_valid)

Unnamed: 0,AUPRC,AUROC
pain_1pt_change,0.556448,0.684272
tiredness_1pt_change,0.656066,0.699594
nausea_1pt_change,0.485958,0.683714
depression_1pt_change,0.390373,0.63438
anxiety_1pt_change,0.409052,0.602399
drowsiness_1pt_change,0.583045,0.698058
appetite_1pt_change,0.567018,0.690192
well_being_1pt_change,0.610164,0.694019
shortness_of_breath_1pt_change,0.415077,0.652561
pain_3pt_change,0.334447,0.743771


In [23]:
evaluate(LR_model, X_test, Y_test)

Unnamed: 0,AUPRC,AUROC
pain_1pt_change,0.505272,0.662678
tiredness_1pt_change,0.655866,0.692875
nausea_1pt_change,0.468333,0.670905
depression_1pt_change,0.378636,0.597825
anxiety_1pt_change,0.393109,0.58365
drowsiness_1pt_change,0.593446,0.685212
appetite_1pt_change,0.54052,0.669799
well_being_1pt_change,0.594569,0.674719
shortness_of_breath_1pt_change,0.390876,0.610694
pain_3pt_change,0.251134,0.692169


In [24]:
evaluate(XGB_model, X_valid, Y_valid)

Unnamed: 0,AUPRC,AUROC
pain_1pt_change,0.544146,0.712755
tiredness_1pt_change,0.649622,0.700841
nausea_1pt_change,0.496063,0.704249
depression_1pt_change,0.418522,0.681586
anxiety_1pt_change,0.431,0.651861
drowsiness_1pt_change,0.614952,0.71619
appetite_1pt_change,0.589528,0.711749
well_being_1pt_change,0.596041,0.697586
shortness_of_breath_1pt_change,0.443365,0.680528
pain_3pt_change,0.308291,0.741611


In [25]:
evaluate(XGB_model, X_test, Y_test)

Unnamed: 0,AUPRC,AUROC
pain_1pt_change,0.526741,0.683103
tiredness_1pt_change,0.67078,0.690211
nausea_1pt_change,0.486248,0.688041
depression_1pt_change,0.420131,0.662109
anxiety_1pt_change,0.424399,0.639286
drowsiness_1pt_change,0.621907,0.692188
appetite_1pt_change,0.543971,0.682613
well_being_1pt_change,0.588564,0.674237
shortness_of_breath_1pt_change,0.422715,0.646291
pain_3pt_change,0.241068,0.704686
