In [2]:
import os
import pandas as pd
import polars as pl
from pathlib import Path

from esgpt_task_querying import main
from EventStream.data.dataset_polars import Dataset

%load_ext autoreload
%autoreload 2

pd.set_option('display.max_rows', 100)
pl.Config.set_tbl_cols(100)
pl.Config.set_tbl_rows(100)

data_path = '../MIMIC_ESD_new_schema_08-31-23-1'
os.getcwd()

  from .autonotebook import tqdm as notebook_tqdm


'/home/justinxu/esgpt/ESGPTTaskQuerying/to_organize'

In [3]:
DATA_DIR = Path(data_path)
ESD = Dataset.load(DATA_DIR)

events_df = ESD.events_df.filter(~pl.all_horizontal(pl.all().is_null()))
dynamic_measurements_df = ESD.dynamic_measurements_df.filter(
    ~pl.all_horizontal(pl.all().is_null())
)

ESD_data = (
    events_df.join(dynamic_measurements_df, on="event_id", how="left")
    .drop(["event_id"])
    .sort(by=["subject_id", "timestamp", "event_type"])
)

if ESD_data["timestamp"].dtype != pl.Datetime:
    ESD_data = ESD_data.with_columns(
        pl.col("timestamp")
        .str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M")
        .cast(pl.Datetime)
    )

Updating config.save_dir from /n/data1/hms/dbmi/zaklab/RAMMS/data/MIMIC_IV/ESD_new_schema_08-31-23-1 to ../MIMIC_ESD_new_schema_08-31-23-1
Loading events from ../MIMIC_ESD_new_schema_08-31-23-1/events_df.parquet...
Loading dynamic_measurements from ../MIMIC_ESD_new_schema_08-31-23-1/dynamic_measurements_df.parquet...


In [None]:
def has_event_type(type_str: str) -> pl.Expr:
    has_event_type = pl.col("event_type").cast(pl.Utf8).str.contains(type_str)
    # has_event_type = event_types.str.contains(type_str)
    return has_event_type

In [None]:
def validate_query(config_path, data_path, ESD_data, samples=10, verbose=False, seed=42):
    cfg = main.load_config(config_path)
    df_data = main.generate_predicate_columns(cfg, ESD_data, verbose=verbose)

    df_result = main.query_task(config_path, data_path, verbose=verbose)

    if df_result.shape[0] == 0:
        print("No results found.")
        return

    try:
        validation = df_result.sample(samples, seed=seed)
        print(validation)
        validation = validation.to_pandas()
    except:
        validation = df_result.to_pandas()

    # validation = df_result.filter(df_result['subject_id'] == 1916).to_pandas()

    for i, row in validation.iterrows():
        if i % 100 == 0:
            print(f"Validating row: {i}/{len(validation)}...")
        #filter ESD_data for the subject_id
        subject_id = row['subject_id']
        filtered_data = df_data.filter(pl.col("subject_id") == subject_id)

        if verbose:
            print('Checking subject_id:', subject_id)

        #sort the values of every column that has 'timestamp' in its name
        timestamps = []
        for col in validation.columns:
            if 'timestamp' in col:
                timestamps.append((row[col], col))
        timestamps.sort(key=lambda x: x[0])

        filtered_data = filtered_data.filter(pl.col("timestamp") >= timestamps[0][0])
        filtered_data = filtered_data.filter(pl.col("timestamp") <= timestamps[-1][0])

        trigger_event = f'is_{cfg.windows.trigger.start}'
        assert filtered_data.filter(pl.col("timestamp") == row['trigger/timestamp']).select(trigger_event).to_pandas().values.flatten()[0] == 1
        
        after_trigger = False
        for i in range(len(timestamps)-1):
            window = timestamps[i]
            name = window[1].split('/')[0]
            if verbose:
                print(f"Checking window: {name}->{timestamps[i+1][1].split('/')[0]}")
            if name == 'trigger':
                after_trigger = True
            if after_trigger:
                name = timestamps[i+1][1].split('/')[0]

            if verbose:
                print(timestamps[i][0], timestamps[i+1][0])
                print(row[f'{name}/window_summary'])
            window_data = filtered_data.filter(pl.col("timestamp") > timestamps[i][0])
            window_data = window_data.filter(pl.col("timestamp") <= timestamps[i+1][0])

            sum_counts = window_data.sum()
            if verbose:
                display(sum_counts)

            for predicate in row[f'{name}/window_summary']:
                count = row[f'{name}/window_summary'][predicate]
                if not count:
                    count = 0
                assert sum_counts.select(predicate).to_pandas().values.flatten()[0] == count, (subject_id, timestamps[i][0], timestamps[i+1][0], row[f'{name}/window_summary'], sum_counts)

In [None]:
# config_path = '../sample_configs/inhospital_mortality.yaml'
# config_path = '../sample_configs/abnormal_lab.yaml'
# config_path = '../sample_configs/imminent_mortality.yaml'
config_path = '../sample_configs/intervention_weaning.yaml'
# config_path = '../sample_configs/long_term_incidence.yaml'
# config_path = '../sample_configs/outlier_detection.yaml'
# config_path = '../sample_configs/readmission_risk.yaml'
validate_query(config_path, data_path, ESD_data, samples=10, verbose=True, seed=42)

NameError: name 'validate_query' is not defined