# Prior/Likelihood Misspecification: Credible vs Conformal Coverage

목표: heavy-tailed(학생 t) 노이즈에서 정규 노이즈를 가정한 베이지안 모델(BART, 공액 BLR)의
posterior credible 구간이 언더커버될 수 있음을 보이고, split-conformal로 커버리지를 복원합니다.

In [None]:
# Repo import path
import sys
from pathlib import Path
repo_root = Path.cwd()
if (repo_root / 'examples').exists() and (repo_root / 'bartpy').exists():
    pass
elif (repo_root.name == 'examples') and (repo_root.parent / 'bartpy').exists():
    repo_root = repo_root.parent
sys.path.insert(0, str(repo_root))
print('Repo root:', repo_root)

In [None]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split

from bartpy.sklearnmodel import SklearnModel
from bartpy import conformal_quantile, build_intervals
from examples.bayes_linreg import design_matrix, bayes_linreg_posterior, sample_posterior

In [None]:
# Heavy-tailed noise data (t_3)
rng = np.random.default_rng(7)
n = 600
X = rng.uniform(-3.0, 3.0, size=(n, 1))
f_true = np.sin(X[:, 0])
# 학생 t(df=3) -> 평균 0, 스케일 맞춤
df = 3.0
scale = 0.3
z = rng.standard_t(df, size=n) * scale
y = f_true + z

# Split: train / cal / test = 60/20/20
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, random_state=42)
X_cal, X_test, y_cal, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=43)
X_train.shape, X_cal.shape, X_test.shape

## BLR(정규 노이즈 가정) 학습 및 평가

In [None]:
# Conjugate BLR posterior from TRAIN only
Phi_tr = design_matrix(X_train)
beta_n, Vn, a_n, b_n = bayes_linreg_posterior(Phi_tr, y_train, tau2=5.0, a0=2.0, b0=1.0)

# Draw posterior samples
S = 1000
beta_draws, sigma_draws = sample_posterior(beta_n, Vn, a_n, b_n, S=S, rng=np.random.default_rng(123))

def blr_predict_draws(X_new):
    Phi = design_matrix(X_new)
    f = Phi @ beta_draws.T  # (n_new, S)
    # Predictive draws y = f + Normal(0, sigma_s)
    eps = np.random.default_rng(99).normal(size=f.shape) * sigma_draws[None, :]
    ypred = f + eps
    return f, ypred

# Calibration predictions
f_cal_blr, ypred_cal_blr = blr_predict_draws(X_cal)
pred_point_cal_blr = np.median(f_cal_blr, axis=1)
res_cal_blr = np.abs(y_cal - pred_point_cal_blr)
q_blr = conformal_quantile(res_cal_blr, alpha=0.1)

# Test predictions and intervals (split-conformal)
f_test_blr, ypred_test_blr = blr_predict_draws(X_test)
pred_point_test_blr = np.median(f_test_blr, axis=1)
lo_c_blr, hi_c_blr = build_intervals(pred_point_test_blr, q_blr)
cov_conf_blr = np.mean((y_test >= lo_c_blr) & (y_test <= hi_c_blr))
width_conf_blr = np.mean(hi_c_blr - lo_c_blr)

# Posterior predictive credible intervals (e.g., 90%)
lo_cr_blr, hi_cr_blr = np.percentile(ypred_test_blr, [5, 95], axis=1)
cov_cred_blr = np.mean((y_test >= lo_cr_blr) & (y_test <= hi_cr_blr))
width_cred_blr = np.mean(hi_cr_blr - lo_cr_blr)

print(f'BLR credible 90%: coverage={cov_cred_blr:.3f}, width={width_cred_blr:.3f}')
print(f'BLR conformal 90%: coverage={cov_conf_blr:.3f}, width={width_conf_blr:.3f}')

## BART(정규 노이즈 가정) 학습 및 평가

In [None]:
model = SklearnModel(
    n_trees=120, n_chains=2, n_samples=250, n_burn=250, thin=0.5,
    store_in_sample_predictions=True, store_acceptance_trace=True, n_jobs=1,
)
model.fit(X_train, y_train)

