# GNN-DRE + Label Shift Correction (LSC) Binary WCP

**Extends**: `gnn_dre_wcp.ipynb`  
**Methods compared** (all at α = 0.10):

| Method | Covariate shift | Label shift | Notes |
|--------|----------------|-------------|-------|
| **Std CP** | ✗ | ✗ | Baseline |
| **WCP-GNN** | GNN-DRE (ESS≈31%) | ✗ | From gnn_dre_wcp_report |
| **WCP-GNN+EM-LSC** | GNN-DRE (ESS≈31%) | EM-estimated NIH prior | **New** |
| **WCP-GNN+Oracle-LSC** | GNN-DRE (ESS≈31%) | True NIH prior | Upper bound |

**Core research question**: Does correcting for label shift (via Bayesian odds-ratio adjustment)
reduce the dangerously high singleton FNR (up to 99.2% for Pneumothorax) observed in WCP-GNN?

In [None]:
import sys
import warnings
warnings.filterwarnings('ignore')
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score

ROOT = Path('../..').resolve()
if str(ROOT / 'src') not in sys.path:
    sys.path.insert(0, str(ROOT / 'src'))

from wcp_l2d.dre import AdaptiveDRE
from wcp_l2d.features import ExtractedFeatures
from wcp_l2d.gnn import build_adjacency_matrix, train_gnn
from wcp_l2d.pathologies import COMMON_PATHOLOGIES
from wcp_l2d.conformal import ConformalPredictor, WeightedConformalPredictor
from wcp_l2d.evaluation import evaluate_standard_cp, evaluate_wcp

SEED   = 42
ALPHA  = 0.10          # 90 % coverage target
EXPERT = 0.85
K      = len(COMMON_PATHOLOGIES)
DEVICE = 'mps' if torch.backends.mps.is_available() else 'cpu'

np.random.seed(SEED)
torch.manual_seed(SEED)

plt.rcParams.update({'figure.dpi': 100, 'figure.facecolor': 'white',
                     'axes.grid': True, 'grid.alpha': 0.3})
print(f'Device:      {DEVICE}')
print(f'Pathologies: {COMMON_PATHOLOGIES}')

## 1. Load Features

In [None]:
FEAT_DIR = ROOT / 'data' / 'features'
chex = ExtractedFeatures.load(FEAT_DIR / 'chexpert_densenet121-res224-chex_features.npz')
nih  = ExtractedFeatures.load(FEAT_DIR / 'nih_densenet121-res224-chex_features.npz')

print(f'CheXpert: {chex.features.shape}  labels: {chex.labels.shape}')
print(f'NIH:      {nih.features.shape}   labels: {nih.labels.shape}')

## 2. Global Data Splits

Identical splits to `gnn_dre_wcp.ipynb` (same `SEED=42`).

In [None]:
rng = np.random.RandomState(SEED)

# --- CheXpert: 60 % train / 20 % cal / 20 % ignored ---
N_chex = len(chex.features)
idx    = rng.permutation(N_chex)
n_tr   = int(0.60 * N_chex)
n_cal  = int(0.20 * N_chex)

X_train_raw = chex.features[idx[:n_tr]]
Y_train     = chex.labels[idx[:n_tr]]
X_cal_raw   = chex.features[idx[n_tr:n_tr + n_cal]]
Y_cal       = chex.labels[idx[n_tr:n_tr + n_cal]]

# --- NIH: 50 % DRE pool / 50 % labelled test ---
N_nih    = len(nih.features)
nih_perm = rng.permutation(N_nih)
n_pool   = N_nih // 2

X_pool_raw = nih.features[nih_perm[:n_pool]]
X_nih_raw  = nih.features[nih_perm[n_pool:]]
Y_nih_test = nih.labels[nih_perm[n_pool:]]

# StandardScaler fitted on train set
scaler  = StandardScaler().fit(X_train_raw)
X_train = scaler.transform(X_train_raw)
X_cal   = scaler.transform(X_cal_raw)
X_pool  = scaler.transform(X_pool_raw)
X_nih   = scaler.transform(X_nih_raw)

print(f'CheXpert  train={len(X_train):,}  cal={len(X_cal):,}')
print(f'NIH       pool={len(X_pool):,}    test={len(X_nih):,}')

## 3. Build 7×7 Label Co-occurrence Adjacency Matrix

In [None]:
A = build_adjacency_matrix(Y_train, tau=0.10)
assert torch.allclose(A.sum(dim=1), torch.ones(K), atol=1e-5), 'Row sums must equal 1'
print(f'Adjacency matrix built. Shape: {A.shape}')
n_nonzero = int((A > 0).sum()) - K
print(f'Non-zero off-diagonal entries: {n_nonzero} / {K*(K-1)} ({100*n_nonzero/(K*(K-1)):.0f}%)')

## 4. Train Binary LR Classifiers (GNN Residual Init)

In [None]:
lrs = []
for k, path in enumerate(COMMON_PATHOLOGIES):
    valid = ~np.isnan(Y_train[:, k])
    if valid.sum() < 10 or len(np.unique(Y_train[valid, k])) < 2:
        lrs.append(None)
        continue
    lr = LogisticRegression(solver='lbfgs', max_iter=1000, random_state=SEED)
    lr.fit(X_train[valid], Y_train[valid, k].astype(int))
    lrs.append(lr)

def get_logits_lr(lrs_, X_s):
    """[N, K] decision function from 7 binary LRs."""
    out = np.zeros((len(X_s), K), dtype=np.float32)
    for k, lr in enumerate(lrs_):
        if lr is not None:
            out[:, k] = lr.decision_function(X_s)
    return out

init_tr   = get_logits_lr(lrs, X_train)
init_cal  = get_logits_lr(lrs, X_cal)
init_pool = get_logits_lr(lrs, X_pool)
init_nih  = get_logits_lr(lrs, X_nih)
print('LR classifiers trained.')

## 5. Train LabelGCN

In [None]:
print(f'Training LabelGCN on {DEVICE} (50 epochs) ...')

gnn, history = train_gnn(
    features_train=X_train,
    labels_train=Y_train,
    features_val=X_cal,
    labels_val=Y_cal,
    adjacency=A,
    init_logits_train=init_tr,
    init_logits_val=init_cal,
    epochs=50,
    save_best=True,
    batch_size=512,
    lr=1e-3,
    weight_decay=1e-4,
    device=DEVICE,
    verbose=False,
)

