In [None]:
%load_ext autoreload
%autoreload 2

import polars as pl

from ethos.constants import PROJECT_ROOT

data_dir = PROJECT_ROOT / "data/tokenized_datasets"

dataset_dir = data_dir / "mimic_synth"

In [None]:
counts_orig = pl.read_csv(dataset_dir / "big/code_counts.csv")

group_expr = (
    pl.when(
        pl.col("code").str.starts_with("ATC")
        & ~pl.col("code").str.starts_with("ATC//4//")
        & ~pl.col("code").str.starts_with("ATC//SFX//")
    )
    .then(pl.lit("ATC"))
    .when(pl.col("code").str.slice(0, 3).is_in(["ICD", "ATC"]))
    .then(
        pl.col("code").str.slice(0, 3)
        + pl.lit("_")
        + pl.col("code").str.split("//").list.get(1, null_on_oob=True)
    )
    .otherwise(pl.col("code").str.split("//").list.get(0))
    .alias("code")
)

df = (
    counts_orig.group_by(group_expr)
    .agg(pl.sum("count"), pl.count("count").alias("n"))
    .sort("count", descending=True, nulls_last=True)
    # .filter((pl.col("count") != pl.col("count_2")) | (pl.col("n") != pl.col("n_2")))
    # .filter(pl.col( "code" ).str.to_uppercase().str.contains("BLOOD"))
)

sfx_to_counts = [
    (sfx, pl.read_csv(dataset_dir / f"big{sfx}/code_counts.csv"))
    for sfx in (
        "_synth",
        "_synth_temp0.9",
        "_synth_temp0.7",
        "_synth_temp1.1",
    )
]

for sfx, counts in sfx_to_counts:
    df = df.join(
        (counts.group_by(group_expr).agg(pl.sum("count"), pl.count("count").alias("n"))),
        on="code",
        how="full",
        suffix=f"_{sfx}",
        coalesce=True,
    )
df = df.sort("count", descending=True, nulls_last=True).rename(
    {
        "code": "Code Group",
        "count": "count__original",
        "n": "n__original",
        "count__synth": "count__synth_temp1",
        "n__synth": "n__synth_temp1",
    }
)
df

In [None]:
total = df.sum()
total[0, "Code Group"] = "Total"
total

In [None]:
import pandas as pd

pdf = pl.concat([df, df.sum()]).with_columns(
    pl.selectors.numeric().map_elements(lambda v: f"{v:,}", return_dtype=pl.Utf8)
)
pdf = pdf.with_columns(pl.col("Code Group").fill_null("Total")).to_pandas()
pdf.columns = pd.MultiIndex.from_tuples(col.split("__")[::-1] for col in pdf.columns)
print(
    pdf.to_latex(
        index=False,
        multicolumn=True,
        multicolumn_format="c",
        escape=True,
        column_format="l" + "c" * (len(pdf.columns) - 1),
        label="tab:token-summary-in-synthetic",
    )
)

In [None]:
import numpy as np

from ethos.datasets import TimelineDataset

data_stats = []

for fold_dir in dataset_dir.glob("train*"):
    d = TimelineDataset(fold_dir)
    timeline_lengths = [
        length
        for shard in d._data.shards
        for length in shard["patient_offsets"][1:] - shard["patient_offsets"][:-1]
    ]

    quantiles = np.quantile(timeline_lengths, [0, 0.25, 0.5, 0.75, 1]).astype(int).tolist()
    data_stats.append((fold_dir.name, f"{len(d.tokens):,}", quantiles))

q_cols = ["min", "25%", "50%", "75%", "max"]
data_stats = (
    pl.DataFrame(
        data_stats,
        schema=["dataset", "num_tokens", "timeline_lengths"],
        orient="row",
    )
    .with_columns(pl.col("timeline_lengths").list.to_struct(fields=q_cols).struct.unnest())
    .with_columns(pl.col(q_cols).map_elements(lambda s: f"{s:,}", return_dtype=pl.String))
    .drop("timeline_lengths")
)
data_stats