# Build per-sample predictions for new X using stored model samples
def bart_f_sigma_on(model, X_new):
    S = len(model._model_samples)
    n_new = X_new.shape[0]
    f = np.empty((n_new, S))
    sig = np.empty(S)
    for s, m in enumerate(model._model_samples):
        f[:, s] = m.predict(X_new)  # out-of-sample mean per sample
        sig[s] = m.sigma.current_unnormalized_value()
    return f, sig

f_cal_bt, sigma_bt = bart_f_sigma_on(model, X_cal)
pred_point_cal_bt = np.median(f_cal_bt, axis=1)
q_bt = conformal_quantile(np.abs(y_cal - pred_point_cal_bt), alpha=0.1)

# Test predictions
f_test_bt, sigma_bt_test = bart_f_sigma_on(model, X_test)  # sigma per sample is global; reuse
pred_point_test_bt = np.median(f_test_bt, axis=1)
lo_c_bt, hi_c_bt = build_intervals(pred_point_test_bt, q_bt)
cov_conf_bt = np.mean((y_test >= lo_c_bt) & (y_test <= hi_c_bt))
width_conf_bt = np.mean(hi_c_bt - lo_c_bt)

# Posterior predictive credible intervals via Gaussian noise per sample
eps_bt = np.random.default_rng(1234).normal(size=f_test_bt.shape) * sigma_bt[None, :]
ypred_test_bt = f_test_bt + eps_bt
lo_cr_bt, hi_cr_bt = np.percentile(ypred_test_bt, [5, 95], axis=1)
cov_cred_bt = np.mean((y_test >= lo_cr_bt) & (y_test <= hi_cr_bt))
width_cred_bt = np.mean(hi_cr_bt - lo_cr_bt)

print(f'BART credible 90%: coverage={cov_cred_bt:.3f}, width={width_cred_bt:.3f}')
print(f'BART conformal 90%: coverage={cov_conf_bt:.3f}, width={width_conf_bt:.3f}')

## 시각화 (테스트 세트)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12,4))
order = np.argsort(X_test[:,0])
# BLR
axes[0].plot(X_test[order,0], y_test[order], 'k.', alpha=0.5, label='y')
axes[0].plot(X_test[order,0], np.median(f_test_blr, axis=1)[order], color='#1f77b4', label='BLR median')
axes[0].fill_between(X_test[order,0], lo_cr_blr[order], hi_cr_blr[order], color='#aec7e8', alpha=0.5, label='Credible 90%')
axes[0].fill_between(X_test[order,0], lo_c_blr[order], hi_c_blr[order], color='#7da6d9', alpha=0.35, label='Conformal 90%')
axes[0].set_title('BLR: credible vs conformal')
axes[0].legend()
# BART
axes[1].plot(X_test[order,0], y_test[order], 'k.', alpha=0.5, label='y')
axes[1].plot(X_test[order,0], np.median(f_test_bt, axis=1)[order], color='#2ca02c', label='BART median')
axes[1].fill_between(X_test[order,0], lo_cr_bt[order], hi_cr_bt[order], color='#b5e3b5', alpha=0.5, label='Credible 90%')
axes[1].fill_between(X_test[order,0], lo_c_bt[order], hi_c_bt[order], color='#9bd39b', alpha=0.35, label='Conformal 90%')
axes[1].set_title('BART: credible vs conformal')
axes[1].legend()
plt.tight_layout(); plt.show()

## PSIS-LOO 기반 Conformal (TRAIN에서 잔차, TEST 적용)

In [None]:
# BLR: TRAIN에서 PSIS-LOO 잔차로 q 산출
from examples.bayes_linreg import posterior_draws_and_loglik
from bartpy import loo_residuals_via_psis, conformal_quantile, build_intervals

