### Example E.5.3. BAYESIAN STUDENT-T REGRESSION from https://openreview.net/pdf?id=HltJfwwfhX

In [None]:
import pystan
import pandas as pd
import numpy as np
import requests

from rpy2.robjects import pandas2ri
from rpy2.robjects.conversion import rpy2py
import rpy2.robjects as ro
import json
import jax
from collections import OrderedDict

from scipy.special import logsumexp

from amis_algorithms import alpha_AMIS_fixed_dof, AMIS_student_fixed_dof

import bridgestan

from utils import old_ksd

import matplotlib.pyplot as plt

# Enable LaTeX for nicer plotting
plt.rc('text', usetex=True)
plt.rc('font', family='serif')


# from tueplots import bundles
# plt.rcParams.update(bundles.aistats2023())

from experiments_amis import run_AMIS_real_dataset


# Load and prepare the dataset
url = "https://github.com/faosorios/heavy/blob/master/data/creatinine.rda?raw=true"
with requests.get(url) as resp:
    with open("creatinine.rda", "wb") as f:
        f.write(resp.content)

# Load RDA file into Python
ro.r['load']("creatinine.rda")
df = pandas2ri.rpy2py_dataframe(ro.r['creatinine'])

data_df = pd.DataFrame(columns=['log_SC', 'log_WT', 'log_140_minus_A', 'log_CR'])

# Apply transformations following https://openreview.net/pdf?id=HltJfwwfhX
data_df['log_SC'] = np.log(df['SC'])
data_df['log_WT'] = np.log(df['WT'])
data_df['log_CR'] = np.log(df['CR'])
data_df['log_140_minus_A'] = np.log(140 - df['Age'])
data_df = data_df.dropna() # remove any rows with NaN values after transformation

# Compile the Stan model
sm = pystan.StanModel(file="./student_reg_model.stan")

# Prepare data for Stan model
data_for_stan = {
    'N': len(data_df),
    'x1': data_df['log_SC'].values.tolist(),
    'x2': data_df['log_WT'].values.tolist(),
    'x3': data_df['log_140_minus_A'].values.tolist(),
    'y': data_df['log_CR'].values.tolist()  # response variable
}

# Save the data dictionary to a JSON file
with open("student_regression_data.json", "w") as f:
    json.dump(data_for_stan, f, indent=4)


# Fit the model and sample from the posterior using NUTS (NUTS paper: https://arxiv.org/abs/1111.4246)
fit = sm.sampling(data=data_for_stan, iter=100, chains=1)
mcmc_samples = fit.extract()

stan = "./student_reg_model.stan"
data = "./student_regression_data.json"
bridgestan_model = bridgestan.StanModel.from_stan_file(stan, data)

true_log_pdf = fit.log_prob

# Step 3: Find the MAP solution
map_sol = sm.optimizing(data=data_for_stan)

# Retrieve the values, extract the single element from each array, and convert to an ndarray
map_sol_array = np.array([value.item() for value in map_sol.values()])

map_sol_list = list(map_sol.values())
log_dens_at_map, _, hessian_at_map = bridgestan_model.log_density_hessian(theta_unc=map_sol_array, propto=True)

dim = 4 # Fixed
dof_proposal = 3
mu_initial_proposal_laplace = map_sol_array

assert np.isclose(log_dens_at_map, fit.log_prob(map_sol_list))

# Negative inverse of the Hessian at the MAP solution used as covariance
cov_laplace = -np.linalg.inv(hessian_at_map)
assert np.all(np.linalg.eigvals(cov_laplace) > 0)
shape_initial_proposal_laplace = (dof_proposal - 2) / (dof_proposal) * cov_laplace


sigma_initial = 10
shape_initial = (dof_proposal - 2) / (dof_proposal) * sigma_initial * np.identity(dim)


num_samples = int(1e5)
n_iter = 25
nb_runs = 100

