In [None]:
%%capture
%cd ../../
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import yaml

from make_clinical_dataset.combine import (
    add_engineered_features,
    combine_demographic_to_main_data, 
    combine_event_to_main_data,
    combine_meas_to_main_data,
    combine_perc_dose_to_main_data,
    combine_treatment_to_main_data,
)
from make_clinical_dataset.label import get_CTCAE_labels, get_death_labels, get_ED_labels, get_symptom_labels
from make_clinical_dataset.preprocess.epr.cancer_registry import get_demographic_data
from make_clinical_dataset.preprocess.epr.clinic import get_clinical_notes_data, get_clinic_visits_during_treatment, backfill_treatment_info
from make_clinical_dataset.preprocess.epr.dart import get_symptoms_data
from make_clinical_dataset.preprocess.epr.emergency import get_emergency_room_data
from make_clinical_dataset.preprocess.epr.lab import get_lab_data
from make_clinical_dataset.preprocess.epr.opis import get_treatment_data
from make_clinical_dataset.preprocess.epr.radiology import get_radiology_data
from make_clinical_dataset.preprocess.epr.recist import get_recist_data
from make_clinical_dataset.util import load_included_drugs, load_included_regimens

from ml_common.anchor import merge_closest_measurements

pd.set_option('display.max_columns', 100)
pd.set_option('display.max_rows', 100)

In [None]:
def quick_summary(df):
    print(f'Number of sessions = {len(df)}')
    print(f'Number of patients = {df["mrn"].nunique()}')
    print(f'Cohort from {df["treatment_date"].min().date()} to {df["treatment_date"].max().date()}')

def check_overlap(main, feat, main_name, feat_name):
    mask = ~main['mrn'].isin(feat['mrn'])
    n_sessions = sum(mask)
    perc_sessions = (mask).mean()*100
    n_patients = main.loc[mask, 'mrn'].nunique()
    perc_patients = (n_patients / main['mrn'].nunique()) * 100
    print(f'{perc_sessions:.1f}% (N={n_sessions}) of sessions and {perc_patients:.1f}% (N={n_patients}) of patients '
          f'in the {main_name} do not have overlapping mrns with the {feat_name}')

In [None]:
# load config
with open('./config.yaml') as file:
    cfg = yaml.safe_load(file)

# load external data
# data_dir = "./data"
data_dir = "/cluster/projects/gliugroup/2BLAST/data/final/data_2023-02-21"
included_drugs = load_included_drugs(data_dir=f'{data_dir}/external')
included_regimens = load_included_regimens(data_dir=f'{data_dir}/external')

mrn_map = pd.read_csv(f'{data_dir}/external/MRN_map.csv')
mrn_map = mrn_map.set_index('RESEARCH_ID')['PATIENT_MRN'].to_dict()

# Build the features & targets

## DART

In [None]:
dart = get_symptoms_data(data_dir=f'{data_dir}/raw')
dart.to_parquet(f'{data_dir}/interim/symptom.parquet', compression='zstd', index=False)

## Cancer Registry

In [None]:
canc_reg = get_demographic_data(data_dir=f'{data_dir}/raw')
canc_reg.to_parquet(f'{data_dir}/interim/demographic.parquet', compression='zstd', index=False)

## OPIS

In [None]:
opis = get_treatment_data(included_drugs, included_regimens, data_dir=f'{data_dir}/raw')
opis.to_parquet(f'{data_dir}/interim/treatment.parquet', compression='zstd', index=False)
quick_summary(opis)
print(f'Number of unique regimens: {opis["regimen"].nunique()}')

## Laboratory Tests 
Hematology and Biochemistry

In [None]:
lab = get_lab_data(mrn_map, data_dir=f'{data_dir}/raw')
lab.to_parquet(f'{data_dir}/interim/lab.parquet', compression='zstd', index=False)

## Emergency Room Visits

In [None]:
er_visit = get_emergency_room_data(data_dir=f'{data_dir}/raw')
er_visit.to_parquet(f'{data_dir}/interim/emergency_room_visit.parquet', compression='zstd', index=False)

## Radiology Reports

In [None]:
reports = get_radiology_data(mrn_map, data_dir=f'{data_dir}/raw')
reports.to_parquet(f'{data_dir}/interim/reports.parquet', compression='zstd', index=False)

## Clinical Notes

In [None]:
clinical_notes = get_clinical_notes_data(data_dir=f'{data_dir}/raw')
clinical_notes.to_parquet(f'{data_dir}/interim/clinical_notes.parquet', compression='zstd', index=False)

## RECIST - COMPASS

In [None]:
recist = get_recist_data(data_dir=f'{data_dir}/external')
recist.to_parquet(f'{data_dir}/interim/recist.parquet', compression='zstd', index=False)

# Combine the features & targets

In [None]:
lab = pd.read_parquet(f'{data_dir}/interim/lab.parquet')
trt = pd.read_parquet(f'{data_dir}/interim/treatment.parquet')
dmg = pd.read_parquet(f'{data_dir}/interim/demographic.parquet')
sym = pd.read_parquet(f'{data_dir}/interim/symptom.parquet')
erv = pd.read_parquet(f'{data_dir}/interim/emergency_room_visit.parquet')
last_seen = pd.read_parquet(f'{data_dir}/interim/last_seen_dates.parquet')

