In [None]:
import os
from pathlib import Path
import polars as pl
pl.enable_string_cache()

from tqdm.auto import tqdm

In [None]:
for env_str in Path(".env").read_text().split():
    var, val = env_str.split("=")
    print(f"Setting {var} to {val}")
    os.environ[var] = val

In [None]:
MEDS_dir = Path(os.environ["MEDS_DIR"])
MEDS_final_cohort = MEDS_dir / "final_cohort"
shards = [str(fp.relative_to(MEDS_final_cohort)) for fp in MEDS_final_cohort.glob("**/*.parquet")]
train_shards = [s for s in shards if s.startswith("train/")]

In [None]:
%%time
code_df = None
for s in tqdm(train_shards):
    df = (
        pl.scan_parquet(MEDS_final_cohort / s)
        .drop_nulls(subset="code")
        .group_by("code")
        .agg(pl.col("patient_id").n_unique().alias("n_patients"), pl.len().alias("n_occurrences"))
    )

    if df.select(pl.col("code").is_null().any()).collect().item():
        raise ValueError

    if code_df is None: code_df = df
    else:
        code_df = (
            code_df
            .join(df, suffix="_right", on="code", how="outer")
            .select(
                pl.coalesce("code", "code_right").alias("code"),
                (pl.col("n_patients").fill_null(0) + pl.col("n_patients_right").fill_null(0)).alias("n_patients"),
                (pl.col("n_occurrences").fill_null(0) + pl.col("n_occurrences_right").fill_null(0)).alias("n_occurrences"),
            )
        )

code_df = code_df.collect()

In [None]:
code_df = code_df.filter(pl.col("n_patients") > 10).sort("n_occurrences", descending=True)

In [None]:
code_df

In [None]:
code_strs = code_df["code"].to_list()

In [None]:
hosp_admit_codes = [c for c in code_strs if c.startswith("HOSPITAL_ADMISSION//")]
# icu_admit_codes = [c for c in code_strs if c.startswith("UNIT_ADMISSION//") and "icu" in c.lower()]
icu_admit_codes = [
    'UNIT_ADMISSION//ICU//stepdown/other',
    'UNIT_ADMISSION//ICU//transfer',
    'UNIT_ADMISSION//Other ICU//admit',
    'UNIT_ADMISSION//Other ICU//stepdown/other',
    'UNIT_ADMISSION//ICU//admit',
    'UNIT_ADMISSION//ICU//readmit'
]
hosp_disch_codes = [c for c in code_strs if c.startswith("HOSPITAL_DISCHARGE//")]
icu_disch_codes = [c for c in code_strs if c.startswith("UNIT_DISCHARGE//")]
death_codes = [c for c in code_strs if "death" in c.lower()]

In [None]:
import string

def make_plain_predicate(code: str, i: int, base_name: str | None = None) -> str:
    pred_name = f"{base_name if base_name is not None else code.split('//')[0].lower()}_{i}"
    return "\n".join([f"  {pred_name}:", f"    code: {code}"])
def make_or_predicate(codes: str, pred_name: str, base_name: str | None = None) -> str:
    codes_as_preds = [f"{base_name if base_name is not None else c.split('//')[0].lower()}_{i}" for i, c in enumerate(codes)]
    return "\n".join([f"  {pred_name}:", f"    expr: or({','.join(codes_as_preds)})"])

## Hospital Admission

In [None]:
for i, code in enumerate(hosp_admit_codes):
    print(make_plain_predicate(code, i))
print(make_or_predicate(hosp_admit_codes, "hospital_admission"))

## Hospital Discharge

In [None]:
for i, code in enumerate(hosp_disch_codes):
    print(make_plain_predicate(code, i))
print(make_or_predicate(hosp_disch_codes, "hospital_discharge"))

## ICU Admission

In [None]:
for i, code in enumerate(icu_admit_codes):
    print(make_plain_predicate(code, i))
print(make_or_predicate(icu_admit_codes, "icu_admission"))

## ICU Discharge

In [None]:
for i, code in enumerate(icu_disch_codes):
    print(make_plain_predicate(code, i))
print(make_or_predicate(icu_disch_codes, "icu_discharge"))

## Death

In [None]:
for i, code in enumerate(death_codes):
    print(make_plain_predicate(code, i, "death"))
print(make_or_predicate(death_codes, "death", base_name="death"))