In [None]:
%load_ext autoreload
%autoreload 2

import json

import polars as pl
from polars.polars import ComputeError

from ethos.constants import MAPPINGS_DIR, PROJECT_ROOT
from ethos.inference.constants import Task

data_dir = PROJECT_ROOT / "data/"

tasks = [
    Task.DRG_PREDICTION,
    Task.SOFA_PREDICTION,
    Task.READMISSION,
    Task.ICU_ADMISSION,
    Task.HOSPITAL_MORTALITY,
]

In [None]:
def load_metadata(metadata_fp) -> dict:
    meta = json.loads(metadata_fp.read_text())
    if not (res_fps := list(metadata_fp.parent.rglob("*.parquet"))):
        return meta
    try:
        meta["rep_num"], meta["n"] = (
            pl.scan_parquet(res_fps, glob=False)
            .group_by("data_idx")
            .agg(rep_num=pl.count("data_idx"))
            .select(rep_num=pl.mean("rep_num"), n=pl.len())
            .collect()
            .row(0)
        )
    except ComputeError as e:
        print(f"Error reading {metadata_fp.parent}: {e}")
    return meta


metadata = pl.from_dicts(
    [
        load_metadata(fp)
        for task in tasks
        for fp in (PROJECT_ROOT / "results" / task).rglob("metadata.json")
        if fp.parent.name.startswith("mimic_synth_layer_3_do_0.3_")
    ]
)

(
    metadata.select(
        "task",
        "rep_num",
        "n",
        fold=pl.col("input_dir").str.split("/").list.last(),
        temp=pl.col("temperature"),
        model=pl.col("model_fp").str.split("/").list[-2].str.slice(len("layer_3_do_0.3_")),
        variant=pl.when(pl.col("model_fp").str.contains("best"))
        .then(pl.lit("best"))
        .otherwise(pl.lit("recent")),
    )
    .with_columns(completion=pl.col("n") / pl.col("n").max().over("fold", "task"))
    .with_columns(
        value=pl.col("completion").map_elements(
            lambda v: "DONE" if v == 1 else f"{v:.0%} ",
            return_dtype=pl.Utf8,
        )
        + pl.col("rep_num").map_elements(lambda v: f" ({v:.1f})", return_dtype=pl.Utf8)
    )
    .drop("n", "completion", "rep_num", "variant")
    .pivot("task", values="value")
    .sort("fold", "model", "temp")
).to_pandas()  # VS Code does not support nice printing of Polars dfs :(