In [1]:
import os
os.environ['POLARS_MAX_THREADS'] = '8'
import pandas as pd
import polars as pl
from pathlib import Path

from esgpt_task_querying import main
from EventStream.data.dataset_polars import Dataset

%load_ext autoreload
%autoreload 2

pd.set_option('display.max_rows', 100)
pl.Config.set_tbl_cols(100)
pl.Config.set_tbl_rows(100)

data_path = '../MIMIC_ESD_new_schema_08-31-23-1'
os.getcwd()

  from .autonotebook import tqdm as notebook_tqdm


'/home/justinxu/esgpt/ESGPTTaskQuerying/to_organize'

In [2]:
DATA_DIR = Path(data_path)
ESD = Dataset.load(DATA_DIR)

events_df = ESD.events_df.filter(~pl.all_horizontal(pl.all().is_null()))
dynamic_measurements_df = ESD.dynamic_measurements_df.filter(
    ~pl.all_horizontal(pl.all().is_null())
)

ESD_data = (
    events_df.join(dynamic_measurements_df, on="event_id", how="left")
    .drop(["event_id"])
    .sort(by=["subject_id", "timestamp", "event_type"])
)

if ESD_data["timestamp"].dtype != pl.Datetime:
    ESD_data = ESD_data.with_columns(
        pl.col("timestamp")
        .str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M")
        .cast(pl.Datetime)
    )

Updating config.save_dir from /n/data1/hms/dbmi/zaklab/RAMMS/data/MIMIC_IV/ESD_new_schema_08-31-23-1 to ../MIMIC_ESD_new_schema_08-31-23-1
Loading events from ../MIMIC_ESD_new_schema_08-31-23-1/events_df.parquet...
Loading dynamic_measurements from ../MIMIC_ESD_new_schema_08-31-23-1/dynamic_measurements_df.parquet...


In [3]:
def has_event_type(type_str: str) -> pl.Expr:
    has_event_type = pl.col("event_type").cast(pl.Utf8).str.contains(type_str)
    # has_event_type = event_types.str.contains(type_str)
    return has_event_type

In [6]:
def validate_query(config_path, ESD_data, samples=10, verbose=False, seed=42):
    cfg = main.load_config(config_path)
    df_data = main.generate_predicate_columns(cfg, ESD_data, verbose=verbose)

    df_result = main.query_task(config_path, ESD_data, verbose=verbose)

    if df_result.shape[0] == 0:
        print("No results found.")
        return

    try:
        validation = df_result.sample(samples, seed=seed)
        validation = validation.to_pandas()
    except:
        validation = df_result.to_pandas()

    # validation = df_result.filter(df_result['subject_id'] == 1916).to_pandas()

    for i, row in validation.iterrows():
        if i % 100 == 0:
            print(f"Validating row: {i}/{len(validation)}...")
        #filter ESD_data for the subject_id
        subject_id = row['subject_id']
        filtered_data = df_data.filter(pl.col("subject_id") == subject_id)

        if verbose:
            print('Checking subject_id:', subject_id)

        #sort the values of every column that has 'timestamp' in its name
        timestamps = []
        for col in validation.columns:
            if 'timestamp' in col:
                timestamps.append((row[col], col))
        timestamps.sort(key=lambda x: x[0])

        filtered_data = filtered_data.filter(pl.col("timestamp") >= timestamps[0][0])
        filtered_data = filtered_data.filter(pl.col("timestamp") <= timestamps[-1][0])

        trigger_event = f'is_{cfg.windows.trigger.start}'
        assert filtered_data.filter(pl.col("timestamp") == row['trigger/timestamp']).select(trigger_event).to_pandas().values.flatten()[0] == 1
        
        after_trigger = False
        for i in range(len(timestamps)-1):
            window = timestamps[i]
            name = window[1].split('/')[0]
            if verbose:
                print(f"Checking window: {name}->{timestamps[i+1][1].split('/')[0]}")
            if name == 'trigger':
                after_trigger = True
            if after_trigger:
                name = timestamps[i+1][1].split('/')[0]

            if verbose:
                print(timestamps[i][0], timestamps[i+1][0])
                print(row[f'{name}/window_summary'])
            window_data = filtered_data.filter(pl.col("timestamp") > timestamps[i][0])
            window_data = window_data.filter(pl.col("timestamp") <= timestamps[i+1][0])

            sum_counts = window_data.sum()
            if verbose:
                display(sum_counts)

            for predicate in row[f'{name}/window_summary']:
                count = row[f'{name}/window_summary'][predicate]
                if not count:
                    count = 0
                assert sum_counts.select(predicate).to_pandas().values.flatten()[0] == count, (subject_id, timestamps[i][0], timestamps[i+1][0], row[f'{name}/window_summary'], sum_counts)

