In [1]:
import sys
from pathlib import Path

# Add repo root to sys.path FIRST so local trace package takes precedence
repo_root = Path("/Users/rjm707/Desktop/trace_paper")
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

# Now import from local trace package
from trace.statistics import logit, inv_logit
import polars as pl

In [63]:
folder = "semaglutide-atc"
df = pl.read_csv(f'data/{folder}/combined_estimates.txt').filter(
    pl.col("method") == "IPW"
)
combined_stats = pl.read_csv(
    f'data/{folder}/combined_stats.txt'
)
outcome_stats = combined_stats.pivot(
    index="outcome",
    on="status",
    values=["Outcome", "Total"],
    aggregate_function="mean"
).drop("Outcome_Total", "Total_Total").select(
    "outcome",
    outcome_untreated="Outcome_Untreated",
    Untreated_prevalence=pl.col("Outcome_Untreated") / pl.col("Total_Untreated"),
    outcome_treated="Outcome_Treated",
    treated_prevalence=pl.col("Outcome_Treated") / pl.col("Total_Treated"),
    RR=(
        (pl.col("Outcome_Treated") / pl.col("Total_Treated")) /
        (pl.col("Outcome_Untreated") / pl.col("Total_Untreated"))
    ),
    log_RR=(
        (pl.col("Outcome_Treated") / pl.col("Total_Treated")) /
        (pl.col("Outcome_Untreated") / pl.col("Total_Untreated"))
    ).log().alias("log_RR")
)
# effect_0 -> log^-1(effect_0) -> mean -> probability0 
# effect_1 -> log^-1(effect_1) -> mean -> probability1
# risk ratio -> pobability1 / probability0
# abs RD -> probability1 - probability0
# relative RD -> abs RD / probability0
table = df.group_by("outcome").agg(
    prob0=inv_logit(pl.col("effect_0").map_elements(lambda x: logit(x)).mean()),
    prob1=inv_logit(pl.col("effect_1").map_elements(lambda x: logit(x)).mean()),
).select(
    "outcome",
    rd=(pl.col("prob1") - pl.col("prob0")),
    rel_rd=(pl.col("prob1") - pl.col("prob0")) / pl.col("prob0"),
).join(
    outcome_stats, on="outcome"
)
table.write_csv(
    f"data/{folder}_table.csv"
)


## RR with CI
CI_table = combined_stats.pivot(
    index="outcome",
    on="status",
    values=["Outcome", "Total"],
    aggregate_function=pl.element()
).with_columns(
    log_RR=(
        (pl.col("Outcome_Treated") / pl.col("Total_Treated")) /
        (pl.col("Outcome_Untreated") / pl.col("Total_Untreated"))
    ).list.eval(pl.element().log())
).select(
    "outcome",
    log_RR_mean=pl.col("log_RR").list.mean(),
    log_RR_std=pl.col("log_RR").list.std(),
    count=pl.col("log_RR").list.len()
).select(
    "outcome",
    RR_mean=pl.col("log_RR_mean").exp(),
    CI_low=(pl.col("log_RR_mean") - 1.96 * (pl.col("log_RR_std") / pl.col("count").sqrt())).exp(),
    CI_high=(pl.col("log_RR_mean") + 1.96 * (pl.col("log_RR_std") / pl.col("count").sqrt())).exp()
)
CI_table.write_csv(
    f"data/{folder}_RR_CI_table.csv"
)