In [1]:
import json
import os
import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
from jax.scipy import special
from tqdm import tqdm
import matplotlib.pyplot as plt
import importlib.resources
import warnings
from typing import Any, Dict
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

warnings.filterwarnings('ignore')
jax.config.update("jax_enable_x64", True)
# jax.config.update("jax_disable_jit", True) # Only if needed

from bayesianquilts.metrics.ais import AdaptiveImportanceSampler, LikelihoodFunction

data_dir = os.path.expanduser("~/Library/CloudStorage/Box-Box/bouldering/ovarian")


Patched _JointDistributionNamedSpec with _structure_with_callables=None
TFP Specs patched successfully.


In [2]:
# -----------------------------------------------------
# 1. Define Reparameterized Likelihood Class
# -----------------------------------------------------

class OvarianReparametrizedLikelihood(LikelihoodFunction):
    def __init__(self, data_dict, priors):
        # Store constants only. Data (X, y) comes at calls.
        self.constants = {k: jnp.asarray(v, dtype=jnp.float64) for k, v in data_dict['constants'].items()}
        # Keep reference to FULL data for reconstruction or if needed, but not primarily for LL
        self.X_full = jnp.asarray(data_dict['X'], dtype=jnp.float64)
        self.y_full = jnp.asarray(data_dict['y'], dtype=jnp.float64)

    def _get_params(self, params):
        beta0 = jnp.asarray(params['beta0'], dtype=jnp.float64)
        z = jnp.asarray(params['z'], dtype=jnp.float64)
        tau = jnp.exp(jnp.asarray(params['log_tau'], dtype=jnp.float64))
        lam = jnp.exp(jnp.asarray(params['log_lambda'], dtype=jnp.float64))
        caux = jnp.exp(jnp.asarray(params['log_caux'], dtype=jnp.float64))
        return beta0, z, tau, lam, caux

    def _compute_beta(self, z, tau, lam, caux):
        c = self.constants['slab_scale'] * jnp.sqrt(caux)
        num = (c**2) * (lam**2)
        denom = (c**2) + (tau**2) * (lam**2)
        lambda_tilde = jnp.sqrt(num / denom)
        beta = z * lambda_tilde * tau
        return beta

    def log_likelihood(self, data: Any, params: Dict[str, Any]) -> jnp.ndarray:
        # USE DATA ARGUMENT, NOT SELF.X
        X = jnp.asarray(data['X'], dtype=jnp.float64)
        y = jnp.asarray(data['y'], dtype=jnp.float64)
        
        beta0, z, tau, lam, caux = self._get_params(params)
        
        target_ndim = z.ndim
        def align_dims(x, target_rank):
            while x.ndim < target_rank:
                x = x[..., jnp.newaxis]
            return x

        tau = align_dims(tau, target_ndim)
        caux = align_dims(caux, target_ndim)
        lam = align_dims(lam, target_ndim)
        
        beta = self._compute_beta(z, tau, lam, caux)
        
        # Logits: X @ beta.T or einsum
        # Check if we are broadcasting over LOO-expanded params (Rank 3) or standard (Rank 2)
        if beta.ndim == 3: # (S, N_batch, D)
             logits = jnp.einsum('nd,snd->sn', X, beta)
             if beta0.ndim < 2: beta0 = align_dims(beta0, 2)
             if beta0.ndim == 3: beta0 = beta0.squeeze(-1)
             logits = logits + beta0
        else: # (S, D)
             logits = jnp.dot(beta, X.T) # (S, N_batch)
             if beta0.ndim == 1: beta0 = beta0[:, jnp.newaxis]
             logits = logits + beta0
        
        y_broad = y[jnp.newaxis, :]
        ll = y_broad * logits - jnp.logaddexp(0.0, logits)
        return ll

    def unormalized_log_prob(self, data=None, **params):
        # For prior, we don't depend on data, but likelihood part does.
        # Usually AIS calls this with FULL data for 'log_ell_original'.
        # If data is None, we default to full data stored in self.
        if data is None:
             data = {'X': self.X_full, 'y': self.y_full}
             
        ll_sum = jnp.sum(self.log_likelihood(data, params), axis=1)
        
        beta0, z, tau, lam, caux = self._get_params(params)
        lp_z = jnp.sum(tfd.Normal(jnp.float64(0.0), jnp.float64(1.0)).log_prob(z), axis=1)
        lp_beta0 = tfd.Normal(jnp.float64(0.0), self.constants['scale_icept']).log_prob(jnp.squeeze(beta0))
        dist_tau = tfd.StudentT(df=self.constants['nu_global'], loc=jnp.float64(0.0), scale=self.constants['scale_global'])
        lp_tau = dist_tau.log_prob(jnp.squeeze(tau)) + jnp.log(2.0) + params['log_tau'].squeeze()
        dist_lam = tfd.StudentT(df=self.constants['nu_local'], loc=jnp.float64(0.0), scale=jnp.float64(1.0))
        lp_lam = jnp.sum(dist_lam.log_prob(lam) + jnp.log(2.0), axis=1) + jnp.sum(params['log_lambda'], axis=1)
        dist_caux = tfd.InverseGamma(concentration=0.5*self.constants['slab_df'], scale=0.5*self.constants['slab_df'])
        lp_caux = dist_caux.log_prob(jnp.squeeze(caux)) + params['log_caux'].squeeze()
        
        return ll_sum + lp_z + lp_beta0 + lp_tau + lp_lam + lp_caux

    def log_likelihood_gradient(self, data: Any, params: Dict[str, Any]) -> jnp.ndarray:
        X = jnp.asarray(data['X'], dtype=jnp.float64)
        y = jnp.asarray(data['y'], dtype=jnp.float64)
        
        def single_point_grad(theta_s, x_n, y_n):
            p = self.reconstruct_parameters(theta_s, params)
            beta0, z, tau, lam, caux = self._get_params(p)
            beta = self._compute_beta(z, tau.squeeze(), lam, caux.squeeze())
            logits = beta0.squeeze() + jnp.dot(beta, x_n)
            ll = y_n * logits - jnp.logaddexp(0.0, logits)
            return ll

        grad_fn = jax.grad(single_point_grad)
        theta = self.extract_parameters(params)
        vmap_grad = jax.vmap(grad_fn, in_axes=(None, 0, 0), out_axes=0)
        vvmap_grad = jax.vmap(vmap_grad, in_axes=(0, None, None), out_axes=0)
        grads = vvmap_grad(theta, X, y)
        return grads
        
    def log_likelihood_hessian_diag(self, data: Any, params: Dict[str, Any]) -> jnp.ndarray:
        X = jnp.asarray(data['X'], dtype=jnp.float64)
        y = jnp.asarray(data['y'], dtype=jnp.float64)
        def single_point_hess(theta_s, x_n, y_n):
            hess = jax.hessian(lambda t: 
                 self._single_point_ll_ad(t, x_n, y_n, params)
            )(theta_s)
            return jnp.diag(hess)
        theta = self.extract_parameters(params)
        vmap_hess = jax.vmap(single_point_hess, in_axes=(None, 0, 0), out_axes=0)
        vvmap_hess = jax.vmap(vmap_hess, in_axes=(0, None, None), out_axes=0)
        return vvmap_hess(theta, X, y)

    def _single_point_ll_ad(self, theta_s, x_n, y_n, template):
        p = self.reconstruct_parameters(theta_s, template)
        beta0, z, tau, lam, caux = self._get_params(p)
        beta = self._compute_beta(z, tau.squeeze(), lam, caux.squeeze())
        logits = beta0.squeeze() + jnp.dot(beta, x_n)
        ll = y_n * logits - jnp.logaddexp(0.0, logits)
        return ll

    def extract_parameters(self, params: Dict[str, Any]) -> jnp.ndarray:
        beta0 = jnp.asarray(params['beta0'], dtype=jnp.float64)
        log_tau = jnp.asarray(params['log_tau'], dtype=jnp.float64)
        log_caux = jnp.asarray(params['log_caux'], dtype=jnp.float64)
        z = jnp.asarray(params['z'], dtype=jnp.float64)
        log_lambda = jnp.asarray(params['log_lambda'], dtype=jnp.float64)
        if beta0.ndim == 1: beta0 = beta0[:, jnp.newaxis]
        if log_tau.ndim == 1: log_tau = log_tau[:, jnp.newaxis]
        if log_caux.ndim == 1: log_caux = log_caux[:, jnp.newaxis]
        return jnp.concatenate([beta0, log_tau, log_caux, z, log_lambda], axis=1)

    def reconstruct_parameters(self, flat_params: jnp.ndarray, template: Dict[str, Any]) -> Dict[str, Any]:
        D = template['z'].shape[-1]
        beta0 = flat_params[..., 0:1]
        log_tau = flat_params[..., 1:2]
        log_caux = flat_params[..., 2:3]
        z = flat_params[..., 3:3+D]
        log_lambda = flat_params[..., 3+D:]
        return {'beta0': beta0, 'log_tau': log_tau, 'log_caux': log_caux, 'z': z, 'log_lambda': log_lambda}