In [8]:
config_path = '../sample_configs/inhospital_mortality.yaml'
# config_path = '../sample_configs/abnormal_lab.yaml'
# config_path = '../sample_configs/imminent_mortality.yaml'
# config_path = '../sample_configs/intervention_weaning.yaml'
# config_path = '../sample_configs/long_term_incidence.yaml'
# config_path = '../sample_configs/readmission_risk.yaml'
validate_query(config_path, ESD_data, samples=10, verbose=True, seed=42)

Added predicate column is_admission.
Added predicate column is_discharge.
Loading config...

Generating predicate columns...

Added predicate column is_admission.
Added predicate column is_discharge.

Building tree...
trigger
┗━━ input
    ┗━━ target


12127 subjects (14623763 rows) were excluded due to trigger event: admission.


Querying...


Querying subtree rooted at input...
88 subjects (119 rows) were excluded due to constraint: [(col("is_discharge")) >= (1)].
12 subjects (12 rows) were excluded due to constraint: [(col("is_discharge")) <= (1)].


Querying subtree rooted at target...


Done.

Validating row: 0/10...
Checking subject_id: 1989
Checking window: trigger->input
2114-05-19 12:32:00 2114-05-24 10:10:00
{'is_admission': 0, 'is_discharge': 1, 'is_any': 64}


subject_id,timestamp,is_admission,is_discharge,is_any
i64,datetime[μs],i32,i32,i32
127296,,0,1,64


Checking window: input->target
2114-05-24 10:10:00 2114-06-23 10:10:00
{'is_admission': None, 'is_discharge': None, 'is_any': None}


subject_id,timestamp,is_admission,is_discharge,is_any
i64,datetime[μs],i32,i32,i32
0,,0,0,0


Checking subject_id: 3738
Checking window: trigger->input
2174-09-06 07:41:00 2174-09-08 16:00:00
{'is_admission': 0, 'is_discharge': 1, 'is_any': 1}


subject_id,timestamp,is_admission,is_discharge,is_any
i64,datetime[μs],i32,i32,i32
3738,,0,1,1


Checking window: input->target
2174-09-08 16:00:00 2174-10-08 16:00:00
{'is_admission': None, 'is_discharge': None, 'is_any': None}


subject_id,timestamp,is_admission,is_discharge,is_any
i64,datetime[μs],i32,i32,i32
0,,0,0,0


Checking subject_id: 3620
Checking window: trigger->input
2171-01-30 04:22:00 2171-02-06 18:18:00
{'is_admission': 0, 'is_discharge': 1, 'is_any': 1}


subject_id,timestamp,is_admission,is_discharge,is_any
i64,datetime[μs],i32,i32,i32
3620,,0,1,1


Checking window: input->target
2171-02-06 18:18:00 2171-03-08 18:18:00
{'is_admission': 2.0, 'is_discharge': 2.0, 'is_any': 4.0}


subject_id,timestamp,is_admission,is_discharge,is_any
i64,datetime[μs],i32,i32,i32
14480,,2,2,4


Checking subject_id: 8838
Checking window: trigger->input
2173-12-10 16:36:00 2173-12-23 15:30:00
{'is_admission': 0, 'is_discharge': 1, 'is_any': 1}


