In [1]:
import os
import pandas as pd
import polars as pl
from pathlib import Path

from esgpt_task_querying import main
from EventStream.data.dataset_polars import Dataset

%load_ext autoreload
%autoreload 2

pd.set_option('display.max_rows', 100)
pl.Config.set_tbl_cols(100)
pl.Config.set_tbl_rows(100)

data_path = '../MIMIC_ESD_new_schema_08-31-23-1'
os.getcwd()

  from .autonotebook import tqdm as notebook_tqdm


'/home/justinxu/esgpt/ESGPTTaskQuerying/to_organize'

In [2]:
DATA_DIR = Path(data_path)
ESD = Dataset.load(DATA_DIR)

events_df = ESD.events_df.filter(~pl.all_horizontal(pl.all().is_null()))
dynamic_measurements_df = ESD.dynamic_measurements_df.filter(
    ~pl.all_horizontal(pl.all().is_null())
)

ESD_data = (
    events_df.join(dynamic_measurements_df, on="event_id", how="left")
    .drop(["event_id"])
    .sort(by=["subject_id", "timestamp", "event_type"])
)

if ESD_data["timestamp"].dtype != pl.Datetime:
    ESD_data = ESD_data.with_columns(
        pl.col("timestamp")
        .str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M")
        .cast(pl.Datetime)
    )

Updating config.save_dir from /n/data1/hms/dbmi/zaklab/RAMMS/data/MIMIC_IV/ESD_new_schema_08-31-23-1 to ../MIMIC_ESD_new_schema_08-31-23-1
Loading events from ../MIMIC_ESD_new_schema_08-31-23-1/events_df.parquet...
Loading dynamic_measurements from ../MIMIC_ESD_new_schema_08-31-23-1/dynamic_measurements_df.parquet...


In [3]:
def has_event_type(type_str: str) -> pl.Expr:
    has_event_type = pl.col("event_type").cast(pl.Utf8).str.contains(type_str)
    # has_event_type = event_types.str.contains(type_str)
    return has_event_type

### In-Hospital Mortality

Query

In [4]:
config_path = 'test_configs/inhospital_mortality.yaml'

In [5]:
cfg = main.load_config(config_path)
ESD_data = main.generate_predicate_columns(cfg, ESD_data)

df_result = main.query_task(config_path, data_path, verbose=True)

Added predicate column is_admission.
Added predicate column is_discharge.
Added predicate column is_death.
Added predicate column is_discharge_or_death.
Updating config.save_dir from /n/data1/hms/dbmi/zaklab/RAMMS/data/MIMIC_IV/ESD_new_schema_08-31-23-1 to ../MIMIC_ESD_new_schema_08-31-23-1
Loading events from ../MIMIC_ESD_new_schema_08-31-23-1/events_df.parquet...
Loading dynamic_measurements from ../MIMIC_ESD_new_schema_08-31-23-1/dynamic_measurements_df.parquet...
Loading config...

Generating predicate columns...

Added predicate column is_admission.
Added predicate column is_discharge.
Added predicate column is_death.
Added predicate column is_discharge_or_death.

Building tree...
trigger
┣━━ gap
┃   ┗━━ target
┗━━ input


12127 subjects (14623763 rows) were excluded due to trigger condition: {'predicate': 'admission', 'min': 1, 'max': 1}.


Querying...


Querying subtree rooted at gap...
858 subjects (1478 rows) were excluded due to constraint: [(col("is_admission")) <= (0)].
947

Validation

In [6]:
validation = df_result.sample(10).to_pandas()

In [7]:
validation

Unnamed: 0,subject_id,trigger/timestamp,gap/timestamp,target/timestamp,input/timestamp,gap/window_summary,target/window_summary,input/window_summary,label
0,4933,2141-10-20 15:56:00,2141-10-24 15:56:00,2141-11-04 16:40:00,2141-09-20 15:56:00,"{'is_admission': 0, 'is_discharge': 0, 'is_dea...","{'is_admission': 0, 'is_discharge': 1, 'is_dea...","{'is_admission': 2, 'is_discharge': 2, 'is_dea...",0
1,8324,2140-11-13 14:31:00,2140-11-17 14:31:00,2140-11-21 17:03:00,2140-10-14 14:31:00,"{'is_admission': 0, 'is_discharge': 0, 'is_dea...","{'is_admission': 0, 'is_discharge': 1, 'is_dea...","{'is_admission': 2, 'is_discharge': 1, 'is_dea...",0
2,11955,2200-07-27 20:33:00,2200-07-31 20:33:00,2200-08-09 14:59:00,2200-06-27 20:33:00,"{'is_admission': 0, 'is_discharge': 0, 'is_dea...","{'is_admission': 0, 'is_discharge': 1, 'is_dea...","{'is_admission': 2, 'is_discharge': 1, 'is_dea...",0
3,3912,2147-09-22 18:32:00,2147-09-26 18:32:00,2147-10-04 15:39:00,2147-08-23 18:32:00,"{'is_admission': 0, 'is_discharge': 0, 'is_dea...","{'is_admission': 0, 'is_discharge': 1, 'is_dea...","{'is_admission': 2, 'is_discharge': 1, 'is_dea...",0
4,6290,2183-07-09 08:15:00,2183-07-13 08:15:00,2183-07-13 18:11:00,2183-06-09 08:15:00,"{'is_admission': 0, 'is_discharge': 0, 'is_dea...","{'is_admission': 0, 'is_discharge': 1, 'is_dea...","{'is_admission': 3, 'is_discharge': 3, 'is_dea...",0
5,10153,2165-03-14 23:59:00,2165-03-18 23:59:00,2165-03-19 14:22:00,2165-02-12 23:59:00,"{'is_admission': 0, 'is_discharge': 0, 'is_dea...","{'is_admission': 0, 'is_discharge': 1, 'is_dea...","{'is_admission': 2, 'is_discharge': 1, 'is_dea...",0
6,5663,2150-01-15 13:55:00,2150-01-19 13:55:00,2150-02-13 12:27:00,2149-12-16 13:55:00,"{'is_admission': 0, 'is_discharge': 0, 'is_dea...","{'is_admission': 0, 'is_discharge': 1, 'is_dea...","{'is_admission': 1, 'is_discharge': 1, 'is_dea...",0
7,6958,2162-07-17 14:31:00,2162-07-21 14:31:00,2162-07-22 13:30:00,2162-06-17 14:31:00,"{'is_admission': 0, 'is_discharge': 0, 'is_dea...","{'is_admission': 0, 'is_discharge': 1, 'is_dea...","{'is_admission': 3, 'is_discharge': 2, 'is_dea...",0
8,3430,2124-05-13 07:15:00,2124-05-17 07:15:00,2124-05-23 12:45:00,2124-04-13 07:15:00,"{'is_admission': 0, 'is_discharge': 0, 'is_dea...","{'is_admission': 0, 'is_discharge': 1, 'is_dea...","{'is_admission': 4, 'is_discharge': 3, 'is_dea...",0
9,5760,2147-12-19 00:14:00,2147-12-23 00:14:00,2147-12-25 02:56:00,2147-11-19 00:14:00,"{'is_admission': 0, 'is_discharge': 0, 'is_dea...","{'is_admission': 0, 'is_discharge': 1, 'is_dea...","{'is_admission': 2, 'is_discharge': 1, 'is_dea...",1