best_ep = history['best_epoch'][0]
print(f'Best val AUC: {max(history["val_auc"]):.4f}  at epoch {best_ep}/50')

## 6. GNN Probability Extraction

In [None]:
def gnn_probs(model, X_s, init_np=None):
    """Forward pass → sigmoid probabilities [N, K]."""
    model.eval()
    with torch.no_grad():
        Xt = torch.tensor(X_s, dtype=torch.float32)
        it = torch.tensor(init_np, dtype=torch.float32) if init_np is not None else None
        logits = model(Xt, it).numpy()
    return 1.0 / (1.0 + np.exp(-logits))

p_train = gnn_probs(gnn, X_train, init_tr)
p_cal   = gnn_probs(gnn, X_cal,   init_cal)
p_pool  = gnn_probs(gnn, X_pool,  init_pool)
p_nih   = gnn_probs(gnn, X_nih,   init_nih)

print(f'GNN probs: p_cal={p_cal.shape}  p_pool={p_pool.shape}  p_nih={p_nih.shape}')

# Per-pathology AUC on NIH test
rows = []
for k, path in enumerate(COMMON_PATHOLOGIES):
    valid = ~np.isnan(Y_nih_test[:, k])
    if valid.sum() < 2 or len(np.unique(Y_nih_test[valid, k])) < 2:
        rows.append({'Pathology': path, 'GNN AUC': float('nan')}); continue
    y = Y_nih_test[valid, k]
    rows.append({'Pathology': path, 'GNN AUC': round(roc_auc_score(y, p_nih[valid, k]), 4)})

auc_df = pd.DataFrame(rows).set_index('Pathology')
print(f'\nNIH GNN AUC (mean={auc_df["GNN AUC"].mean():.4f}):')
print(auc_df.T.to_string())

## 7. GNN-DRE

Source = CheXpert calibration set (GNN probability space).  
Target = NIH unlabelled pool (GNN probability space).  
No PCA, no clipping — identical to `gnn_dre_wcp.ipynb`.

In [None]:
dre_gnn = AdaptiveDRE(n_components=None, weight_clip=None, random_state=SEED)
dre_gnn.fit(source_features=p_cal, target_features=p_pool)
w_cal_gnn = dre_gnn.compute_weights(p_cal)
w_nih_gnn = dre_gnn.compute_weights(p_nih)
diag_gnn  = dre_gnn.diagnostics(p_cal)

print('=== GNN-DRE Diagnostics (CheXpert cal set) ===')
print(f'  Domain AUC : {diag_gnn.domain_auc:.4f}')
print(f'  ESS        : {diag_gnn.ess:.1f}  ({diag_gnn.ess_fraction*100:.1f}%)')
print(f'  Weight mean: {diag_gnn.weight_mean:.3f}  median: {diag_gnn.weight_median:.3f}  max: {diag_gnn.weight_max:.3f}')

# Per-pathology ESS
ess_rows = []
for k, path in enumerate(COMMON_PATHOLOGIES):
    c_mask = ~np.isnan(Y_cal[:, k])
    wc = w_cal_gnn[c_mask]
    ess_k = float(wc.sum()**2 / (wc**2).sum()) / c_mask.sum()
    ess_rows.append({'Pathology': path, 'n_cal': int(c_mask.sum()), 'ESS%': round(ess_k*100, 1)})
print('\nPer-pathology ESS on filtered cal subset:')
print(pd.DataFrame(ess_rows).set_index('Pathology').T.to_string())

## 8. Label Shift Correction (LSC)

### Motivation

WCP-GNN addresses covariate shift but not label shift.  Under label shift, the class prior
changes between domains: CheXpert has higher disease prevalence than NIH (e.g. Pneumothorax
6.7% vs 0.9%).  The binary LR classifier trained on CheXpert is therefore biased toward
predicting positive, causing high FNR on NIH singleton decisions.

**LSC** applies a Bayesian odds-ratio correction:
$$
\tilde{p}(y=1|x) = \sigma\!\left(\log\text{odds}_{\text{src}}(x) + \log\frac{\pi_{\text{tgt}}/(1-\pi_{\text{tgt}})}{\pi_{\text{src}}/(1-\pi_{\text{src}})}\right)
$$
where $\text{odds}_{\text{src}}(x) = p_{\text{src}}(y=1|x)/(1-p_{\text{src}}(y=1|x))$ is the
source-domain model's odds, and $\pi_{\text{src}},\pi_{\text{tgt}}$ are the source and target
class priors.

### EM Algorithm (Unsupervised)

The target prior $\pi_{\text{tgt}}$ is estimated without labels using Expectation-Maximisation
on the GNN probability outputs on the NIH test set.

In [None]:
# =====================================================
# LSC helper functions (provided in experiment spec)
# =====================================================

def estimate_target_prior_em(gnn_probs_tgt, prior_src, max_iters=100, tol=1e-5):
    """
    EM algorithm to estimate target domain prevalence from unlabelled data.
    Treats each of the K pathologies as an independent binary classification task.

    Args:
        gnn_probs_tgt : [N, K]  GNN probabilities on NIH test
        prior_src     : [K]     CheXpert train prevalences
        max_iters, tol          convergence control
    Returns:
        estimated_prior_tgt : [K]  estimated NIH prevalences
    """
    N, K = gnn_probs_tgt.shape
    eps = 1e-7

    prior_tgt_current  = np.clip(np.copy(prior_src), eps, 1 - eps)
    prior_src_clipped  = np.clip(prior_src, eps, 1 - eps)
    probs_clipped      = np.clip(gnn_probs_tgt, eps, 1 - eps)

    for i in range(max_iters):
        # E-step: Bayes-adjust current probs to current tgt prior
        odds_ratio    = (prior_tgt_current / (1 - prior_tgt_current)) / \
                        (prior_src_clipped  / (1 - prior_src_clipped))
        odds_src      = probs_clipped / (1 - probs_clipped)
        odds_tgt      = odds_src * odds_ratio
        adjusted_probs = odds_tgt / (1 + odds_tgt)

        # M-step: update prior estimate
        prior_tgt_new = np.mean(adjusted_probs, axis=0)

        if np.max(np.abs(prior_tgt_new - prior_tgt_current)) < tol:
            print(f'  EM converged at iteration {i+1}')
            prior_tgt_current = prior_tgt_new
            break
        prior_tgt_current = prior_tgt_new

    return prior_tgt_current


