In [1]:
%load_ext autoreload
%load_ext memory_profiler

import os
import sys

import rootutils

root = rootutils.setup_root(os.path.abspath(""), dotenv=True, pythonpath=True, cwd=False)
sys.path.append(os.environ["EVENT_STREAM_PATH"])

In [12]:
%autoreload

from pathlib import Path

import polars as pl
from EventStream.data.dataset_polars import Dataset

In [13]:
COHORT_NAME = "MIMIC_IV/ESD_new_schema_08-22-23-1"
PROJECT_DIR = Path(os.environ["PROJECT_DIR"])
DATA_DIR = PROJECT_DIR / "data" / COHORT_NAME
assert DATA_DIR.is_dir()

ESD = Dataset.load(DATA_DIR)

In [15]:
from typing import Sequence
from EventStream.data.types import DataModality, TemporalityType, NumericDataModalitySubtype

def _HF_template_melt_df(
    self, source_df: pl.DataFrame, id_cols: Sequence[str], measures: list[str],
    default_struct_fields: dict[str, pl.DataType] | None = None,
) -> pl.Expr:
    """Re-formats `source_df` into the desired deep-learning output format."""
    struct_fields_by_m = {}
    total_vocab_size = self.vocabulary_config.total_vocab_size
    idx_dt = self.get_smallest_valid_int_type(total_vocab_size)
    
    if default_struct_fields is None: default_struct_fields = {}
    else: default_struct_fields = {**default_struct_fields}

    for m in measures:
        if m == "event_type":
            cfg = None
            modality = DataModality.SINGLE_LABEL_CLASSIFICATION
        else:
            cfg = self.measurement_configs[m]
            modality = cfg.modality

        if modality != DataModality.UNIVARIATE_REGRESSION:
            idx_value_expr = (
                pl.when(pl.col(m).is_not_null())
                .then(f"{m}/" + pl.col(m).cast(pl.Utf8))
                .otherwise(pl.lit(None, dtype=pl.Utf8))
            )
        else:
            idx_value_expr = (
                pl.when(pl.col(m).is_not_null())
                .then(pl.lit(f"{m}", dtype=pl.Utf8))
                .otherwise(pl.lit(None, dtype=pl.Utf8))
            )
            
        idx_value_expr = idx_value_expr.alias("code")

        if (modality == DataModality.UNIVARIATE_REGRESSION) and (
            cfg.measurement_metadata.value_type
            in (NumericDataModalitySubtype.FLOAT, NumericDataModalitySubtype.INTEGER)
        ):
            val_expr = pl.col(m).cast(pl.Float32)
        elif modality == DataModality.MULTIVARIATE_REGRESSION:
            val_expr = pl.col(cfg.values_column).cast(pl.Float32)
        else:
            val_expr = pl.lit(None, dtype=pl.Float32)
            
        struct_fields = {**default_struct_fields}
        
        struct_fields.update({
            'code': idx_value_expr,
            'numeric_value': val_expr.alias("numeric_value"),
        })
                
        if cfg is not None and cfg.modifiers is not None:
            for mod_col in cfg.modifiers:
                mod_col_expr = pl.col(mod_col)
                if source_df[mod_col].dtype == pl.Categorical:
                    mod_col_expr = mod_col_expr.cast(pl.Utf8)
                    
                struct_fields[mod_col] = mod_col_expr.alias(mod_col)
        
        struct_fields_by_m[m] = struct_fields
        
    struct_field_order = ['code', 'numeric_value', 'text_value', 'datetime_value']
    struct_field_order += sorted([k for k in default_struct_fields.keys() if k not in struct_field_order])
    struct_exprs = [
        pl.struct([fields[k] for k in struct_field_order]).alias(m)
        for m, fields in struct_fields_by_m.items()
    ]
    
    return (
        source_df.select(*id_cols, *struct_exprs)
        .melt(
            id_vars=id_cols,
            value_vars=measures,
            variable_name="_to_drop",
            value_name="measurement",
        )
        .filter(pl.col("measurement").struct.field("code").is_not_null())
        .select(*id_cols, "measurement")
    )