Statistics:
- http://varianceexplained.org/statistics/beta_distribution_and_baseball/
- https://www.getguesstimate.com/scratchpad

TODO:

-/ computational profile

-/ prompt provide predicates

-/ local data schema: https://eventstreamml.readthedocs.io/en/dev/_collections/local_tutorial_notebook.html

- check if the filtered out cohort is right / check if final cohort is right

In [8]:
for i, row in validation.iterrows():
    #filter ESD_data for the subject_id
    subject_id = row['subject_id']
    filtered_data = ESD_data.filter(pl.col("subject_id") == subject_id)

    #sort the values of every column that has 'timestamp' in its name
    timestamps = []
    for col in validation.columns:
        if 'timestamp' in col:
            timestamps.append((row[col], col))
    timestamps.sort(key=lambda x: x[0])

    filtered_data = filtered_data.filter(pl.col("timestamp") >= timestamps[0][0])
    filtered_data = filtered_data.filter(pl.col("timestamp") <= timestamps[-1][0])

    trigger_event = f'is_{cfg.windows.trigger.start}'
    assert filtered_data.filter(pl.col("timestamp") == row['trigger/timestamp']).select(trigger_event).to_pandas().values.flatten()[0] == 1
    
    after_trigger = False
    for i in range(len(timestamps)-1):
        window = timestamps[i]
        name = window[1].split('/')[0]
        if name == 'trigger':
            after_trigger = True
        if after_trigger:
            name = timestamps[i+1][1].split('/')[0]


        print(timestamps[i][0], timestamps[i+1][0])
        print(row[f'{name}/window_summary'])
        window_data = filtered_data.filter(pl.col("timestamp") > timestamps[i][0])
        window_data = window_data.filter(pl.col("timestamp") <= timestamps[i+1][0])
        
        sum_counts = window_data.sum()
        display(sum_counts)

        for predicate in row[f'{name}/window_summary']:
            assert sum_counts.select(predicate).to_pandas().values.flatten()[0] == row[f'{name}/window_summary'][predicate]
    break

2141-09-20 15:56:00 2141-10-20 15:56:00
{'is_admission': 2, 'is_discharge': 2, 'is_death': 0, 'is_discharge_or_death': 2, 'is_any': 103}


subject_id,timestamp,is_admission,is_discharge,is_death,is_discharge_or_death,is_any
i64,datetime[μs],i32,i32,i32,i32,i32
508099,,2,2,0,2,103


2141-10-20 15:56:00 2141-10-24 15:56:00
{'is_admission': 0, 'is_discharge': 0, 'is_death': 0, 'is_discharge_or_death': 0, 'is_any': 0}


subject_id,timestamp,is_admission,is_discharge,is_death,is_discharge_or_death,is_any
i64,datetime[μs],i32,i32,i32,i32,i32
0,,0,0,0,0,0


2141-10-24 15:56:00 2141-11-04 16:40:00
{'is_admission': 0, 'is_discharge': 1, 'is_death': 0, 'is_discharge_or_death': 1, 'is_any': 1}


subject_id,timestamp,is_admission,is_discharge,is_death,is_discharge_or_death,is_any
i64,datetime[μs],i32,i32,i32,i32,i32
4933,,0,1,0,1,1


In [9]:
ESD_data.filter((pl.col("subject_id") == 7374) & (pl.col('is_admission') == 1))

subject_id,timestamp,is_admission,is_discharge,is_death,is_discharge_or_death,is_any
u16,datetime[μs],i32,i32,i32,i32,i32
7374,2164-06-09 10:56:00,1,0,0,0,1
7374,2164-07-10 20:42:00,1,0,0,0,1
7374,2164-07-30 15:21:00,1,0,0,0,1
7374,2164-08-20 15:54:00,1,0,0,0,1
7374,2164-08-27 03:47:00,1,0,0,0,1
7374,2164-10-05 17:12:00,1,0,0,0,1
7374,2164-10-13 21:00:00,1,0,0,0,1