def apply_label_shift_correction(gnn_probs, prior_src, prior_tgt):
    """
    Bayesian odds-ratio correction.
    Works on [N, K] matrices (K pathologies in parallel).
    """
    eps = 1e-7
    prior_src  = np.clip(prior_src,  eps, 1 - eps)
    prior_tgt  = np.clip(prior_tgt,  eps, 1 - eps)
    gnn_probs  = np.clip(gnn_probs,  eps, 1 - eps)

    odds_ratio = (prior_tgt / (1 - prior_tgt)) / (prior_src / (1 - prior_src))
    odds_src   = gnn_probs  / (1 - gnn_probs)
    odds_tgt   = odds_src * odds_ratio
    return odds_tgt / (1 + odds_tgt)


def corrected_binary_logits(clf, X, prior_src_k, prior_tgt_k, eps=1e-7):
    """
    Apply per-pathology LSC to binary LR probabilities and return binary logits.
    Binary logits: [-d, d] where d = log-odds of corrected P(y=1|x).
    """
    probs = np.clip(clf.predict_proba(X)[:, 1], eps, 1 - eps)          # [N]
    prior_s = np.clip(prior_src_k, eps, 1 - eps)
    prior_t = np.clip(prior_tgt_k, eps, 1 - eps)

    odds_ratio = (prior_t / (1 - prior_t)) / (prior_s / (1 - prior_s))
    adj_odds   = (probs / (1 - probs)) * odds_ratio
    adj_probs  = np.clip(adj_odds / (1 + adj_odds), eps, 1 - eps)

    logits = np.log(adj_probs) - np.log(1 - adj_probs)                 # log-odds
    return np.column_stack([-logits, logits])                           # [N, 2]

In [None]:
# --- Source prior from CheXpert train ---
prior_src = np.nanmean(Y_train, axis=0)      # [K]

# --- Oracle target prior (true NIH test prevalence) ---
prior_tgt_oracle = np.nanmean(Y_nih_test, axis=0)  # [K]

# --- EM-estimated target prior (unsupervised, runs on NIH test GNN probs) ---
print('Running EM on NIH test GNN probabilities ...')
prior_tgt_em = estimate_target_prior_em(p_nih, prior_src)

# --- Comparison table ---
df_priors = pd.DataFrame({
    'CheXpert (src)':    np.round(prior_src, 4),
    'NIH oracle':        np.round(prior_tgt_oracle, 4),
    'NIH EM-estimated':  np.round(prior_tgt_em, 4),
    'EM error (pp)':     np.round((prior_tgt_em - prior_tgt_oracle) * 100, 2),
    'Odds ratio (EM/src)': np.round(
        (np.clip(prior_tgt_em, 1e-7, 1-1e-7) / (1 - np.clip(prior_tgt_em, 1e-7, 1-1e-7))) /
        (np.clip(prior_src,   1e-7, 1-1e-7) / (1 - np.clip(prior_src,   1e-7, 1-1e-7))),
        4),
}, index=COMMON_PATHOLOGIES)
print('\nPrevalence comparison:')
print(df_priors.to_string())
print('\n(Odds ratio < 1 = model needs to predict LESS positive in target)')

## 9. Per-Pathology Binary WCP — Four Methods

For each pathology:
1. Filter cal / NIH-test to non-NaN rows for that pathology.
2. Train per-pathology binary LR on filtered train set.
3. Compute binary RAPS logits — both raw (no LSC) and LSC-corrected variants.
4. Run all four methods across α ∈ [0.01, 0.50].

> **Note on LSC application**: LSC is applied to *both* calibration and test LR probabilities
> so that RAPS scores are computed from the same (target-recalibrated) score function.
> This ensures calibration and test scores live on the same scale.

In [None]:
alphas = np.linspace(0.01, 0.50, 50)

def binary_logits_from_lr(clf, X):
    """Standard binary logits from LR decision function."""
    d = clf.decision_function(X)
    return np.column_stack([-d, d])

def at_alpha(res_list, a=ALPHA):
    return min(res_list, key=lambda r: abs(r.alpha_or_threshold - a))

all_results = {}
clfs_per_path = {}   # cache classifiers for later analysis

for pathology in COMMON_PATHOLOGIES:
    k = COMMON_PATHOLOGIES.index(pathology)

    c_mask = ~np.isnan(Y_cal[:, k])
    n_mask = ~np.isnan(Y_nih_test[:, k])
    Xc, yc = X_cal[c_mask],  Y_cal[c_mask, k].astype(int)
    Xn, yn = X_nih[n_mask],  Y_nih_test[n_mask, k].astype(int)

    # Per-pathology binary LR
    tr_mask = ~np.isnan(Y_train[:, k])
    clf_p   = LogisticRegression(solver='lbfgs', max_iter=1000, C=1.0, random_state=SEED)
    clf_p.fit(X_train[tr_mask], Y_train[tr_mask, k].astype(int))
    clfs_per_path[pathology] = (clf_p, c_mask, n_mask, yc, yn)

    # Logits — raw and LSC-corrected
    cal_lg        = binary_logits_from_lr(clf_p, Xc)
    nih_lg        = binary_logits_from_lr(clf_p, Xn)
    cal_lg_em     = corrected_binary_logits(clf_p, Xc, prior_src[k], prior_tgt_em[k])
    nih_lg_em     = corrected_binary_logits(clf_p, Xn, prior_src[k], prior_tgt_em[k])
    cal_lg_oracle = corrected_binary_logits(clf_p, Xc, prior_src[k], prior_tgt_oracle[k])
    nih_lg_oracle = corrected_binary_logits(clf_p, Xn, prior_src[k], prior_tgt_oracle[k])

    # DRE weights
    wc_gnn = w_cal_gnn[c_mask]
    wn_gnn = w_nih_gnn[n_mask]

    nih_auc = roc_auc_score(yn, clf_p.predict_proba(Xn)[:, 1])

    std_cp         = evaluate_standard_cp(cal_lg,        yc, nih_lg,        yn, alphas, expert_accuracy=EXPERT)
    wcp_gnn        = evaluate_wcp(cal_lg,        yc, wc_gnn, nih_lg,        yn, wn_gnn, alphas, expert_accuracy=EXPERT)
    wcp_gnn_em     = evaluate_wcp(cal_lg_em,     yc, wc_gnn, nih_lg_em,     yn, wn_gnn, alphas, expert_accuracy=EXPERT)
    wcp_gnn_oracle = evaluate_wcp(cal_lg_oracle, yc, wc_gnn, nih_lg_oracle, yn, wn_gnn, alphas, expert_accuracy=EXPERT)

    all_results[pathology] = {
        'std_cp':         std_cp,
        'wcp_gnn':        wcp_gnn,
        'wcp_gnn_em':     wcp_gnn_em,
        'wcp_gnn_oracle': wcp_gnn_oracle,
        'nih_auc':        nih_auc,
        'n_cal':          int(c_mask.sum()),
        'n_nih':          int(n_mask.sum()),
        'n_pos':          int(yn.sum()),
    }

    pt_gnn    = at_alpha(wcp_gnn)
    pt_em     = at_alpha(wcp_gnn_em)
    pt_oracle = at_alpha(wcp_gnn_oracle)
    print(f'{pathology:<16} AUC={nih_auc:.3f}  '
          f'GNN={pt_gnn.deferral_rate:.3f}({pt_gnn.coverage_rate:.3f})  '
          f'EM-LSC={pt_em.deferral_rate:.3f}({pt_em.coverage_rate:.3f})  '
          f'Oracle={pt_oracle.deferral_rate:.3f}({pt_oracle.coverage_rate:.3f})')

