# Bayesian Logistic Regression with ESS and MESS

This notebook implements the Bayesian logistic regression example from the original Murray et al. (2010) paper using Elliptical Slice Sampling (ESS) and MESS (Multiple-proposal ESS). Use it to compare ESS against the uniform, angular and Euclidean MESS variants. 

Key metrics tracked:
- Number of shrinking steps (intervals)
- Effective Sample Size (ESS)
- Mean Squared Jumping Distance (MSJD)
- Computation time

## Import Required Libraries

In [None]:
import sys
import os
import time

# Get absolute path to src directory (go up from notebooks to repo root)
repo_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
src_path = os.path.join(repo_root, 'src')
sys.path.insert(0, src_path)

print(f"Repo root: {repo_root}")
print(f"Added to path: {src_path}")

import numpy as np
import matplotlib.pyplot as plt
from mess.algorithms.ess import ess_step
from mess.algorithms.mess import mess_step
from mess.algorithms.effective_sample_size import (
    estimate_effective_sample_size,
    compute_mean_squared_jumping_distance,
    compute_normalized_jumping_distance
)
from mess.problems.logistic_regression import BayesianLogisticRegression
from mess.data.logistic_regression import generate_logistic_regression_data

## Generate Synthetic Logistic Regression Data

In [None]:
# Generate synthetic logistic regression data
data = generate_logistic_regression_data(
    n_samples=100,
    n_features=10,
    seed=42
)

X = data['X']
y = data['y']
beta_true = data['beta_true']

# Create problem
problem = BayesianLogisticRegression(X, y, prior_var=1.0)

print(f"Data shapes:")
print(f"  X: {X.shape}")
print(f"  y: {y.shape}")
print(f"  Number of class 0: {np.sum(y == 0)}")
print(f"  Number of class 1: {np.sum(y == 1)}")
print(f"\nTrue coefficients: {beta_true}")

# Initialize sampling
beta0 = problem.sample_prior(np.random.default_rng(seed=42))
print(f"\nInitial log-likelihood: {problem.log_likelihood(beta0):.4f}")

## Experiment Setup

In [None]:
# Sampler parameters
n_iters = 10000
burn_in = 1000
seed = 42

# MESS M values to test
M_values = [1, 2, 5, 10, 15, 20, 25, 30, 40, 50, 75, 100, 200]
M_values_filtered = [M for M in M_values if M <= 100]

print(f"Experiment setup:")
print(f"  n_iters: {n_iters}")
print(f"  burn_in: {burn_in}")
print(f"  M values: {M_values}")

## Run MCMC Sampling Experiments

Run MESS with uniform, angular, and euclidean variants for different M values.

In [None]:
# Store results for all variants
results = {
    'uniform': {'intervals': {}, 'chains': {}, 'times': [], 'statistics': {}},
    'angular': {'intervals': {}, 'chains': {}, 'times': [], 'statistics': {}},
    'euclidean': {'intervals': {}, 'chains': {}, 'times': [], 'statistics': {}},
    'ess': {'uniform': {}, 'angular': {}, 'euclidean': {}},
    'msjd': {'uniform': {}, 'angular': {}, 'euclidean': {}},
    'log_likelihood': {'uniform': {}, 'angular': {}, 'euclidean': {}},
}

# Run MESS Uniform
print(f"Running MESS Uniform...")
for M in M_values:
    print(f"  M = {M}")
    rng = np.random.default_rng(seed)
    chain = np.zeros((n_iters, X.shape[1]))
    intervals = np.zeros(n_iters, dtype=int)
    x = beta0.copy()

    t0 = time.time()
    for t in range(n_iters):
        x, nr_intervals, _ = mess_step(x, problem, rng, M=M, use_lp=False)
        chain[t] = x
        intervals[t] = nr_intervals
    elapsed = time.time() - t0

    results['uniform']['chains'][M] = chain
    results['uniform']['intervals'][M] = intervals
    results['uniform']['times'].append(elapsed)

    # Compute statistics
    intervals_post = intervals[burn_in:]
    results['uniform']['statistics'][M] = {
        'mean_intervals': np.mean(intervals_post),
        'std_intervals': np.std(intervals_post),
        'median_intervals': np.median(intervals_post),
    }

    print(f"    Time: {elapsed:.2f}s, Mean nr. intervals: {np.mean(intervals_post):.4f}")

