NOTE: hormonal therapy is mostly oral therapy, hard to determine if patient actually followed the treatment plan or not. We decided not to include hormonal treatments into our dataset. But still worth doing EDA on

NOTE: there are many types of therapies (chemo, hormonal, radiation, immuno, targeted). We will utilize only chemo and radiation therapy.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import polars as pl
import seaborn as sns
from matplotlib.ticker import MaxNLocator
from make_clinical_dataset.shared.constants import INFO_DIR, ROOT_DIR

pd.set_option('display.max_rows', 200)
pd.set_option('display.max_columns', 100)
pd.set_option('display.max_colwidth', 300)

pl.Config.set_tbl_rows(100)
pl.Config.set_fmt_table_cell_list_len(-1) 

In [None]:
date = '2025-03-29'
# chemo = pl.read_parquet(f'{ROOT_DIR}/data/final/data_{date}/interim/chemo.parquet')
chemo = pd.read_parquet(f'{ROOT_DIR}/data/final/data_{date}/interim/chemo.parquet')

## Years

In [None]:
grouped = chemo.groupby(chemo['treatment_date'].dt.year)
pd.DataFrame({
    'num_unique_patients': grouped['mrn'].nunique(),
    'num_unique_treatments': grouped.apply(
        lambda g: g[['mrn', 'treatment_date']].drop_duplicates().shape[0],
        include_groups=False
    ),
    'num_rows': grouped.size()
})

## Partial Duplicates

In [None]:
"""Count the number of partial duplicates by their non-matching columns"""

chemo_pl = pl.read_parquet(f'{ROOT_DIR}/data/final/data_{date}/interim/chemo.parquet').lazy()
keys = ["mrn", "treatment_date", "drug_name"]
cols = [c for c in chemo_pl.collect_schema().names() if c not in keys]

# for each group, determine which columns had more than 1 unique value
diffs = (
    chemo_pl
    .group_by(keys)
    .agg([
        pl.col(c).n_unique().alias(c) > 1 for c in cols
    ])
)

# keep groups with at least one column with more than 1 unique value - the partial duplicates
mask = pl.sum_horizontal(cols) == 0
diffs = diffs.filter(~mask)

# for each group, aggregate the differing columns into a list
diffs = (
    diffs
    .unpivot(index=keys, on=cols)              # convert to long format
    .filter(pl.col("value"))                   # keep only the columns with more than 1 unique value
    .group_by(keys)                            # regroup per original row
    .agg(pl.col("variable").alias("diff_cols"))# aggregate the differing columns into a list for each group
)

# get the freq count for each differing column subsets
freq_count = (
    diffs
    .group_by("diff_cols")
    .len()
    .sort("len", descending=True)
)

In [None]:
freq_count.collect()

In [None]:
# check examples of partial duplicates
chemo_pl = chemo_pl.collect()
no_cols = ["given_dose", "dose_ordered", "route"]
cols = [c for c in chemo_pl.columns if c not in no_cols]
chemo_pl.filter(chemo_pl[cols].is_duplicated())

## Missing Doses

In [None]:
# how many doses are missing? 
# chemo.group_by(["data_source", "drug_type"]).agg([
#     pl.col("given_dose").is_null().sum().alias("dose_missing"),
#     pl.col("given_dose").is_not_null().sum().alias("dose_present"),
# ])
chemo.groupby(["data_source", "drug_type"]).agg(
    dose_missing=("given_dose", lambda x: x.isnull().sum()),
    dose_present=("given_dose", lambda x: x.notnull().sum())
).reset_index()

## Intent

In [None]:
"""
Usually, after palliative intent treatment, rest of the treatment remains palliative.

Check the number of cases where that is not true.
"""
def check_intent_stays_palliative(df):
    mask = df['intent'] == 'palliative'
    if not mask.any():
        return True
    idx = mask[mask].index[0]
    return all(mask.loc[idx:])

mask = chemo.groupby('mrn').apply(check_intent_stays_palliative, include_groups=False)
mask.value_counts()

In [None]:
"""Number of treatments over time"""
chemo['treatment_year'] = chemo['treatment_date'].dt.year
counts = chemo.groupby('treatment_year')['intent'].value_counts()
counts = pd.DataFrame(counts).rename(columns={'intent': 'count'}).reset_index()
counts = counts.pivot(index='treatment_year', columns='intent', values='count')
ax = sns.lineplot(data=counts)
ax.xaxis.set_major_locator(MaxNLocator(integer=True))

## First treatment date

In [None]:
"""
Usually, the first treatment date is monotonically increasing (patients do not continue old treatment after starting new treatment)

Check the number of cases where that is not true.
"""
mask = chemo.groupby('mrn').apply(
    lambda g: all(g['first_treatment_date'] == sorted(g['first_treatment_date'])), 
    include_groups=False
)
mask.value_counts()

## Height and Weight

In [None]:
"""Height/weight distribution of patients undertaking treatment""" 
height_and_weight = chemo.groupby('mrn').agg({'height': 'mean', 'weight': 'mean'})
sns.displot(data=height_and_weight, x='height', y='weight')

## Direct Drugs

In [None]:
direct = chemo.query('drug_type == "direct"')
counts = direct['drug_name'].value_counts()
top_drugs = counts.index[:30]

In [None]:
"""Number of drugs"""
len(counts)

In [None]:
"""Number of top direct drugs administered over time"""
chemo['treatment_year'] = chemo['treatment_date'].dt.year
annual_counts = direct.groupby('treatment_year')['drug_name'].value_counts().reset_index()
annual_counts = annual_counts[annual_counts['drug_name'].isin(top_drugs)]
g = sns.relplot(
    data=annual_counts, x='treatment_year', y='count', col='drug_name', col_wrap=3, kind='line', 
    facet_kws={'sharex': False, 'sharey': False}
)
for ax in g.axes.flat:
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))

In [None]:
"""
Check the number of cases where the given dose differed from ordered dose
"""
tmp = direct[direct['given_dose'].notna()]
mask = (tmp['given_dose'] == tmp['dose_ordered'])
mask.value_counts()