# Label-Shift Correction for WCP-L2D

The CheXpert-to-NIH shift includes massive **prevalence shift**
(e.g., Effusion: 47% &rarr; 4%). The DRE-based WCP only corrects
**covariate shift** (feature distribution), not label shift.

This notebook implements two label-shift correction approaches:

1. **Prior-adjusted CP** &mdash; adjust classifier logits for prevalence shift
   via Bayes' rule, then run standard CP on corrected posteriors.
2. **Label-shift WCP** &mdash; use per-class importance weights
   $w_i = p_{\text{target}}(y_i) / p_{\text{source}}(y_i)$ in the weighted quantile,
   with per-class test weights.

Both require an estimate of the target prevalence. We compare two estimators:
- **BBSE** (Black-Box Shift Estimation) &mdash; confusion matrix inversion (can fail with extreme shifts)
- **MLLS** (Maximum Likelihood Label Shift) &mdash; EM algorithm, more robust

We also combine label-shift correction with DRE (covariate shift) for
a joint correction: **Prior-adjusted + DRE WCP**.

In [None]:
import math

import numpy as np
import torch
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from torchcp.classification.score import RAPS

from wcp_l2d.features import ExtractedFeatures
from wcp_l2d.pathologies import COMMON_PATHOLOGIES
from wcp_l2d.label_utils import extract_binary_labels
from wcp_l2d.dre import AdaptiveDRE
from wcp_l2d.conformal import ConformalPredictor, WeightedConformalPredictor
from wcp_l2d.evaluation import (
    compute_coverage,
    compute_system_accuracy,
    _predictions_from_sets,
    DeferralResult,
)

SEED = 42
EXPERT_ACCURACY = 0.85
FEATURE_DIR = Path("../data/features")
TARGET_PATHOLOGY = "Effusion"

np.random.seed(SEED)
torch.manual_seed(SEED)
print("Setup complete.")

## 1. Data Preparation

In [None]:
# Load features
chexpert = ExtractedFeatures.load(
    FEATURE_DIR / "chexpert_densenet121-res224-chex_features.npz"
)
nih = ExtractedFeatures.load(FEATURE_DIR / "nih_densenet121-res224-chex_features.npz")

# Binary labels for target pathology
chex_feats, chex_labels, _ = extract_binary_labels(
    chexpert.features, chexpert.labels, COMMON_PATHOLOGIES, TARGET_PATHOLOGY
)
nih_feats, nih_labels, _ = extract_binary_labels(
    nih.features, nih.labels, COMMON_PATHOLOGIES, TARGET_PATHOLOGY
)

# Splits
chex_train_feats, chex_temp_feats, chex_train_labels, chex_temp_labels = (
    train_test_split(
        chex_feats, chex_labels, test_size=0.4, random_state=SEED, stratify=chex_labels
    )
)
chex_cal_feats, chex_test_feats, chex_cal_labels, chex_test_labels = train_test_split(
    chex_temp_feats,
    chex_temp_labels,
    test_size=0.5,
    random_state=SEED,
    stratify=chex_temp_labels,
)