y_tr_blr, f_tr_blr, loglik_tr_blr = posterior_draws_and_loglik(X_train, y_train, S=800, seed=99)
res_psis_blr, k_psis_blr, loo_tr_blr = loo_residuals_via_psis(y_tr_blr, f_tr_blr, loglik_tr_blr)
q_psis_blr = conformal_quantile(res_psis_blr, alpha=0.1)

# TEST 적용 (점예측은 f 중앙값)
f_test_blr, _ = blr_predict_draws(X_test)
pred_point_test_blr = np.median(f_test_blr, axis=1)
lo_psis_blr, hi_psis_blr = build_intervals(pred_point_test_blr, q_psis_blr)
cov_psis_blr = np.mean((y_test >= lo_psis_blr) & (y_test <= hi_psis_blr))
width_psis_blr = np.mean(hi_psis_blr - lo_psis_blr)
print(f'BLR PSIS-conformal 90%: coverage={cov_psis_blr:.3f}, width={width_psis_blr:.3f}, k_max={np.nanmax(k_psis_blr):.3f}')

In [None]:
# BART: TRAIN에서 PSIS-LOO 잔차로 q 산출
from bartpy import prepare_bart_loglik_and_draws, loo_residuals_via_psis, conformal_quantile, build_intervals

y_tr_bt, f_tr_bt, loglik_tr_bt = prepare_bart_loglik_and_draws(model)
res_psis_bt, k_psis_bt, loo_tr_bt = loo_residuals_via_psis(y_tr_bt, f_tr_bt, loglik_tr_bt)
q_psis_bt = conformal_quantile(res_psis_bt, alpha=0.1)

# TEST 적용 (점예측은 f 중앙값)
f_test_bt, _ = bart_f_sigma_on(model, X_test)
pred_point_test_bt = np.median(f_test_bt, axis=1)
lo_psis_bt, hi_psis_bt = build_intervals(pred_point_test_bt, q_psis_bt)
cov_psis_bt = np.mean((y_test >= lo_psis_bt) & (y_test <= hi_psis_bt))
width_psis_bt = np.mean(hi_psis_bt - lo_psis_bt)
print(f'BART PSIS-conformal 90%: coverage={cov_psis_bt:.3f}, width={width_psis_bt:.3f}, k_max={np.nanmax(k_psis_bt):.3f}')

## 비교 플롯 (credible vs split vs PSIS)

In [None]:
import matplotlib.pyplot as plt
order = np.argsort(X_test[:,0])
x = X_test[order,0]

fig, axes = plt.subplots(1, 2, figsize=(12,4))
# BLR
axes[0].plot(x, y_test[order], 'k.', ms=3, alpha=0.6, label='y')
axes[0].plot(x, np.median(f_test_blr, axis=1)[order], color='#1f77b4', lw=2, label='BLR median')
axes[0].fill_between(x, lo_cr_blr[order],   hi_cr_blr[order],   color='#ff7f0e', alpha=0.35, label='Credible 90%')
axes[0].fill_between(x, lo_c_blr[order],    hi_c_blr[order],    color='#9467bd', alpha=0.25, label='Split 90%')
axes[0].fill_between(x, lo_psis_blr[order], hi_psis_blr[order], color='#17becf', alpha=0.25, label='PSIS 90%')
axes[0].set_title('BLR: credible vs split vs PSIS'); axes[0].legend()
# BART
axes[1].plot(x, y_test[order], 'k.', ms=3, alpha=0.6, label='y')
axes[1].plot(x, np.median(f_test_bt, axis=1)[order], color='#2ca02c', lw=2, label='BART median')
axes[1].fill_between(x, lo_cr_bt[order],   hi_cr_bt[order],   color='#ff7f0e', alpha=0.35, label='Credible 90%')
axes[1].fill_between(x, lo_c_bt[order],    hi_c_bt[order],    color='#9467bd', alpha=0.25, label='Split 90%')
axes[1].fill_between(x, lo_psis_bt[order], hi_psis_bt[order], color='#17becf', alpha=0.25, label='PSIS 90%')
axes[1].set_title('BART: credible vs split vs PSIS'); axes[1].legend()
plt.tight_layout(); plt.show()