In [1]:
import os
os.environ['POLARS_MAX_THREADS'] = '16'
import pandas as pd
import polars as pl
from pathlib import Path

from esgpt_task_querying import main
from EventStream.data.dataset_polars import Dataset

pd.set_option('display.max_rows', 100)
pl.Config.set_tbl_cols(100)
pl.Config.set_tbl_rows(100)

data_path = '../MIMIC_ESD_new_schema_08-31-23-1'
os.getcwd()

  from .autonotebook import tqdm as notebook_tqdm


'/home/justinxu/esgpt/ESGPTTaskQuerying/validation'

In [2]:
# DATA_DIR = Path(data_path)
# ESD = Dataset.load(DATA_DIR)

# events_df = ESD.events_df.filter(~pl.all_horizontal(pl.all().is_null()))
# dynamic_measurements_df = ESD.dynamic_measurements_df.filter(
#     ~pl.all_horizontal(pl.all().is_null())
# )

# ESD_data = (
#     events_df.join(dynamic_measurements_df, on="event_id", how="left")
#     .drop(["event_id"])
#     .sort(by=["subject_id", "timestamp", "event_type"])
# )

# if ESD_data["timestamp"].dtype != pl.Datetime:
#     ESD_data = ESD_data.with_columns(
#         pl.col("timestamp")
#         .str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M")
#         .cast(pl.Datetime)
#     )

In [3]:
def has_event_type(type_str: str) -> pl.Expr:
    has_event_type = pl.col("event_type").cast(pl.Utf8).str.contains(type_str)
    # has_event_type = event_types.str.contains(type_str)
    return has_event_type

In [4]:
def validate_query(config_path, ESD_data, samples=10, seed=42):
    cfg = main.load_config(config_path)
    df_data = main.generate_predicate_columns(cfg, ESD_data)
    display(df_data)

    df_result = main.query_task(config_path, ESD_data)
    display(df_result)

    if df_result.shape[0] == 0:
        print("No results found.")
        return

    try:
        validation = df_result.sample(samples, seed=seed)
        validation = validation.to_pandas()
    except:
        validation = df_result.to_pandas()

    for i, row in validation.iterrows():
        if i % 100 == 0:
            print(f"Validating row: {i+1}/{len(validation)}...")
        #filter ESD_data for the subject_id
        subject_id = row['subject_id']
        filtered_data = df_data.filter(pl.col("subject_id") == subject_id)

        print('Checking subject_id:', subject_id)

        # sort the values of every column that has 'timestamp' in its name
        timestamps = []
        for col in validation.columns:
            if 'timestamp' in col:
                timestamps.append((row[col], col))
        timestamps.sort(key=lambda x: x[0])

        filtered_data = filtered_data.filter(pl.col("timestamp") >= timestamps[0][0])
        filtered_data = filtered_data.filter(pl.col("timestamp") <= timestamps[-1][0])

        trigger_event = f'is_{cfg.windows.trigger.start}'
        assert filtered_data.filter(pl.col("timestamp") == row['trigger/timestamp']).select(trigger_event).to_pandas().values.flatten()[0] == 1
        
        after_trigger = False
        for i in range(len(timestamps)-1):
            window = timestamps[i]
            name = window[1].split('/')[0]
            print(f"Checking window: {name}->{timestamps[i+1][1].split('/')[0]}")
            if name == 'trigger':
                after_trigger = True
            if after_trigger:
                name = timestamps[i+1][1].split('/')[0]

            if cfg.windows[name].st_inclusive:
                window_data = filtered_data.filter(pl.col("timestamp") >= timestamps[i][0])
            else:
                window_data = filtered_data.filter(pl.col("timestamp") > timestamps[i][0])

            if cfg.windows[name].end_inclusive:
                window_data = window_data.filter(pl.col("timestamp") <= timestamps[i+1][0])
            else:
                window_data = window_data.filter(pl.col("timestamp") < timestamps[i+1][0])

            sum_counts = window_data.sum()

            for predicate in row[f'{name}/window_summary']:
                count = row[f'{name}/window_summary'][predicate]
                if not count:
                    count = 0
                try:
                    assert sum_counts.select(predicate).to_pandas().values.flatten()[0] == count
                except:
                    print('Failed for subject_id:', subject_id, ", on window:", name, f"{timestamps[i][0]}->{timestamps[i+1][0]}")
                    print(row[f'{name}/window_summary'])
                    print(sum_counts)
                    raise AssertionError(f"Predicate {predicate} failed for window {name}.")