# NOTE: Angular/Euclidean runs are disabled for uniform-only analysis.
#
# # Run MESS Angular
# print(f"\nRunning MESS Angular...")
# for M in M_values_filtered:
#     print(f"  M = {M}")
#     rng = np.random.default_rng(seed)
#     chain = np.zeros((n_iters + 1, X.shape[1]))
#     chain[0] = beta0.copy()
#     intervals = np.zeros(n_iters, dtype=int)
#     x = beta0.copy()
#
#     t0 = time.time()
#     for t in range(n_iters):
#         x, nr_intervals, _ = mess_step(x, problem, rng, M=M, use_lp=True,
#                                        distance_metric='angular', lam=0.05)
#         chain[t + 1] = x
#         intervals[t] = nr_intervals
#     elapsed = time.time() - t0
#
#     results['angular']['chains'][M] = chain
#     results['angular']['intervals'][M] = intervals
#     results['angular']['times'].append(elapsed)
#
#     # Compute statistics
#     intervals_post = intervals[burn_in + 1:]
#     results['angular']['statistics'][M] = {
#         'mean_intervals': np.mean(intervals_post),
#         'std_intervals': np.std(intervals_post),
#         'median_intervals': np.median(intervals_post),
#     }
#
#     print(f"    Time: {elapsed:.2f}s, Mean intervals: {np.mean(intervals_post):.4f}")
#
# # Run MESS Euclidean
# print(f"\nRunning MESS Euclidean...")
# for M in M_values_filtered:
#     print(f"  M = {M}")
#     rng = np.random.default_rng(seed)
#     chain = np.zeros((n_iters + 1, X.shape[1]))
#     chain[0] = beta0.copy()
#     intervals = np.zeros(n_iters, dtype=int)
#     x = beta0.copy()
#
#     t0 = time.time()
#     for t in range(n_iters):
#         x, nr_intervals, _ = mess_step(x, problem, rng, M=M, use_lp=True,
#                                        distance_metric='euclidean', lam=0.05)
#         chain[t + 1] = x
#         intervals[t] = nr_intervals
#     elapsed = time.time() - t0
#
#     results['euclidean']['chains'][M] = chain
#     results['euclidean']['intervals'][M] = intervals
#     results['euclidean']['times'].append(elapsed)
#
#     # Compute statistics
#     intervals_post = intervals[burn_in + 1:]
#     results['euclidean']['statistics'][M] = {
#         'mean_intervals': np.mean(intervals_post),
#         'std_intervals': np.std(intervals_post),
#         'median_intervals': np.median(intervals_post),
#     }
#
#     print(f"    Time: {elapsed:.2f}s, Mean intervals: {np.mean(intervals_post):.4f}")

print(f"\nAll experiments completed!")

## Compute ESS and MSJD

In [None]:
max_lag = 1000

print(f"\nComputing ESS and MSJD")

# Compute ESS for Uniform
for M in M_values:
    chain = results['uniform']['chains'][M][burn_in:, :]
    ess_values = estimate_effective_sample_size(chain, max_lag=max_lag)
    msjd_values = compute_mean_squared_jumping_distance(chain)

    results['ess']['uniform'][M] = ess_values
    results['msjd']['uniform'][M] = msjd_values

