# Build Task DataFrames over MIMIC-IV

In [1]:
%load_ext autoreload
%load_ext memory_profiler

import os
import sys

import rootutils

root = rootutils.setup_root(os.path.abspath(""), dotenv=True, pythonpath=True, cwd=False)
sys.path.append(os.environ["EVENT_STREAM_PATH"])

In [2]:
%autoreload

from pathlib import Path

import polars as pl
from EventStream.data.dataset_polars import Dataset

In [28]:
COHORT_NAME = "MIMIC_IV/12-4-23-1"
PROJECT_DIR = Path(os.environ["PROJECT_DIR"])
DATA_DIR = PROJECT_DIR / "data" / COHORT_NAME
assert DATA_DIR.is_dir()

TASK_DF_DIR = DATA_DIR / "task_dfs"
TASK_DF_DIR.mkdir(exist_ok=True, parents=False)

ESD = Dataset.load(DATA_DIR)

DO_ESDS = True
if DO_ESDS:
    ESDS_TASKS_DIR = DATA_DIR / "as_ESDS" / "tasks"
    ESDS_TASKS_DIR.mkdir(exist_ok=True, parents=True)

# Event Timing Tasks

In [8]:
def has_event_type(type_str: str) -> pl.Expr:
    event_types = pl.col("event_type").cast(pl.Utf8).str.split("&")
    return event_types.list.contains(type_str)

## Readmission Risk Prediction

In [19]:
%%time
%%memit

events_df = ESD.events_df.lazy()

readmission_30d = events_df.with_columns(
    has_event_type('DISCHARGE').alias('is_discharge'),
    has_event_type('ADMISSION').alias('is_admission')
).filter(
    pl.col('is_discharge') | pl.col('is_admission')
).sort(
    ['subject_id', 'timestamp'], descending=False
).with_columns(
    pl.when(
        pl.col('is_admission')
    ).then(
        pl.col('timestamp')
    ).otherwise(
        None
    ).alias(
        'admission_time'
    ).cast(
        pl.Datetime
    )
).with_columns(
    pl.col('admission_time').fill_null(strategy='backward').over('subject_id').alias('next_admission_time'),
    pl.col('admission_time').fill_null(strategy='forward').over('subject_id').alias('prev_admission_time'),
).with_columns(
    (
        (pl.col('next_admission_time') - pl.col('timestamp')) < pl.duration(days=30)
    ).fill_null(False).alias('30d_readmission')
).filter(
    pl.col('is_discharge')
)

readmission_30d_all = readmission_30d.select(
    'subject_id', pl.lit(None).cast(pl.Datetime).alias('start_time'), pl.col('timestamp').alias('end_time'), 
    '30d_readmission'
)

readmission_30d_admission_only = readmission_30d.select(
    'subject_id', pl.col('prev_admission_time').alias('start_time'), pl.col('timestamp').alias('end_time'),
    '30d_readmission'
)

readmission_30d_all.collect().write_parquet(
    TASK_DF_DIR / 'readmission_30d_all.parquet', use_pyarrow=True
)
readmission_30d_admission_only.collect().write_parquet(
    TASK_DF_DIR / 'readmission_30d_admission_only.parquet', use_pyarrow=True
)

prevalence = readmission_30d_all.select(pl.col("30d_readmission").mean()).collect().item()
print(f"The {COHORT_NAME} cohort has a {prevalence*100:.1f}% 30d readmission prevalence.")

The MIMIC_IV/12-4-23-1 cohort has a 32.6% 30d readmission prevalence.
peak memory: 1155.97 MiB, increment: 217.58 MiB
CPU times: user 8.73 s, sys: 480 ms, total: 9.21 s
Wall time: 5.64 s


In [18]:
if DO_ESDS:
    (
        readmission_30d_all
        .select(pl.col("subject_id").alias("patient_id"), "end_time", "30d_readmission")
        .collect()
        .write_parquet(ESDS_TASKS_DIR / "readmission_30d.parquet", use_pyarrow=True)
    )

## In-hospital Mortality after 24-hrs in-ICU Risk Prediction

