In [None]:
%load_ext autoreload
%autoreload 2

import polars as pl

from ethos.constants import PROJECT_ROOT
from ethos.inference.constants import Task

# results from Google Drive
result_dir = PROJECT_ROOT / "results"
# data dir with tokenized mimic_ed
data_dir = PROJECT_ROOT / "data"
tokenized_data_dir = data_dir / "tokenized_datasets" / "mimic_ed_procedures/test"

In [None]:
from ethos.datasets import (
    CriticalOutcomeAtTriageDataset,
    EdReattendenceDataset,
    HospitalAdmissionAtTriageDataset,
    TimelineDataset,
)

gts = {
    Task.ED_HOSPITALIZATION: HospitalAdmissionAtTriageDataset,
    Task.ED_CRITICAL_OUTCOME: CriticalOutcomeAtTriageDataset,
    Task.ED_REPRESENTATION: EdReattendenceDataset,
}


def dataset2df(dataset):
    return pl.from_dicts(y for _, y in dataset).with_columns(
        pl.col("prediction_time").cast(pl.Datetime)
    )


gts = {task: dataset2df(cls(tokenized_data_dir)) for task, cls in gts.items()}

In [None]:
# Change the task to see the label mismatch
task = Task.ED_REPRESENTATION

ed_bench_dir = result_dir / "baseline_ed_bench"

ed_bench_results = (
    pl.scan_parquet((ed_bench_dir / task).with_suffix(".parquet"))
    .with_columns(
        pl.col("intime", "outtime").str.strptime(pl.Datetime, "%Y-%m-%d %H:%M:%S"),
    )
    .sort("subject_id", "intime", "outtime")
    .collect()
)

In [None]:
# This can be generated in `notebooks/all_task_label_dumps.ipynb`
ed_task_labels_dir = data_dir / "ed_task_labels"

ethos_labels = (
    pl.scan_parquet((ed_task_labels_dir / task).with_suffix(".parquet"))
    .filter(fold="test")
    .sort("subject_id", "time")
    .collect()
)
mismatch = ethos_labels.join(
    ed_bench_results,
    left_on=["subject_id", "time"],
    right_on=["subject_id", "outtime" if task == Task.ED_REPRESENTATION else "intime"],
    how="full",
).filter(pl.col("boolean_value") != pl.col("boolean_value_right"))
mismatch.join(
    gts[task],
    left_on=["subject_id", "time"],
    right_on=["patient_id", "prediction_time"],
    how="left",
    maintain_order="left",
)

In [None]:
# Print timeline fragments for the selected label mismatch
n = 0

d = TimelineDataset(tokenized_data_dir)

data_idx = mismatch.join(
    gts[task],
    left_on=["subject_id", "time"],
    right_on=["patient_id", "prediction_time"],
    maintain_order="left",
).row(n, named=True)["data_idx"]

timeline_start = data_idx
timeline_end = d.patient_data_end_at_idx[timeline_start].item() + 5
print(timeline_start, timeline_end)

timeline_slice = slice(timeline_start, timeline_end)
tokens = d.vocab.decode(d.tokens[timeline_slice])
times = d.times[timeline_slice].tolist()
timeline = pl.DataFrame(
    [tokens, times], schema={"token": pl.Utf8, "time": pl.Datetime}
).with_row_index(offset=timeline_start)
timeline