# NOTE: Angular/Euclidean ESS and MSJD are disabled for uniform-only analysis.
#
# # Compute ESS for Angular
# for M in M_values_filtered:
#     chain = results['angular']['chains'][M][burn_in+1:, :]
#     ess_values = estimate_effective_sample_size(chain, max_lag=max_lag)
#     msjd_values = compute_mean_squared_jumping_distance(chain)
#
#     results['ess']['angular'][M] = ess_values
#     results['msjd']['angular'][M] = msjd_values
#
# # Compute ESS for Euclidean
# for M in M_values_filtered:
#     chain = results['euclidean']['chains'][M][burn_in+1:, :]
#     ess_values = estimate_effective_sample_size(chain, max_lag=max_lag)
#     msjd_values = compute_mean_squared_jumping_distance(chain)
#
#     results['ess']['euclidean'][M] = ess_values
#     results['msjd']['euclidean'][M] = msjd_values

# Compute log-likelihood for all chains
for M in M_values:
    # Uniform
    chain_uniform = results['uniform']['chains'][M]
    ll_uniform = np.array([problem.log_likelihood(chain_uniform[t, :]) for t in range(n_iters)])
    results['log_likelihood']['uniform'][M] = ll_uniform

    # Angular and Euclidean (only for M > 1)
    # if M > 1 and M in M_values_filtered:
    #     chain_angular = results['angular']['chains'][M]
    #     ll_angular = np.array([problem.log_likelihood(chain_angular[t, :]) for t in range(n_iters + 1)])
    #     results['log_likelihood']['angular'][M] = ll_angular
    #
    #     chain_euclidean = results['euclidean']['chains'][M]
    #     ll_euclidean = np.array([problem.log_likelihood(chain_euclidean[t, :]) for t in range(n_iters + 1)])
    #     results['log_likelihood']['euclidean'][M] = ll_euclidean

print("ESS, MSJD, and log-likelihood computation completed!")

## Figure 1: Effective Sample Size and MSJD Boxplots

Compare ESS and MSJD distributions across algorithms (Uniform, Angular, Euclidean) for different M values.

In [None]:
# Create 1x2 figure: ESS (left) and MSJD (right)
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
fig.suptitle('Bayesian Logistic Regression: ESS and MSJD (Uniform Only)', fontsize=16, fontweight='bold')

# Prepare boxplot data
uniform_ess_box = []
# angular_ess_box = []
# euclidean_ess_box = []
uniform_msjd_box = []
# angular_msjd_box = []
# euclidean_msjd_box = []
M_box = []

for M in M_values:
    uniform_ess_box.append(results['ess']['uniform'][M])
    uniform_msjd_box.append(results['msjd']['uniform'][M])
    M_box.append(M)

    # NOTE: Angular/Euclidean data disabled for uniform-only analysis.
    # if M > 1 and M in M_values_filtered:
    #     angular_ess_box.append(results['ess']['angular'][M])
    #     euclidean_ess_box.append(results['ess']['euclidean'][M])
    #     angular_msjd_box.append(results['msjd']['angular'][M])
    #     euclidean_msjd_box.append(results['msjd']['euclidean'][M])

# ===== ESS BOXPLOT (LEFT) =====
ax = axes[0]

# Positions
positions_uniform = list(range(len(M_box)))
# positions_angular = []
# positions_euclidean = []

# Create boxplots
bp1 = ax.boxplot(uniform_ess_box, positions=positions_uniform, widths=0.6,
                 patch_artist=True, label='Uniform')
# bp2 = ax.boxplot(angular_ess_box, positions=positions_angular, widths=0.6,
#                   patch_artist=True, label='Angular')
# bp3 = ax.boxplot(euclidean_ess_box, positions=positions_euclidean, widths=0.6,
#                   patch_artist=True, label='Euclidean')

# Color boxes
for patch in bp1['boxes']:
    patch.set_facecolor('#1f77b4')
    patch.set_alpha(0.7)
# for patch in bp2['boxes']:
#     patch.set_facecolor('#ff7f0e')
#     patch.set_alpha(0.7)
# for patch in bp3['boxes']:
#     patch.set_facecolor('#2ca02c')
#     patch.set_alpha(0.7)