In [21]:
in_hosp_mort = (
    events_df.with_columns(
        has_event_type('DEATH').alias('is_death'),
        has_event_type('DISCHARGE').alias('is_discharge'),
        has_event_type('ICU_STAY_START').alias('is_icustay_admission')
    )
    .filter(pl.col('is_death') | pl.col('is_icustay_admission') | pl.col('is_discharge'))
    .with_columns([
        (
            pl.when(pl.col(f"is_{c}"))
            .then(pl.col('timestamp'))
            .otherwise(pl.lit(None, dtype=pl.Datetime))
            .cast(pl.Datetime)
            .alias(f"{c}_time")
        ) for c in ("icustay_admission", "death", "discharge")
    ])
    .sort(['subject_id', 'timestamp'], descending=False)
    .with_columns(
        pl.col('icustay_admission_time')
        .fill_null(strategy='forward')
        .over('subject_id')
        .alias('curr_icustay_admission_start_time'), 
    )
    .group_by('subject_id', 'curr_icustay_admission_start_time')
    .agg(
        pl.col('death_time').min(),
        pl.col('discharge_time').min(),
    ))

In [27]:
%%time
%%memit

events_df = ESD.events_df.lazy()
window_size = 24
gap_hours = 24

task_name = f"in_hosp_mort/{window_size}h_in_{gap_hours}h_gap"

in_hosp_mort = (
    events_df.with_columns(
        has_event_type('DEATH').alias('is_death'),
        has_event_type('DISCHARGE').alias('is_discharge'),
        has_event_type('ICU_STAY_START').alias('is_icustay_admission')
    )
    .filter(pl.col('is_death') | pl.col('is_icustay_admission') | pl.col('is_discharge'))
    .with_columns([
        (
            pl.when(pl.col(f"is_{c}"))
            .then(pl.col('timestamp'))
            .otherwise(pl.lit(None, dtype=pl.Datetime))
            .cast(pl.Datetime)
            .alias(f"{c}_time")
        ) for c in ("icustay_admission", "death", "discharge")
    ])
    .sort(['subject_id', 'timestamp'], descending=False)
    .with_columns(
        pl.col('icustay_admission_time')
        .fill_null(strategy='forward')
        .over('subject_id')
        .alias('curr_icustay_admission_start_time'), 
    )
    .group_by('subject_id', 'curr_icustay_admission_start_time')
    .agg(
        pl.col('death_time').min(),
        pl.col('discharge_time').min(),
    )
    .filter(
        (
            pl.min_horizontal('death_time', 'discharge_time') -
            pl.col('curr_icustay_admission_start_time')
        ) > pl.duration(hours=(window_size + gap_hours))
    )
    .with_columns(
        (pl.col('death_time').is_not_null() & (pl.col('death_time') <= pl.col('discharge_time')))
        .alias('in_hosp_mortality')
    )
    .select(
        'subject_id',
        pl.lit(None, dtype=pl.Datetime).alias('start_time'),
        (pl.col('curr_icustay_admission_start_time') + pl.duration(hours=window_size)).alias('end_time'),
        'in_hosp_mortality',
    )
)

task_fp = TASK_DF_DIR / f"{task_name}.parquet"
task_fp.parent.mkdir(exist_ok=True, parents=True)

in_hosp_mort.collect().write_parquet(task_fp, use_pyarrow=True)

prevalence = in_hosp_mort.select(pl.col("in_hosp_mortality").mean()).collect().item()
print(
    f"The {COHORT_NAME} cohort has a {prevalence*100:.1f}% in-hospital mortality prevalence "
    f"in the {task_name} sub-cohort."
)

The MIMIC_IV/12-4-23-1 cohort has a 8.7% in-hospital mortality prevalence in the in_hosp_mort/24h_in_24h_gap sub-cohort.
peak memory: 1126.54 MiB, increment: 179.90 MiB
CPU times: user 10.6 s, sys: 346 ms, total: 11 s
Wall time: 4.71 s


In [30]:
if DO_ESDS:
    esds_task_fp = ESDS_TASKS_DIR / f"{task_name}.parquet"
    esds_task_fp.parent.mkdir(exist_ok=True, parents=True)
    (
        in_hosp_mort
        .select(pl.col("subject_id").alias("patient_id"), "end_time", "in_hosp_mortality")
        .collect()
        .write_parquet(esds_task_fp, use_pyarrow=True)
    )