# Test EventStreamGPT on Medpar data

First create esgpt environment:

1. mamba create --name esgpt python=3.10
2. mamba activate esgpt
3. pip install -e .

In [None]:
import os
import rootutils

root = rootutils.setup_root(os.path.abspath(''), dotenv=True, pythonpath=True, cwd=True)

import polars as pl
pl.Config.set_tbl_cols(7)

#### Denom_yyyy.parquet

In your terminal, create mbsf_medpar_denom by creating the symlinks found in the medpar_data/README.md.

This file contains per-subject data. It has one row per subject, with each row containing a subject identifier (here called "`bene_id`"),a date of birth, state, sex, race, and other information printed below.

In [None]:
df = pl.scan_parquet('medpar_data/mbsf_medpar_denom/denom_2000.parquet')
print("Dynamic Measurement Columns:\n  * " + '\n  * '.join(df.columns))
display(df.head(5).collect())

In [None]:
# Print dod column:
print("dod column:")
print(df.select('dod').collect())

#### Small EDA to assess the temporalities of each variables

In [None]:
# Can a bene_id have multiple race values?
# Load the DataFrame lazily
df = pl.scan_parquet('medpar_data/mbsf_medpar_denom/denom_2000.parquet')

# Group by 'bene_id', count unique 'race' codes, and filter for those with more than one unique code
multiple_races = (
    df
    .groupby("bene_id")
    .agg(pl.col("race").n_unique().alias("unique_race_count"))
    .filter(pl.col("unique_race_count") > 1)
).head(5)

# Collect the results to view
result = multiple_races.collect()

# Display the result
print(result)

In our case, a bene_id can have only one race value.

In [None]:
# Unique values of zcta per bene_id
# Load the DataFrame lazily
df = pl.scan_parquet('medpar_data/mbsf_medpar_denom/denom_2000.parquet')

# Group by 'bene_id', count unique 'zcta' codes, and filter for those with more than one unique code
multiple_zcta = (
    df
    .groupby("bene_id")
    .agg(pl.col("zcta").n_unique().alias("unique_zcta_count"))
    .filter(pl.col("unique_zcta_count") > 1)
).head(5)

# Collect the results to view
result = multiple_zcta.collect()

# Display the result
print(result)

So here zcta in one year can be considered as static. But thorough the years, it would probablily be time dependent ?

#### Inpatient_yyyy.parquet

This file contains dynamic data quantifying both fictional subject hospital admissions, and diagnoses measured for those subjects. Each row of this file records a unique diagnoses measurement for a patient, affiliated with the associated admission listed in the row. This means that admission level information is _heavily duplicated_ within this file, which is a phenomena sometimes observed in real data, and something we'll need to account for in our pipeline's setup.

In [None]:
df = pl.scan_parquet('medpar_data/mbsf_medpar_denom/inpatient_2000.parquet')
print("Dynamic Measurement Columns:\n  * " + '\n  * '.join(df.columns))
display(df.head(5).collect())

In [None]:
display(pl.read_parquet('medpar_data/mbsf_medpar_denom/inpatient_2000.parquet').select([
    'admission_date',
    'discharge_date'
]).head(5))

Each row :ts_col is used for data-sources where each row represents one event, and start_/end_ts_col for data-sources where each row specifies a range in time.

So range in time is the admission and discharge date.
One event would be the primary diag for example ? 

In [None]:
# Print admsn_type_cd
display(pl.read_parquet('medpar_data/mbsf_medpar_denom/inpatient_2000.parquet').select([
    'dschrgcd',
    'admsn_type_cd',
    "dschrg_dstntn_cd",
    "primary_diag",
    "diagnoses"
]).head(5))

In [None]:
# Number of unique discharge code values
print(pl.read_parquet('medpar_data/mbsf_medpar_denom/inpatient_2000.parquet').select([
    "primary_diag"
]).n_unique())

In [None]:
# Output the dob column
display(pl.read_parquet('medpar_data/mbsf_medpar_denom/denom_2000.parquet').select('dob').head(5))

In [None]:
# Display admission date and discharge date
display(pl.read_parquet('medpar_data/mbsf_medpar_denom/inpatient_2000.parquet').select([
    'admission_date',
    'discharge_date'
]).head(5))

## Processing Medpar Data with ESGPT

Now that we see the form of this medpar data, we can examine how to process it with Event Stream GPT. From
the base directory of the ESGPT repository, we can run the following command:

```bash
PYTHONPATH=$(pwd):$PYTHONPATH ./scripts/build_dataset.py \
	--config-path="$(pwd)/sample_data/" \
	--config-name=dataset \
	"hydra.searchpath=[$(pwd)/configs]"
```