# X-axis labels
x_ticks = positions_uniform
x_labels = [str(int(M)) for M in M_box]

# Previous comparison-based labels (disabled)
# x_ticks = []
# x_labels = []
# for i, M in enumerate(M_box):
#     if M == 1:
#         x_ticks.append(positions_uniform[i])
#         x_labels.append(f"M={int(M)}\n(ESS)")
#     else:
#         j = len([m for m in M_box[:i+1] if m > 1]) - 1
#         avg_pos = (positions_uniform[i] + positions_angular[j] + positions_euclidean[j]) / 3
#         x_ticks.append(avg_pos)
#         x_labels.append(str(int(M)))

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_labels)
ax.set_xlabel('M (Number of Proposals)', fontsize=12, fontweight='bold')
ax.set_ylabel('Effective Sample Size (ESS)', fontsize=12, fontweight='bold')
ax.set_title('ESS', fontsize=13, fontweight='bold')
ax.legend(fontsize=10, loc='best')
ax.grid(True, alpha=0.3, axis='y')

# ===== MSJD BOXPLOT (RIGHT) =====
ax = axes[1]

# Positions
positions_uniform_m = list(range(len(M_box)))
# positions_angular_m = []
# positions_euclidean_m = []

# Create boxplots
bp1 = ax.boxplot(uniform_msjd_box, positions=positions_uniform_m, widths=0.6,
                 patch_artist=True, label='Uniform')
# bp2 = ax.boxplot(angular_msjd_box, positions=positions_angular_m, widths=0.6,
#                   patch_artist=True, label='Angular')
# bp3 = ax.boxplot(euclidean_msjd_box, positions=positions_euclidean_m, widths=0.6,
#                   patch_artist=True, label='Euclidean')

# Color boxes
for patch in bp1['boxes']:
    patch.set_facecolor('#1f77b4')
    patch.set_alpha(0.7)
# for patch in bp2['boxes']:
#     patch.set_facecolor('#ff7f0e')
#     patch.set_alpha(0.7)
# for patch in bp3['boxes']:
#     patch.set_facecolor('#2ca02c')
#     patch.set_alpha(0.7)

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_labels)
ax.set_xlabel('M (Number of Proposals)', fontsize=12, fontweight='bold')
ax.set_ylabel('Mean Squared Jumping Distance (MSJD)', fontsize=12, fontweight='bold')
ax.set_title('MSJD', fontsize=13, fontweight='bold')
ax.legend(fontsize=10, loc='best')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

In [None]:
# Save the last figure
fig.savefig('log_regression_ess_msjd_boxplots.png', dpi=600)


## Figure 2: Shrinking Steps and Computation Time

In [None]:
# Create 1x2 figure: Intervals and Time
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
fig.suptitle('Bayesian Logistic Regression: Shrinking Steps and Computation Time in log-scale (Uniform Only)', fontsize=16, fontweight='bold')

use_log_scale = True  # Set to True for log scale computation time

# Prepare data
uniform_intervals_box = []
# angular_intervals_box = []
# euclidean_intervals_box = []
M_box = []

for M in M_values:
    intervals = results['uniform']['intervals'][M][burn_in:]
    uniform_intervals_box.append(intervals)
    M_box.append(M)

    # NOTE: Angular/Euclidean data disabled for uniform-only analysis.
    # if M > 1 and M in M_values_filtered:
    #     intervals_ang = results['angular']['intervals'][M][burn_in+1:]
    #     intervals_euc = results['euclidean']['intervals'][M][burn_in+1:]
    #     angular_intervals_box.append(intervals_ang)
    #     euclidean_intervals_box.append(intervals_euc)

# ===== SHRINKING STEPS BOXPLOT (LEFT) =====
ax = axes[0]

# Positions
positions_uniform = list(range(len(M_box)))
# positions_angular = []
# positions_euclidean = []

