# Imports and Data

In [1]:
%load_ext autoreload

import sys
sys.path.append('../../EventStreamGPT')

In [2]:
%autoreload
import gc, humanize, polars as pl, pandas as pd
import dataclasses
from pathlib import Path
from contextlib import contextmanager
from datetime import datetime, timedelta

from EventStream.data.dataset_polars import Dataset
pl.enable_string_cache(True)

In [3]:
COHORT_NAME = 'ESD_07-23-23'
RAW_DATA_DIR = Path('/storage/shared/mgh-hf-dataset/interim')
OUT_DATA_DIR = Path('/storage/shared/mgh-hf-dataset/processed/') / COHORT_NAME

TASK_DF_DIR = OUT_DATA_DIR / 'task_dfs'
TASK_DF_DIR.mkdir(exist_ok=True, parents=False)

In [4]:
%%time
ESD = Dataset.load(OUT_DATA_DIR)

CPU times: user 119 ms, sys: 32.3 ms, total: 151 ms
Wall time: 150 ms


# Task DataFrames

1. 30-day readmission
2. Lab test values:
  - A few days to a week. Specific ones:
    - Potassium, Creatinine, 'Troponin T cardiac', 'Urea nitrogen', 'Glomerular filtration rate'
    - 'N-terminal pro-brain natriuretic peptide', 
  - Classification of "low,normal,high".
3. Echo parameters
  - lv_ef (classification) vs. lv_ef_value regression, 'av_stenosis', 'mv_regurg'
  - Up to 3 months in advance is ok.
4. Elevated pressures:
  - wedge pressure, re pressure, pa pressure
  - Payal to send ranges
  - gaps at the same-day level. Mostly predicted from ECG and ECHO
  - need to think about 
  
Real tasks:
1. Predict (1) what is potassium now, (2) what is potassium in a week, (3) was there an associated change in medications between those two values. First do this with ev.

## Direct Measurment Prediction 

In [34]:
import dataclasses
from datetime import timedelta
import math
import numpy as np
import omegaconf
from tqdm.auto import tqdm

from EventStream.data.types import DataModality


@dataclasses.dataclass
class MeasurementTaskSpec:
    DEFAULT_NAMES = ['LOW', 'NORMAL', 'HIGH']
    
    gap_time: timedelta | None = None
    window_size: timedelta | None = None
    bounds: list[float] | dict[tuple[str, str], list[float]] | None = None
    bound_names: list[str] | None = None
        
    def __post_init__(self):
        if self.bounds is None:
            if self.bound_names is not None:
                raise ValueError("Bound names shouldn't be set in classification mode")
            return
        
        match self.bounds:
            case dict():
                L = None
                for k, v in self.bounds.items():
                    if type(k) is not tuple and len(k) != 2: raise TypeError(f"Bounds malformed with key {k}")
                    if L is None: L = len(v)
                    elif len(v) != L: raise ValueError("All bounds must have the same length!")
            case list: L = len(self.bounds)

        if self.bound_names is None and L == len(self.DEFAULT_NAMES) - 1:
            self.bound_names = self.DEFAULT_NAMES

        if L != len(self.bound_names) - 1:
            raise ValueError(f"Bound names {self.bound_names} and bounds {self.bounds} misaligned.")

TASK_SCHEMA_T = tuple[str, MeasurementTaskSpec | dict[str, MeasurementTaskSpec]]
            
@dataclasses.dataclass
class BuildForecastingTaskDfConfig:
    data_dir: Path = omegaconf.MISSING
    task_df_name_template: str = "{measurement}/{time_str}.parquet"
    task_schemas: list[TASK_SCHEMA_T] = dataclasses.field(default_factory=list)

        
def make_time_str(td: timedelta) -> str:
    time_str = ''
    sub = {}
    for unit, mult, n in [('days', 1, 'd'), ('seconds', 60*60, 'h'), ('seconds', 60, 'm'), ('seconds', 1, 's')]:
        val = getattr(td, unit)
        if unit in sub: val -= sub[unit]
        else: sub[unit] = 0
        
        val = int(math.floor(val / mult))
        
        sub[unit] += val * mult
        
        if val != 0:
            time_str += f"{val}{n}"
    return time_str

def norm(bounds: list[float], mean_: float = None, std_: float = None): 
    return (np.array(bounds) - mean_) / std_