In [5]:
config_path = '../sample_configs/inhospital_mortality.yaml'
# config_path = '../sample_configs/abnormal_lab.yaml'
# config_path = '../sample_configs/imminent_mortality.yaml'
# config_path = '../sample_configs/intervention_weaning.yaml'
# config_path = '../sample_configs/long_term_incidence.yaml'
# config_path = '../sample_configs/readmission_risk.yaml'
df_temp = pl.from_pandas(pd.read_csv('../sample_data/sample.csv'))
df_temp = df_temp.with_columns(
    pl.col("timestamp")
    .str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M")
    .cast(pl.Datetime)
)
# display(df_temp)
validate_query(config_path, df_temp, samples=10, seed=42)

[32m2024-04-14 03:40:33.581[0m | [34m[1mDEBUG   [0m | [36mesgpt_task_querying.event_predicates[0m:[36mgenerate_predicate_columns[0m:[36m118[0m - [34m[1mAdded predicate column is_admission.[0m
[32m2024-04-14 03:40:33.582[0m | [34m[1mDEBUG   [0m | [36mesgpt_task_querying.event_predicates[0m:[36mgenerate_predicate_columns[0m:[36m118[0m - [34m[1mAdded predicate column is_discharge.[0m
[32m2024-04-14 03:40:33.583[0m | [34m[1mDEBUG   [0m | [36mesgpt_task_querying.event_predicates[0m:[36mgenerate_predicate_columns[0m:[36m118[0m - [34m[1mAdded predicate column is_death.[0m
[32m2024-04-14 03:40:33.584[0m | [34m[1mDEBUG   [0m | [36mesgpt_task_querying.event_predicates[0m:[36mgenerate_predicate_columns[0m:[36m125[0m - [34m[1mAdded predicate column is_discharge_or_death.[0m


subject_id,timestamp,is_admission,is_discharge,is_death,is_discharge_or_death,is_any
i64,datetime[μs],i32,i32,i32,i32,i32
1,1989-12-01 12:03:00,1,0,0,0,1
1,1989-12-01 13:14:00,0,0,0,0,1
1,1989-12-01 15:17:00,0,0,0,0,1
1,1989-12-01 16:17:00,0,0,0,0,1
1,1989-12-01 20:17:00,0,0,0,0,1
1,1989-12-02 03:00:00,0,0,0,0,1
1,1989-12-02 09:00:00,0,0,0,0,1
1,1989-12-02 15:00:00,0,1,0,1,1
1,1991-01-27 23:32:00,1,0,0,0,1
1,1991-01-27 23:46:00,0,0,0,0,1


[32m2024-04-14 03:40:33.593[0m | [34m[1mDEBUG   [0m | [36mesgpt_task_querying.main[0m:[36mquery_task[0m:[36m63[0m - [34m[1mLoading config...[0m
[32m2024-04-14 03:40:33.606[0m | [34m[1mDEBUG   [0m | [36mesgpt_task_querying.main[0m:[36mquery_task[0m:[36m66[0m - [34m[1mGenerating predicate columns...[0m
[32m2024-04-14 03:40:33.608[0m | [34m[1mDEBUG   [0m | [36mesgpt_task_querying.event_predicates[0m:[36mgenerate_predicate_columns[0m:[36m118[0m - [34m[1mAdded predicate column is_admission.[0m
[32m2024-04-14 03:40:33.609[0m | [34m[1mDEBUG   [0m | [36mesgpt_task_querying.event_predicates[0m:[36mgenerate_predicate_columns[0m:[36m118[0m - [34m[1mAdded predicate column is_discharge.[0m
[32m2024-04-14 03:40:33.610[0m | [34m[1mDEBUG   [0m | [36mesgpt_task_querying.event_predicates[0m:[36mgenerate_predicate_columns[0m:[36m118[0m - [34m[1mAdded predicate column is_death.[0m
[32m2024-04-14 03:40:33.612[0m | [34m[1mDEBUG   [0m

trigger
┣━━ gap
┃   ┗━━ target
┗━━ input


subject_id,trigger/timestamp,gap/timestamp,target/timestamp,input/timestamp,gap/window_summary,target/window_summary,input/window_summary,label
i64,datetime[μs],datetime[μs],datetime[μs],datetime[μs],struct[5],struct[5],struct[5],i32
1,1991-01-27 23:32:00,1991-01-29 23:32:00,1991-01-31 02:15:00,1990-12-28 23:32:00,"{0,0,0,0,3}","{0,1,0,1,6}","{1,0,0,0,1}",0
3,1996-03-08 02:24:00,1996-03-10 02:24:00,1996-03-12 00:00:00,1996-02-07 02:24:00,"{0,0,0,0,5}","{0,0,1,1,1}","{1,0,0,0,1}",1


Validating row: 1/2...
Checking subject_id: 1
Checking window: input->trigger
Checking window: trigger->gap
Checking window: gap->target
Checking subject_id: 3
Checking window: input->trigger
Checking window: trigger->gap
Checking window: gap->target