Note that this script, like all built-in ESGPT scripts, uses [Hydra](https://hydra.cc/), a configuration file and experiment run-script library. In hydra, all scripts can take as input a set of composable configuration files which can be overwritten via files or via the command line. If you aren't already familiar with Hydra, you should read through some of their examples or tutorials to gain some familiarity with their system.

Before we actually run this command, we need to do 2 things:

  1. Decide what we _want_ the command to do, conceptually.
  2. Understand what we're _telling_ the library to do, via its input arguments.
  
### What do we _want_ to happen?
We can see that our synthetic data has a few different kinds of things happening to these subjects. In the ESGPT data model, we want to organize this data so that we clearly know who our subjects are, quantify when things happen to those subjects, and record in a sparse manner what is happening to those patients. Let's list a few more specific desiderata:

  1. We should expect our system to quantify those subjects in our synthetic data that meet our inclusion criteria (which we haven't yet specified).
  2. The system should bucket all interactions for subjects into appropriately defined events, across admissions, discharges and diagnoses.
  3. The system should learn appropriate categorical vocabularies, numerical outlier detector models, numerical normalization models, for the various measurements we want to extract (which we haven't yet specified).
  4. The system should produce "deep-learning friendly" representations of these data.

A quick tangent -- what do we mean by "deep-learning friendly" representations of these data? Well, right now, if we were to try to run these data through _any_ deep-learning system for longitudinal data, we'd need to re-format these data such that it is easy to efficiently (ideally $O(1)$) retrieve all data corresponding to a single subject in an organized timeseries format that we can then efficiently (meaning in a manner requiring minimal GPU memory) pass into a sequential neural network. 

In the current representation, this retrieval process would not be $O(1)$; instead, if we didn't modify the data's organization at all, for each new MRN, we'd need to select from each data file all those rows with that MRN (each selection being an $O(N)$ operation), and then we would need to subsequently sort all the temporal data by timestamp (another $O(L\ln(L))$ operation).

Similarly, if we use a naive, dense encoding of the data per measurement for our DL representation, this will be very wasteful in terms of GPU memory, as each record will need to occupy memory proportionate to the total number of possible measurements we could observe in our data (e.g., the total number of lab tests, plus the total number of vitals signs, plus the total number of admission departments, etc.). Instead, a sparse encoding should be used.

These two properties are exactly what we mean by a "deep-learning friendly" representation of the data.

We can see that there are several questions posed by these desiderata that we need to answer, such as:

  1. What are our inclusion criteria?
  2. How should we bucket interactions into events?
  3. What measurements do we want to extract?
  4. How do we want to define "outliers"?
  5. How do we define "appropriate categorical vocabularies"?
  6. How do we want to normalize numerical measurements?
  
To start us off, let's use the following answers:

  1. We'll include all subjects who have at least 3 events, with no other inclusion/exclusion criteria.
  2. We'll define an "event" to be any interactions happening to a patient within a 1 hour period. We'll bucket these interactions together starting at the earliest event. **QUESTION: how to define an event within what time period ?**
  3. Ideally, we'd like to extract _all_ measurements. As we'll see, however, due to a limitation in the current version of ESGPT, we'll extract all measurements except for the patient's **height what could be the issue in our case ?**. In particular, we'll extract the occurrence of admissions, discharges, diagnoses, los_day_cnt, as well as the subject's race, sex, state, zcta, county, the values recorded for discharge (like discharge code), and all claims value **QUESTION: do we need the claims value** ?
  4. We'll use a very simple outlier model, that excludes numerical data as outliers if their values exceed 1.5 standard deviations from the mean. This is an extremely aggressive cutoff only suitable for this synthetic data setting.
  5. We'll keep any categorical observation as a vocabulary element if it occurs at least 5 times.
  6. We'll normalize our numerical observations to have zero mean and unit variance.
  
### Telling the pipeline what to do: input config
Now that we have some basic idea of what we want the pipeline to do, let's examine the input configuration file that we pass to the dataset script:

In [None]:
!cat medpar_data/dataset.yaml

There are a number of sections in this file. Firstly, the first three lines ensure this config builds on the defaults provided with the ESGPT library, via Hydra's normal mechanisms. If you aren't familiar with this syntax, check out the [Hydra documentation](https://hydra.cc/docs/1.3/advanced/defaults_list/).

Next, there is a section defining some overarching variables and a section defining our input sources. We can see this section details the paths to each of our input files as well as the formatting used for (most of) the timestamps within these files. Note that this section makes use of [Hydra/OmegaConf's Interpolations](https://omegaconf.readthedocs.io/en/2.3_branch/grammar.html#interpolation-strings) to simplify the specification of the file paths used. 

**Warning**: Two parameters in this section are required: `subject_id_col`, and `cohort_name`. This will be explored in more detail later in this tutorial.

#### Temporality

How this measure varies in time. If TemporalityType.STATIC, this is a static measurement. If TemporalityType.FUNCTIONAL_TIME_DEPENDENT, then this measurement is a time-dependent measure that varies with time and static data in an analytically computable manner (e.g., age). If TemporalityType.DYNAMIC, then this is a measurement that varies in time in a non-a-priori computable manner.

Stores the name of this measurement; also the column in the appropriate internal dataframe (subjects_df, events_df, or dynamic_measurements_df) that will contain this measurement. All measurements will have this set.

The ‘column’ linkage has slightly different meanings depending on self.modality:

If modality == DataModality.UNIVARIATE_REGRESSION, then this column stores the values associated with this continuous-valued measure.

If modality == DataModality.MULTIVARIATE_REGRESSION, then this column stores the keys that dictate the dimensions for which the associated values_column has the values.

Otherwise, this column stores the categorical values of this measure.

Similarly, it has slightly different meanings depending on self.temporality:

If temporality == TemporalityType.STATIC, this is an existent column in the subjects_df dataframe.

If temporality == TemporalityType.DYNAMIC, this is an existent column in the dynamic_measurements_df dataframe.

Otherwise, (when temporality == TemporalityType.FUNCTIONAL_TIME_DEPENDENT), then this is the name the output-to-be-created column will take in the events_df dataframe.

Temporality¶
As stated above, measurements can take on one of the following three modes relating to how they vary in time:

STATIC: in which case they are unchanging and can be linked uniquely to a subject.

FUNCTIONAL_TIME_DEPENDENT: in which case they can be specified in functional form dependent only on static subject data and/or a continuous timestamp.

DYNAMIC: in which case they are time-varying, but the manner of this variation cannot be specified in a static functional form as in the case of FUNCTIONAL_TIME_DEPENDENT. Accordingly, these measurements are linked to events in a many to one fashion and are identified via a separate, metadata_id identifier.

#### Defining Temporality in Medpar 

**denom_yyyy.parquet**

STATIC:

- sex: static, single_label_classification. Why ? A bene_id can have only one sex.
- race: static, single_label_classification. Why ? A bene_id can have only one race. (not like MIMIC-IV)
- ZCTA: static, single_label_classification. Why ? bene_id can have only one zcta code affiliated in the single 2000 year. (is it time-dependent though ? Moving throughout the year ?)
- State: static, single_label_classification. 
- age_dob: static, single_label_classification
- dod: dynamic, multi_label_classification (can take either null or more dates, type: datetime[ns]). **Question: is NULL = Zero in that case ?**
- 

FUNCTIONAL TIME DEPENDENT
- age: functional_time_dependent measure

**inpatient_yyyy.parquet**

- bene_id
- year
- adm_id
- admission_date
- discharge_date
- dschrgcd: dynamic, single_label_classification (codes are string from 0-9)
- diagnoses: dynamic, multi_label_classification
- los_day_cnt: dynamic, univariate_regression
- dschrg_dstntn_cd: dynamic, single_label_classification
- src_admsn_cd: dynamic, single_label_classification
- admsn_type_cd: dynamic, single_label_classification
  * drg_price_amt
  * drg_outlier_pmt_amt
  * pass_thru_amt
  * mdcr_pmt_amt
  * bene_blood_ddctbl_amt
  * bene_prmry_pyr_amt
  * bene_ip_ddctbl_amt
  * bene_pta_coinsrnc_amt
  * admission_index
  * non_external_all
- primary_diag: dynamic, multi_label_classification (as MIMIC-IV)
  * non_external_primary
- multivariate_regression ["primary_diag", "drg_price_amt"]


Next, we have a section defining the various measurements we'll exctract in this dataset. We can see we specify each of the measurements we discussed above:
  1. sex, race, zcta, state **QUESTION: are state, zcta static?**  is extracted as a static, multiple/single_label_classification measure. 
  2. `age` is extracted as a `functional_time_dependent` measure, leveraging the date-of-birth column `dob`. _Note that this is where we define the timestamp format for the `dob` column, as it is a timestamp formatted static column!_
  3. admsn_type_cd, diagnoses, primary_diag, is extracted as a `dynamic`, `multi_label_classification` measure.
  4. `HR`, and `temp` are extracted as `dynamic`, `univariate_regression` measures.
  5. `lab_name` and `lab_value` are extracted as a single `dynamic`, `multivariate_regression` measure.
  
Note that the terms `static`, `functional_time_dependent`, & `dynamic` and `single_label_classification`, `multi_label_classification`, `univariate_regression`, and `multivariate_regression`, are defined enumerations in the `EventStream.data.config` sub-module, and dictate where measurements are stored and how they are pre-processed.
  
Finally, we have the remaining set of parameters, which define our inclusion-exclusion criteria (by specifying `min_events_per_subject`), our outlier and normalizer model configuration parameters (`normalization` being omitted here as what we want is the default value), our filtering thresholds for vocabulary elements, and the aggregation time-scale for events.

#### What else _could_ we have specified?
To better understand the structure of this input specification, let's explore this input configuration file in a bit more detail. To start with, let's look at what the default, base config contains (the config we inherit from in the defaults list):

Here they usually define a special event like VISIT when admission_time = discharge_time.

In [None]:
import subprocess
os.environ["HYDRA_FULL_ERROR"] = "1"
command = """\
PYTHONPATH=$(pwd):$PYTHONPATH ./scripts/build_dataset.py \
 --config-path="$(pwd)/medpar_data/" \
 --config-name=dataset \
 "hydra.searchpath=[$(pwd)/configs]" """

command_out = subprocess.run(command, shell=True, capture_output=True)
print(command_out.stdout.decode())

if command_out.returncode == 1:
    print("Command Errored!")

print(command_out.stderr.decode())

In [None]:
!du -sh processed_data/CMS_sample/