# random_mu_initial = np.random.multivariate_normal(mean=np.zeros(dim), cov=np.identity(dim), size=1)
# random_mu_initial = np.random.uniform(-1, 1, dim)
# shape_initial_proposal = (dof_proposal - 2) / (dof_proposal) * np.identity(dim)
mu_initial = np.ones(dim)

assert np.all(np.linalg.eigvals(shape_initial_proposal_laplace) > 0)

mean_Z_baseline, std_Z_baseline, mean_ESS_baseline, mean_alphaESS_baseline, std_ESS_baseline, std_alphaESS_baseline = run_AMIS_real_dataset(alg=AMIS_student_fixed_dof, nb_runs=nb_runs, n_iterations=n_iter, log_pi_tilde=true_log_pdf, dof_proposal=dof_proposal, M=num_samples, d=dim, mu_initial=mu_initial, shape_initial=shape_initial)


mean_Z, std_Z, mean_ESS, mean_alphaESS, std_ESS, std_alphaESS = run_AMIS_real_dataset(alg=alpha_AMIS_fixed_dof, nb_runs=nb_runs, n_iterations=n_iter, log_pi_tilde=true_log_pdf, dof_proposal=dof_proposal, M=num_samples, d=dim, mu_initial=mu_initial, shape_initial=shape_initial)

# Last key is the log probability, which we don't want
exclude_key = "lp__"
mcmc_samples = OrderedDict((k, v) for k, v in mcmc_samples.items() if k != exclude_key)

mcmc_samples_array = np.vstack(list(mcmc_samples.values())).T

### Plot results

In [None]:
# Define confidence interval multiplier for 95% confidence
ci_multiplier = 1.96

# Create the first plot
plt.figure(figsize=(10, 6))
plt.errorbar(range(len(mean_Z_baseline)), mean_Z_baseline, yerr=std_Z_baseline * ci_multiplier, fmt='o', label='$Z_{baseline}$')
plt.title(r'Mean $Z_{baseline} \pm 1.96 \sigma$')
plt.xlabel('Index')
plt.ylabel('Value')
plt.legend()
plt.show()

# Create the second plot
plt.figure(figsize=(10, 6))
plt.errorbar(range(len(mean_ESS_baseline)), mean_ESS_baseline, yerr=std_ESS_baseline * ci_multiplier, fmt='o', label='$ESS_{baseline}$')
plt.errorbar(range(len(mean_alphaESS_baseline)), mean_alphaESS_baseline, yerr=std_alphaESS_baseline * ci_multiplier, fmt='s', label='$\\alpha ESS_{baseline}$')
plt.title(r'Mean $ESS_{baseline} \pm 1.96 \sigma$ and Mean $\\alpha ESS_{baseline} \pm 1.96 \sigma$')
plt.xlabel('Index')
plt.ylabel('Value')
plt.legend()
plt.show()


In [None]:
# def numpy_callback(x):
#   # Need to forward-declare the shape & dtype of the expected output.
#   result_shape = jax.core.ShapedArray(x.shape, x.dtype)
#   return jax.pure_callback(np.sin, result_shape, x)

def log_density_gradient_correct(theta):
    return bridgestan_model.log_density_gradient(theta)[1]

def log_density_gradient(theta):
    result_shape = jax.core.ShapedArray(theta.shape , theta.dtype)
    # _, gradient = bridgestan_model.log_density_gradient(theta)
    gradient = jax.experimental.io_callback(log_density_gradient_correct, result_shape, theta)
    return gradient

In [None]:
jax.config.update("jax_enable_x64", True)
# compare results with MCMC via the KSD
ksd_mcmc_samples = old_ksd(mcmc_samples_array, log_density_gradient)
print("KSD using true samples:", ksd_mcmc_samples)



In [None]:
from scipy.special import logsumexp

final_samples = adapted_proposal.rvs(size=50000)

weights = bridgestan_model.log_density(final_samples) - adapted_proposal.logpdf(final_samples)
normalized_weights = np.exp(weights - logsumexp(weights))
ksd_fixed_dof = old_ksd(final_samples, log_density_gradient, weights=normalized_weights)
print("KSD for adapted proposal samples:", ksd_fixed_dof)