# Create boxplots
bp1 = ax.boxplot(uniform_intervals_box, positions=positions_uniform, widths=0.6,
                 patch_artist=True, label='Uniform')
# bp2 = ax.boxplot(angular_intervals_box, positions=positions_angular, widths=0.6,
#                   patch_artist=True, label='Angular')
# bp3 = ax.boxplot(euclidean_intervals_box, positions=positions_euclidean, widths=0.6,
#                   patch_artist=True, label='Euclidean')

# Color boxes
for patch in bp1['boxes']:
    patch.set_facecolor('#1f77b4')
    patch.set_alpha(0.7)
# for patch in bp2['boxes']:
#     patch.set_facecolor('#ff7f0e')
#     patch.set_alpha(0.7)
# for patch in bp3['boxes']:
#     patch.set_facecolor('#2ca02c')
#     patch.set_alpha(0.7)

# X-axis labels
x_ticks = positions_uniform
x_labels = [str(int(M)) for M in M_box]

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_labels)
ax.set_xlabel('M (Number of Proposals)', fontsize=12, fontweight='bold')
ax.set_ylabel('Number of Shrinking Steps', fontsize=12, fontweight='bold')
ax.set_title('Shrinking Steps', fontsize=13, fontweight='bold')
ax.legend(fontsize=10, loc='best')
ax.grid(True, alpha=0.3, axis='y')

# ===== COMPUTATION TIME (RIGHT) =====
ax = axes[1]

# Prepare time data
M_for_time = list(M_values)
uniform_times = results['uniform']['times']
# angular_times = results['angular']['times']
# euclidean_times = results['euclidean']['times']

# Bar plot positions
x_pos = np.arange(len(M_for_time))
width = 0.5

# Create bars
ax.bar(x_pos, uniform_times, width, label='Uniform', color='#1f77b4', alpha=0.8)

# Only plot angular and euclidean for M > 1 (disabled)
# angular_times_aligned = []
# euclidean_times_aligned = []
# for M in M_for_time:
#     if M > 1:
#         idx = M_values_filtered.index(M)
#         angular_times_aligned.append(angular_times[idx])
#         euclidean_times_aligned.append(euclidean_times[idx])
#     else:
#         angular_times_aligned.append(np.nan)
#         euclidean_times_aligned.append(np.nan)
#
# ax.bar(x_pos, angular_times_aligned, width, label='Angular',
#        color='#ff7f0e', alpha=0.8)
# ax.bar(x_pos + width, euclidean_times_aligned, width, label='Euclidean',
#        color='#2ca02c', alpha=0.8)

ax.set_xticks(x_pos)
ax.set_xticklabels([str(int(M)) for M in M_for_time])
ax.set_xlabel('M (Number of Proposals)', fontsize=12, fontweight='bold')
ax.set_ylabel('Log-Computation Time (log-seconds)', fontsize=12, fontweight='bold')
ax.set_title('Log-Computation Time', fontsize=13, fontweight='bold')

# Apply log scale if requested
if use_log_scale:
    ax.set_yscale('log')

ax.legend(fontsize=10, loc='best')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

## Figure 2b: Uniform Shrinking vs Likelihood Evaluations
Compare uniform MESS shrinking steps with serial and parallelized likelihood evaluations as a function of $M$.

In [None]:
# Uniform-only comparison: shrinking steps vs likelihood evaluations
M_plot = [M for M in M_values if M <= 200]
mean_shrinks = []
mean_serial_evals = []
mean_parallel_evals = []

for M in M_plot:
    intervals_post = results['uniform']['intervals'][M][burn_in:]
    mean_intervals = float(np.mean(intervals_post))
    mean_shrinks.append(mean_intervals)
    # Each loop evaluates M likelihoods; one extra loop beyond nr_intervals
    mean_serial = M * (mean_intervals + 1)
    mean_parallel = mean_serial / M
    mean_serial_evals.append(mean_serial)
    mean_parallel_evals.append(mean_parallel)

