# PSIS-Conformal on More Complex Data: BART vs Bayesian Linear Regression

We generate a heteroscedastic, nonlinear dataset in 2D and compare PSIS-Conformal intervals from BART and a conjugate Bayesian linear regression on a rich basis (polynomials + sin/cos).

In [None]:
# Make repo importable when run from examples/
import sys
from pathlib import Path
repo_root = Path.cwd()
if (repo_root / 'examples').exists() and (repo_root / 'bartpy').exists():
    pass
elif (repo_root.name == 'examples') and (repo_root.parent / 'bartpy').exists():
    repo_root = repo_root.parent
sys.path.insert(0, str(repo_root))
print('Repo root:', repo_root)

In [None]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt

from bartpy.sklearnmodel import SklearnModel
from bartpy import (prepare_bart_loglik_and_draws, loo_residuals_via_psis, conformal_quantile, build_intervals)
from examples.bayes_linreg import posterior_draws_and_loglik

In [None]:
# Complex synthetic data: nonlinear + heteroscedastic noise
rng = np.random.default_rng(123)
n = 300
X = np.empty((n, 2))
X[:, 0] = rng.uniform(-3.0, 3.0, size=n)  # x1
X[:, 1] = rng.uniform(-2.0, 2.0, size=n)  # x2
f_true = np.sin(X[:, 0]) + 0.5 * (X[:, 1] ** 2) + 0.3 * X[:, 0] * X[:, 1]
sigma_x = 0.2 + 0.15 * (np.abs(X[:, 0]) > 1.0) + 0.1 * (X[:, 1] > 0)
y = f_true + rng.normal(0.0, sigma_x)
X[:3], y[:3], sigma_x[:3]

## BART: Fit once, PSIS-LOO conformal

In [None]:
model = SklearnModel(
    n_trees=150, n_chains=2, n_samples=250, n_burn=250, thin=0.5,
    store_in_sample_predictions=True, store_acceptance_trace=True, n_jobs=1,
)
model.fit(X, y)
y_obs_bart, draws_bart, loglik_bart = prepare_bart_loglik_and_draws(model)
res_bart, k_bart, loo_bart = loo_residuals_via_psis(y_obs_bart, draws_bart, loglik_bart)
q_bart = conformal_quantile(res_bart, alpha=0.1)
lo_bart, hi_bart = build_intervals(loo_bart, q_bart)
cov_bart = np.mean((y_obs_bart >= lo_bart) & (y_obs_bart <= hi_bart))
width_bart = np.mean(hi_bart - lo_bart)
print(f'BART: coverage={cov_bart:.3f}, mean width={width_bart:.3f}, k max={np.nanmax(k_bart):.3f}')

## Bayesian Linear Regression (rich basis): PSIS-LOO conformal

In [None]:
y_obs_blr, draws_blr, loglik_blr = posterior_draws_and_loglik(X, y, S=800, tau2=10.0, a0=2.0, b0=1.0, seed=99)
res_blr, k_blr, loo_blr = loo_residuals_via_psis(y_obs_blr, draws_blr, loglik_blr)
q_blr = conformal_quantile(res_blr, alpha=0.1)
lo_blr, hi_blr = build_intervals(loo_blr, q_blr)
cov_blr = np.mean((y_obs_blr >= lo_blr) & (y_obs_blr <= hi_blr))
width_blr = np.mean(hi_blr - lo_blr)
print(f'BLR: coverage={cov_blr:.3f}, mean width={width_blr:.3f}, k max={np.nanmax(k_blr):.3f}')

## Diagnostics and visualization (projected by x1)

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
# Pareto k
axes[0,0].hist(k_bart[np.isfinite(k_bart)], bins=20, color='#4472c4')
axes[0,0].axvline(0.5, color='r', ls='--'); axes[0,0].axvline(0.7, color='orange', ls='--')
axes[0,0].set_title('BART Pareto k')
axes[0,1].hist(k_blr[np.isfinite(k_blr)], bins=20, color='#6eaa2c')
axes[0,1].axvline(0.5, color='r', ls='--'); axes[0,1].axvline(0.7, color='orange', ls='--')
axes[0,1].set_title('BLR Pareto k')
# Intervals vs x1
order = np.argsort(X[:,0])
axes[1,0].plot(X[order,0], y[order], 'k.', alpha=0.4)
axes[1,0].plot(X[order,0], loo_bart[order], color='#2ca02c', label='BART LOO mean')
axes[1,0].fill_between(X[order,0], lo_bart[order], hi_bart[order], color='#9bd39b', alpha=0.4, label='BART 90%')
axes[1,0].set_title('BART intervals (train)'); axes[1,0].legend()
axes[1,1].plot(X[order,0], y[order], 'k.', alpha=0.4)
axes[1,1].plot(X[order,0], loo_blr[order], color='#1f77b4', label='BLR LOO mean')
axes[1,1].fill_between(X[order,0], lo_blr[order], hi_blr[order], color='#7da6d9', alpha=0.4, label='BLR 90%')
axes[1,1].set_title('BLR intervals (train)'); axes[1,1].legend()
plt.tight_layout(); plt.show()