def get_reg_out_df(
    df: pl.LazyFrame,
    bounds: list[float] | dict[str, list[float]],
    norm_params: dict[str, float],
    vals_col: str,
    names: list[str],
) -> pl.LazyFrame:
    if type(bounds) is not dict: bounds = {(None, None): norm(bounds, **norm_params)}
    else: bounds = {k: norm(v, **norm_params) for k, v in bounds.items()}

    out_df = []
    out_col = f"{vals_col}_category"
    for (subj_col, subj_val), norm_bounds in bounds.items():
        if subj_col is not None:
            df_for_subj = df.filter(pl.col(subj_col) == subj_val)
        else:
            df_for_subj = df

        cat_expr = pl.when(pl.col(vals_col) < norm_bounds[0]).then(pl.lit(names[0]))
        old_b = norm_bounds[0]
        for b, n in zip(norm_bounds[1:], names[1:-1]):
            cat_expr = cat_expr.when(
                (pl.col(vals_col) >= old_b) & (pl.col(vals_col) < b)
            ).then(n)
            old_b = b

        cat_expr = cat_expr.otherwise(names[-1]).alias(out_col).cast(pl.Categorical)
        out_df.append(df_for_subj.with_columns(cat_expr))

    return pl.concat(out_df, how='vertical'), out_col

def reformat_task_df(
    df: pl.LazyFrame,
    gap_time: timedelta,
    label_col: str,
    window_size: timedelta
) -> pl.LazyFrame:
    
    df = df.with_columns(__indicator=pl.lit(1)).collect()
    
    labels = df.select(pl.col(label_col).drop_nulls().unique())[label_col].to_list()
    print(labels)
    
    df = (
        df
        .pivot(
            index=['subject_id', 'timestamp'],
            columns=label_col,
            values='__indicator',
            aggregate_function='sum',
        )
        .select('subject_id', 'timestamp', *labels)
        .sort(by=['subject_id', 'timestamp'])
        .fill_null(0)
    )
    
    label_cols = [c for c in df.columns if c not in ('subject_id', 'timestamp')]
    
    return (
        df
        .lazy()
        .groupby_rolling(
            'timestamp',
            period=gap_time + window_size,
            offset=timedelta(days=0),
            by=['subject_id'],
        )
        .agg(
            *[
                pl.col(c)
                .filter(pl.col('timestamp') - pl.col('timestamp').min() > gap_time)
                .sum()
                .fill_null(0)
                for c in label_cols
            ]
        )
        .filter(pl.any(pl.col(c) > 0 for c in label_cols))
        .select(
            'subject_id',
            *[pl.col(c) > 0 for c in label_cols],
            start_time=pl.lit(None).cast(pl.Datetime),
            end_time=pl.col('timestamp'),
        )
    )

