In [None]:
%%capture
%cd ../../
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import polars as pl

from make_clinical_dataset.epic.combine import (
    merge_closest_measurements, 
    combine_chemo_to_main_data, 
    combine_demographic_to_main_data,
    combine_event_to_main_data,
    combine_radiation_to_main_data
)
from make_clinical_dataset.epic.label import get_acu_labels, get_CTCAE_labels, get_symptom_labels
from make_clinical_dataset.epic.preprocess.demographic import get_demographic_data
from make_clinical_dataset.shared.constants import INFO_DIR, ROOT_DIR, SYMP_COLS

pl.Config.set_tbl_rows(205)
pd.set_option('display.max_column', 100)

In [None]:
DATE = '2025-03-29'
DATA_DIR = f"{ROOT_DIR}/data/final/data_{DATE}"

# Combine the features & targets

In [None]:
# load the features and targets
chemo = pl.read_parquet(f'{DATA_DIR}/interim/chemo.parquet')
rad = pl.read_parquet(f'{DATA_DIR}/interim/radiation.parquet')
lab = pl.read_parquet(f'{DATA_DIR}/interim/lab.parquet')
lab = lab.with_columns(pl.col('mrn').cast(pl.Int64)) # TODO: do this in lab preprocessing
sym = pl.read_parquet(f'{DATA_DIR}/interim/symptom.parquet')
acu = pl.read_parquet(f'{DATA_DIR}/interim/acute_care_use.parquet')
# TODO: EDA - show number of patients with multiple birth dates
demog = pl.from_pandas(get_demographic_data())

## Align on chemo sessions

In [None]:
%%time
# select anchor
main_date_col = 'assessment_date'
main = (
    chemo
    .filter(pl.col('drug_type') == "direct")
    .select('mrn', 'treatment_date').unique()
    .rename({'treatment_date': main_date_col})
    .sort('mrn', main_date_col)
)

# merge demographics
main = combine_demographic_to_main_data(main, demog, main_date_col)

# merge chemotherapy treatments
main = combine_chemo_to_main_data(main, chemo, main_date_col, time_window=(-28,0))

# merge radiation treatments
main = combine_radiation_to_main_data(main, rad, main_date_col, time_window=(-28,0))

# merge laboratory tests
main = merge_closest_measurements(main, lab, main_date_col, "obs_date", include_meas_date=True, time_window=(-5,0))

# merge symptom surveys
main = merge_closest_measurements(main, sym, main_date_col, "obs_date", include_meas_date=True, time_window=(-30,0))

# merge acute care use
main = combine_event_to_main_data(main, acu, main_date_col, "ED_visit", lookback_window=5)

# add lables
# 1) ED
main = get_acu_labels(main, acu, main_date_col, lookahead_window=[30, 60, 90])
# 2) CTCAE
main = get_CTCAE_labels(main.lazy(), lab.lazy(), main_date_col, lookahead_window=30).collect()
# 3) Symptoms
main = get_symptom_labels(main, sym, main_date_col)
# 4) Death

In [None]:
date_cols = ['mrn'] + [col for col in main.columns if col.endswith('date')]
str_cols = ['cancer_type', 'primary_site_desc', 'intent', 'drug_name', 'postal_code']
feat_cols = ['mrn', 'assessment_date'] + str_cols + [col for col in main.columns if col not in date_cols+str_cols]
main_dates = main.select(date_cols)
main_dates.write_parquet(f'{DATA_DIR}/processed/treatment_centered_dates.parquet')
main_data = main.select(feat_cols)
main_data.write_parquet(f'{DATA_DIR}/processed/treatment_centered_data.parquet')

## Align on clinic dates

In [None]:
%%time
PATH = f'{ROOT_DIR}/data/processed/clinical_notes/data_pull_2025-01-08/merged_processed_cleaned_clinical_notes.parquet.gzip'
main = pl.scan_parquet(PATH)
main = main.with_columns(pl.col('processed_date').cast(chemo.schema["treatment_date"])) # ensure date types match

main_date_col = 'assessment_date'
main = (
    main
    .filter(pl.col('Observations.ProcName').is_in(["Clinic Note", "Clinic Note (Non-dictated)", "Consultation Note"]))
    .join(chemo.lazy(), on="mrn", how="inner")
    .filter(
        # only keep clinic visits that has a treatment scheduled within the next 5 days
        (pl.col("treatment_date") > pl.col('processed_date')) &
        (pl.col("treatment_date") <= pl.col('processed_date') + pl.duration(days=5))
    )
    .rename({'processed_date': main_date_col, 'treatment_date': 'next_sched_trt_date'})
    .select('mrn', main_date_col, 'next_sched_trt_date')
    .unique(subset=['mrn', 'next_sched_trt_date'])
    .sort(by=['mrn', main_date_col])
    .collect()
)

In [None]:
# Extract features
main = combine_demographic_to_main_data(main, demog, main_date_col)
main = combine_chemo_to_main_data(main, chemo, main_date_col, time_window=(-28,0))
main = combine_radiation_to_main_data(main, rad, main_date_col, time_window=(-28,0))
main = merge_closest_measurements(main, lab, main_date_col, "obs_date", include_meas_date=True, time_window=(-5,0))
main = merge_closest_measurements(main, sym, main_date_col, "obs_date", include_meas_date=True, time_window=(-30,0))
main = combine_event_to_main_data(main, acu, main_date_col, "ED_visit", lookback_window=5)

# Extract targets
main = get_acu_labels(main, acu, main_date_col, lookahead_window=[30, 60, 90])
main = get_CTCAE_labels(main.lazy(), lab.lazy(), main_date_col, lookahead_window=30).collect()
main = get_symptom_labels(main, sym, main_date_col)

In [None]:
date_cols = ['mrn'] + [col for col in main.columns if col.endswith('date')]
str_cols = ['cancer_type', 'primary_site_desc', 'intent', 'drug_name', 'postal_code']
feat_cols = ['mrn', 'assessment_date'] + str_cols + [col for col in main.columns if col not in date_cols+str_cols]
main_dates = main.select(date_cols)
main_dates.write_parquet(f'{DATA_DIR}/processed/clinic_centered_dates.parquet')
main_data = main.select(feat_cols)
main_data.write_parquet(f'{DATA_DIR}/processed/clinic_centered_data.parquet')