# UniHPF Style Linearization Example

## Synthetic Data
For this example, we'll build on the local, sample data tutorial in the `sample_data/examine_synthetic_data.ipynb` notebook. We'll highlight some of the differences from that tutorial first, then skip to the relevant new, UniHPF components. The notable difference in setup is that there is now a `medications.csv` file that we'll use to showcase a more complex data structure

In [1]:
!ls --color -lah raw

total 5.5M
drwxrwxr-x 2 mmd mmd 4.0K Aug 31 11:31 [0m[01;34m.[0m
drwxrwxr-x 6 mmd mmd 4.0K Aug 31 11:38 [01;34m..[0m
-rw-rw-r-- 1 mmd mmd 3.6M Aug 31 11:37 admit_vitals.csv
-rw-rw-r-- 1 mmd mmd 2.0M Aug 31 11:37 labs.csv
-rw-rw-r-- 1 mmd mmd 4.8K Aug 31 11:37 medications.csv
-rw-rw-r-- 1 mmd mmd 4.2K Aug 31 11:37 subjects.csv


In [2]:
!csvlook -d "," --max-rows=5 raw/medications.csv

|       MRN | timestamp           | name     | dose | frequency |         duration | generic_name   |
| --------- | ------------------- | -------- | ---- | --------- | ---------------- | -------------- |
|    87,570 | 21:41:00-2010-09-08 | motrin   |  600 | 2x/day    |  9 days, 0:00:00 | Ibuprofen      |
| 1,180,380 | 11:21:37-2010-06-24 | Benadryl |  100 | 1x/day    | 17 days, 0:00:00 | Diphenydramine |
| 1,405,746 | 07:01:22-2010-07-24 | Tylenol  |  600 | 4x/day    |   1 day, 0:00:00 | Acetaminophen  |
| 1,084,237 | 02:51:59-2010-04-20 | Tylenol  |  500 | 4x/day    |  2 days, 0:00:00 | Acetaminophen  |
|    42,335 | 15:26:13-2010-03-14 | Tylenol  |  500 | 1x/day    |  2 days, 0:00:00 | Acetaminophen  |
|       ... | ...                 | ...      |  ... | ...       |              ... | ...            |


## Processing Synthetic Data with ESGPT

Much like before, we'll parse this with ESGPT, though this time we'll turn off normalization, outlier detection, not drop any measurements due to frequency, and retain some extra columns for the medication source. Here's the config we'll use for that:

In [3]:
!cat dataset_unihpf.yaml

defaults:
  - dataset_base
  - _self_

# So that it can be run multiple times without issue.
do_overwrite: True

cohort_name: "unihpf_sample"
subject_id_col: "MRN"
raw_data_dir: "./sample_data/raw/"
save_dir: "./sample_data/processed/${cohort_name}"

DL_chunk_size: null

inputs:
  subjects:
    input_df: "${raw_data_dir}/subjects.csv"
  admissions:
    input_df: "${raw_data_dir}/admit_vitals.csv"
    start_ts_col: "admit_date"
    end_ts_col: "disch_date"
    ts_format: "%m/%d/%Y, %H:%M:%S"
    event_type: ["OUTPATIENT_VISIT", "ADMISSION", "DISCHARGE"]
  vitals:
    input_df: "${raw_data_dir}/admit_vitals.csv"
    ts_col: "vitals_date"
    ts_format: "%m/%d/%Y, %H:%M:%S"
  labs:
    input_df: "${raw_data_dir}/labs.csv"
    ts_col: "timestamp"
    ts_format: "%H:%M:%S-%Y-%m-%d"
  medications:
    input_df: "${raw_data_dir}/medications.csv"
    ts_col: "timestamp"
    ts_format: "%H:%M:%S-%Y-%m-%d"
    columns: {"name": "medication"}


measurements:
  static:
    single_label_classificat

There are two major modifications of note. First, beyond just adding the `medications` input source, we also add a `medication` measurement which has the additional `modifiers` key, which tracks other columns that should be leveraged there:
```yaml
      medications:
        - name: name
          modifiers: 
            - [dose, "float"]
            - [frequency, "categorical"]
            - [duration, "categorical"]
            - [generic_name, "categorical"]
```
`modifiers` takes a list of pairs of column names and column import types which should also be read.

In addition, we've set many other parameters to `null` to turn those components of the pre-processing pipeline off:
```yaml
outlier_detector_config: null
normalizer_config: null
min_valid_vocab_element_observations: null
min_valid_column_observations: null
min_true_float_frequency: null
min_unique_numerical_observations: null
...
agg_by_time_scale: null
```

Now, with that, we can run the pipeline with this command:
```bash
PYTHONPATH=$(pwd):$PYTHONPATH ./scripts/build_dataset.py \
	--config-path="$(pwd)/sample_data/" \
	--config-name=dataset_unihpf \
	"hydra.searchpath=[$(pwd)/configs]"
```

In [4]:
import subprocess

command = """\
PYTHONPATH=$(pwd):$PYTHONPATH ./scripts/build_dataset.py \
 --config-path="$(pwd)/sample_data/" \
 --config-name=dataset_unihpf \
 "hydra.searchpath=[$(pwd)/configs]" """

command_out = subprocess.run(command, cwd="../", shell=True, capture_output=True)
print(command_out.stdout.decode())

if command_out.returncode == 1:
    print("Command Errored!")

print(command_out.stderr.decode())

Empty new events dataframe of type OUTPATIENT_VISIT!




After we do so, the output will look very similar to the typical ESD pipeline:

In [5]:
!du -sh processed/unihpf_sample/

4.3M	processed/unihpf_sample/


In [6]:
!ls --color -lah -R processed/unihpf_sample/

processed/unihpf_sample/:
total 2.3M
drwxrwxr-x 5 mmd mmd 4.0K Aug 31 11:38 [0m[01;34m.[0m
drwxrwxr-x 5 mmd mmd 4.0K Aug 31 11:38 [01;34m..[0m
-rw-rw-r-- 1 mmd mmd 2.4K Aug 31 11:39 config.json
drwxrwxr-x 2 mmd mmd 4.0K Aug 31 11:38 [01;34mDL_reps[0m
-rw-rw-r-- 1 mmd mmd 864K Aug 31 11:39 dynamic_measurements_df.parquet
-rw-rw-r-- 1 mmd mmd 5.0K Aug 31 11:39 E.pkl
-rw-rw-r-- 1 mmd mmd 1.4M Aug 31 11:39 events_df.parquet
-rw-rw-r-- 1 mmd mmd 1.9K Aug 31 11:39 hydra_config.yaml
-rw-rw-r-- 1 mmd mmd 3.1K Aug 31 11:39 inferred_measurement_configs.json
drwxrwxr-x 2 mmd mmd 4.0K Aug 31 11:38 [01;34minferred_measurement_metadata[0m
-rw-rw-r-- 1 mmd mmd 2.2K Aug 31 11:39 input_schema.json
drwxrwxr-x 3 mmd mmd 4.0K Aug 31 11:38 [01;34m.logs[0m
-rw-rw-r-- 1 mmd mmd 2.7K Aug 31 11:39 subjects_df.parquet
-rw-rw-r-- 1 mmd mmd  742 Aug 31 11:39 vocabulary_config.json

processed/unihpf_sample/DL_reps:
total 2.0M
drwxrwxr-x 2 mmd mmd 4.0K Aug 31 11:38 [01;34m.[0m
drwxrwxr-x 5 mmd mmd 4.0K

Inspecting the input schema, we can see that it is being instructed to pull the modifier columns:

In [7]:
import json
with open('processed/unihpf_sample/input_schema.json', mode='r') as f:
    cfg = json.load(f)
    dynamic_measurements = cfg['dynamic']
    for m in dynamic_measurements:
        if m['input_df'].endswith('medications.csv'):
            print(m['data_schema'][0])

{'name': ['medication', 'categorical'], 'dose': 'float', 'frequency': 'categorical', 'duration': 'categorical', 'generic_name': 'categorical'}


And, looking at the inferred measurement configs, we can further see that the modifiers list is still stored:

In [8]:
import json
with open('processed/unihpf_sample/inferred_measurement_configs.json', mode='r') as f:
    cfg = json.load(f)
    print(cfg['medication'])

{'name': 'medication', 'temporality': 'dynamic', 'modality': 'multi_label_classification', 'observation_rate_over_cases': 0.0007905763301446755, 'observation_rate_per_case': 1.0, 'functor': None, 'vocabulary': {'vocabulary': ['UNK', 'Tylenol', 'Advil', 'Benadryl', 'Motrin', 'motrin'], 'obs_frequencies': [0.0, 0.26666666666666666, 0.26666666666666666, 0.18333333333333332, 0.16666666666666666, 0.11666666666666667]}, 'values_column': None, '_measurement_metadata': None, 'modifiers': ['dose', 'frequency', 'duration', 'generic_name']}


Note that even though `dose` was configured as a numeric, there is no metadata associated with it in the folder:

In [9]:
!ls processed/unihpf_sample/inferred_measurement_metadata/

age.csv  HR.csv  lab_name.csv  temp.csv


Further, given the config's instructions, there are no outlier or normalizer parameters learned for the other measurements:

In [10]:
!csvlook -d "," processed/unihpf_sample/inferred_measurement_metadata/age.csv

| a             | age   |
| ------------- | ----- |
| value_type    | float |
| outlier_model |       |
| normalizer    |       |


Let's look at the actual dataset as well:

In [11]:
import sys

sys.path.append("../")

In [12]:
# Imports
import os, polars as pl
from pathlib import Path

from EventStream.data.dataset_polars import Dataset

In [13]:
dataset_dir = Path(os.getcwd()) / "processed/unihpf_sample"

With the dataset loaded, we can ask about the three dataframes we inspected above...

In [14]:
ESD = Dataset.load(dataset_dir)
display(ESD.subjects_df.head(3))
display(ESD.events_df.head(3))
display(ESD.dynamic_measurements_df.head(3))

Loading subjects from /home/mmd/Projects/EventStreamGPT/sample_data/processed/unihpf_sample/subjects_df.parquet...


subject_id,MRN,eye_color,dob
u8,cat,cat,datetime[μs]
0,"""310243""","""GREEN""",1981-07-28 00:00:00
1,"""384198""","""BROWN""",1985-04-15 00:00:00
2,"""520533""","""BROWN""",1979-04-15 00:00:00


Loading events from /home/mmd/Projects/EventStreamGPT/sample_data/processed/unihpf_sample/events_df.parquet...


event_id,subject_id,timestamp,event_type,age
u32,u8,datetime[μs],cat,f64
0,0,2010-06-24 13:23:00,"""ADMISSION&VITA…",28.907755
1,0,2010-06-24 13:57:39,"""LAB""",28.907821
2,0,2010-06-24 14:02:54,"""LAB""",28.907831


Loading dynamic_measurements from /home/mmd/Projects/EventStreamGPT/sample_data/processed/unihpf_sample/dynamic_measurements_df.parquet...


measurement_id,department,HR,temp,lab_name,lab_value,medication,dose,frequency,duration,generic_name,event_id
u32,cat,f64,f64,str,f64,cat,f32,cat,cat,cat,u32
0,"""ORTHOPEDIC""",,,,,,,,,,79307
1,"""PULMONARY""",,,,,,,,,,77947
2,"""ORTHOPEDIC""",,,,,,,,,,61768


We can see that the additional columns tagged have been extracted as well, and that lab values haven't been normalized.

In [15]:
display(ESD.dynamic_measurements_df.filter(pl.col('medication').is_not_null()).head(2))
display(ESD.dynamic_measurements_df.filter(pl.col('lab_name').is_not_null()).head(2))

measurement_id,department,HR,temp,lab_name,lab_value,medication,dose,frequency,duration,generic_name,event_id
u32,cat,f64,f64,str,f64,cat,f32,cat,cat,cat,u32
92938,,,,,,"""Tylenol""",600.0,"""4x/day""","""1 days""","""Acetaminophen""",5434
92939,,,,,,"""Advil""",400.0,"""2x/day""","""8 days""","""Ibuprofen""",55201


measurement_id,department,HR,temp,lab_name,lab_value,medication,dose,frequency,duration,generic_name,event_id
u32,cat,f64,f64,str,f64,cat,f32,cat,cat,cat,u32
38675,,,,"""SpO2""",50.0,,,,,,38233
38676,,,,"""SpO2""",54.0,,,,,,13322


However, the "canonical" deep learning output pipeline doesn't leverage these columns at all, so its output is unchanged: 

In [16]:
!ls --color -lah processed/unihpf_sample/DL_reps/

total 2.0M
drwxrwxr-x 2 mmd mmd 4.0K Aug 31 11:38 [0m[01;34m.[0m
drwxrwxr-x 5 mmd mmd 4.0K Aug 31 11:38 [01;34m..[0m
-rw-rw-r-- 1 mmd mmd 217K Aug 31 11:39 held_out_0.parquet
-rw-rw-r-- 1 mmd mmd 1.6M Aug 31 11:39 train_0.parquet
-rw-rw-r-- 1 mmd mmd 148K Aug 31 11:39 tuning_0.parquet


In [17]:
df = pl.scan_parquet('processed/unihpf_sample/DL_reps/tuning_*.parquet')
print("DL Dataframe Columns:\n  * " + '\n  * '.join(df.columns))
display(df.head(4).collect())

DL Dataframe Columns:
  * subject_id
  * static_measurement_indices
  * static_indices
  * start_time
  * time
  * dynamic_measurement_indices
  * dynamic_indices
  * dynamic_values


subject_id,static_measurement_indices,static_indices,start_time,time,dynamic_measurement_indices,dynamic_indices,dynamic_values
u8,list[u8],list[u8],datetime[μs],list[f64],list[list[u8]],list[list[u8]],list[list[f64]]
1,[5],[14],2010-02-12 20:16:13,"[0.0, 3.7, … 4154.633333]","[[1, 3, … 8], [1, 3, 6], … [1, 3, 4]]","[[3, 8, … 25], [1, 8, 18], … [4, 8, 12]]","[[null, 24.831881, … 98.900002], [null, 24.831888, 50.0], … [null, 24.83978, null]]"
5,[5],[14],2010-01-16 07:34:43,"[0.0, 24.75, … 6175.983333]","[[1, 3, … 8], [1, 3, 6], … [1, 3, 4]]","[[3, 8, … 25], [1, 8, 18], … [4, 8, 11]]","[[null, 38.637415, … 95.699997], [null, 38.637462, 50.0], … [null, 38.649157, null]]"
9,[5],[14],2010-05-25 03:00:54,"[0.0, 52.1, … 228996.383333]","[[1, 3, … 8], [1, 3, 6], … [1, 3, 4]]","[[3, 8, … 25], [1, 8, 18], … [4, 8, 12]]","[[null, 32.972281, … 96.300003], [null, 32.97238, 51.0], … [null, 33.407668, null]]"
12,[5],[16],2010-02-06 13:42:56,"[0.0, 11.45, … 186737.516667]","[[1, 3, … 8], [1, 3, … 8], … [1, 3, 4]]","[[3, 8, … 25], [2, 8, … 25], … [4, 8, 12]]","[[null, 24.653173, … 103.699997], [null, 24.653195, … 100.199997], … [null, 25.008214, null]]"


## Converting to a UniHPF friendly DL representation

Let's convert this to something that permits a more natural UniHPF style representation.

We'll make our target output schema the (current) planned standardized HF dataset schema for these style of datasets:

```python
import pyarrow as pa
measurement = pa.struct([
    ('code', pa.string()),
    ('text_value', pa.string()),
    ('numeric_value', pa.float32()),
    ('datetime_value', pa.timestamp('us')),
])


event_nested = pa.struct([
    ('time', pa.timestamp('us')),
    ('measurements', pa.list_(measurement))
])


schema_nested = pa.schema([
    ('subject_id', pa.int64()),
    ('static_measurements', pa.list_(measurement)),
    ('events', pa.list_(event_nested)), # Require ordered by time
])
```


To do so, we'll modify the existing DL conversion code in `EventStream.data.dataset_polars`, which is reproduced below:
```python
def _melt_df(self, source_df: DF_T, id_cols: Sequence[str], measures: list[str]) -> pl.Expr:
    """Re-formats `source_df` into the desired deep-learning output format."""
    struct_exprs = []
    total_vocab_size = self.vocabulary_config.total_vocab_size
    idx_dt = self.get_smallest_valid_uint_type(total_vocab_size)

    for m in measures:
        if m == "event_type":
            cfg = None
            modality = DataModality.SINGLE_LABEL_CLASSIFICATION
        else:
            cfg = self.measurement_configs[m]
            modality = cfg.modality

        if m in self.measurement_vocabs:
            idx_present_expr = pl.col(m).is_not_null() & pl.col(m).is_in(self.measurement_vocabs[m])
            idx_value_expr = pl.col(m).map_dict(self.unified_vocabulary_idxmap[m], return_dtype=idx_dt)
        else:
            idx_present_expr = pl.col(m).is_not_null()
            idx_value_expr = pl.lit(self.unified_vocabulary_idxmap[m][m]).cast(idx_dt)

        idx_present_expr = idx_present_expr.cast(pl.Boolean).alias("present")
        idx_value_expr = idx_value_expr.alias("index")

        if (modality == DataModality.UNIVARIATE_REGRESSION) and (
            cfg.measurement_metadata.value_type
            in (NumericDataModalitySubtype.FLOAT, NumericDataModalitySubtype.INTEGER)
        ):
            val_expr = pl.col(m)
        elif modality == DataModality.MULTIVARIATE_REGRESSION:
            val_expr = pl.col(cfg.values_column)
        else:
            val_expr = pl.lit(None).cast(pl.Float64)

        struct_exprs.append(
            pl.struct([idx_present_expr, idx_value_expr, val_expr.alias("value")]).alias(m)
        )

    measurements_idx_dt = self.get_smallest_valid_uint_type(len(self.unified_measurements_idxmap))
    return (
        source_df.select(*id_cols, *struct_exprs)
        .melt(
            id_vars=id_cols,
            value_vars=measures,
            variable_name="measurement",
            value_name="value",
        )
        .filter(pl.col("value").struct.field("present"))
        .select(
            *id_cols,
            pl.col("measurement")
            .map_dict(self.unified_measurements_idxmap)
            .cast(measurements_idx_dt)
            .alias("measurement_index"),
            pl.col("value").struct.field("index").alias("index"),
            pl.col("value").struct.field("value").alias("value"),
        )
    )

def build_DL_cached_representation(
    self, subject_ids: list[int] | None = None, do_sort_outputs: bool = False
) -> DF_T:
    # Identify the measurements sourced from each dataframe:
    subject_measures, event_measures, dynamic_measures = [], ["event_type"], []
    for m in self.unified_measurements_vocab[1:]:
        temporality = self.measurement_configs[m].temporality
        match temporality:
            case TemporalityType.STATIC:
                subject_measures.append(m)
            case TemporalityType.FUNCTIONAL_TIME_DEPENDENT:
                event_measures.append(m)
            case TemporalityType.DYNAMIC:
                dynamic_measures.append(m)
            case _:
                raise ValueError(f"Unknown temporality type {temporality} for {m}")

    # 1. Process subject data into the right format.
    if subject_ids:
        subjects_df = self._filter_col_inclusion(self.subjects_df, {"subject_id": subject_ids})
    else:
        subjects_df = self.subjects_df

    static_data = (
        self._melt_df(subjects_df, ["subject_id"], subject_measures)
        .groupby("subject_id")
        .agg(
            pl.col("measurement_index").alias("static_measurement_indices"),
            pl.col("index").alias("static_indices"),
        )
    )

    # 2. Process event data into the right format.
    if subject_ids:
        events_df = self._filter_col_inclusion(self.events_df, {"subject_id": subject_ids})
        event_ids = list(events_df["event_id"])
    else:
        events_df = self.events_df
        event_ids = None
    event_data = self._melt_df(events_df, ["subject_id", "timestamp", "event_id"], event_measures)

    # 3. Process measurement data into the right base format:
    if event_ids:
        dynamic_measurements_df = self._filter_col_inclusion(
            self.dynamic_measurements_df, {"event_id": event_ids}
        )
    else:
        dynamic_measurements_df = self.dynamic_measurements_df

    dynamic_ids = ["event_id", "measurement_id"] if do_sort_outputs else ["event_id"]
    dynamic_data = self._melt_df(dynamic_measurements_df, dynamic_ids, dynamic_measures)

    if do_sort_outputs:
        dynamic_data = dynamic_data.sort("event_id", "measurement_id")

    # 4. Join dynamic and event data.

    event_data = pl.concat([event_data, dynamic_data], how="diagonal")
    event_data = (
        event_data.groupby("event_id")
        .agg(
            pl.col("timestamp").drop_nulls().first().alias("timestamp"),
            pl.col("subject_id").drop_nulls().first().alias("subject_id"),
            pl.col("measurement_index").alias("dynamic_measurement_indices"),
            pl.col("index").alias("dynamic_indices"),
            pl.col("value").alias("dynamic_values"),
        )
        .sort("subject_id", "timestamp")
        .groupby("subject_id")
        .agg(
            pl.col("timestamp").first().alias("start_time"),
            ((pl.col("timestamp") - pl.col("timestamp").min()).dt.nanoseconds() / (1e9 * 60)).alias(
                "time"
            ),
            pl.col("dynamic_measurement_indices"),
            pl.col("dynamic_indices"),
            pl.col("dynamic_values"),
        )
    )

    out = static_data.join(event_data, on="subject_id", how="outer")
    if do_sort_outputs:
        out = out.sort("subject_id")

    return out
```

Our new code will be very similar, but it will 

  1. Capture the extra output columns when present
  2. Use the nested struct format instead of aligned lists
  3. Use string types instead of categorical types

In [18]:
from typing import Sequence
from EventStream.data.types import DataModality, TemporalityType, NumericDataModalitySubtype

def _HF_template_melt_df(
    self, source_df: pl.DataFrame, id_cols: Sequence[str], measures: list[str],
    default_struct_fields: dict[str, pl.DataType] | None = None,
) -> pl.Expr:
    """Re-formats `source_df` into the desired deep-learning output format."""
    struct_fields_by_m = {}
    total_vocab_size = self.vocabulary_config.total_vocab_size
    idx_dt = self.get_smallest_valid_uint_type(total_vocab_size)
    
    if default_struct_fields is None: default_struct_fields = {}
    else: default_struct_fields = {**default_struct_fields}

    for m in measures:
        if m == "event_type":
            cfg = None
            modality = DataModality.SINGLE_LABEL_CLASSIFICATION
        else:
            cfg = self.measurement_configs[m]
            modality = cfg.modality

        if modality != DataModality.UNIVARIATE_REGRESSION:
            idx_value_expr = (
                pl.when(pl.col(m).is_not_null())
                .then(f"{m}/" + pl.col(m).cast(pl.Utf8))
                .otherwise(pl.lit(None, dtype=pl.Utf8))
            )
        else:
            idx_value_expr = (
                pl.when(pl.col(m).is_not_null())
                .then(pl.lit(f"{m}", dtype=pl.Utf8))
                .otherwise(pl.lit(None, dtype=pl.Utf8))
            )
            
        idx_value_expr = idx_value_expr.alias("code")

        if (modality == DataModality.UNIVARIATE_REGRESSION) and (
            cfg.measurement_metadata.value_type
            in (NumericDataModalitySubtype.FLOAT, NumericDataModalitySubtype.INTEGER)
        ):
            val_expr = pl.col(m).cast(pl.Float32)
        elif modality == DataModality.MULTIVARIATE_REGRESSION:
            val_expr = pl.col(cfg.values_column).cast(pl.Float32)
        else:
            val_expr = pl.lit(None, dtype=pl.Float32)
            
        struct_fields = {**default_struct_fields}
        
        struct_fields.update({
            'code': idx_value_expr,
            'numeric_value': val_expr.alias("numeric_value"),
        })
                
        if cfg is not None and cfg.modifiers is not None:
            for mod_col in cfg.modifiers:
                mod_col_expr = pl.col(mod_col)
                if source_df[mod_col].dtype == pl.Categorical:
                    mod_col_expr = mod_col_expr.cast(pl.Utf8)
                    
                struct_fields[mod_col] = mod_col_expr.alias(mod_col)
        
        struct_fields_by_m[m] = struct_fields
        
    struct_field_order = ['code', 'numeric_value', 'text_value', 'datetime_value']
    struct_field_order += sorted([k for k in default_struct_fields.keys() if k not in struct_field_order])
    struct_exprs = [
        pl.struct([fields[k] for k in struct_field_order]).alias(m)
        for m, fields in struct_fields_by_m.items()
    ]
    
    return (
        source_df.select(*id_cols, *struct_exprs)
        .melt(
            id_vars=id_cols,
            value_vars=measures,
            variable_name="_to_drop",
            value_name="measurement",
        )
        .filter(pl.col("measurement").struct.field("code").is_not_null())
        .select(*id_cols, "measurement")
    )


def build_HF_representation(
    self, subject_ids: list[int] | None = None, do_sort_outputs: bool = False
) -> pl.DataFrame:
    # Identify the measurements sourced from each dataframe:
    subject_measures, event_measures, dynamic_measures = [], ["event_type"], []
    default_struct_fields = {
        'text_value': pl.lit(None, dtype=pl.Utf8).alias('text_value'),
        'datetime_value': pl.lit(None, dtype=pl.Datetime).alias("datetime_value"),
    }
    for m in self.unified_measurements_vocab[1:]:
        cfg = self.measurement_configs[m]
        match cfg.temporality:
            case TemporalityType.STATIC:
                source_df = self.subjects_df
                subject_measures.append(m)
            case TemporalityType.FUNCTIONAL_TIME_DEPENDENT:
                source_df = self.events_df
                event_measures.append(m)
            case TemporalityType.DYNAMIC:
                source_df = self.dynamic_measurements_df
                dynamic_measures.append(m)
            case _:
                raise ValueError(f"Unknown temporality type {cfg.temporality} for {m}")
        
        if cfg.modifiers is None: continue
            
        for mod_col in cfg.modifiers:
            if mod_col not in source_df:
                raise IndexError(f"mod_col {mod_col} missing!")
            
            out_dt = source_df[mod_col].dtype
            if out_dt == pl.Categorical:
                out_dt = pl.Utf8
            default_struct_fields[mod_col] = pl.lit(None, dtype=out_dt).alias(mod_col)

    # 1. Process subject data into the right format.
    if subject_ids:
        subjects_df = self._filter_col_inclusion(self.subjects_df, {"subject_id": subject_ids})
    else:
        subjects_df = self.subjects_df
        
    static_data = (
        _HF_template_melt_df(
            self, subjects_df, ["subject_id"], subject_measures,
            default_struct_fields=default_struct_fields
        )
        .groupby("subject_id")
        .agg(pl.col("measurement").alias("static_measurements"))
    )

    # 2. Process event data into the right format.
    if subject_ids:
        events_df = self._filter_col_inclusion(self.events_df, {"subject_id": subject_ids})
        event_ids = list(events_df["event_id"])
    else:
        events_df = self.events_df
        event_ids = None
    event_data = _HF_template_melt_df(
        self, events_df, ["subject_id", "timestamp", "event_id"], event_measures,
        default_struct_fields=default_struct_fields
    )

    # 3. Process measurement data into the right base format:
    if event_ids:
        dynamic_measurements_df = self._filter_col_inclusion(
            self.dynamic_measurements_df, {"event_id": event_ids}
        )
    else:
        dynamic_measurements_df = self.dynamic_measurements_df

    dynamic_ids = ["event_id", "measurement_id"] if do_sort_outputs else ["event_id"]
    dynamic_data = _HF_template_melt_df(
        self, dynamic_measurements_df, dynamic_ids, dynamic_measures,
        default_struct_fields=default_struct_fields
    )

    if do_sort_outputs:
        dynamic_data = dynamic_data.sort("event_id", "measurement_id")

    # 4. Join dynamic and event data.

    event_data = pl.concat([event_data, dynamic_data], how="diagonal")
    event_data = (
        event_data.groupby("event_id")
        .agg(
            pl.col('subject_id').drop_nulls().first(),
            pl.col('timestamp').drop_nulls().first(),
            pl.col("measurement").alias("measurements")
        )
        .with_columns(
            pl.struct([
                pl.col("timestamp").alias("time"),
                pl.col("measurements").alias("measurements")
            ]).alias("event")
        )
        .sort("subject_id", "timestamp")
        .groupby("subject_id")
        .agg(pl.col("event").alias("events"))
    )

    out = static_data.join(event_data, on="subject_id", how="outer")
    if do_sort_outputs:
        out = out.sort("subject_id")

    return out

In [19]:
import pyarrow as pa
measurement = pa.struct([
    ('code', pa.string()),
    ('numeric_value', pa.float32()),
    ('text_value', pa.string()),
    ('datetime_value', pa.timestamp('us')),
   
    # Extra fields -- only needed here as we explicitly convert from polars to pyarrow format
    ('dose', pa.float32()),
    ('duration', pa.string()),
    ('frequency', pa.string()),
    ('generic_name', pa.string()),
])


event = pa.struct([
    ('time', pa.timestamp('us')),
    ('measurements', pa.list_(measurement))
])


schema = pa.schema([
    ('subject_id', pa.int64()),
    ('static_measurements', pa.list_(measurement)),
    ('events', pa.list_(event)), # Require ordered by time
])

In [20]:
import math, numpy as np
import pyarrow.parquet
from tqdm.auto import tqdm

for sp, subjs in tqdm(list(ESD.split_subjects.items())):
    n_chunks = int(math.ceil(len(subjs) / 20))
    for i, subjs_chunk in enumerate(np.array_split(list(subjs), n_chunks)):
        df = build_HF_representation(ESD, do_sort_outputs=True, subject_ids=list(subjs_chunk))
        arr_table = df.to_arrow().cast(schema)
        fp = ESD.config.save_dir / "HF_Dataset" / sp / f"{i}.parquet"
        fp.parent.mkdir(exist_ok=True, parents=True)
        pyarrow.parquet.write_table(arr_table, fp)

  0%|          | 0/3 [00:00<?, ?it/s]

Now that we've run this, what does out output look like?

In [21]:
print("In raw form:")
display(df.head(2))
print("Exploded out:")
display(
    df[0].explode('events').unnest('events')
    .explode('measurements').unnest('measurements')
    .filter(pl.col('code').str.starts_with('lab_name'))
    .head(2)
)
display(
    df[0].explode('events').unnest('events')
    .explode('measurements').unnest('measurements')
    .filter(pl.col('code').str.starts_with('medication'))
    .head(2)
)

In raw form:


subject_id,static_measurements,events
u8,list[struct[8]],list[struct[2]]
1,"[{""eye_color/BROWN"",null,null,null,null,null,null,null}]","[{2010-02-12 20:16:13,[{""event_type/ADMISSION&VITAL"",null,null,null,null,null,null,null}, {""age"",24.831881,null,null,null,null,null,null}, … {""temp"",98.900002,null,null,null,null,null,null}]}, {2010-02-12 20:19:55,[{""event_type/LAB"",null,null,null,null,null,null,null}, {""age"",24.831888,null,null,null,null,null,null}, {""lab_name/SpO2"",50.0,null,null,null,null,null,null}]}, … {2010-02-15 17:30:51,[{""event_type/DISCHARGE"",null,null,null,null,null,null,null}, {""age"",24.839781,null,null,null,null,null,null}, {""department/ORTHOPEDIC"",null,null,null,null,null,null,null}]}]"
5,"[{""eye_color/BROWN"",null,null,null,null,null,null,null}]","[{2010-01-16 07:34:43,[{""event_type/ADMISSION&VITAL"",null,null,null,null,null,null,null}, {""age"",38.637413,null,null,null,null,null,null}, … {""temp"",95.699997,null,null,null,null,null,null}]}, {2010-01-16 07:59:28,[{""event_type/LAB"",null,null,null,null,null,null,null}, {""age"",38.637463,null,null,null,null,null,null}, {""lab_name/SpO2"",50.0,null,null,null,null,null,null}]}, … {2010-01-20 14:30:42,[{""event_type/DISCHARGE"",null,null,null,null,null,null,null}, {""age"",38.649158,null,null,null,null,null,null}, {""department/CARDIAC"",null,null,null,null,null,null,null}]}]"


Exploded out:


subject_id,static_measurements,time,code,numeric_value,text_value,datetime_value,dose,duration,frequency,generic_name
u8,list[struct[8]],datetime[μs],str,f32,str,datetime[μs],f32,str,str,str
1,"[{""eye_color/BROWN"",null,null,null,null,null,null,null}]",2010-02-12 20:19:55,"""lab_name/SpO2""",50.0,,,,,,
1,"[{""eye_color/BROWN"",null,null,null,null,null,null,null}]",2010-02-12 20:24:39,"""lab_name/SOFA""",1.0,,,,,,


subject_id,static_measurements,time,code,numeric_value,text_value,datetime_value,dose,duration,frequency,generic_name
u8,list[struct[8]],datetime[μs],str,f32,str,datetime[μs],f32,str,str,str


We can use this as the specification for a Huggingface Dataset which can then be processed for tokenization, subsequence sampling, etc.

In [22]:
import datasets

features = datasets.Features.from_arrow_schema(schema)
ds = datasets.load_dataset(
    "parquet",
    data_files={
        sp: str(ESD.config.save_dir / "HF_Dataset" / sp / "*.parquet")
        for sp in ESD.split_subjects
    },
    features=features
)

print(ds['train'][0])

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating held_out split: 0 examples [00:00, ? examples/s]

Generating tuning split: 0 examples [00:00, ? examples/s]

{'subject_id': 0, 'static_measurements': [{'code': 'eye_color/GREEN', 'numeric_value': None, 'text_value': None, 'datetime_value': None, 'dose': None, 'duration': None, 'frequency': None, 'generic_name': None}], 'events': [{'time': datetime.datetime(2010, 6, 24, 13, 23), 'measurements': [{'code': 'event_type/ADMISSION&VITAL', 'numeric_value': None, 'text_value': None, 'datetime_value': None, 'dose': None, 'duration': None, 'frequency': None, 'generic_name': None}, {'code': 'age', 'numeric_value': 28.90775489807129, 'text_value': None, 'datetime_value': None, 'dose': None, 'duration': None, 'frequency': None, 'generic_name': None}, {'code': 'department/ORTHOPEDIC', 'numeric_value': None, 'text_value': None, 'datetime_value': None, 'dose': None, 'duration': None, 'frequency': None, 'generic_name': None}, {'code': 'HR', 'numeric_value': 166.0, 'text_value': None, 'datetime_value': None, 'dose': None, 'duration': None, 'frequency': None, 'generic_name': None}, {'code': 'temp', 'numeric_val

In [23]:
ds['train'][0]

{'subject_id': 0,
 'static_measurements': [{'code': 'eye_color/GREEN',
   'numeric_value': None,
   'text_value': None,
   'datetime_value': None,
   'dose': None,
   'duration': None,
   'frequency': None,
   'generic_name': None}],
 'events': [{'time': datetime.datetime(2010, 6, 24, 13, 23),
   'measurements': [{'code': 'event_type/ADMISSION&VITAL',
     'numeric_value': None,
     'text_value': None,
     'datetime_value': None,
     'dose': None,
     'duration': None,
     'frequency': None,
     'generic_name': None},
    {'code': 'age',
     'numeric_value': 28.90775489807129,
     'text_value': None,
     'datetime_value': None,
     'dose': None,
     'duration': None,
     'frequency': None,
     'generic_name': None},
    {'code': 'department/ORTHOPEDIC',
     'numeric_value': None,
     'text_value': None,
     'datetime_value': None,
     'dose': None,
     'duration': None,
     'frequency': None,
     'generic_name': None},
    {'code': 'HR',
     'numeric_value': 166.0,