In [None]:
%%capture
%cd ../../
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import polars as pl

from make_clinical_dataset.epic.combine import (
    merge_closest_measurements, 
    combine_acu_to_main_data,
    combine_chemo_to_main_data, 
    combine_demographic_to_main_data,
    combine_radiation_to_main_data,
    get_clinic_prior_to_treatment
)
from make_clinical_dataset.epic.label import get_acu_labels, get_death_labels, get_CTCAE_labels, get_symptom_labels
from make_clinical_dataset.shared.constants import INFO_DIR, ROOT_DIR, SYMP_COLS

pl.Config.set_tbl_rows(100)
pd.set_option('display.max_column', 100)

In [None]:
DATE = '2025-03-29'
DATA_DIR = f"{ROOT_DIR}/data/final/data_{DATE}"

# Combine the features & targets

In [None]:
# load the features and targets
chemo = pl.read_parquet(f'{DATA_DIR}/interim/chemo.parquet')
rad = pl.read_parquet(f'{DATA_DIR}/interim/radiation.parquet')
lab = pl.read_parquet(f'{DATA_DIR}/interim/lab.parquet')
sym = pl.read_parquet(f'{DATA_DIR}/interim/symptom.parquet')
acu = pl.read_parquet(f'{DATA_DIR}/interim/acute_care_use.parquet')
# TODO: EDA - show number of patients with multiple birth dates
demog = pl.read_parquet(f'{DATA_DIR}/interim/demographic.parquet')
last_seen = pl.read_parquet(f'{DATA_DIR}/interim/last_seen_dates.parquet')
clinic = pl.read_parquet(f'{DATA_DIR}/interim/clinic_visits.parquet')

supp = chemo.filter(pl.col('drug_type') == "supportive")
chemo = chemo.filter(pl.col('drug_type') == "direct")

## Align on chemo sessions

In [None]:
%%time
# select anchor
main_date_col = 'assessment_date'
main = (
    chemo
    .select('mrn', 'treatment_date')
    .unique().sort('mrn', 'treatment_date')
    .rename({'treatment_date': main_date_col})
)

# merge last seen date
main = main.join(last_seen.select('mrn', 'last_seen_date'), on="mrn", how="left")

# merge demographics
main = combine_demographic_to_main_data(main, demog, main_date_col)

# merge chemotherapy treatments
main = combine_chemo_to_main_data(main, chemo, main_date_col, time_window=(-28,0))

# merge radiation treatments
main = combine_radiation_to_main_data(main, rad, main_date_col, time_window=(-28,0))

# merge acute care use
main = combine_acu_to_main_data(main, acu, main_date_col, lookback_window=5)

# merge laboratory tests
main = merge_closest_measurements(main, lab, main_date_col, "obs_date", include_meas_date=True, time_window=(-5,0))

# merge symptom surveys
main = merge_closest_measurements(main, sym, main_date_col, "obs_date", include_meas_date=True, time_window=(-30,0))

# add lables
# 1) Acute Care Use
main = get_acu_labels(main, acu, main_date_col, lookahead_window=[30, 60, 90])
# 2) CTCAE
main = get_CTCAE_labels(main.lazy(), lab.lazy(), main_date_col, lookahead_window=30).collect()
# 3) Symptoms
main = get_symptom_labels(main, sym, main_date_col)
# 4) Death
main = get_death_labels(main, lookahead_window=[30, 365])

In [None]:
date_cols = ['mrn'] + [col for col, dtype in main.schema.items() if dtype in [pl.Datetime, pl.Date]]
str_cols = [col for col, dtype in main.schema.items() if dtype == pl.String]
feat_cols = ['mrn', 'assessment_date'] + str_cols + [col for col in main.columns if col not in date_cols+str_cols]
main_dates = main.select(date_cols)
main_dates.write_parquet(f'{DATA_DIR}/processed/treatment_centered_dates.parquet')
main_data = main.select(feat_cols)
main_data.write_parquet(f'{DATA_DIR}/processed/treatment_centered_data.parquet')

## Align on clinic dates