print('\n(deferral | coverage)  Done.')

## 10. Summary Table at α = 0.10

In [None]:
rows = []
for path in COMMON_PATHOLOGIES:
    r   = all_results[path]
    std = at_alpha(r['std_cp'])
    gnn = at_alpha(r['wcp_gnn'])
    em  = at_alpha(r['wcp_gnn_em'])
    orc = at_alpha(r['wcp_gnn_oracle'])
    rows.append({
        'Pathology':        path,
        'NIH AUC':          f"{r['nih_auc']:.3f}",
        'Std Defer':        f"{std.deferral_rate:.3f}",
        'GNN Defer':        f"{gnn.deferral_rate:.3f}",
        'GNN Cov':          f"{gnn.coverage_rate:.3f}",
        'EM-LSC Defer':     f"{em.deferral_rate:.3f}",
        'EM-LSC Cov':       f"{em.coverage_rate:.3f}",
        'Oracle Defer':     f"{orc.deferral_rate:.3f}",
        'Oracle Cov':       f"{orc.coverage_rate:.3f}",
    })

df_sum = pd.DataFrame(rows)
print(f'LSC Experiment Summary — α={ALPHA}')
print('=' * 120)
print(df_sum.to_string(index=False))

print()
for col, key in [('Std Defer','std_cp'), ('GNN Defer','wcp_gnn'),
                  ('EM-LSC Defer','wcp_gnn_em'), ('Oracle Defer','wcp_gnn_oracle')]:
    mean_d = np.mean([at_alpha(all_results[p][key]).deferral_rate for p in COMMON_PATHOLOGIES])
    mean_c = np.mean([at_alpha(all_results[p][key]).coverage_rate  for p in COMMON_PATHOLOGIES])
    print(f'  Mean {col:<18}  defer={mean_d:.3f}   coverage={mean_c:.3f}')

## 11. Deferral Rate vs Confidence Level

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(22, 9), sharey=False)
axes_flat = axes.flatten()
conf_lvls = 1 - alphas

METHODS = [
    ('Std CP',     'std_cp',         'o-',  '#1f77b4', 1.5),
    ('WCP-GNN',    'wcp_gnn',        's-',  '#2ca02c', 1.8),
    ('WCP+EM-LSC', 'wcp_gnn_em',     '^-',  '#d62728', 2.0),
    ('WCP+Oracle', 'wcp_gnn_oracle', 'D--', '#9467bd', 1.5),
]

for i, path in enumerate(COMMON_PATHOLOGIES):
    ax = axes_flat[i]
    r  = all_results[path]
    for label, key, style, col, lw in METHODS:
        ax.plot(conf_lvls, [x.deferral_rate for x in r[key]],
                style, ms=2, lw=lw, alpha=0.85, label=label, color=col)
    ax.axvline(1 - ALPHA, color='gray', linestyle=':', alpha=0.5)
    ax.set_title(f'{path}\nAUC={r["nih_auc"]:.3f}', fontsize=10, fontweight='bold')
    ax.set_xlabel('Confidence (1−α)'); ax.set_ylabel('Deferral rate')
    ax.set_xlim(0.5, 1.0); ax.set_ylim(-0.05, 1.05)
    ax.legend(fontsize=7)