subject_id,timestamp,is_admission,is_discharge,is_any
i64,datetime[μs],i32,i32,i32
8838,,0,1,1


Checking window: input->target
2173-12-23 15:30:00 2174-01-22 15:30:00
{'is_admission': 1.0, 'is_discharge': 0.0, 'is_any': 1.0}


subject_id,timestamp,is_admission,is_discharge,is_any
i64,datetime[μs],i32,i32,i32
8838,,1,0,1


Checking subject_id: 6498
Checking window: trigger->input
2114-08-22 10:47:00 2114-09-10 18:16:00
{'is_admission': 0, 'is_discharge': 1, 'is_any': 265}


subject_id,timestamp,is_admission,is_discharge,is_any
i64,datetime[μs],i32,i32,i32
1721970,,0,1,265


Checking window: input->target
2114-09-10 18:16:00 2114-10-10 18:16:00
{'is_admission': 1.0, 'is_discharge': 1.0, 'is_any': 2.0}


subject_id,timestamp,is_admission,is_discharge,is_any
i64,datetime[μs],i32,i32,i32
12996,,1,1,2


Checking subject_id: 1883
Checking window: trigger->input
2156-02-17 12:59:00 2156-03-07 19:00:00
{'is_admission': 0, 'is_discharge': 1, 'is_any': 687}


subject_id,timestamp,is_admission,is_discharge,is_any
i64,datetime[μs],i32,i32,i32
1293621,,0,1,687


Checking window: input->target
2156-03-07 19:00:00 2156-04-06 19:00:00
{'is_admission': None, 'is_discharge': None, 'is_any': None}


subject_id,timestamp,is_admission,is_discharge,is_any
i64,datetime[μs],i32,i32,i32
0,,0,0,0


Checking subject_id: 3897
Checking window: trigger->input
2184-11-12 21:53:00 2184-11-24 19:50:00
{'is_admission': 0, 'is_discharge': 1, 'is_any': 810}


subject_id,timestamp,is_admission,is_discharge,is_any
i64,datetime[μs],i32,i32,i32
3156570,,0,1,810


Checking window: input->target
2184-11-24 19:50:00 2184-12-24 19:50:00
{'is_admission': None, 'is_discharge': None, 'is_any': None}


subject_id,timestamp,is_admission,is_discharge,is_any
i64,datetime[μs],i32,i32,i32
0,,0,0,0


Checking subject_id: 5546
Checking window: trigger->input
2157-09-26 21:03:00 2157-09-26 23:35:00
{'is_admission': 0, 'is_discharge': 1, 'is_any': 1}


subject_id,timestamp,is_admission,is_discharge,is_any
i64,datetime[μs],i32,i32,i32
5546,,0,1,1


Checking window: input->target
2157-09-26 23:35:00 2157-10-26 23:35:00
{'is_admission': 1.0, 'is_discharge': 1.0, 'is_any': 2.0}


subject_id,timestamp,is_admission,is_discharge,is_any
i64,datetime[μs],i32,i32,i32
11092,,1,1,2


Checking subject_id: 160
Checking window: trigger->input
2138-06-10 15:38:00 2138-06-11 00:53:00
{'is_admission': 0, 'is_discharge': 1, 'is_any': 1}


subject_id,timestamp,is_admission,is_discharge,is_any
i64,datetime[μs],i32,i32,i32
160,,0,1,1


Checking window: input->target
2138-06-11 00:53:00 2138-07-11 00:53:00
{'is_admission': 1.0, 'is_discharge': 1.0, 'is_any': 2.0}


subject_id,timestamp,is_admission,is_discharge,is_any
i64,datetime[μs],i32,i32,i32
320,,1,1,2


Checking subject_id: 2836
Checking window: trigger->input
2172-12-02 00:26:00 2172-12-03 14:01:00
{'is_admission': 0, 'is_discharge': 1, 'is_any': 1}


subject_id,timestamp,is_admission,is_discharge,is_any
i64,datetime[μs],i32,i32,i32
2836,,0,1,1