## Train/Cal/Test split + credible vs split vs PSIS (test coverage)

In [None]:
from sklearn.model_selection import train_test_split
# Split the complex dataset
X_tr, X_tmp, y_tr, y_tmp = train_test_split(X, y, test_size=0.4, random_state=123)
X_cal, X_te,  y_cal, y_te  = train_test_split(X_tmp, y_tmp, test_size=0.5, random_state=124)
X_tr.shape, X_cal.shape, X_te.shape

### BLR: credible vs split vs PSIS (test)

In [None]:
from examples.bayes_linreg import design_matrix, bayes_linreg_posterior, sample_posterior, posterior_draws_and_loglik
# Fit BLR on TRAIN
Phi_tr = design_matrix(X_tr)
beta_n, Vn, a_n, b_n = bayes_linreg_posterior(Phi_tr, y_tr, tau2=10.0, a0=2.0, b0=1.0)
S = 800
beta_draws, sigma_draws = sample_posterior(beta_n, Vn, a_n, b_n, S=S, rng=np.random.default_rng(999))

def blr_predict_draws_on(X_new):
    Phi = design_matrix(X_new)
    f = Phi @ beta_draws.T
    eps = np.random.default_rng(77).normal(size=f.shape) * sigma_draws[None, :]
    ypred = f + eps
    return f, ypred

# Split-conformal q on CAL
f_cal, ypred_cal = blr_predict_draws_on(X_cal)
pred_med_cal = np.median(f_cal, axis=1)
q_split_blr = conformal_quantile(np.abs(y_cal - pred_med_cal), alpha=0.1)

# PSIS-Conformal q from TRAIN
y_psis_tr, f_psis_tr, loglik_psis_tr = posterior_draws_and_loglik(X_tr, y_tr, S=S, seed=42)
res_psis_blr, k_psis_blr, _ = loo_residuals_via_psis(y_psis_tr, f_psis_tr, loglik_psis_tr)
q_psis_blr = conformal_quantile(res_psis_blr, alpha=0.1)

# Test predictions
f_te_blr, ypred_te_blr = blr_predict_draws_on(X_te)
pred_med_te_blr = np.median(f_te_blr, axis=1)
# credible 90%
lo_cr_blr, hi_cr_blr = np.percentile(ypred_te_blr, [5, 95], axis=1)
# split-conformal
lo_sc_blr, hi_sc_blr = build_intervals(pred_med_te_blr, q_split_blr)
# psis-conformal
lo_pc_blr, hi_pc_blr = build_intervals(pred_med_te_blr, q_psis_blr)

def cov_width(lo, hi):
    return float(np.mean((y_te >= lo) & (y_te <= hi))), float(np.mean(hi - lo))
cov_cr, w_cr = cov_width(lo_cr_blr, hi_cr_blr)
cov_sc, w_sc = cov_width(lo_sc_blr, hi_sc_blr)
cov_pc, w_pc = cov_width(lo_pc_blr, hi_pc_blr)
print(f'BLR credible 90%: cov={cov_cr:.3f}, width={w_cr:.3f}')
print(f'BLR split    90%: cov={cov_sc:.3f}, width={w_sc:.3f}')
print(f'BLR PSIS     90%: cov={cov_pc:.3f}, width={w_pc:.3f}, k_max={np.nanmax(k_psis_blr):.3f}')

### BART: credible vs split vs PSIS (test)

In [None]:
# Fit BART on TRAIN
model_te = SklearnModel(n_trees=150, n_chains=2, n_samples=250, n_burn=250, thin=0.5,
                         store_in_sample_predictions=True, store_acceptance_trace=True, n_jobs=1)
model_te.fit(X_tr, y_tr)

# Helper: per-sample out-of-sample predictions (unnormalized) and sigma (unnormalized)
def bart_f_sigma_unnorm(model, X_new):
    S = len(model._model_samples)
    n_new = X_new.shape[0]
    f = np.empty((n_new, S))
    sig = np.empty(S)
    for s, m in enumerate(model._model_samples):
        mu = m._out_of_sample_predict(X_new)  # normalized scale
        f[:, s] = m.data.y.unnormalize_y(mu)  # back to original scale
        sig[s] = m.sigma.current_unnormalized_value()
    return f, sig

# Split-conformal on CAL
f_cal_bt, sigma_bt = bart_f_sigma_unnorm(model_te, X_cal)
pred_med_cal_bt = np.median(f_cal_bt, axis=1)
q_split_bt = conformal_quantile(np.abs(y_cal - pred_med_cal_bt), alpha=0.1)

# PSIS-Conformal q from TRAIN
y_tr_bt, f_tr_bt, loglik_tr_bt = prepare_bart_loglik_and_draws(model_te)
res_psis_bt, k_psis_bt, _ = loo_residuals_via_psis(y_tr_bt, f_tr_bt, loglik_tr_bt)
q_psis_bt = conformal_quantile(res_psis_bt, alpha=0.1)