fig, ax = plt.subplots(figsize=(10, 4.5))

# Grayscale and different line styles/widths instead of colors
ax.plot(
    M_plot,
    mean_shrinks,
    marker='o',
    markersize=8,
    linewidth=2.5,
    linestyle='-',
    color='black',
    label='Nr. shrinking steps',
)
ax.plot(
    M_plot,
    mean_serial_evals,
    marker='s',
    markersize=7,
    linewidth=2.0,
    linestyle='-.',
    color='gray',
    label='Nr. likelihood evals (serial)',
)
# ax.plot(
#     M_plot,
#     mean_parallel_evals,
#     marker='^',
#     markersize=7,
#     linewidth=1.5,
#     linestyle='--',
#     color='black',
#     label='Parallelized likelihood evals (theoretical)',
# )

# Inset zoom for M = 1, 2, 5, 10
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

inset_M = [M for M in [1, 2, 5, 10] if M in M_plot]
inset_shrinks = [mean_shrinks[M_plot.index(M)] for M in inset_M]
inset_serial = [mean_serial_evals[M_plot.index(M)] for M in inset_M]

inset_ax = inset_axes(ax, width='40%', height='55%', loc='center right', borderpad=2)
inset_ax.plot(
    inset_M,
    inset_shrinks,
    marker='o',
    markersize=5,
    linewidth=1.8,
    linestyle='-',
    color='black',
    label='Nr. shrinking steps',
)
inset_ax.plot(
    inset_M,
    inset_serial,
    marker='s',
    markersize=4,
    linewidth=1.5,
    linestyle='-.',
    color='gray',
    label='Nr. likelihood evals',
)
inset_ax.set_xlim(0.5, 10.5)
inset_ax.grid(True, alpha=0.3)
inset_ax.set_title('Zoom: $M\leq10$', fontsize=16, loc='left')
inset_ax.tick_params(labelsize=12)

ax.set_xlabel('M (Number of Proposals)', fontsize=16)
ax.set_ylabel('Mean count per iteration', fontsize=16)
# ax.set_title('Shrinking Steps and Likelihood Evaluations', fontsize=18)
ax.grid(True, alpha=0.3)
ax.legend(fontsize=16, loc='best')
ax.tick_params(labelsize=12)

# # Add ESS label for M=1 with arrow
# ax.annotate('ESS', 
#             xy=(1, mean_shrinks[0]),
#             xytext=(1, ax.get_ylim()[0] - 0.15*(ax.get_ylim()[1]-ax.get_ylim()[0])),
#             ha='center', fontsize=12, color='red',
#             arrowprops={'arrowstyle': '->', 'color': 'red', 'lw': 1.5})

plt.tight_layout()
plt.show()

## Figure 3: Trace Plots for M=1 (ESS)

Visualize the chains for the ESS sampler to assess convergence.

In [None]:
# Trace plots for first 5 features for M=1 (ESS)
chain_ess = results['uniform']['chains'][1]

fig, axes = plt.subplots(5, 1, figsize=(14, 10))
fig.suptitle('ESS (M=1) Trace Plots - First 5 Coefficients', fontsize=14, fontweight='bold')

for i in range(5):
    ax = axes[i]
    ax.plot(chain_ess[:, i], alpha=0.8, linewidth=0.5)
    ax.axvline(x=burn_in, color='red', linestyle='--', label=f'Burn-in')
    ax.set_ylabel(f'Î²_{i}', fontsize=11, fontweight='bold')
    ax.grid(True, alpha=0.3)
    if i == 0:
        ax.legend()

axes[-1].set_xlabel('Iteration', fontsize=11, fontweight='bold')
plt.tight_layout()
plt.show()

## Log-Likelihood Traceplots

Examine how the log-likelihood evolves across iterations for different methods to assess convergence behavior.