In [None]:
main_date_col = "assessment_date"
main = get_clinic_prior_to_treatment(clinic, chemo, lookback_window=5)
main = main.select("mrn", "clinic_date", "next_sched_trt_date").rename({"clinic_date": main_date_col})

In [None]:
# Extract features
main = main.join(last_seen.select('mrn', 'last_seen_date'), on="mrn", how="left")
main = combine_demographic_to_main_data(main, demog, main_date_col)
main = combine_chemo_to_main_data(main, chemo, main_date_col, time_window=(-28,0))
main = combine_radiation_to_main_data(main, rad, main_date_col, time_window=(-28,0))
main = combine_acu_to_main_data(main, acu, main_date_col, lookback_window=5)
main = merge_closest_measurements(main, lab, main_date_col, "obs_date", include_meas_date=True, time_window=(-5,0))
main = merge_closest_measurements(main, sym, main_date_col, "obs_date", include_meas_date=True, time_window=(-30,0))

# Extract targets
main = get_acu_labels(main, acu, main_date_col, lookahead_window=[30, 60, 90])
main = get_CTCAE_labels(main.lazy(), lab.lazy(), main_date_col, lookahead_window=30).collect()
main = get_symptom_labels(main, sym, main_date_col)
main = get_death_labels(main, lookahead_window=[30, 365])

In [None]:
date_cols = ['mrn'] + [col for col, dtype in main.schema.items() if dtype in [pl.Datetime, pl.Date]]
str_cols = [col for col, dtype in main.schema.items() if dtype == pl.String]
feat_cols = ['mrn', 'assessment_date'] + str_cols + [col for col in main.columns if col not in date_cols+str_cols]
main_dates = main.select(date_cols)
main_dates.write_parquet(f'{DATA_DIR}/processed/clinic_centered_dates.parquet')
main_data = main.select(feat_cols)
main_data.write_parquet(f'{DATA_DIR}/processed/clinic_centered_data.parquet')

## Align on hospital admission dates

In [None]:
main_date_col = 'assessment_date'

main = pl.read_csv(f'{DATA_DIR}/processed/hosp/active_chemo_only_past90days.csv')
main.columns = [col + '_orig' for col in main.columns]
main = (
    main
    .with_columns(
        pl.col('mrn_orig').cast(pl.Int64).alias('mrn'),
        pl.col('admission_date_orig').cast(pl.Date).cast(pl.Datetime).alias(main_date_col)
    )
    .sort('mrn', main_date_col)
)

In [None]:
# Extract features
main = main.join(last_seen.select('mrn', 'last_seen_date'), on="mrn", how="left")
main = combine_demographic_to_main_data(main, demog, main_date_col, exclude_underage=False, exclude_missing_age=False)
main = combine_chemo_to_main_data(main, chemo, main_date_col, time_window=(-90,0))
main = combine_radiation_to_main_data(main, rad, main_date_col, time_window=(-90,0))
main = combine_acu_to_main_data(main, acu, main_date_col, lookback_window=5)
main = merge_closest_measurements(main, lab, main_date_col, "obs_date", include_meas_date=True, time_window=(-5,0))
main = merge_closest_measurements(main, sym, main_date_col, "obs_date", include_meas_date=True, time_window=(-30,0))

# Extract targets
main = get_acu_labels(main, acu, main_date_col, lookahead_window=[30, 60, 90])
main = get_CTCAE_labels(main.lazy(), lab.lazy(), main_date_col, lookahead_window=30).collect()
main = get_symptom_labels(main, sym, main_date_col)
main = get_death_labels(main, lookahead_window=[30, 365])

In [None]:
date_cols = ['mrn'] + [col for col, dtype in main.schema.items() if dtype in [pl.Datetime, pl.Date]]
feat_cols = ['mrn', 'assessment_date'] + [col for col in main.columns if col not in date_cols]
main_dates = main.select(date_cols)
main_dates.write_parquet(f'{DATA_DIR}/processed/hosp/hosp_centered_dates.parquet')
main_data = main.select(feat_cols)
main_data.write_parquet(f'{DATA_DIR}/processed/hosp/hosp_centered_data.parquet')