In [3]:
# -----------------------------------------------------
# 2. Load Data and Posterior
# -----------------------------------------------------
with importlib.resources.path('bayesianquilts.data', "overianx.csv") as xpath:
    X = pd.read_csv(xpath, header=None)
with importlib.resources.path('bayesianquilts.data', "overiany.csv") as ypath:
    y = pd.read_csv(ypath, header=None)
X_scaled = (X - X.mean())/X.std()
X_scaled = X_scaled.fillna(0)
X_np = X_scaled.to_numpy(dtype=float)
y_np = y.to_numpy(dtype=float).flatten()
n = X_np.shape[0]
p = X_np.shape[1]
guessnumrelevcov = n / 10
scale_global_val = guessnumrelevcov / ((p - guessnumrelevcov) * np.sqrt(n))
constants = {'slab_scale': 2.5, 'scale_icept': 5.0, 'nu_global': 1.0, 'nu_local': 1.0, 'slab_df': 1.0, 'scale_global': scale_global_val}
data_dict_full = {'X': X_np, 'y': y_np, 'constants': constants}

print("Loading Full Posterior...")
fname_0 = os.path.join(data_dir, "ovarian_loo_0.npy")
params_full = np.load(fname_0, allow_pickle=True).item()

# AGGRESSIVE SUBSAMPLE for stability
n_sub = 50
idx = np.arange(n_sub)