axes_flat[-1].axis('off')
plt.suptitle('Deferral Rate vs Confidence Level — Four Methods\n'
             'Vertical dashed line marks α=0.10 operating point',
             fontsize=13, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

## 12. Extended Analysis

### A0. Prediction Set Collection at α = 0.10

In [None]:
ALPHA_FIXED = 0.10
detail = {}

for pathology in COMMON_PATHOLOGIES:
    k = COMMON_PATHOLOGIES.index(pathology)
    clf_p, c_mask, n_mask, yc, yn = clfs_per_path[pathology]

    Xc = X_cal[c_mask]
    Xn = X_nih[n_mask]

    # Logits
    cal_lg        = binary_logits_from_lr(clf_p, Xc)
    nih_lg        = binary_logits_from_lr(clf_p, Xn)
    cal_lg_em     = corrected_binary_logits(clf_p, Xc, prior_src[k], prior_tgt_em[k])
    nih_lg_em     = corrected_binary_logits(clf_p, Xn, prior_src[k], prior_tgt_em[k])
    cal_lg_oracle = corrected_binary_logits(clf_p, Xc, prior_src[k], prior_tgt_oracle[k])
    nih_lg_oracle = corrected_binary_logits(clf_p, Xn, prior_src[k], prior_tgt_oracle[k])

    wc_gnn = w_cal_gnn[c_mask]
    wn_gnn = w_nih_gnn[n_mask]

    # Standard CP
    cp_std = ConformalPredictor(penalty=0.1, kreg=1, randomized=False)
    q_std  = cp_std.calibrate(cal_lg, yc, alpha=ALPHA_FIXED)
    ps_std = cp_std.predict(nih_lg)
    cal_scores_std = cp_std.cal_scores

    # WCP-GNN
    wcp_g = WeightedConformalPredictor(penalty=0.1, kreg=1, randomized=False)
    wcp_g.calibrate(cal_lg, yc, wc_gnn)
    ps_gnn = wcp_g.predict(nih_lg, wn_gnn, alpha=ALPHA_FIXED)
    cal_scores_gnn = wcp_g.cal_scores_sorted   # sorted cal RAPS scores (no LSC)
    cal_w_gnn      = wcp_g.cal_weights_sorted

    # WCP-GNN + EM-LSC
    wcp_em = WeightedConformalPredictor(penalty=0.1, kreg=1, randomized=False)
    wcp_em.calibrate(cal_lg_em, yc, wc_gnn)
    ps_em = wcp_em.predict(nih_lg_em, wn_gnn, alpha=ALPHA_FIXED)
    cal_scores_em = wcp_em.cal_scores_sorted
    cal_w_em      = wcp_em.cal_weights_sorted

    # WCP-GNN + Oracle-LSC
    wcp_oracle = WeightedConformalPredictor(penalty=0.1, kreg=1, randomized=False)
    wcp_oracle.calibrate(cal_lg_oracle, yc, wc_gnn)
    ps_oracle = wcp_oracle.predict(nih_lg_oracle, wn_gnn, alpha=ALPHA_FIXED)
    cal_scores_oracle = wcp_oracle.cal_scores_sorted
    cal_w_oracle      = wcp_oracle.cal_weights_sorted

    detail[pathology] = dict(
        ps_std=ps_std, ps_gnn=ps_gnn, ps_em=ps_em, ps_oracle=ps_oracle,
        yn=yn, nih_lg=nih_lg, nih_lg_em=nih_lg_em, nih_lg_oracle=nih_lg_oracle,
        cal_scores_std=cal_scores_std,
        cal_scores_gnn=cal_scores_gnn, cal_w_gnn=cal_w_gnn,
        cal_scores_em=cal_scores_em,   cal_w_em=cal_w_em,
        cal_scores_oracle=cal_scores_oracle, cal_w_oracle=cal_w_oracle,
        wc_gnn=wc_gnn, wn_gnn=wn_gnn, q_std=q_std,
    )

print(f'Prediction sets collected for {len(detail)} pathologies at α={ALPHA_FIXED}')
print(f'\n{"Pathology":<16}  {"n_nih":>6}  '
      f'{"Std |C|=0/1/2":>18}  {"GNN |C|=0/1/2":>18}  '
      f'{"EM-LSC |C|=0/1/2":>20}  {"Oracle |C|=0/1/2":>20}')
print('-' * 110)
for path, d in detail.items():
    def fmtsz(ps):
        s = ps.sum(axis=1)
        return f'{(s==0).mean():.2f}/{(s==1).mean():.2f}/{(s==2).mean():.2f}'
    print(f'{path:<16}  {len(d["yn"]):>6}  '
          f'{fmtsz(d["ps_std"]):>18}  {fmtsz(d["ps_gnn"]):>18}  '
          f'{fmtsz(d["ps_em"]):>20}  {fmtsz(d["ps_oracle"]):>20}')

### A1. Empirical Coverage Validity

In [None]:
target_cov = 1 - ALPHA
print(f'A1: Coverage validity at α={ALPHA}  (target ≥ {target_cov:.2f}):')
hdr = (f"{'Pathology':<16} | "
       f"{'Std cov':>9} {'dev':>7} | "
       f"{'GNN cov':>9} {'dev':>7} | "
       f"{'EM-LSC cov':>11} {'dev':>7} | "
       f"{'Oracle cov':>11} {'dev':>7}")
print(hdr); print('-' * len(hdr))

under_cov = {'Std CP': 0, 'WCP-GNN': 0, 'WCP-GNN+EM': 0, 'WCP-GNN+Oracle': 0}

for path in COMMON_PATHOLOGIES:
    r = all_results[path]
    std_c   = at_alpha(r['std_cp']).coverage_rate
    gnn_c   = at_alpha(r['wcp_gnn']).coverage_rate
    em_c    = at_alpha(r['wcp_gnn_em']).coverage_rate
    orc_c   = at_alpha(r['wcp_gnn_oracle']).coverage_rate

    def flag(c): return ' ✗' if c < target_cov else '  '
    print(f'{path:<16} | '
          f'{std_c:>9.3f} {std_c - target_cov:>+7.3f}{flag(std_c)} | '
          f'{gnn_c:>9.3f} {gnn_c - target_cov:>+7.3f}{flag(gnn_c)} | '
          f'{em_c:>11.3f} {em_c - target_cov:>+7.3f}{flag(em_c)} | '
          f'{orc_c:>11.3f} {orc_c - target_cov:>+7.3f}{flag(orc_c)}')

    for key, cnt_key in [
        ('std_cp','Std CP'), ('wcp_gnn','WCP-GNN'),
        ('wcp_gnn_em','WCP-GNN+EM'), ('wcp_gnn_oracle','WCP-GNN+Oracle')
    ]:
        if at_alpha(r[key]).coverage_rate < target_cov:
            under_cov[cnt_key] += 1

print(f'\nUnder-coverage count (cov < {target_cov:.2f}) at α={ALPHA}:')
for method, cnt in under_cov.items():
    print(f'  {method:<20}  {cnt}/7 pathologies  '
          f'{"← INVALID" if cnt > 0 else "✓ all valid"}')

### A2. Prediction Set Size Breakdown

In [None]:
methods_vis = [
    ('Std CP',     'ps_std',    '#1f77b4'),
    ('WCP-GNN',    'ps_gnn',    '#2ca02c'),
    ('WCP+EM-LSC', 'ps_em',     '#d62728'),
    ('WCP+Oracle', 'ps_oracle', '#9467bd'),
]

fig, axes = plt.subplots(1, 4, figsize=(22, 5))
x = np.arange(len(COMMON_PATHOLOGIES))

for ax, (label, ps_key, base_col) in zip(axes, methods_vis):
    f0 = np.array([detail[p][ps_key].sum(axis=1) == 0 for p in COMMON_PATHOLOGIES], dtype=float).mean(axis=1)
    f1 = np.array([detail[p][ps_key].sum(axis=1) == 1 for p in COMMON_PATHOLOGIES], dtype=float).mean(axis=1)
    f2 = np.array([detail[p][ps_key].sum(axis=1) == 2 for p in COMMON_PATHOLOGIES], dtype=float).mean(axis=1)
    avg= np.array([detail[p][ps_key].sum(axis=1).mean() for p in COMMON_PATHOLOGIES])

    ax.bar(x, f0,         color='#d62728', alpha=0.85, label='|C|=0 (empty)')
    ax.bar(x, f1, bottom=f0,     color='#2ca02c', alpha=0.85, label='|C|=1 (singleton)')
    ax.bar(x, f2, bottom=f0+f1, color='#ff7f0e', alpha=0.85, label='|C|=2 (full/defer)')

    for xi, av in zip(x, avg):
        ax.text(xi, 1.02, f'{av:.2f}', ha='center', va='bottom', fontsize=8)

    ax.set_xticks(x)
    ax.set_xticklabels([p[:9] for p in COMMON_PATHOLOGIES], rotation=35, ha='right')
    ax.set_ylabel('Fraction of test samples')
    ax.set_title(f'{label}  (α=0.10)', fontsize=11, fontweight='bold')
    ax.legend(fontsize=8); ax.set_ylim(0, 1.12)

plt.suptitle('A2: Prediction Set Size Distribution per Method\n'
             'Number above bar = average |C(X)|;  ideal = 1.0',
             fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

# Summary table
print(f'\nA2 Summary: singleton rate (f₁) and average set size at α={ALPHA}:')
hdr2 = (f"{'Pathology':<16} | "
        f"{'Std f1':>7} {'avg':>5} | "
        f"{'GNN f1':>7} {'avg':>5} | "
        f"{'EM f1':>7} {'avg':>5} | "
        f"{'Oracle f1':>9} {'avg':>5}")
print(hdr2); print('-' * len(hdr2))
for path in COMMON_PATHOLOGIES:
    d = detail[path]
    def ss(ps_key):
        s = d[ps_key].sum(axis=1)
        return f'{(s==1).mean():>7.3f} {s.mean():>5.2f}'
    print(f'{path:<16} | {ss("ps_std")} | {ss("ps_gnn")} | {ss("ps_em")} | {ss("ps_oracle"):>9}')

### A3. Singleton Error Rate — FNR / FPR

**Core question**: Does LSC reduce the dangerously high FNR on singleton decisions?

Recall: with q̂ = 1.0 (binary RAPS), every singleton `{k}` satisfies
"model's top prediction = k". So singleton FNR = P(top prediction = 0 | true label = 1,
sample is singleton) = fraction of true positives for which the model ranks class 1 second.

In [None]:
singleton_rows = []

for path in COMMON_PATHOLOGIES:
    d  = detail[path]
    yn = d['yn']
    row = {'Pathology': path, 'n_pos': int(yn.sum())}

    for label, ps_key in [('Std CP','ps_std'), ('WCP-GNN','ps_gnn'),
                           ('WCP+EM','ps_em'),  ('WCP+Oracle','ps_oracle')]:
        ps   = d[ps_key]
        sizes = ps.sum(axis=1)
        singleton_idx = np.where(sizes == 1)[0]
        n_s = len(singleton_idx)

        if n_s == 0:
            row[f'{label}_n_pct']  = '0 (0%)'
            row[f'{label}_FNR']    = np.nan
            row[f'{label}_FPR']    = np.nan
            continue

        preds_s = ps[singleton_idx].argmax(axis=1)   # top prediction
        y_s     = yn[singleton_idx]
        pos     = y_s == 1;  neg = y_s == 0

        fnr = float((preds_s[pos] == 0).mean()) if pos.sum() > 0 else np.nan
        fpr = float((preds_s[neg] == 1).mean()) if neg.sum() > 0 else np.nan

        row[f'{label}_n_pct'] = f'{n_s} ({100*n_s/len(yn):.0f}%)'
        row[f'{label}_FNR']   = round(fnr, 4)
        row[f'{label}_FPR']   = round(fpr, 4)

    singleton_rows.append(row)

df_single = pd.DataFrame(singleton_rows).set_index('Pathology')

print(f'A3: Singleton Error Rate at α={ALPHA}')
print('=' * 100)
hdr3 = (f"{'Pathology':<16} {'n_pos':>6} | "
        f"{'GNN n_single':>14} {'FNR':>6} {'FPR':>6} | "
        f"{'EM-LSC n_single':>16} {'FNR':>6} {'FPR':>6} | "
        f"{'Oracle n_single':>16} {'FNR':>6} {'FPR':>6}")
print(hdr3); print('-' * len(hdr3))

for path in COMMON_PATHOLOGIES:
    r = df_single.loc[path]
    def fmt(m):
        n   = str(r.get(f'{m}_n_pct', '—'))
        fnr = r.get(f'{m}_FNR', np.nan)
        fpr = r.get(f'{m}_FPR', np.nan)
        fnr_s = f'{fnr:.3f}' if not np.isnan(fnr) else '  n/a'
        fpr_s = f'{fpr:.3f}' if not np.isnan(fpr) else '  n/a'
        return f'{n:>14} {fnr_s:>6} {fpr_s:>6}'
    print(f'{path:<16} {r["n_pos"]:>6} | {fmt("WCP-GNN")} | {fmt("WCP+EM")} | {fmt("WCP+Oracle")}')

# FNR/FPR bar plot
fig, axes = plt.subplots(1, 2, figsize=(16, 5))
x = np.arange(len(COMMON_PATHOLOGIES))
width = 0.25

vis_methods = [('WCP-GNN','#2ca02c'), ('WCP+EM','#d62728'), ('WCP+Oracle','#9467bd')]

for ax, metric in zip(axes, ['FNR', 'FPR']):
    for m_i, (method, col) in enumerate(vis_methods):
        vals = [float(df_single.loc[p].get(f'{method}_{metric}', np.nan))
                for p in COMMON_PATHOLOGIES]
        vals = [v if not np.isnan(v) else 0 for v in vals]
        ax.bar(x + (m_i - 1) * width, vals, width, label=method, color=col, alpha=0.8)
    ax.set_xticks(x)
    ax.set_xticklabels(COMMON_PATHOLOGIES, rotation=30, ha='right')
    ax.set_ylabel(metric)
    lbl = ('False Negative Rate (missed diagnoses)' if metric == 'FNR'
           else 'False Positive Rate (false alarms)')
    ax.set_title(f'Singleton {metric}: {lbl}\n(non-deferred samples only, α=0.10)')
    ax.legend(fontsize=9); ax.set_ylim(0, 1.05)

plt.suptitle('A3: Singleton Error Rate — Impact of Label Shift Correction on FNR & FPR',
             fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

### A4. q̂ Stability — Weighted CDF of Calibration RAPS Scores

LSC shifts the distribution of calibration RAPS scores because it recalibrates
the model probabilities.  Under lower target prevalence:
- True-negative calibration samples → model becomes MORE confident negative → RAPS score drops
- True-positive calibration samples → model becomes LESS confident positive → RAPS score rises

If NIH-like calibration samples are predominantly negatives, the weighted quantile
q̂ drops further under LSC, potentially creating more singletons.

In [None]:
focus_paths = ['Cardiomegaly', 'Pneumothorax', 'Effusion', 'Atelectasis']
fig, axes = plt.subplots(len(focus_paths), 3, figsize=(18, 5 * len(focus_paths)))

def plot_weighted_cdf(ax, scores, weights, label, color, alpha_line=ALPHA, title=''):
    """Plot weighted vs unweighted empirical CDF and mark q̂."""
    sort_idx = np.argsort(scores)
    s_sorted = scores[sort_idx]
    w_sorted = weights[sort_idx]
    w_norm   = w_sorted / w_sorted.sum()
    cum_w    = np.cumsum(w_norm)
    cum_uw   = np.arange(1, len(s_sorted) + 1) / len(s_sorted)

    ax.step(s_sorted, cum_uw, where='post', color='#1f77b4', lw=1.5, alpha=0.7, label='Unweighted CDF')
    ax.step(s_sorted, cum_w,  where='post', color=color,     lw=2.0, alpha=0.9, label=f'Weighted CDF ({label})')

    target  = 1 - alpha_line
    idx_q   = min(int(np.searchsorted(cum_w, target, side='left')), len(s_sorted) - 1)
    q_hat_w = s_sorted[idx_q]

    ax.axvline(q_hat_w, color=color, linestyle='--', lw=1.8, label=f'q̂ weighted = {q_hat_w:.3f}')
    ax.axhline(target, color='gray', linestyle=':', lw=1.0, alpha=0.5, label=f'1−α = {target:.2f}')

    ess_pct = float(weights.sum()**2 / (weights**2).sum()) / len(weights) * 100
    ax.set_title(f'{title}\n{label} | ESS={ess_pct:.1f}% | q̂={q_hat_w:.3f}',
                 fontsize=9, fontweight='bold')
    ax.set_xlabel('RAPS score'); ax.set_ylabel('Cumulative probability')
    ax.legend(fontsize=7)
    ax.set_xlim(max(-0.05, float(s_sorted.min()) - 0.05),
                min(float(s_sorted.max()) + 0.05, 1.25))

for row_i, path in enumerate(focus_paths):
    d = detail[path]
    plot_weighted_cdf(axes[row_i, 0], d['cal_scores_gnn'],    d['cal_w_gnn'],    'GNN (no LSC)',   '#2ca02c', title=path)
    plot_weighted_cdf(axes[row_i, 1], d['cal_scores_em'],     d['cal_w_em'],     'GNN+EM-LSC',    '#d62728', title=path)
    plot_weighted_cdf(axes[row_i, 2], d['cal_scores_oracle'], d['cal_w_oracle'], 'GNN+Oracle-LSC','#9467bd', title=path)

plt.suptitle('A4: Weighted CDF of Calibration RAPS Scores\n'
             'Columns: no LSC | EM-LSC | Oracle-LSC;  Rows: focus pathologies',
             fontsize=12, fontweight='bold', y=1.01)
plt.tight_layout()
plt.show()

# q̂ table (using median test weight as representative)
def scalar_qhat(scores_sorted, cal_w, test_weight, alpha):
    all_w = np.append(cal_w, test_weight)
    p     = all_w / all_w.sum()
    cum_p = np.cumsum(p[:-1])
    reached = cum_p >= (1 - alpha)
    if not reached.any(): return float('inf')
    return float(scores_sorted[int(np.argmax(reached))])

print(f'\nq̂ at median test weight (α={ALPHA}):')
hdr4 = f"{'Pathology':<16} {'Std q̂':>8} {'GNN q̂':>8} {'EM-LSC q̂':>11} {'Oracle q̂':>11}"
print(hdr4); print('-'*len(hdr4))
for path in COMMON_PATHOLOGIES:
    d   = detail[path]
    med_w = float(np.median(d['wn_gnn']))
    q_gnn = scalar_qhat(d['cal_scores_gnn'],    d['cal_w_gnn'],    med_w, ALPHA)
    q_em  = scalar_qhat(d['cal_scores_em'],      d['cal_w_em'],     med_w, ALPHA)
    q_orc = scalar_qhat(d['cal_scores_oracle'],  d['cal_w_oracle'], med_w, ALPHA)
    fmt = lambda q: f'{q:.3f}' if q != float('inf') else '  ∞'
    print(f'{path:<16} {d["q_std"]:>8.3f} {fmt(q_gnn):>8} {fmt(q_em):>11} {fmt(q_orc):>11}')

### A5. FNR Breakdown: Positives in Singleton Set

In [None]:
# Breakdown: for singleton positive samples, how many are predicted correctly?
print('A5: Positive samples in singleton set and their predictions')
print('(key: with q̂=1.0, singleton prediction = model top-ranked class)')
print('='*100)

for path in COMMON_PATHOLOGIES:
    d   = detail[path]
    yn  = d['yn']
    print(f'\n{path}  (n_pos={yn.sum()}, n_nih={len(yn)})')
    hdr5 = (f"  {'Method':<18} {'n_single':>9} {'n_single_pos':>13} "
            f"{'n_pos_correct':>15} {'FNR':>7} {'FPR':>7}")
    print(hdr5); print('  ' + '-'*(len(hdr5)-2))

    for method, ps_key in [('WCP-GNN','ps_gnn'), ('WCP+EM-LSC','ps_em'), ('WCP+Oracle','ps_oracle')]:
        ps    = d[ps_key]
        sizes = ps.sum(axis=1)
        sidx  = np.where(sizes == 1)[0]
        n_s   = len(sidx)
        if n_s == 0:
            print(f'  {method:<18} {0:>9}'); continue

        preds = ps[sidx].argmax(axis=1)
        y_s   = yn[sidx]
        pos   = y_s == 1;  neg = y_s == 0
        n_pos_s   = int(pos.sum())
        n_pos_c   = int((preds[pos] == 1).sum()) if n_pos_s > 0 else 0
        fnr = 1 - n_pos_c / n_pos_s if n_pos_s > 0 else np.nan
        fpr = float((preds[neg] == 1).mean()) if neg.sum() > 0 else np.nan
        print(f'  {method:<18} {n_s:>9} {n_pos_s:>13} {n_pos_c:>15} '
              f'{fnr:>7.3f}' + f'{fpr:>8.3f}')

## 13. Comparison with gnn_dre_wcp_report.md Baseline

In [None]:
# Reference GNN-DRE results from gnn_dre_wcp_report.md (Table 5, α=0.10)
REF_GNN = {
    'Atelectasis':   dict(defer=0.952, cov=0.994, fnr=0.784, fpr=0.092, f1=0.05, avg=1.95),
    'Cardiomegaly':  dict(defer=0.038, cov=0.887, fnr=0.779, fpr=0.059, f1=0.96, avg=0.96),
    'Consolidation': dict(defer=0.120, cov=0.849, fnr=0.861, fpr=0.023, f1=0.88, avg=0.88),
    'Edema':         dict(defer=0.956, cov=0.996, fnr=0.500, fpr=0.096, f1=0.04, avg=1.96),
    'Effusion':      dict(defer=0.230, cov=0.923, fnr=0.527, fpr=0.082, f1=0.77, avg=1.23),
    'Pneumonia':     dict(defer=0.042, cov=0.883, fnr=0.872, fpr=0.073, f1=0.96, avg=0.96),
    'Pneumothorax':  dict(defer=0.087, cov=0.901, fnr=0.992, fpr=0.005, f1=0.91, avg=0.91),
}

cmp_rows = []
for path in COMMON_PATHOLOGIES:
    ref = REF_GNN[path]
    d   = detail[path]
    r   = all_results[path]
    em  = at_alpha(r['wcp_gnn_em'])

    # Singleton FNR for EM-LSC
    ps   = d['ps_em']; sizes = ps.sum(axis=1)
    sidx = np.where(sizes == 1)[0]
    yn   = d['yn']
    if len(sidx) > 0:
        preds = ps[sidx].argmax(axis=1); y_s = yn[sidx]
        pos = y_s == 1; neg = y_s == 0
        fnr_em = float((preds[pos] == 0).mean()) if pos.sum() > 0 else np.nan
        fpr_em = float((preds[neg] == 1).mean()) if neg.sum() > 0 else np.nan
        f1_em  = float((sizes==1).mean())
    else:
        fnr_em = fpr_em = f1_em = np.nan

    cmp_rows.append({
        'Pathology':         path,
        'Ref-GNN Defer':     f"{ref['defer']:.3f}",
        'EM-LSC Defer':      f"{em.deferral_rate:.3f}",
        'Δ Defer':           f"{em.deferral_rate - ref['defer']:+.3f}",
        'Ref-GNN Cov':       f"{ref['cov']:.3f}",
        'EM-LSC Cov':        f"{em.coverage_rate:.3f}",
        'Ref-GNN FNR':       f"{ref['fnr']:.3f}",
        'EM-LSC FNR':        f"{fnr_em:.3f}" if not np.isnan(fnr_em) else 'n/a',
        'Δ FNR':             f"{fnr_em - ref['fnr']:+.3f}" if not np.isnan(fnr_em) else 'n/a',
        'Ref-GNN f₁':        f"{ref['f1']:.3f}",
        'EM-LSC f₁':         f"{f1_em:.3f}" if not np.isnan(f1_em) else 'n/a',
    })

df_cmp = pd.DataFrame(cmp_rows)
print('Comparison: WCP-GNN (baseline) vs WCP-GNN+EM-LSC (new)  — α=0.10')
print('='*120)
print(df_cmp.to_string(index=False))

## 14. Final Summary

In [None]:
print('=' * 75)
print(f'GNN-DRE + EM-LSC Summary  (α={ALPHA})')
print('=' * 75)

print('\nEstimated vs True NIH Prevalences:')
for k, path in enumerate(COMMON_PATHOLOGIES):
    print(f'  {path:<16}  src={prior_src[k]:.4f}  oracle={prior_tgt_oracle[k]:.4f}  '
          f'EM={prior_tgt_em[k]:.4f}  '
          f'err={abs(prior_tgt_em[k]-prior_tgt_oracle[k])*100:+.2f}pp')

print('\nMean results at α=0.10:')
for label, key in [('Std CP','std_cp'), ('WCP-GNN','wcp_gnn'),
                    ('WCP+EM-LSC','wcp_gnn_em'), ('WCP+Oracle','wcp_gnn_oracle')]:
    mean_d = np.mean([at_alpha(all_results[p][key]).deferral_rate for p in COMMON_PATHOLOGIES])
    mean_c = np.mean([at_alpha(all_results[p][key]).coverage_rate  for p in COMMON_PATHOLOGIES])
    print(f'  {label:<20}  defer={mean_d:.3f}   coverage={mean_c:.3f}')

print('\nKey pathology outcomes (EM-LSC vs baseline GNN):')
for path in COMMON_PATHOLOGIES:
    gnn_d  = at_alpha(all_results[path]['wcp_gnn']).deferral_rate
    em_d   = at_alpha(all_results[path]['wcp_gnn_em']).deferral_rate
    em_cov = at_alpha(all_results[path]['wcp_gnn_em']).coverage_rate
    arrow  = '↓' if em_d < gnn_d - 0.01 else ('↑' if em_d > gnn_d + 0.01 else '≈')
    print(f'  {path:<16}  GNN defer={gnn_d:.3f}  EM-LSC defer={em_d:.3f}  '
          f'({arrow}{abs(em_d-gnn_d):.3f})  cov={em_cov:.3f}')