#  03 - Outcomes - Sample Size Estimations for Research Questions

1. Diagnostic Performance Compared to Bayley Screener
    - Sens, Spec, PPV, NPV with 95% CI halfwidth of 0.1
2. Agreement with Bayley
3. Test-Retest Reliability

### Estimated Prevalence of Neuromotor Development Problems

- Gross Motor (4 months): 2.6%
- Gross Motor (12 months): 3.6%
- Gross Motor (12 months, Preterm): 12%





In [None]:
%reload_ext autoreload
%autoreload 2

import math

from scipy.stats import norm

import polars as pl
from polars import DataFrame

from early_markers.cribsy.common.enums import MetricType
from early_markers.cribsy.common.metrics_n import RocMetricSampleSize
from early_markers.cribsy.common.constants import (
    IPC_DIR,
    DETAIL_MAP,
)
# from early_markers.cribsy.common.metrics_n import MetricsSampleSize
#
# mss = MetricsSampleSize()
#
rev_map = {v: k for k, v in DETAIL_MAP.items()}

df_top_auc = pl.read_ipc(IPC_DIR / "top_25_auc.ipc").rename(rev_map)
df_top_j = pl.read_ipc(IPC_DIR / "top_25_j.ipc").rename(rev_map)

# ss = RocMetricSampleSize()
#
# for row in df_top_auc.rows(named=True):
#

#
# for row in df_top_auc.rows(named=True):
#     mss.estimate_n_for_sensitivity(row["Sensitivity"], ci_width=0.2, prevalence=0.05)
#     mss.estimate_n_for_specificity(row["Specificity"], ci_width=0.2, prevalence=0.05)
#     mss.estimate_n_for_ppv(ppv=row["PPV"], sensitivity=row["Sensitivity"], ci_width=0.2, prevalence=0.05)
#     mss.estimate_n_for_npv(row["NPV"], sensitivity=row["Sensitivity"], specificity=row["Specificity"], ci_width=0.2, prevalence=0.05)
#     mss.estimate_n_for_f1(row["F1"], row["PPV"], sensitivity=row["Sensitivity"], specificity=row["Specificity"], ci_width=0.2, prevalence=0.05)
#     mss.estimate_n_for_accuracy(row["Accuracy"], ci_width=0.2, prevalence=0.05)
#
# results = mss.results_as_frames

# PEB 2025.03.20 23:19 => TODO: in BayesianData class create base dfs from roc metrics for further analyses (e.g., sensitivity instead of sensitivity with ci) .  Can select from bd._metrics[<rfe_name>]["metrics"] df.

In [None]:
-0.685313 + 0.987278

In [None]:
def se_from_ci(ci_width: float, conf_level: float = 0.95):
    alpha = 1 - conf_level
    z = norm.ppf(1 - (alpha / 2))
    se = ci_width / (2 * z)
    return se

# [0.685313, 0.987278] sens ci

In [None]:
se_from_ci(0.301965)

In [None]:


def se_from_metrics(
    metric_type: MetricType,
    sens: float | None = None,
    spec: float | None = None,
    ppv: float | None = None,
    npv: float | None = None,
    acc: float | None = None,
    prev: float | None = None,
    n: int | None = None,
    tp: int | None = None,
    tn: int | None = None,
    fp: int | None = None,
    fn: int | None = None,
    ci: float | None = None,
):
    if n is None and all([x is not None for x in [tp, tn, fp, fn]]):
        n = sum([tp, tn, fp, fn])

    se = None
    match metric_type:
        case MetricType.SENSITIVITY:
            num = sens * (1 - sens)
            den = tp + fn
            se = math.sqrt(num / den)
        case MetricType.SPECIFICITY:
            num = spec * (1 - spec)
            den = tn + fp
            se = math.sqrt(num / den)
        case MetricType.PPV:
            num = ppv * (1 - ppv)
            den = tp + fp
            se = math.sqrt(num / den)
        case MetricType.NPV:
            num = npv * (1 - npv)
            den = tn + fn
            se = math.sqrt(num / den)
        case MetricType.F1:
            pr = ppv
            re = sens
            se_pr = se_from_metrics(
                MetricType.PPV, sens=sens, spec=spec, ppv=ppv, npv=npv, acc=acc,
                prev=prev, n=n, tp=tp, tn=tn, fp=fp, fn=fn
            )
            se_re = se_from_metrics(
                MetricType.SENSITIVITY, sens=sens, spec=spec, ppv=ppv, npv=npv,
                acc=acc, prev=prev, n=n, tp=tp, tn=tn, fp=fp, fn=fn
            )
            c1 = (pr * (1 - pr) * (1 - re)) / prev
            c2 = (pr * (1 - pr) * spec) / (1 - prev)
            cov = (c1 + c2) / n
            num = (re**4 * se_pr**2) + (2 * pr**2 * re**2 * cov) + (pr**4 * se_re**2)
            den = (pr + re)**4
            se = math.sqrt(4 * num / den)
        case MetricType.ACCURACY:
            num = acc * (1 - acc)
            den = n
            se = math.sqrt(num / den)
        case _:
            ...
    return se