def build_measurement_forecast_task_df(cfg: BuildForecastingTaskDfConfig, ESD: Dataset | None = None):
    if ESD is None: ESD = Dataset._load(cfg.data_dir)
        
    task_df_dir = cfg.data_dir / 'task_dfs'
    task_df_dir.mkdir(exist_ok=True, parents=False)
    
    out_dfs_all = {}
    
    for measurement, task_schema in cfg.task_schemas:
        subj_cols_needed = set()
        meas_cfg = ESD.measurement_configs[measurement]
        
        if type(task_schema) is dict:
            if meas_cfg.modality != DataModality.MULTIVARIATE_REGRESSION:
                raise ValueError(f"Misconfigured schema for {measurement}")
                
            for key, schema in task_schema.items():
                if key not in ESD.measurement_vocabs[measurement]:
                    raise KeyError(f"Can't find {key} in {measurement} vocabulary!")
                if type(schema.bounds) is dict:
                    for (subj_col, _) in schema.bounds.keys():
                        subj_cols_needed.add(subj_col)
            
            meas_filter = pl.col(measurement).is_in(list(task_schema.keys()))
        else:
            if ESD.measurement_configs[measurement].modality == DataModality.MULTIVARIATE_REGRESSION:
                raise ValueError(f"Misconfigured schema for {measurement}")
            if type(schema.bounds) is dict:
                for (subj_col, _) in schema.bounds.keys():
                    subj_cols_needed.add(subj_col)
                    
            meas_filter = pl.col(measurement).is_not_null()

        subjects_df = ESD.subjects_df.lazy().select('subject_id', *subj_cols_needed)

        events_df = ESD.events_df.lazy().select(
            'event_id', 'subject_id', 'timestamp'
        ).join(
            subjects_df, on='subject_id'
        )
        
        meas_df = ESD.dynamic_measurements_df.lazy().filter(
            meas_filter
        ).join(events_df, on='event_id', how='left')

        match meas_cfg.modality:
            case DataModality.MULTIVARIATE_REGRESSION:
                for key_val, schema in task_schema.items():
                    df_for_test = meas_df.filter(pl.col(measurement) == key_val)
                    norm_params = meas_cfg.measurement_metadata.loc[key_val, 'normalizer']
                    
                    labels_df, label_col = get_reg_out_df(
                        df_for_test, schema.bounds, norm_params, meas_cfg.values_column, names=schema.bound_names
                    )
                    
                    events_with_labels = (
                        events_df
                        .join(labels_df, on='event_id', how='left')
                        .select('subject_id', 'timestamp', label_col)
                    )
                    
                    out_df = reformat_task_df(
                        df=events_with_labels,
                        gap_time=schema.gap_time,
                        label_col=label_col,
                        window_size=schema.window_size
                    )
                    task_df_fp = cfg.task_df_name_template.format(
                        measurement=f"{measurement}/{key_val}", time_str=make_time_str(schema.gap_time)
                    )
                    
                    out_dfs_all[(measurement, key_val)] = (out_df, task_df_fp)
                    
            case DataModality.UNIVARIATE_REGRESSION:
                norm_params = meas_cfg.measurement_metadata.loc['normalizer']
                out_df, cat_col = get_reg_out_df(
                    meas_df, task_schema.bounds, norm_params, measurement, names=task_schema.bound_names
                )
                out_df = reformat_task_df(
                    df=out_df, gap_time=task_schema.gap_time, label_col=cat_col,
                    window_size=task_schema.window_size, 
                )
                task_df_fp = cfg.task_df_name_template.format(
                    measurement=measurement, time_str=make_time_str(task_schema.gap_time)
                )
                out_dfs_all[measurement] = (out_df, task_df_fp)

            case DataModality.MULTI_LABEL_CLASSIFICATION | DataModality.SINGLE_LABEL_CLASSIFICATION:
                if task_schema.bounds is not None: 
                    raise ValueError(f"Bounds must be none for classification! Got {schema.bounds}")
                
                out_df = reformat_task_df(
                    df=meas_df, gap_time=task_schema.gap_time, label_col=measurement,
                    window_size=task_schema.window_size
                )
                task_df_fp = cfg.task_df_name_template.format(
                    measurement=measurement, time_str=make_time_str(task_schema.gap_time)
                )
                out_dfs_all[measurement] = (out_df, task_df_fp)
                
            case _: raise ValueError(f"Modality {meas_cfg.modality} invalid for {measurement}!")

    for key, (df, fn) in tqdm(out_dfs_all.items(), leave=False, desc="Task DF"):
        try:
            fp = task_df_dir / fn
            Path(fp).parent.mkdir(exist_ok=True, parents=True)
            df.collect().write_parquet(fp)
        except Exception as e:
            raise ValueError(f"Failed to construct df for {key}") from e

    return out_dfs_all

In [35]:
%%time

pl.Config.set_fmt_str_lengths(100)
pl.Config.set_tbl_rows(100)
df = pl.scan_csv(RAW_DATA_DIR / 'lab.csv').filter(
    pl.col('name').is_in(ESD.measurement_vocabs['lab_test'])
).groupby('name').agg(
    pl.col('units').n_unique().alias('n_units'),
    pl.col('units').unique().first().alias('unit'),
).collect()

assert df['n_units'].max() == 1

display(df[['name', 'unit']])

name,unit
str,str
"""Oxygen partial pressure in venous blood""","""mmhg"""
"""Calcium""","""mg/dl"""
"""Oxygen partial pressure in blood""","""mmhg"""
"""Hematocrit""","""%"""
"""Alkaline phosphatase""","""u/l"""
"""Glomerular filtration rate""","""ml/min/1.73m2"""
"""Cholesterol total/Cholesterol in HDL""",
"""Cholesterol in LDL""","""mg/dl"""
"""Bicarbonate""","""mmol/l"""
"""Urea nitrogen""","""mg/dl"""


CPU times: user 1min 44s, sys: 36 s, total: 2min 20s
Wall time: 5.98 s