# Test predictions
f_te_bt, sigma_te_bt = bart_f_sigma_unnorm(model_te, X_te)
pred_med_te_bt = np.median(f_te_bt, axis=1)
# credible via Gaussian draws
eps_bt = np.random.default_rng(321).normal(size=f_te_bt.shape) * sigma_te_bt[None, :]
ypred_te_bt = f_te_bt + eps_bt
lo_cr_bt, hi_cr_bt = np.percentile(ypred_te_bt, [5, 95], axis=1)
# split
lo_sc_bt, hi_sc_bt = build_intervals(pred_med_te_bt, q_split_bt)
# psis
lo_pc_bt, hi_pc_bt = build_intervals(pred_med_te_bt, q_psis_bt)

cov_cr_bt, w_cr_bt = float(np.mean((y_te >= lo_cr_bt) & (y_te <= hi_cr_bt))), float(np.mean(hi_cr_bt - lo_cr_bt))
cov_sc_bt, w_sc_bt = float(np.mean((y_te >= lo_sc_bt) & (y_te <= hi_sc_bt))), float(np.mean(hi_sc_bt - lo_sc_bt))
cov_pc_bt, w_pc_bt = float(np.mean((y_te >= lo_pc_bt) & (y_te <= hi_pc_bt))), float(np.mean(hi_pc_bt - lo_pc_bt))
print(f'BART credible 90%: cov={cov_cr_bt:.3f}, width={w_cr_bt:.3f}')
print(f'BART split    90%: cov={cov_sc_bt:.3f}, width={w_sc_bt:.3f}')
print(f'BART PSIS     90%: cov={cov_pc_bt:.3f}, width={w_pc_bt:.3f}, k_max={np.nanmax(k_psis_bt):.3f}')

### Overlay plots (test): credible vs split vs PSIS (high-contrast)

In [None]:
order = np.argsort(X_te[:,0])
x = X_te[order,0]

fig, axes = plt.subplots(1, 2, figsize=(12,4))
# Colors
c_pts = '0.25'
c_med_blr, c_med_bt = '#1f77b4', '#2ca02c'
c_cred, c_split, c_psis = '#ff7f0e', '#9467bd', '#17becf'

# BLR overlay
ax = axes[0]
ax.plot(x, y_te[order], 'o', ms=3, color=c_pts, alpha=0.6, label='y', zorder=5)
ax.plot(x, pred_med_te_blr[order], color=c_med_blr, lw=2.2, label='BLR median', zorder=6)
for lo, hi, col, ls, lab in [
    (lo_cr_blr, hi_cr_blr, c_cred,  '--', 'Credible 90%'),
    (lo_sc_blr, hi_sc_blr, c_split, ':',  'Split 90%'),
    (lo_pc_blr, hi_pc_blr, c_psis,  '-.', 'PSIS 90%'),
]:
    lo_o, hi_o = lo[order], hi[order]
    ax.fill_between(x, lo_o, hi_o, color=col, alpha=0.18, zorder=1)
    ax.plot(x, lo_o, color=col, ls=ls, lw=1.6, zorder=2)
    ax.plot(x, hi_o, color=col, ls=ls, lw=1.6, zorder=2, label=lab)
ax.set_title('BLR (test)')
ax.legend()

# BART overlay
ax = axes[1]
ax.plot(x, y_te[order], 'o', ms=3, color=c_pts, alpha=0.6, label='y', zorder=5)
ax.plot(x, pred_med_te_bt[order], color=c_med_bt, lw=2.2, label='BART median', zorder=6)
for lo, hi, col, ls, lab in [
    (lo_cr_bt, hi_cr_bt, c_cred,  '--', 'Credible 90%'),
    (lo_sc_bt, hi_sc_bt, c_split, ':',  'Split 90%'),
    (lo_pc_bt, hi_pc_bt, c_psis,  '-.', 'PSIS 90%'),
]:
    lo_o, hi_o = lo[order], hi[order]
    ax.fill_between(x, lo_o, hi_o, color=col, alpha=0.18, zorder=1)
    ax.plot(x, lo_o, color=col, ls=ls, lw=1.6, zorder=2)
    ax.plot(x, hi_o, color=col, ls=ls, lw=1.6, zorder=2, label=lab)
ax.set_title('BART (test)')
ax.legend()
plt.tight_layout(); plt.show()

### Summary table (test)

In [None]:
import pandas as pd
rows = [
    ('BLR','credible', cov_cr,    w_cr),
    ('BLR','split',    cov_sc,    w_sc),
    ('BLR','psis',     cov_pc,    w_pc),
    ('BART','credible', cov_cr_bt, w_cr_bt),
    ('BART','split',    cov_sc_bt, w_sc_bt),
    ('BART','psis',     cov_pc_bt, w_pc_bt),
]
df = pd.DataFrame(rows, columns=['model','method','coverage','mean_width'])
display(df)
print('Pareto k max: BLR={:.3f}, BART={:.3f}'.format(np.nanmax(k_psis_blr), np.nanmax(k_psis_bt)))