beta0 = jnp.asarray(np.squeeze(params_full['beta0'])[idx], dtype=jnp.float64)
z = jnp.asarray(np.squeeze(params_full['z'])[idx], dtype=jnp.float64)
tau = jnp.asarray(np.squeeze(params_full['tau'])[idx], dtype=jnp.float64)
lam = jnp.asarray(np.squeeze(params_full['lambda'])[idx], dtype=jnp.float64)
caux = jnp.asarray(np.squeeze(params_full['caux'])[idx], dtype=jnp.float64)

# Log transforms
log_tau = jnp.log(tau)
log_lambda = jnp.log(lam)
log_caux = jnp.log(caux)

params_jax = {
    'beta0': beta0, 
    'z': z, 
    'log_tau': log_tau, 
    'log_lambda': log_lambda, 
    'log_caux': log_caux
}


Loading Full Posterior...


In [None]:
# -----------------------------------------------------
# 3. Run Adaptive IS in Batches
# -----------------------------------------------------

model = OvarianReparametrizedLikelihood(data_dict_full, priors=None)
ais = AdaptiveImportanceSampler(likelihood_fn=model)
ais.model = model 

methods = ['identity', 'll', 'kl', 'var', 'mm1', 'mm2', 'pmm1', 'pmm2']
rhos = [1/2**i for i in range(1, 7)]

# Batching strategy
batch_size = 2
n_data = X_np.shape[0]
batches = range(0, n_data, batch_size)

accumulated_results = {m: [] for m in methods}

print(f"Running AIS with methods: {methods}, Batch Size: {batch_size}")

for start_idx in tqdm(batches, desc="Processing Batches"):
    end_idx = min(start_idx + batch_size, n_data)
    
    # Slice Data
    X_batch = X_np[start_idx:end_idx]
    y_batch = y_np[start_idx:end_idx]
    data_batch = {'X': X_batch, 'y': y_batch, 'constants': constants}
    
    try:
        batch_res = ais.adaptive_is_loo(
            data=data_batch,
            params=params_jax,
            rhos=rhos,
            variational=False,
            transformations=methods
        )
        
        for method in methods:
            if method in batch_res:
                accumulated_results[method].append(batch_res[method]['p_loo_psis'])
    except Exception as e:
        print(f"Batch {start_idx}-{end_idx} failed: {e}")

# Concatenate Results
final_results = {}
for method in methods:
    if accumulated_results[method]:
        concatenated = np.concatenate([np.asarray(a) for a in accumulated_results[method]])
        final_results[method] = {'p_loo_psis': concatenated}
        
print("AIS Batch Processing Complete.")


Running AIS with methods: ['identity', 'll', 'kl', 'var', 'mm1', 'mm2', 'pmm1', 'pmm2'], Batch Size: 2


Processing Batches:   0%|          | 0/27 [00:00<?, ?it/s]

In [6]:
# -----------------------------------------------------
# 4. Compare with Ground Truth
# -----------------------------------------------------

n_samples = X_np.shape[0]
loo_gt_elpd = np.zeros(n_samples)
missing = []

def log_bernoulli(y_true, f_logits):
    return y_true * f_logits - jnp.logaddexp(0.0, f_logits)
    