In [None]:
# Create log-likelihood traceplots for each M value (plot only until burnin)
fig, axes = plt.subplots(len(M_values), 1, figsize=(14, 4*len(M_values)))
if len(M_values) == 1:
    axes = [axes]

fig.suptitle('Bayesian Logistic Regression: Log-Likelihood Traceplots (Uniform Only, Until Burnin)',
             fontsize=14, fontweight='bold', y=0.995)

for m_idx, M in enumerate(M_values):
    ax = axes[m_idx]

    # Plot log-likelihood for uniform only (up to burnin)
    ll_uniform = results['log_likelihood']['uniform'][M][:burn_in]
    ax.plot(ll_uniform, linewidth=1, alpha=0.8, label='Uniform', color='#1f77b4')

    # Angular/Euclidean curves disabled for uniform-only analysis.
    # if M > 1 and M in M_values_filtered:
    #     ll_angular = results['log_likelihood']['angular'][M][:burn_in]
    #     ll_euclidean = results['log_likelihood']['euclidean'][M][:burn_in]
    #     ax.plot(ll_angular, linewidth=1, alpha=0.8, label='Angular', color='#ff7f0e')
    #     ax.plot(ll_euclidean, linewidth=1, alpha=0.8, label='Euclidean', color='#2ca02c')

    ax.set_ylabel(f'M={M}\nLog-likelihood', fontsize=10, fontweight='bold')
    if m_idx == len(M_values) - 1:
        ax.set_xlabel('Iteration', fontsize=10)
    ax.set_title(f'M = {M}', fontsize=11, fontweight='bold')
    ax.grid(True, alpha=0.3)
    if m_idx == 0:
        ax.legend(fontsize=9, loc='best')

plt.tight_layout()
plt.show()

## Log-Likelihood ESS and MSJD Comparison

In [None]:
# Compute ESS and MSJD for log-likelihood
ll_ess = {'uniform': {}, 'angular': {}, 'euclidean': {}}
ll_msjd = {'uniform': {}, 'angular': {}, 'euclidean': {}}

# Uniform
for M in M_values:
    ll = results['log_likelihood']['uniform'][M][burn_in:]
    ll_ess['uniform'][M] = estimate_effective_sample_size(ll.reshape(-1, 1), max_lag=max_lag)[0]
    ll_msjd['uniform'][M] = compute_mean_squared_jumping_distance(ll.reshape(-1, 1))[0]

# Angular and Euclidean (disabled)
# for M in M_values_filtered:
#     # Angular
#     ll_ang = results['log_likelihood']['angular'][M][burn_in+1:]
#     ll_ess['angular'][M] = estimate_effective_sample_size(ll_ang.reshape(-1, 1), max_lag=max_lag)[0]
#     ll_msjd['angular'][M] = compute_mean_squared_jumping_distance(ll_ang.reshape(-1, 1))[0]
#
#     # Euclidean
#     ll_euc = results['log_likelihood']['euclidean'][M][burn_in+1:]
#     ll_ess['euclidean'][M] = estimate_effective_sample_size(ll_euc.reshape(-1, 1), max_lag=max_lag)[0]
#     ll_msjd['euclidean'][M] = compute_mean_squared_jumping_distance(ll_euc.reshape(-1, 1))[0]

# Create comparison plots (uniform only)
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

fig.suptitle('Bayesian Logistic Regression: Log-Likelihood ESS and MSJD (Uniform Only)',
             fontsize=16, fontweight='bold', y=0.995)

# Prepare data
M_box = list(M_values)
uniform_ess_box = [ll_ess['uniform'][M] for M in M_values]
uniform_msjd_box = [ll_msjd['uniform'][M] for M in M_values]

# ===== ESS PLOT (LEFT) =====
ax = axes[0]
ax.bar(range(len(M_box)), uniform_ess_box, width=0.6, label='Uniform',
       color='#1f77b4', alpha=0.8)