In [36]:
lab_test_schema = {
    'N-terminal pro-brain natriuretic peptide': MeasurementTaskSpec(
        gap_time=timedelta(days=7),
        window_size=timedelta(days=3),
        bounds={('sex', 'Male'): [0, 300], ('sex', 'Female'): [0, 450]}  # units pg/ml
    ),
    'Troponin T cardiac': MeasurementTaskSpec(
        gap_time=timedelta(days=2),
        window_size=timedelta(days=1),
        bounds={('sex', 'Male'): [0, 15], ('sex', 'Female'): [0, 10]}  # units ng/ml
    ),
    'Troponin I cardiac': MeasurementTaskSpec(
        gap_time=timedelta(days=2),
        window_size=timedelta(days=1),
        bounds={('sex', 'Male'): [0, 34], ('sex', 'Female'): [0, 16]}  # units ng/ml
    ),
    'Creatinine': MeasurementTaskSpec(
        gap_time=timedelta(days=7),
        window_size=timedelta(days=3),
        bounds={('sex', 'Male'): [0.7, 1.3], ('sex', 'Female'): [0.6, 1.1]}  # units mg/dl
    ),
    'Potassium': MeasurementTaskSpec(
        gap_time=timedelta(days=1),
        window_size=timedelta(days=1),
        bounds=[3.5, 5.2]  # units mmol/l
    ),
    'Sodium': MeasurementTaskSpec(
        gap_time=timedelta(days=1),
        window_size=timedelta(days=1),
        bounds=[135, 145]  # units mmol/l
    ),
    'C reactive protein': MeasurementTaskSpec(
        gap_time=timedelta(days=7),
        window_size=timedelta(days=3),
        bounds=[0, 10]  # units mg/l
    ),
    'Lactate dehydrogenase': MeasurementTaskSpec(
        gap_time=timedelta(days=7),
        window_size=timedelta(days=3),
        bounds=[140, 280]  # units u/l
    ),
    'Glomerular filtration rate': MeasurementTaskSpec(
        gap_time=timedelta(days=30),
        window_size=timedelta(days=7),
        bounds={('sex', 'Male'): [60, 120], ('sex', 'Female'): [45, 105]}  # units ml/min/1.73m2
    ),
    'Hemoglobin': MeasurementTaskSpec(
        gap_time=timedelta(days=2),
        window_size=timedelta(days=1),
        bounds={('sex', 'Male'): [13.5, 17.5], ('sex', 'Female'): [12, 15.5]}
    ),
}

cfg = BuildForecastingTaskDfConfig(
    data_dir = OUT_DATA_DIR,
    task_schemas = [
        ('lab_test', lab_test_schema),
        ('lv_ef', MeasurementTaskSpec(gap_time=timedelta(days=60), window_size=timedelta(days=14))),
        ('av_stenosis', MeasurementTaskSpec(gap_time=timedelta(days=60), window_size=timedelta(days=14))),
        ('mv_regurg', MeasurementTaskSpec(gap_time=timedelta(days=60), window_size=timedelta(days=14))),
        (
            'mean_wedge_pressure',
            MeasurementTaskSpec(gap_time=timedelta(hours=6), window_size=timedelta(hours=2), bounds=[4, 12])
        ),
        (
            'mean_pa_pressure',
            MeasurementTaskSpec(gap_time=timedelta(hours=6), window_size=timedelta(hours=2), bounds=[18, 25])
        ),
    ]
)

In [37]:
%%time

out_dfs = build_measurement_forecast_task_df(cfg, ESD=ESD)

['HIGH', 'NORMAL']
['HIGH', 'NORMAL']
['HIGH', 'NORMAL']
['HIGH', 'NORMAL', 'LOW']
['HIGH', 'NORMAL', 'LOW']
['HIGH', 'NORMAL', 'LOW']
['HIGH', 'NORMAL']
['HIGH', 'NORMAL', 'LOW']
['HIGH', 'NORMAL', 'LOW']
['HIGH', 'NORMAL', 'LOW']
['I', 'N', 'A', 'L', 'H']
['N', 'Y']
['N', 'Y']
['NORMAL', 'LOW', 'HIGH']
['NORMAL', 'LOW', 'HIGH']


Task DF:   0%|          | 0/15 [00:00<?, ?it/s]

CPU times: user 1h 44min 37s, sys: 8min 53s, total: 1h 53min 31s
Wall time: 12min 1s


In [38]:
from sparklines import sparklines