print("Computing Ground Truth...")

for i in tqdm(range(1, n_samples + 1), desc="Loading GT Folds"):
    fname = os.path.join(data_dir, f"ovarian_loo_{i}.npy")
    if os.path.exists(fname):
        try:
            p = np.load(fname, allow_pickle=True).item()
            xi = X_np[i-1]
            yi = y_np[i-1]
            beta_sample = np.squeeze(p['beta'])
            beta0_sample = np.squeeze(p['beta0'])
            f = beta0_sample + np.dot(beta_sample, xi)
            ll = log_bernoulli(yi, f)
            loo_gt_elpd[i-1] = special.logsumexp(ll) - np.log(len(ll))
        except Exception as e:
             missing.append(i)
    else: missing.append(i)

gt_total = loo_gt_elpd.sum()
# SE for Ground Truth (approximate using variance of single points against their mean)
# SE = sqrt(N) * std(loo_gt_elpd, ddof=1)
gt_se = np.sqrt(n_samples) * np.std(loo_gt_elpd, ddof=1)
print(f"Ground Truth Total LOO: {gt_total:.2f} +/- {gt_se:.2f}")

# Collect Results
summary = []
summary.append({'Method': 'Ground Truth', 'Total LOO': gt_total, 'SE': gt_se, 'Diff': 0})

for key, res in final_results.items():
    if key == 'best': continue
    
    # Note: final_results[key]['p_loo_psis'] should be shape (N_samples,)
    p_loo = res['p_loo_psis']
    
    if len(p_loo) == n_samples:
        # The sum is: sum(log(1/S sum(w))) for each i
        # p_loo vector contains p(y_i|y_-i)
        
        elpd_vec = np.log(p_loo + 1e-100) # Safety epsilon
        total = np.sum(elpd_vec)
        se = np.sqrt(n_samples) * np.std(elpd_vec, ddof=1)
        summary.append({
            'Method': key, 
            'Total LOO': total, 
            'SE': se,
            'Diff': total - gt_total
        })
    else:
        print(f"Method {key} has incomplete results: {len(p_loo)}/{n_samples}")

df = pd.DataFrame(summary)

# Visualization: Publication Quality Forest Plot
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({'font.size': 12, 'font.family': 'sans-serif'})

# Separation of groups
baselines = ['Ground Truth', 'identity', 'll']
proposed = ['kl', 'var', 'mm1', 'mm2', 'pmm1', 'pmm2']

# Order map: Proposed Top, Baselines Bottom. 
# Since we plot bottom-up, higher order map index = higher in plot.
# Let's put Baselines at the very bottom (Index 0..k)
# Proposed above them.
full_order = baselines + proposed
order_map = {name: i for i, name in enumerate(full_order)}
df['Order'] = df['Method'].map(order_map)
df = df.sort_values('Order', ascending=True) # Sort for consistent plotting logic below

plt.figure(figsize=(8, 4.5))

# Plot Loop
for i, row in df.iterrows():
    method = row['Method']
    y_val = row['Order'] # Use fixed y-positions based on group
    x_val = row['Total LOO']
    x_err = row['SE']
    
    # Style logic
    if method in proposed:
        color = '#D55E00' # Vermilion
        fmt = 'o'
    elif method == 'Ground Truth':
        color = 'black'
        fmt = 'D'
    else: # Other baselines
        color = '#0072B2' # Blue
        fmt = 's'
        
    plt.errorbar(
        x=x_val, 
        y=y_val, 
        xerr=x_err, 
        fmt=fmt, 
        color=color, 
        capsize=4, 
        markersize=6, 
        linewidth=1.5
    )
    
plt.axvline(gt_total, color='black', linestyle=':', alpha=0.6)

# Set Y-ticks based on what is in DF
yticks = sorted(df['Order'].unique())
yticklabels = [df[df['Order'] == y]['Method'].iloc[0] for y in yticks]
plt.yticks(yticks, yticklabels)

# Bold proposed methods
ax = plt.gca()
for tick in ax.get_yticklabels():
    if tick.get_text() in proposed:
        tick.set_fontweight('bold')
        
plt.xlabel('Total LOO-ELPD')
plt.grid(axis='y', alpha=0.0) # Only x grid
plt.grid(axis='x', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.savefig('ais_reparam_comparison.png', dpi=300, bbox_inches='tight')
plt.show()


Computing Ground Truth...


Loading GT Folds: 100%|██████████| 54/54 [00:03<00:00, 14.85it/s]


Ground Truth Total LOO: -14.65 +/- 3.35


NameError: name 'final_results' is not defined