In [None]:
check_overlap(trt, lab, 'treatment database', 'laboratory database')
check_overlap(trt, sym, 'treatment database', 'symptoms database')

## Align on treatment sessions

In [None]:
df = combine_demographic_to_main_data(trt, dmg, 'treatment_date')
df['last_seen_date'] = df['mrn'].map(last_seen['last_seen_date'])
df['assessment_date'] = df['treatment_date']
quick_summary(df)

In [None]:
# Extract features
# df = combine_meas_to_main_data(df, sym, 'treatment_date', 'survey_date', time_window=cfg['symp_lookback_window'], stats=['last'])
# df = combine_meas_to_main_data(df, lab, 'treatment_date', 'obs_date', time_window=cfg['lab_lookback_window'], stats=['last'])
# df.columns = df.columns.str.replace('_LAST', '')
df = merge_closest_measurements(df, sym, 'treatment_date', 'survey_date', time_window=cfg['symp_lookback_window'])
df = merge_closest_measurements(df, lab, 'treatment_date', 'obs_date', time_window=cfg['lab_lookback_window'])
df = combine_event_to_main_data(df, erv, 'treatment_date', 'event_date', event_name='ED_visit', lookback_window=cfg['ed_visit_lookback_window'])
df = combine_perc_dose_to_main_data(df, included_drugs)
df = add_engineered_features(df, 'treatment_date')

In [None]:
# Extract targets
df = get_death_labels(df, lookahead_window=[30, 365])
df = get_ED_labels(df, erv[['mrn', 'event_date']].copy(), lookahead_window=30) #, 'CTAS_score', 'CEDIS_complaint']
df = get_symptom_labels(df, sym, lookahead_window=30)
df = get_CTCAE_labels(df, lab, lookahead_window=30)

In [None]:
df.to_parquet(f'{data_dir}/processed/treatment_centered_dataset.parquet', compression='zstd', index=False)

## Align on clinic visits

In [None]:
clinic = pd.read_parquet(f'{data_dir}/interim/clinical_notes.parquet')
check_overlap(trt, clinic, 'treatment database', 'clinic database')
clinic = get_clinic_visits_during_treatment(clinic, trt)

In [None]:
# Extract features
df = combine_treatment_to_main_data(clinic, trt, 'clinic_date', time_window=cfg['trt_lookback_window'])
df['last_seen_date'] = df['mrn'].map(last_seen['last_seen_date'])
df['assessment_date'] = df['clinic_date']
df = backfill_treatment_info(df)
quick_summary(df)
df = combine_demographic_to_main_data(df, dmg, 'clinic_date')
df = merge_closest_measurements(df, sym, 'clinic_date', 'survey_date', time_window=cfg['symp_lookback_window'])
df = merge_closest_measurements(df, lab, 'clinic_date', 'obs_date', time_window=cfg['lab_lookback_window'])
df = combine_event_to_main_data(df, erv, 'clinic_date', 'event_date', event_name='ED_visit', lookback_window=cfg['ed_visit_lookback_window'])
df = combine_perc_dose_to_main_data(df, included_drugs)
df = add_engineered_features(df, 'clinic_date')
# Extract targets
df = get_death_labels(df, lookahead_window=[30, 365])
df = get_ED_labels(df, erv[['mrn', 'event_date']].copy(), lookahead_window=30)
df = get_symptom_labels(df, sym, lookahead_window=30)
df = get_CTCAE_labels(df, lab, lookahead_window=30)
df.to_parquet(f'{data_dir}/processed/clinic_centered_dataset.parquet', compression='zstd', index=False)

## Align on every Mondays

In [None]:
from itertools import product
mrns = trt['mrn'].unique()
dates = pd.date_range(start='2018-01-01', end='2018-12-31', freq='W-MON')
df = pd.DataFrame(product(mrns, dates), columns=['mrn', 'assessment_date'])
df['last_seen_date'] = df['mrn'].map(last_seen['last_seen_date'])

In [None]:
# Extract features
df = combine_treatment_to_main_data(df, trt, 'assessment_date', time_window=cfg['trt_lookback_window'])
df = combine_demographic_to_main_data(df, dmg, 'assessment_date')
df = merge_closest_measurements(df, sym, 'assessment_date', 'survey_date', time_window=cfg['symp_lookback_window'])
df = merge_closest_measurements(df, lab, 'assessment_date', 'obs_date', time_window=cfg['lab_lookback_window'])
df = combine_event_to_main_data(df, erv, 'assessment_date', 'event_date', event_name='ED_visit', lookback_window=cfg['ed_visit_lookback_window'])
df = combine_perc_dose_to_main_data(df, included_drugs)
df = add_engineered_features(df, 'assessment_date')
# Extract targets
df = get_death_labels(df, lookahead_window=[30, 365])
df = get_ED_labels(df, erv[['mrn', 'event_date']].copy(), lookahead_window=30)
df = get_symptom_labels(df, sym, lookahead_window=30)
df = get_CTCAE_labels(df, lab, lookahead_window=30)
# df.to_parquet(f'{data_dir}/processed/weekly_monday_clinical_dataset.parquet', compression='zstd', index=False)