def estimate_n(
    metric_type: MetricType,
    sens: float | None = None,
    spec: float | None = None,
    ppv: float | None = None,
    npv: float | None = None,
    acc: float | None = None,
    prev: float | None = None,
    n: int | None = None,
    tp: int | None = None,
    tn: int | None = None,
    fp: int | None = None,
    fn: int | None = None,
    ci: float | None = None,
):
    if n is None and all([x is not None for x in [tp, tn, fp, fn]]):
        n = sum([tp, tn, fp, fn])

    if ci is None:
        se = se_from_metrics(metric_type=metric_type, sens=sens, spec=spec, ppv=ppv, npv=npv, acc=acc, prev=prev, n=n, tp=tp, tn=tn, fp=fp, fn=fn,)
    else:
        se = se_from_ci(ci)

    est_n = None
    match metric_type:
        case MetricType.SENSITIVITY:
            num = sens * (1 - sens)
            den = se**2 * prev
            est_n = num / den
        case MetricType.SPECIFICITY:
            num = spec * (1 - spec)
            den = se**2 * (1 - prev)
            est_n = num / den
        case MetricType.PPV:
            num = ppv**2 * (1-ppv)
            den = se**2 * prev * sens
            est_n = num / den
        case MetricType.NPV:
            num = npv * (1 - ppv)
            den = se**2 * (spec * (1 - prev) + prev * (1 - sens))
            est_n = num / den
        case MetricType.F1:
            pr = ppv
            re = sens
            se_pr = se_from_metrics(MetricType.PPV, sens=sens, spec=spec, ppv=ppv, npv=npv, acc=acc, prev=prev, n=n, tp=tp, tn=tn, fp=fp, fn=fn)
            se_re = se_from_metrics(MetricType.SENSITIVITY, sens=sens, spec=spec, ppv=ppv, npv=npv, acc=acc, prev=prev, n=n, tp=tp, tn=tn, fp=fp, fn=fn)
            num1 = 2 * pr**2 * re**2
            num2 = pr * (1 - pr) * (1 - re) / prev
            num3 = pr * (1 - pr) * spec / (1 - prev)
            num = num1 * (num2 + num3)
            den1 = se**2 * (pr + re)**4 / 4
            den2 = re**4 * se_pr**2
            den3 = pr**4 * se_re**2
            den = den1 - den2 - den3
            est_n = num / den
        case MetricType.ACCURACY:
            est_n = math.sqrt(acc * (1 - acc) / n)
        case _:
            ...
    return est_n


metric = []
from_ci = []
from_metrics = []

test = []
for t in list(MetricType):
    test.append(
        {
            "metric": t.name,
            "from_ci": se_from_ci(ci_width=0.2, conf_level=0.95),
            "from_metrics": se_from_metrics(
                metric_type=t, prev=0.05,
                sens=0.928571, spec=0.729323, ppv=0.265306, npv=0.989796, acc=0.748299,
                tp=13, tn=97, fp=36, fn=1, n=147
            ),
        }
    )

DataFrame(test)

In [None]:
n = 147
prev = 0.05
test = []
for t in list(MetricType):
    est_n = estimate_n(t, prev=prev,
        sens=0.928571, spec=0.729323, ppv=0.265306, npv=0.989796, acc=0.748299,
        tp=13, tn=97, fp=36, fn=1, n=n, ci=0.2
    )
    n_pos = round(est_n * prev)
    n_neg = round(est_n * (1 - prev))
    n_tot = n_pos + n_neg
    test.append(
        {
            "metric": t.name,
            "se_from_ci": se_from_ci(ci_width=0.2, conf_level=0.95),
            "se_from_metrics": se_from_metrics(
                metric_type=t, prev=prev,
                sens=0.928571, spec=0.729323, ppv=0.265306, npv=0.989796, acc=0.748299,
                tp=13, tn=97, fp=36, fn=1, n=n
            ),
            "n": n_tot,
            "n_pos": n_pos,
            "n_neg":n_neg,
        }
    )

DataFrame(test)