Checking window: input->target
2172-12-03 14:01:00 2173-01-02 14:01:00
{'is_admission': 1.0, 'is_discharge': 1.0, 'is_any': 2.0}


subject_id,timestamp,is_admission,is_discharge,is_any
i64,datetime[μs],i32,i32,i32
5672,,1,1,2


In [6]:
config_path = '../sample_configs/inhospital_mortality.yaml'
df_result = main.query_task(config_path, ESD_data, verbose=True)

Loading config...

Generating predicate columns...



Added predicate column is_admission.
Added predicate column is_discharge.
Added predicate column is_death.
Added predicate column is_discharge_or_death.

Building tree...
trigger
┣━━ gap
┃   ┗━━ target
┗━━ input


12127 subjects (14623763 rows) were excluded due to trigger event: admission.


Querying...


Querying subtree rooted at gap...
342 subjects (631 rows) were excluded due to constraint: [(col("is_admission")) <= (0)].
7143 subjects (17534 rows) were excluded due to constraint: [(col("is_discharge")) <= (0)].


Querying subtree rooted at target...
75 subjects (100 rows) were excluded due to constraint: [(col("is_discharge_or_death")) >= (1)].
5721 subjects (14185 rows) were excluded due to constraint: [(col("is_discharge_or_death")) <= (1)].


Querying subtree rooted at input...


Done.



In [7]:
df_result

subject_id,trigger/timestamp,gap/timestamp,target/timestamp,input/timestamp,gap/window_summary,target/window_summary,input/window_summary,label
u16,datetime[μs],datetime[μs],datetime[μs],datetime[μs],struct[5],struct[5],struct[5],i32
0,2125-11-30 16:02:00,2125-12-02 16:02:00,2125-12-03 14:44:00,1852-02-16 16:02:00,"{0,0,0,0,0}","{0,1,0,1,1}","{3,2,0,2,5}",0
0,2130-04-15 21:10:00,2130-04-17 21:10:00,2130-04-19 16:00:00,1856-07-01 21:10:00,"{0,0,0,0,0}","{0,1,0,1,1}","{10,9,0,9,19}",0
0,2130-08-21 15:26:00,2130-08-23 15:26:00,2130-08-23 16:40:00,1856-11-06 15:26:00,"{0,0,0,0,0}","{0,1,0,1,1}","{12,11,0,11,23}",0
0,2130-10-08 19:03:00,2130-10-10 19:03:00,2130-10-12 19:05:00,1856-12-24 19:03:00,"{0,0,0,0,0}","{0,1,0,1,1}","{14,13,0,13,27}",0
0,2130-10-19 16:20:00,2130-10-21 16:20:00,2130-10-22 15:13:00,1857-01-04 16:20:00,"{0,0,0,0,0}","{0,1,0,1,1}","{16,15,0,15,31}",0
0,2130-12-27 21:23:00,2130-12-29 21:23:00,2130-12-30 15:33:00,1857-03-14 21:23:00,"{0,0,0,0,0}","{0,1,0,1,1}","{21,20,0,20,41}",0
0,2131-01-07 20:39:00,2131-01-09 20:39:00,2131-01-20 05:15:00,1857-03-25 20:39:00,"{0,0,0,0,0}","{0,1,1,1,1009}","{22,21,0,21,43}",1
1,2128-07-29 17:01:00,2128-07-31 17:01:00,2128-07-31 18:00:00,1854-10-15 17:01:00,"{0,0,0,0,0}","{0,1,0,1,1}","{1,0,0,0,1}",0
1,2129-08-04 12:44:00,2129-08-06 12:44:00,2129-08-18 16:53:00,1855-10-21 12:44:00,"{0,0,0,0,124}","{0,1,0,1,242}","{2,1,0,1,3}",0
1,2130-09-23 21:59:00,2130-09-25 21:59:00,2130-09-29 18:55:00,1856-12-09 21:59:00,"{0,0,0,0,113}","{0,1,0,1,101}","{4,3,0,3,372}",0