x_ticks = list(range(len(M_box)))
x_labels = [str(int(M)) for M in M_box]

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_labels)
ax.set_xlabel('M (Number of Proposals)', fontsize=12, fontweight='bold')
ax.set_ylabel('Log-Likelihood ESS', fontsize=12, fontweight='bold')
ax.set_title('ESS', fontsize=13, fontweight='bold')
ax.legend(fontsize=11, loc='best')
ax.grid(True, alpha=0.3, axis='y')

# ===== MSJD PLOT (RIGHT) =====
ax = axes[1]
ax.bar(range(len(M_box)), uniform_msjd_box, width=0.6, label='Uniform',
       color='#1f77b4', alpha=0.8)

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_labels)
ax.set_xlabel('M (Number of Proposals)', fontsize=12, fontweight='bold')
ax.set_ylabel('Log-Likelihood MSJD', fontsize=12, fontweight='bold')
ax.set_title('MSJD', fontsize=13, fontweight='bold')
ax.legend(fontsize=11, loc='best')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

# Previous multi-method comparison (disabled)
# angular_ess_box = [ll_ess['angular'][M] for M in M_values_filtered]
# euclidean_ess_box = [ll_ess['euclidean'][M] for M in M_values_filtered]
# angular_msjd_box = [ll_msjd['angular'][M] for M in M_values_filtered]
# euclidean_msjd_box = [ll_msjd['euclidean'][M] for M in M_values_filtered]
#
# positions_uniform = []
# positions_angular = []
# positions_euclidean = []
# pos_counter = 0
#
# for i, M in enumerate(M_box):
#     positions_uniform.append(pos_counter)
#     pos_counter += 1
#
#     if M > 1:
#         positions_angular.append(pos_counter)
#         pos_counter += 1
#         positions_euclidean.append(pos_counter)
#         pos_counter += 1
#
#     pos_counter += 1
#
# # ===== ESS PLOT (LEFT) =====
# ax = axes[0]
# ax.bar(positions_uniform, uniform_ess_box, width=0.25, label='Uniform',
#        color='#1f77b4', alpha=0.8)
# if angular_ess_box:
#     ax.bar(positions_angular, angular_ess_box, width=0.25, label='Angular',
#            color='#ff7f0e', alpha=0.8)
# if euclidean_ess_box:
#     ax.bar(positions_euclidean, euclidean_ess_box, width=0.25, label='Euclidean',
#            color='#2ca02c', alpha=0.8)
#
# x_ticks = []
# x_labels = []
# for i, M in enumerate(M_box):
#     if M == 1:
#         x_ticks.append(positions_uniform[i])
#         x_labels.append(f"M={int(M)}")
#     else:
#         j = len([m for m in M_box[:i+1] if m > 1]) - 1
#         avg_pos = (positions_uniform[i] + positions_angular[j] + positions_euclidean[j]) / 3
#         x_ticks.append(avg_pos)
#         x_labels.append(str(int(M)))

In [None]:
# Generate LaTeX table for ESS and MSJD
M_values_table = [10, 20, 50]
data = []

print("Extracting data for LaTeX table:")
for M in M_values_table:
    ess = np.mean(results['ess']['uniform'][M])
    msjd = np.mean(results['msjd']['uniform'][M])
    data.append((M, ess, msjd))
    print(f"M={M}: ESS={ess:.4f}, MSJD={msjd:.6f}")

# Generate LaTeX table code
latex_code = r"""\begin{table}[h]
    \centering
    \begin{tabular}{|c|c|c|}
        \hline
        $M$ & ESS & MSJD \\
        \hline
"""

for M, ess, msjd in data:
    latex_code += f"        {M} & {ess:.4f} & {msjd:.6f} \\\\\n"

latex_code += r"""        \hline
    \end{tabular}
    \caption{Effective Sample Size (ESS) and Mean Squared Jumping Distance (MSJD) for different values of M}
    \label{tab:ess_msjd}
\end{table}
"""

print("\n" + "="*60)
print("LaTeX Code:")
print("="*60)
print(latex_code)