# NIH pool (unlabeled, for DRE + BBSE) and test
rng = np.random.RandomState(SEED)
nih_all_perm = rng.permutation(len(nih.features))
nih_pool_feats_all = nih.features[nih_all_perm[: len(nih.features) // 2]]

_, nih_test_feats, _, nih_test_labels = train_test_split(
    nih_feats, nih_labels, test_size=0.5, random_state=SEED, stratify=nih_labels
)

# Classifier + scaler
scaler = StandardScaler()
X_train = scaler.fit_transform(chex_train_feats)
X_cal = scaler.transform(chex_cal_feats)
X_test_nih = scaler.transform(nih_test_feats)
X_pool_nih = scaler.transform(nih_pool_feats_all)

clf = LogisticRegression(solver="lbfgs", max_iter=1000, C=1.0, random_state=SEED)
clf.fit(X_train, chex_train_labels)


def get_binary_logits(clf, X):
    d = clf.decision_function(X)
    return np.column_stack([-d, d])


cal_logits = get_binary_logits(clf, X_cal)
test_nih_logits = get_binary_logits(clf, X_test_nih)

# DRE
dre = AdaptiveDRE(n_components=4, weight_clip=20.0, random_state=SEED)
dre.fit(chex_cal_feats, nih_pool_feats_all)
cal_weights = dre.compute_weights(chex_cal_feats)
test_nih_weights = dre.compute_weights(nih_test_feats)

print(f"CheXpert cal: {len(chex_cal_labels)} (prev={chex_cal_labels.mean():.3f})")
print(f"NIH test:     {len(nih_test_labels)} (prev={nih_test_labels.mean():.3f})")
print(f"NIH AUC: {roc_auc_score(nih_test_labels, clf.predict_proba(X_test_nih)[:, 1]):.4f}")
print(f"DRE ESS: {dre.diagnostics(chex_cal_feats).ess_fraction:.3f}")

## 2. Prevalence Estimation

**BBSE** (Lipton et al., 2018): Estimate confusion matrix $C_{jk} = P(\hat{y}=j \mid y=k)$
on source data, get predicted proportions $\hat{\mu}$ on target, solve $C \cdot p_{\text{target}} = \hat{\mu}$.
Can produce negative values when the classifier has high error rates and the shift is extreme.

**MLLS** (Alexandari et al., 2020): EM algorithm that iteratively adjusts posteriors
using current prevalence estimate, then re-estimates prevalence from adjusted posteriors.
Naturally produces valid probability distributions.

In [None]:
def estimate_prevalence_bbse(clf, X_cal, y_cal, X_target_unlabeled):
    """BBSE: estimate p_target(y) from unlabeled target data."""
    y_pred_cal = clf.predict(X_cal)

    # C[j,k] = P(y_hat=j | y=k) estimated from source calibration
    C = np.zeros((2, 2))
    for k in range(2):
        mask = y_cal == k
        for j in range(2):
            C[j, k] = (y_pred_cal[mask] == j).mean()

    # mu[j] = P_target(y_hat=j) from unlabeled target predictions
    y_pred_target = clf.predict(X_target_unlabeled)
    mu = np.array([(y_pred_target == j).mean() for j in range(2)])

    # Solve C @ p_target = mu
    p_target = np.linalg.solve(C, mu)
    p_target = np.clip(p_target, 1e-4, 1 - 1e-4)
    p_target = p_target / p_target.sum()

    return p_target


def estimate_prevalence_mlls(clf, X_target_unlabeled, p_source, n_iter=200, tol=1e-8):
    """MLLS: Maximum Likelihood Label Shift via EM.

    Iteratively:
      E-step: adjust posteriors using current prevalence ratio
      M-step: re-estimate target prevalence from adjusted posteriors
    """
    probs = clf.predict_proba(X_target_unlabeled)  # [N, K]
    w = np.ones(len(p_source))  # importance weights p_target/p_source

    for _ in range(n_iter):
        # E-step: adjust posteriors
        adjusted = probs * w[np.newaxis, :]
        adjusted = adjusted / adjusted.sum(axis=1, keepdims=True)

        # M-step: estimate target prevalence
        p_target = adjusted.mean(axis=0)
        p_target = np.clip(p_target, 1e-8, 1.0)
        p_target = p_target / p_target.sum()

        w_new = p_target / p_source
        if np.max(np.abs(w_new - w)) < tol:
            break
        w = w_new

    return p_target


# Estimate prevalences
p_source = np.array([1 - chex_cal_labels.mean(), chex_cal_labels.mean()])
p_target_oracle = np.array([1 - nih_test_labels.mean(), nih_test_labels.mean()])
p_target_bbse = estimate_prevalence_bbse(clf, X_cal, chex_cal_labels, X_pool_nih)
p_target_mlls = estimate_prevalence_mlls(clf, X_pool_nih, p_source)

print(f"Source prevalence (CheXpert):  neg={p_source[0]:.4f}  pos={p_source[1]:.4f}")
print(f"BBSE estimate:                neg={p_target_bbse[0]:.4f}  pos={p_target_bbse[1]:.4f}")
print(f"MLLS estimate:                neg={p_target_mlls[0]:.4f}  pos={p_target_mlls[1]:.4f}")
print(f"Oracle (true NIH):            neg={p_target_oracle[0]:.4f}  pos={p_target_oracle[1]:.4f}")
print()
print(f"BBSE error:  |est - true| = {abs(p_target_bbse[1] - p_target_oracle[1]):.4f}")
print(f"MLLS error:  |est - true| = {abs(p_target_mlls[1] - p_target_oracle[1]):.4f}")

## 3. Label-Shift Correction Methods

In [None]:
def adjust_logits(logits, p_source, p_target):
    """Adjust logits for prevalence shift via Bayes' rule.

    p_adjusted(y|x) = softmax(logit_y + log(p_target(y) / p_source(y)))
    """
    log_ratio = np.log(p_target / p_source)
    return logits + log_ratio[np.newaxis, :]


class LabelShiftWCP:
    """Weighted CP with per-class label-shift importance weights.

    Cal weight for sample i:       w_i = p_target(y_i) / p_source(y_i)
    Test weight for candidate y:   w_test(y) = p_target(y) / p_source(y)

    Since weights depend only on class (not features), each candidate class
    gets a single constant quantile threshold — efficient for any test set size.

    Optionally multiplies by DRE weights for joint covariate+label correction.
    """

    def __init__(self, penalty=0.1, kreg=1, randomized=False):
        self.score_fn = RAPS(penalty=penalty, kreg=kreg, randomized=randomized)
        self.cal_scores_sorted = None
        self.cal_weights_sorted = None

    def calibrate(self, logits, labels, p_source, p_target, dre_weights=None):
        """Calibrate with label-shift (and optionally DRE) weights."""
        logits_t = torch.tensor(logits, dtype=torch.float32)
        labels_t = torch.tensor(labels, dtype=torch.long)
        scores = self.score_fn(logits_t, labels_t).numpy()

        # Label-shift weights per calibration sample
        weights = p_target[labels] / p_source[labels]
        if dre_weights is not None:
            weights = weights * dre_weights

        sort_idx = np.argsort(scores)
        self.cal_scores_sorted = scores[sort_idx]
        self.cal_weights_sorted = weights[sort_idx]

    def predict(self, logits, p_source, p_target, alpha=0.1, dre_test_weights=None):
        """Predict with per-class test weights."""
        logits_t = torch.tensor(logits, dtype=torch.float32)
        all_scores = self.score_fn(logits_t).numpy()

        N, K = all_scores.shape
        n_cal = len(self.cal_scores_sorted)
        prediction_sets = np.zeros((N, K), dtype=np.int32)

        for c in range(K):
            ls_weight = p_target[c] / p_source[c]

            if dre_test_weights is not None:
                # Per-test-point weights: DRE * label-shift
                test_w = dre_test_weights * ls_weight
                cal_w = self.cal_weights_sorted[np.newaxis, :]  # [1, n_cal]
                all_w = np.concatenate(
                    [np.broadcast_to(cal_w, (N, n_cal)), test_w[:, np.newaxis]], axis=1
                )
                p_norm = all_w / all_w.sum(axis=1, keepdims=True)
                cumprob = np.cumsum(p_norm[:, :n_cal], axis=1)

                reached = cumprob >= (1 - alpha)
                has_any = reached.any(axis=1)
                first_idx = np.argmax(reached, axis=1)
                q_hat = np.where(
                    has_any, self.cal_scores_sorted[first_idx], np.inf
                )
            else:
                # Constant test weight -> single quantile per class
                all_w = np.append(self.cal_weights_sorted, ls_weight)
                p_norm = all_w / all_w.sum()
                cumprob = np.cumsum(p_norm[:n_cal])

                reached = cumprob >= (1 - alpha)
                q_hat_c = (
                    float(self.cal_scores_sorted[np.argmax(reached)])
                    if reached.any()
                    else float("inf")
                )
                q_hat = q_hat_c

            prediction_sets[:, c] = (all_scores[:, c] <= q_hat).astype(np.int32)

        return prediction_sets


print("Methods defined.")

## 4. Evaluate on Target Pathology

Compare 5 methods:
1. **Standard CP** — no shift correction (baseline)
2. **WCP** — DRE covariate-shift correction only (baseline)
3. **Prior-adjusted CP** — logit adjustment for label shift
4. **Label-shift WCP** — weighted quantile with label-shift weights
5. **Prior-adj + DRE WCP** — logit adjustment + DRE weights (both shifts)

In [None]:
alphas = np.linspace(0.01, 0.5, 50)

# Use MLLS estimate (more robust than BBSE)
p_est = p_target_mlls

# Pre-compute adjusted logits
adj_cal_logits = adjust_logits(cal_logits, p_source, p_est)
adj_test_logits = adjust_logits(test_nih_logits, p_source, p_est)

# Also with oracle for comparison
adj_cal_logits_oracle = adjust_logits(cal_logits, p_source, p_target_oracle)
adj_test_logits_oracle = adjust_logits(test_nih_logits, p_source, p_target_oracle)


def run_evaluation(alphas):
    """Run all methods across alpha range."""
    results = {}

    for alpha in alphas:
        pred_sets = {}

        # 1. Standard CP
        cp = ConformalPredictor(penalty=0.1, kreg=1, randomized=False)
        cp.calibrate(cal_logits, chex_cal_labels, alpha=alpha)
        pred_sets["Standard CP"] = cp.predict(test_nih_logits)

        # 2. WCP (DRE only)
        wcp = WeightedConformalPredictor(penalty=0.1, kreg=1, randomized=False)
        wcp.calibrate(cal_logits, chex_cal_labels, cal_weights)
        pred_sets["WCP (DRE)"] = wcp.predict(test_nih_logits, test_nih_weights, alpha=alpha)

        # 3. Prior-adjusted CP (MLLS)
        pa_cp = ConformalPredictor(penalty=0.1, kreg=1, randomized=False)
        pa_cp.calibrate(adj_cal_logits, chex_cal_labels, alpha=alpha)
        pred_sets["Prior-adj CP"] = pa_cp.predict(adj_test_logits)

        # 4. Label-shift WCP (MLLS)
        ls_wcp = LabelShiftWCP(penalty=0.1, kreg=1, randomized=False)
        ls_wcp.calibrate(cal_logits, chex_cal_labels, p_source, p_est)
        pred_sets["LS-WCP"] = ls_wcp.predict(
            test_nih_logits, p_source, p_est, alpha=alpha
        )

        # 5. Prior-adjusted + DRE WCP (MLLS)
        pa_dre = WeightedConformalPredictor(penalty=0.1, kreg=1, randomized=False)
        pa_dre.calibrate(adj_cal_logits, chex_cal_labels, cal_weights)
        pred_sets["Prior-adj+DRE"] = pa_dre.predict(
            adj_test_logits, test_nih_weights, alpha=alpha
        )

        # 6. Prior-adjusted CP (Oracle) — upper bound
        pa_oracle = ConformalPredictor(penalty=0.1, kreg=1, randomized=False)
        pa_oracle.calibrate(adj_cal_logits_oracle, chex_cal_labels, alpha=alpha)
        pred_sets["PA-CP (oracle)"] = pa_oracle.predict(adj_test_logits_oracle)

        for name, ps in pred_sets.items():
            cov = compute_coverage(ps, nih_test_labels)
            preds, defer_mask = _predictions_from_sets(ps, test_nih_logits)
            sys = compute_system_accuracy(
                preds, nih_test_labels, defer_mask, expert_accuracy=EXPERT_ACCURACY
            )
            results.setdefault(name, []).append(
                DeferralResult(
                    method=name,
                    alpha_or_threshold=float(alpha),
                    system_accuracy=sys["system_accuracy"],
                    deferral_rate=sys["deferral_rate"],
                    coverage_rate=cov["coverage_rate"],
                    average_set_size=cov["average_set_size"],
                    model_accuracy_on_kept=sys["model_accuracy_on_kept"],
                    n_total=len(nih_test_labels),
                    n_deferred=sys["n_deferred"],
                )
            )

    return results


all_results = run_evaluation(alphas)
print("All methods evaluated.")

In [None]:
# Summary at alpha = 0.1
alpha_target = 0.1
rows = []
for name, res_list in all_results.items():
    r = min(res_list, key=lambda r: abs(r.alpha_or_threshold - alpha_target))
    rows.append(
        {
            "Method": name,
            "Coverage": f"{r.coverage_rate:.4f}",
            "Avg |C|": f"{r.average_set_size:.3f}",
            "Deferral": f"{r.deferral_rate:.4f}",
            "Sys Acc": f"{r.system_accuracy:.4f}",
            "Model Acc": f"{r.model_accuracy_on_kept:.4f}",
        }
    )

df = pd.DataFrame(rows)
print(f"\n{TARGET_PATHOLOGY} at alpha={alpha_target} (target coverage >= 90%)")
print("=" * 85)
print(df.to_string(index=False))

## 5. Class-Conditional Coverage Analysis

In [None]:
alpha_target = 0.1


def class_conditional_analysis(pred_sets, labels, method_name):
    set_sizes = pred_sets.sum(axis=1)
    covered = pred_sets[np.arange(len(labels)), labels].astype(bool)
    deferred = set_sizes != 1
    rows = []
    for c, cname in [(0, "Neg"), (1, "Pos")]:
        m = labels == c
        rows.append(
            {
                "Method": method_name,
                "Class": cname,
                "N": int(m.sum()),
                "Cov": f"{covered[m].mean():.4f}",
                "Defer": f"{deferred[m].mean():.4f}",
                "|C|": f"{set_sizes[m].mean():.3f}",
            }
        )
    rows.append(
        {
            "Method": method_name,
            "Class": "All",
            "N": len(labels),
            "Cov": f"{covered.mean():.4f}",
            "Defer": f"{deferred.mean():.4f}",
            "|C|": f"{set_sizes.mean():.3f}",
        }
    )
    return rows


all_rows = []
method_configs = [
    ("Standard CP", cal_logits, test_nih_logits, False, False),
    ("WCP (DRE)", cal_logits, test_nih_logits, True, False),
    ("Prior-adj CP", adj_cal_logits, adj_test_logits, False, False),
    ("LS-WCP", cal_logits, test_nih_logits, False, True),
    ("Prior-adj+DRE", adj_cal_logits, adj_test_logits, True, False),
    ("PA-CP (oracle)", adj_cal_logits_oracle, adj_test_logits_oracle, False, False),
]

for name, c_lg, t_lg, use_dre, use_ls in method_configs:
    if use_ls:
        pred = LabelShiftWCP(penalty=0.1, kreg=1, randomized=False)
        pred.calibrate(c_lg, chex_cal_labels, p_source, p_est)
        ps = pred.predict(t_lg, p_source, p_est, alpha=alpha_target)
    elif use_dre:
        pred = WeightedConformalPredictor(penalty=0.1, kreg=1, randomized=False)
        pred.calibrate(c_lg, chex_cal_labels, cal_weights)
        ps = pred.predict(t_lg, test_nih_weights, alpha=alpha_target)
    else:
        pred = ConformalPredictor(penalty=0.1, kreg=1, randomized=False)
        pred.calibrate(c_lg, chex_cal_labels, alpha=alpha_target)
        ps = pred.predict(t_lg)
    all_rows.extend(class_conditional_analysis(ps, nih_test_labels, name))

print(f"Class-conditional analysis for {TARGET_PATHOLOGY} at alpha={alpha_target}")
print("=" * 75)
print(pd.DataFrame(all_rows).to_string(index=False))

## 6. Coverage and Deferral vs Alpha

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

colors = {
    "Standard CP": "#1f77b4",
    "WCP (DRE)": "#ff7f0e",
    "Prior-adj CP": "#2ca02c",
    "LS-WCP": "#d62728",
    "Prior-adj+DRE": "#9467bd",
    "PA-CP (oracle)": "#8c564b",
}

# Coverage vs alpha
ax = axes[0]
for name, res in all_results.items():
    a = [r.alpha_or_threshold for r in res]
    c = [r.coverage_rate for r in res]
    ls = "--" if "oracle" in name else "-"
    ax.plot(a, c, label=name, color=colors[name], linewidth=1.5, linestyle=ls, marker="o", markersize=2)
ax.plot(alphas, 1 - alphas, "k--", alpha=0.5, linewidth=1.5, label=r"Ideal $1-\alpha$")
ax.set_xlabel(r"$\alpha$")
ax.set_ylabel("Coverage")
ax.set_title(f"Coverage vs Alpha ({TARGET_PATHOLOGY})")
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

# Deferral vs alpha
ax = axes[1]
for name, res in all_results.items():
    a = [r.alpha_or_threshold for r in res]
    d = [r.deferral_rate for r in res]
    ls = "--" if "oracle" in name else "-"
    ax.plot(a, d, label=name, color=colors[name], linewidth=1.5, linestyle=ls, marker="o", markersize=2)
ax.set_xlabel(r"$\alpha$")
ax.set_ylabel("Deferral Rate")
ax.set_title(f"Deferral Rate vs Alpha ({TARGET_PATHOLOGY})")
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 7. Multi-Pathology Comparison

In [None]:
alpha_target = 0.1
multi_results = []

for pathology in COMMON_PATHOLOGIES:
    # Binary labels
    c_feats, c_labels, _ = extract_binary_labels(
        chexpert.features, chexpert.labels, COMMON_PATHOLOGIES, pathology
    )
    n_feats, n_labels, _ = extract_binary_labels(
        nih.features, nih.labels, COMMON_PATHOLOGIES, pathology
    )

    # Splits
    c_tr_f, c_tmp_f, c_tr_l, c_tmp_l = train_test_split(
        c_feats, c_labels, test_size=0.4, random_state=SEED, stratify=c_labels
    )
    c_cal_f, _, c_cal_l, _ = train_test_split(
        c_tmp_f, c_tmp_l, test_size=0.5, random_state=SEED, stratify=c_tmp_l
    )
    _, n_te_f, _, n_te_l = train_test_split(
        n_feats, n_labels, test_size=0.5, random_state=SEED, stratify=n_labels
    )

    # Classifier
    sc = StandardScaler()
    Xtr = sc.fit_transform(c_tr_f)
    Xcal = sc.transform(c_cal_f)
    Xte = sc.transform(n_te_f)
    Xpool = sc.transform(nih_pool_feats_all)

    model = LogisticRegression(solver="lbfgs", max_iter=1000, C=1.0, random_state=SEED)
    model.fit(Xtr, c_tr_l)
    nih_auc = roc_auc_score(n_te_l, model.predict_proba(Xte)[:, 1])

    def _logits(m, X):
        d = m.decision_function(X)
        return np.column_stack([-d, d])

    c_lg = _logits(model, Xcal)
    t_lg = _logits(model, Xte)

    # DRE
    d = AdaptiveDRE(n_components=4, weight_clip=20.0, random_state=SEED)
    d.fit(c_cal_f, nih_pool_feats_all)
    cw = d.compute_weights(c_cal_f)
    tw = d.compute_weights(n_te_f)

    # Prevalence estimation
    p_src = np.array([1 - c_cal_l.mean(), c_cal_l.mean()])
    p_tgt_mlls = estimate_prevalence_mlls(model, Xpool, p_src)
    p_tgt_oracle = np.array([1 - n_te_l.mean(), n_te_l.mean()])

    # Adjusted logits (MLLS)
    adj_c_lg = adjust_logits(c_lg, p_src, p_tgt_mlls)
    adj_t_lg = adjust_logits(t_lg, p_src, p_tgt_mlls)

    # Adjusted logits (Oracle)
    adj_c_lg_o = adjust_logits(c_lg, p_src, p_tgt_oracle)
    adj_t_lg_o = adjust_logits(t_lg, p_src, p_tgt_oracle)

    row = {
        "Pathology": pathology,
        "AUC": f"{nih_auc:.3f}",
        "Src prev": f"{p_src[1]:.3f}",
        "Tgt prev": f"{p_tgt_oracle[1]:.3f}",
        "MLLS prev": f"{p_tgt_mlls[1]:.3f}",
    }

    # Evaluate methods
    method_configs = [
        ("Std", ConformalPredictor, c_lg, t_lg, None, None, None, None),
        ("WCP", WeightedConformalPredictor, c_lg, t_lg, cw, tw, None, None),
        ("PA-CP", ConformalPredictor, adj_c_lg, adj_t_lg, None, None, None, None),
        ("LS", "ls_wcp", c_lg, t_lg, None, None, p_src, p_tgt_mlls),
        ("PA+DRE", WeightedConformalPredictor, adj_c_lg, adj_t_lg, cw, tw, None, None),
        ("PA(orc)", ConformalPredictor, adj_c_lg_o, adj_t_lg_o, None, None, None, None),
    ]

    for mname, mclass, cl, tl, cw_, tw_, ps, pt in method_configs:
        if mclass == "ls_wcp":
            pred = LabelShiftWCP(penalty=0.1, kreg=1, randomized=False)
            pred.calibrate(cl, c_cal_l, ps, pt)
            ps_out = pred.predict(tl, ps, pt, alpha=alpha_target)
        elif mclass == WeightedConformalPredictor:
            pred = mclass(penalty=0.1, kreg=1, randomized=False)
            pred.calibrate(cl, c_cal_l, cw_)
            ps_out = pred.predict(tl, tw_, alpha=alpha_target)
        else:
            pred = mclass(penalty=0.1, kreg=1, randomized=False)
            pred.calibrate(cl, c_cal_l, alpha=alpha_target)
            ps_out = pred.predict(tl)

        cov = ps_out[np.arange(len(n_te_l)), n_te_l].mean()
        defer = (ps_out.sum(axis=1) != 1).mean()
        row[f"{mname} Cov"] = f"{cov:.3f}"
        row[f"{mname} Def"] = f"{defer:.3f}"

    multi_results.append(row)
    print(
        f"{pathology:<16} MLLS={p_tgt_mlls[1]:.3f}(true={p_tgt_oracle[1]:.3f})  "
        f"Std={row['Std Def']}  WCP={row['WCP Def']}  "
        f"PA-CP={row['PA-CP Cov']}/{row['PA-CP Def']}  "
        f"PA+DRE={row['PA+DRE Cov']}/{row['PA+DRE Def']}  "
        f"PA(orc)={row['PA(orc) Cov']}/{row['PA(orc) Def']}"
    )

df_multi = pd.DataFrame(multi_results)
print(f"\n{'=' * 150}")
print(f"Multi-pathology comparison at alpha={alpha_target}")
print(f"{'=' * 150}")
print(df_multi.to_string(index=False))

In [None]:
# Multi-pathology visualization
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

pathology_names = [r["Pathology"] for r in multi_results]
x = np.arange(len(pathology_names))
width = 0.13
method_keys = [
    ("Std", "Standard CP"),
    ("WCP", "WCP (DRE)"),
    ("PA-CP", "Prior-adj CP"),
    ("LS", "LS-WCP"),
    ("PA+DRE", "Prior-adj+DRE"),
    ("PA(orc)", "PA-CP (oracle)"),
]

# Deferral rates
ax = axes[0]
for i, (key, label) in enumerate(method_keys):
    vals = [float(r[f"{key} Def"]) for r in multi_results]
    ax.bar(x + i * width, vals, width, label=label)
ax.set_xticks(x + 2.5 * width)
ax.set_xticklabels(pathology_names, rotation=45, ha="right")
ax.set_ylabel("Deferral Rate")
ax.set_title(f"Deferral Rate by Pathology (alpha={alpha_target})")
ax.legend(fontsize=7, loc="upper left")
ax.grid(True, alpha=0.3, axis="y")

# Coverage
ax = axes[1]
for i, (key, label) in enumerate(method_keys):
    vals = [float(r[f"{key} Cov"]) for r in multi_results]
    ax.bar(x + i * width, vals, width, label=label)
ax.axhline(y=0.9, color="red", linestyle="--", alpha=0.7, label="Target (90%)")
ax.set_xticks(x + 2.5 * width)
ax.set_xticklabels(pathology_names, rotation=45, ha="right")
ax.set_ylabel("Coverage")
ax.set_title(f"Coverage by Pathology (alpha={alpha_target})")
ax.legend(fontsize=7, loc="lower left")
ax.grid(True, alpha=0.3, axis="y")
ax.set_ylim(0.5, 1.05)

plt.tight_layout()
plt.show()