def build_HF_representation(
    self, subject_ids: list[int] | None = None, do_sort_outputs: bool = False
) -> pl.DataFrame:
    # Identify the measurements sourced from each dataframe:
    subject_measures, event_measures, dynamic_measures = [], ["event_type"], []
    default_struct_fields = {
        'text_value': pl.lit(None, dtype=pl.Utf8).alias('text_value'),
        'datetime_value': pl.lit(None, dtype=pl.Datetime).alias("datetime_value"),
    }
    for m in self.unified_measurements_vocab[1:]:
        cfg = self.measurement_configs[m]
        match cfg.temporality:
            case TemporalityType.STATIC:
                source_df = self.subjects_df
                subject_measures.append(m)
            case TemporalityType.FUNCTIONAL_TIME_DEPENDENT:
                source_df = self.events_df
                event_measures.append(m)
            case TemporalityType.DYNAMIC:
                source_df = self.dynamic_measurements_df
                dynamic_measures.append(m)
            case _:
                raise ValueError(f"Unknown temporality type {cfg.temporality} for {m}")
        
        if cfg.modifiers is None: continue
            
        for mod_col in cfg.modifiers:
            if mod_col not in source_df:
                raise IndexError(f"mod_col {mod_col} missing!")
            
            out_dt = source_df[mod_col].dtype
            if out_dt == pl.Categorical:
                out_dt = pl.Utf8
            default_struct_fields[mod_col] = pl.lit(None, dtype=out_dt).alias(mod_col)

    # 1. Process subject data into the right format.
    if subject_ids:
        subjects_df = self._filter_col_inclusion(self.subjects_df, {"subject_id": subject_ids})
    else:
        subjects_df = self.subjects_df
        
    static_data = (
        _HF_template_melt_df(
            self, subjects_df, ["subject_id"], subject_measures,
            default_struct_fields=default_struct_fields
        )
        .groupby("subject_id")
        .agg(pl.col("measurement").alias("static_measurements"))
    )

    # 2. Process event data into the right format.
    if subject_ids:
        events_df = self._filter_col_inclusion(self.events_df, {"subject_id": subject_ids})
        event_ids = list(events_df["event_id"])
    else:
        events_df = self.events_df
        event_ids = None
    event_data = _HF_template_melt_df(
        self, events_df, ["subject_id", "timestamp", "event_id"], event_measures,
        default_struct_fields=default_struct_fields
    )

    # 3. Process measurement data into the right base format:
    if event_ids:
        dynamic_measurements_df = self._filter_col_inclusion(
            self.dynamic_measurements_df, {"event_id": event_ids}
        )
    else:
        dynamic_measurements_df = self.dynamic_measurements_df

    dynamic_ids = ["event_id", "measurement_id"] if do_sort_outputs else ["event_id"]
    dynamic_data = _HF_template_melt_df(
        self, dynamic_measurements_df, dynamic_ids, dynamic_measures,
        default_struct_fields=default_struct_fields
    )

    if do_sort_outputs:
        dynamic_data = dynamic_data.sort("event_id", "measurement_id")

    # 4. Join dynamic and event data.

    event_data = pl.concat([event_data, dynamic_data], how="diagonal")
    event_data = (
        event_data.groupby("event_id")
        .agg(
            pl.col('subject_id').drop_nulls().first(),
            pl.col('timestamp').drop_nulls().first(),
            pl.col("measurement").alias("measurements")
        )
        .with_columns(
            pl.struct([
                pl.col("timestamp").alias("time"),
                pl.col("measurements").alias("measurements")
            ]).alias("event")
        )
        .sort("subject_id", "timestamp")
        .groupby("subject_id")
        .agg(pl.col("event").alias("events"))
    )

    out = static_data.join(event_data, on="subject_id", how="outer")
    if do_sort_outputs:
        out = out.sort("subject_id")

    return out

In [16]:
import pyarrow as pa
measurement = pa.struct([
    ('code', pa.string()),
    ('numeric_value', pa.float32()),
    ('text_value', pa.string()),
    ('datetime_value', pa.timestamp('us')),
])


event = pa.struct([
    ('time', pa.timestamp('us')),
    ('measurements', pa.list_(measurement))
])


schema = pa.schema([
    ('subject_id', pa.int64()),
    ('static_measurements', pa.list_(measurement)),
    ('events', pa.list_(event)), # Require ordered by time
])

In [19]:
import math, numpy as np
import pyarrow.parquet
from tqdm.auto import tqdm

for sp, subjs in tqdm(list(ESD.split_subjects.items())):
    n_chunks = int(math.ceil(len(subjs) / 20))
    for i, subjs_chunk in enumerate(np.array_split(list(subjs), n_chunks)):
        df = build_HF_representation(ESD, do_sort_outputs=True, subject_ids=list(subjs_chunk))
        arr_table = df.to_arrow().cast(schema)
        fp = ESD.config.save_dir / "HF_Dataset" / sp / f"{i}.parquet"
        fp.parent.mkdir(exist_ok=True, parents=True)
        pyarrow.parquet.write_table(arr_table, fp)

  0%|          | 0/3 [00:00<?, ?it/s]

Loading dynamic_measurements from /n/data1/hms/dbmi/zaklab/RAMMS/data/MIMIC_IV/ESD_new_schema_08-22-23-1/dynamic_measurements_df.parquet...
Loading subjects from /n/data1/hms/dbmi/zaklab/RAMMS/data/MIMIC_IV/ESD_new_schema_08-22-23-1/subjects_df.parquet...


In [None]:
df = pl.scan_parquet(ESD.config.save_dir / "HF_Dataset" / "*/*.parquet")
print("In raw form:")
display(df.head(2).collect())
print("Exploded out:")
display(
    df
    .head(1)
    .explode('events')
    .unnest('events')
    .explode('measurements')
    .unnest('measurements')
    .head(2)
    .collect()
)