out_lines = []
for key, (_, fn) in tqdm(out_dfs.items()):
    fp = TASK_DF_DIR / fn
    if not fp.is_file():
        print(f"{fn} does not exist @ {fp}!")
        continue
    df = pl.scan_parquet(fp)
    label_cols = [c for c in df.columns if c not in ('subject_id', 'start_time', 'end_time')]
    for label_col in label_cols:
        val_counts = df.select(pl.col(label_col).value_counts(sort=True)).collect()
        labels, cnts = [], []
        for r in val_counts[label_col]:
            labels.append(r[label_col])
            cnts.append(r['counts'])
    
        freqs = np.array(cnts) / sum(cnts)

        if type(key) is tuple: key = '/'.join(key)

        out_line = f"{key} {label_col}: "

        if len(labels) < 7:
            out_line += ' '.join(f"{l} ({f*100:.1f}%)" for l, f in zip(labels, freqs))
        else:
            N = len(out_line)
            out_line += ' '.join(f"{l} ({f*100:.1f}%)" for l, f in zip(labels[:4], freqs)) + '...'
            out_line += '\n' + ' '*N + str(sparklines(freqs)[0])
        out_lines.append(out_line)
    
print('\n'.join(out_lines))

  0%|          | 0/15 [00:00<?, ?it/s]

lab_test/N-terminal pro-brain natriuretic peptide HIGH: True (90.9%) False (9.1%)
lab_test/N-terminal pro-brain natriuretic peptide NORMAL: False (90.8%) True (9.2%)
lab_test/Troponin T cardiac HIGH: False (98.8%) True (1.2%)
lab_test/Troponin T cardiac NORMAL: True (98.9%) False (1.1%)
lab_test/Troponin I cardiac HIGH: False (94.6%) True (5.4%)
lab_test/Troponin I cardiac NORMAL: True (95.0%) False (5.0%)
lab_test/Creatinine HIGH: True (51.1%) False (48.9%)
lab_test/Creatinine NORMAL: False (53.0%) True (47.0%)
lab_test/Creatinine LOW: False (83.9%) True (16.1%)
lab_test/Potassium HIGH: False (96.8%) True (3.2%)
lab_test/Potassium NORMAL: True (92.2%) False (7.8%)
lab_test/Potassium LOW: False (85.0%) True (15.0%)
lab_test/Sodium HIGH: False (84.2%) True (15.8%)
lab_test/Sodium NORMAL: True (70.1%) False (29.9%)
lab_test/Sodium LOW: False (65.6%) True (34.4%)
lab_test/C reactive protein HIGH: True (71.8%) False (28.2%)
lab_test/C reactive protein NORMAL: False (71.3%) True (28.7%)
lab

## Readmission Risk

In [8]:
%%time

events_df = ESD.events_df.lazy()

readmission_30d = events_df.with_columns(
    pl.col('event_type').cast(pl.Utf8).str.contains('DISCHARGE').alias('is_discharge'),
    pl.col('event_type').cast(pl.Utf8).str.contains('ADMISSION').alias('is_admission')
).filter(
    pl.col('is_discharge') | pl.col('is_admission')
).sort(
    ['subject_id', 'timestamp'], descending=False
).with_columns(
    pl.when(
        pl.col('is_admission')
    ).then(
        pl.col('timestamp')
    ).otherwise(
        None
    ).alias(
        'admission_time'
    ).cast(
        pl.Datetime
    )
).with_columns(
    pl.col('admission_time').fill_null(strategy='backward').over('subject_id').alias('next_admission_time'),
    pl.col('admission_time').fill_null(strategy='forward').over('subject_id').alias('prev_admission_time'),
).with_columns(
    (
        (pl.col('next_admission_time') - pl.col('timestamp')) < pl.duration(days=30)
    ).fill_null(False).alias('30d_readmission')
).filter(
    pl.col('is_discharge')
)

readmission_30d_all = readmission_30d.select(
    'subject_id', pl.lit(None).cast(pl.Datetime).alias('start_time'), pl.col('timestamp').alias('end_time'), 
    '30d_readmission'
)

readmission_30d_admission_only = readmission_30d.select(
    'subject_id', pl.col('prev_admission_time').alias('start_time'), pl.col('timestamp').alias('end_time'),
    '30d_readmission'
)

readmission_30d_all.collect().write_parquet(TASK_DF_DIR / 'readmission_30d_all.parquet')
readmission_30d_admission_only.collect().write_parquet(TASK_DF_DIR / 'readmission_30d_admission_only.parquet')

Loading events from /storage/shared/mgh-hf-dataset/processed/ESD_07-23-23/events_df.parquet...
CPU times: user 14.5 s, sys: 5.37 s, total: 19.9 s
Wall time: 3.6 s
