In [None]:
#The file for the simple test cases used in the paper. 
#Please cite the paper, if you use the code or parts of it, as writing the code took a lot of time and effort. 
#Staicova D., Universe 2025, 11(2), 68; https://doi.org/10.3390/universe11020068, arXiv:2501.06022
# !in order to reproduce the plots, you need to run it as the github version doesn't load the original results
import numpy as np
import emcee
import pymc as pm
import numpyro
import dynesty
import jax.numpy as jnp
from jax import random
import time
import arviz as az
import traceback
from numpyro.infer import NUTS, MCMC
import threading
import psutil
import numpy as np
import matplotlib.pyplot as plt
import arviz as az
import time
import pymc as pm
import emcee
import numpyro
import dynesty
from jax import random
import jax.numpy as jnp
import numpy as np
import time
import arviz as az
import memory_profiler
from functools import wraps
from mpi4py import MPI


class TestProblems:
    @staticmethod
    def correlated_gaussian(ndim):
        """Correlated Gaussian test problem"""
        def log_like(x):
            # Check if input is a PyMC variable
            if hasattr(x, 'type'):
                # PyMC version
                import pymc as pm
                return -0.5 * pm.math.sum(x**2)
            else:
                # Regular numpy version
                x = np.asarray(x)
                return float(-0.5 * np.sum(x**2))

        def jax_log_like(x):
            return -0.5 * jnp.sum(x**2)

        return log_like, jax_log_like

    @staticmethod
    def rosenbrock(ndim):
        """Rosenbrock banana test problem"""
        def log_like(x):
            if hasattr(x, 'type'):
                # PyMC version
                import pymc as pm
                term1 = 100.0 * (x[1:] - x[:-1]**2.0)**2.0
                term2 = (1 - x[:-1])**2.0
                return -pm.math.sum(term1 + term2)
            else:
                # Regular numpy version
                x = np.asarray(x)
                return float(-np.sum(100.0 * (x[1:] - x[:-1]**2.0)**2.0 + (1 - x[:-1])**2.0))

        def jax_log_like(x):
            return -jnp.sum(100.0 * (x[1:] - x[:-1]**2.0)**2.0 + (1 - x[:-1])**2.0)

        return log_like, jax_log_like

    @staticmethod
    def gaussian_mixture(ndim):
        """Gaussian mixture test problem"""
        def log_like(x):
            if hasattr(x, 'type'):
                # PyMC version
                import pymc as pm
                log_prob1 = -0.5 * pm.math.sum((x + 2)**2)
                log_prob2 = -0.5 * pm.math.sum((x - 2)**2)
                return pm.math.maximum(log_prob1, log_prob2)
            else:
                # Regular numpy version
                x = np.asarray(x)
                log_prob1 = -0.5 * np.sum((x + 2)**2)
                log_prob2 = -0.5 * np.sum((x - 2)**2)
                return float(np.logaddexp(log_prob1, log_prob2))

        def jax_log_like(x):
            log_prob1 = -0.5 * jnp.sum((x + 2)**2)
            log_prob2 = -0.5 * jnp.sum((x - 2)**2)
            return jnp.logaddexp(log_prob1, log_prob2)

        return log_like, jax_log_like


def measure_memory(func):
    """Decorator to measure memory usage"""
    @wraps(func)
    def wrapper(*args, **kwargs):
        mem_before = memory_profiler.memory_usage()[0]
        result = func(*args, **kwargs)
        mem_after = memory_profiler.memory_usage()[0]
        return result, mem_after - mem_before
    return wrapper



In [None]:
### full code

In [None]:
import numpy as np
import emcee
import pymc as pm
import numpyro
import dynesty
import jax.numpy as jnp
from jax import random
import time
import arviz as az
import traceback
from numpyro.infer import NUTS, MCMC
import threading
import psutil

class MemoryTracker:
    def __init__(self, sampling_interval=0.1):
        self.baseline_memory = None
        self.peak_memory = {'rss': 0, 'vms': 0, 'shared': 0}
        self.start_memory = None
        self.end_memory = None
        self.sampling_interval = sampling_interval
        self.is_tracking = False
        self.tracking_thread = None
        
    def _get_memory_usage(self):
        current = psutil.Process()
        memory_info = current.memory_info()
        children_memory = {'rss': 0, 'vms': 0, 'shared': 0}
        try:
            for child in current.children(recursive=True):
                try:
                    child_mem = child.memory_info()
                    children_memory['rss'] += child_mem.rss
                    children_memory['vms'] += child_mem.vms
                    children_memory['shared'] += getattr(child_mem, 'shared', 0)
                except (psutil.NoSuchProcess, psutil.AccessDenied):
                    continue
        except psutil.Error:
            pass
            
        return {
            'rss': (memory_info.rss - self.baseline_memory.rss + children_memory['rss']) / (1024 * 1024),
            'vms': (memory_info.vms - self.baseline_memory.vms + children_memory['vms']) / (1024 * 1024),
            'shared': (getattr(memory_info, 'shared', 0) - getattr(self.baseline_memory, 'shared', 0) + 
                      children_memory['shared']) / (1024 * 1024),
            'children_rss': children_memory['rss'] / (1024 * 1024),
            'children_vms': children_memory['vms'] / (1024 * 1024)
        }
        
    def _track_memory(self):
        while self.is_tracking:
            current_memory = self._get_memory_usage()
            self.peak_memory['rss'] = max(self.peak_memory['rss'], current_memory['rss'])
            self.peak_memory['vms'] = max(self.peak_memory['vms'], current_memory['vms'])
            self.peak_memory['shared'] = max(self.peak_memory['shared'], current_memory['shared'])
            time.sleep(self.sampling_interval)
    
    def start(self):
        self.baseline_memory = psutil.Process().memory_info()
        self.start_memory = self._get_memory_usage()
        self.is_tracking = True
        self.tracking_thread = threading.Thread(target=self._track_memory)
        self.tracking_thread.start()
    
    def stop(self):
        self.is_tracking = False
        if self.tracking_thread:
            self.tracking_thread.join()
        self.end_memory = self._get_memory_usage()
        
    def get_memory_stats(self):
        return {
            'peak': dict(self.peak_memory),
            'change': {
                'rss_change': self.end_memory['rss'] - self.start_memory['rss'],
                'vms_change': self.end_memory['vms'] - self.start_memory['vms'],
                'shared_change': self.end_memory['shared'] - self.start_memory['shared']
            },
            'children_peak': {
                'rss': max(self.start_memory['children_rss'], 
                          self.end_memory['children_rss']),
                'vms': max(self.start_memory['children_vms'], 
                          self.end_memory['children_vms'])
            }
        }

def run_traditional_mcmc(log_like, ndim, param_ranges, ndraws=1000):
    """Run traditional MCMC using PyMC with improved tuning for complex distributions"""
    active_mask = [True] * ndim  # All parameters are active in test problems
    
    with pm.Model() as model:
        try:
            # Define variables with proper ranges
            params = []
            for i in range(ndim):
                min_val, max_val = param_ranges[i]
                p = pm.Uniform(f'p_{i}', min_val, max_val)
                params.append(p)
            
            params_combined = pm.math.stack(params)
            likelihood = pm.Potential('likelihood', log_like(params_combined))
            
            # Use multiple step sizes for better adaptation
            step_sizes = [0.1, 0.05, 0.01]
            steps = []
            for size in step_sizes:
                steps.append(pm.Metropolis(
                    vars=params,
                    tune=True,
                    scaling=size,
                    tune_interval=50  # More frequent tuning (100)
                ))
            
            # Start closer to the expected mode for each problem
            start = {f'p_{i}': 0.5 * (param_ranges[i][0] + param_ranges[i][1]) 
                    for i in range(ndim)}
            
            start_time = time.time()
            
            # Increase tuning and sampling
            trace = pm.sample(
                draws=ndraws*5,  # More draws
                tune=int(ndraws * 0.5),  # More tuning steps
                return_inferencedata=True,
                chains=4,
                cores=4,
                progressbar=False,
                initvals=start,
                step=steps,
                discard_tuned_samples=True,
#                target_accept=0.8  # Slightly higher acceptance rate target
            )
            runtime = time.time() - start_time
            
            samples = np.stack([
                trace.posterior[f'p_{i}'].values 
                for i in range(ndim)
            ], axis=-1)
            
            # Extract acceptance rate from trace
            try:
                if hasattr(trace, 'sample_stats'):
                    accept_stat = trace.sample_stats.accept.mean().item()
                    accept_rate = float(accept_stat * 100)  # Convert to percentage
                else:
                    accept_rate = None
            except Exception as e:
                print(f"Could not extract acceptance rate: {e}")
                accept_rate = None
            
            return {
                'trace': trace,
                'samples': samples,
                'raw_samples': samples,
                'runtime': runtime,
                'n_active': ndim,
                'active_mask': active_mask,
                'accept_rate': accept_rate
            }
        
        except Exception as e:
            print(f"PyMC sampling error: {str(e)}")
            raise
            
def run_emcee(log_like, param_ranges, nwalkers=None, nsteps=1000):
    """Run emcee with robust error handling for test problems"""
    ndim = len(param_ranges)
    if nwalkers is None:
        nwalkers = max(20 * ndim, 40)
    
    min_vals = np.array([r[0] for r in param_ranges])
    max_vals = np.array([r[1] for r in param_ranges])
    
    def wrap_likelihood(x_phys):
        """Wrapper for physical-space likelihood"""
        try:
            if np.any(x_phys < min_vals) or np.any(x_phys > max_vals):
                return -np.inf
            return float(log_like(x_phys))
        except Exception as e:
            return -np.inf
    
    # Initialize walkers near the origin
    pos = np.random.uniform(
        low=min_vals + 0.1 * (max_vals - min_vals),
        high=max_vals - 0.1 * (max_vals - min_vals),
        size=(nwalkers, ndim)
    )
    
    sampler = emcee.EnsembleSampler(
        nwalkers, 
        ndim, 
        wrap_likelihood,
        a=2.0
    )
    
    try:
        start_time = time.time()
        # Run burn-in
        pos, prob, state = sampler.run_mcmc(pos, 100, progress=True)
        sampler.reset()
        
        # Run production
        final_pos, final_prob, final_state = sampler.run_mcmc(pos, nsteps, progress=True)
        
        active_mask = [True] * ndim
        samples = sampler.chain
        
        return {
            'samples': np.expand_dims(samples, 0),
            'raw_samples': samples,
            'runtime': time.time() - start_time,
            'accept_rate': np.mean(sampler.acceptance_fraction),
            'log_prob': sampler.lnprobability,
            'active_mask': active_mask,
            'n_active': ndim
        }
        
    except Exception as e:
        print(f"Error during sampling: {str(e)}")
        return None

def run_hmc(jax_log_like, ndim, param_ranges, num_warmup=500, num_samples=1000):
    """Run HMC using NumPyro for test problems"""
    import jax
    import jax.numpy as jnp
    from jax import random
    import numpyro
    import numpyro.distributions as dist
    from numpyro.infer import MCMC, NUTS
    
    active_mask = [True] * ndim
    
    def model():
        params = []
        for i in range(ndim):
            min_val, max_val = param_ranges[i]
            # Use unconstrained sampler and transform
            u = numpyro.sample(f'u_{i}', dist.Normal(0.0, 1.0))
            param = numpyro.deterministic(
                f'p_{i}',
                min_val + (max_val - min_val) * jax.nn.sigmoid(u)
            )
            params.append(param)
        
        params = jnp.stack(params)
        like = jax_log_like(params)
        numpyro.factor('likelihood', like)
        return params
    
    start_time = time.time()
    
    try:
        # Initialize the sampler
        rng_key = random.PRNGKey(0)
        kernel = NUTS(model, 
                     target_accept_prob=0.8,
                     adapt_step_size=True,
                     max_tree_depth=10)
        
        mcmc = MCMC(
            kernel,
            num_warmup=num_warmup,
            num_samples=num_samples,
            num_chains=1,
            progress_bar=False
        )
        
        # Run the sampler
        mcmc.run(rng_key)
        
        # Get samples and transform them back to the original space
        samples = mcmc.get_samples()
        
        # Extract parameter samples
        active_samples = []
        for i in range(ndim):
            param_samples = np.array(samples[f'p_{i}'])
            active_samples.append(param_samples)
        
        active_samples = np.stack(active_samples, axis=-1)
        

        if hasattr(mcmc, 'sampler') and hasattr(mcmc.sampler, 'acceptance_rate'):
            accept_rate = float(mcmc.sampler.acceptance_rate)
        else:
           sample_stats = mcmc.get_extra_fields()
        if 'accept_prob' in sample_stats:
            accept_rate = float(np.mean(sample_stats['accept_prob']))
        else:
            print("Warning: Could not find acceptance rate in MCMC results")
            accept_rate = None
        
        return {
            'samples': active_samples.reshape(1, -1, ndim),
            'raw_samples': samples,
            'runtime': time.time() - start_time,
            'accept_rate': accept_rate,
            'n_active': ndim,
            'active_mask': active_mask
        }
        
    except Exception as e:
        print(f"Error during HMC sampling: {str(e)}")
        traceback.print_exc()
        return None

def run_nested(log_like, ndim, param_ranges, nlive=1000):
    """Run nested sampling using dynesty for test problems"""
    def prior_transform(unit_coords):
        """Transform from unit cube to physical parameter space"""
        physical_coords = np.zeros(ndim)
        for i in range(ndim):
            pmin, pmax = param_ranges[i]
            physical_coords[i] = pmin + (pmax - pmin) * unit_coords[i]
        return physical_coords
    
    sampler = dynesty.NestedSampler(
        log_like,
        prior_transform,
        ndim,
        nlive=nlive,
        bound='multi',
        sample='rwalk'
    )
    
    start_time = time.time()
    sampler.run_nested(dlogz=0.25)
    runtime = time.time() - start_time
    
    results = sampler.results
    samples = results.samples
    weights = np.exp(results.logwt - results.logz[-1])
    samples_equal = dynesty.utils.resample_equal(samples, weights)
    
    active_mask = [True] * ndim
    samples_reshaped = samples_equal.reshape(1, -1, ndim)
    
    return {
        'samples': samples_reshaped,
        'raw_samples': samples,
        'runtime': runtime,
        'accept_rate': results.eff,
        'logz': results.logz[-1],
        'weights': weights,
        'results': results
    }


def run_slice(log_like, ndim, param_ranges, ndraws=1000):
    """Run Slice Sampling for test problems"""
    def slice_sample_step(x0, d, width=0.1):
        x = x0.copy()
        current_like = log_like(x)
        slice_height = current_like - np.random.exponential()
        
        pmin, pmax = param_ranges[d]
        r = np.random.uniform(0, width)
        left = max(pmin, x[d] - r)
        right = min(pmax, left + width)
        left = max(pmin, right - width)
        
        while True:
            x_left = x.copy()
            x_left[d] = left
            if log_like(x_left) <= slice_height or left <= pmin:
                break
            left = max(pmin, left - width)
            
        while True:
            x_right = x.copy()
            x_right[d] = right
            if log_like(x_right) <= slice_height or right >= pmax:
                break
            right = min(pmax, right + width)
            
        while True:
            new_x = np.random.uniform(left, right)
            x_new = x.copy()
            x_new[d] = new_x
            
            if log_like(x_new) > slice_height:
                return x_new
                
            if new_x < x[d]:
                left = new_x
            else:
                right = new_x

    chains = 4
    samples = np.zeros((chains, ndraws, ndim))
    acceptance_rates = np.zeros(chains)
    
    start_time = time.time()
    active_mask = [True] * ndim
    
    for chain in range(chains):
        x = np.zeros(ndim)
        for i, (pmin, pmax) in enumerate(param_ranges):
            x[i] = (pmin + pmax) / 2
            
        accepted = 0
        
        for i in range(ndraws):
            x_old = x.copy()
            
            for d in range(ndim):
                x = slice_sample_step(x, d)
            
            samples[chain, i] = x
            if not np.array_equal(x, x_old):
                accepted += 1
                
        acceptance_rates[chain] = accepted / ndraws
    
    return {
        'samples': samples,
        'raw_samples': samples,
        'runtime': time.time() - start_time,
        'accept_rate': np.mean(acceptance_rates),
        'n_active': ndim,
        'active_mask': active_mask
    }


def run_polychord(log_like, ndim, param_ranges, problem="test", nlive=100):
    """Run nested sampling using PolyChord for test problems"""
    import pypolychord
    from pypolychord.settings import PolyChordSettings
    import os
    import numpy as np
    
    # Create chains directory if it doesn't exist
    os.makedirs('chains', exist_ok=True)
    
    # Settings with dimension handling
    settings = PolyChordSettings(ndim, 1)  # ndim parameters, 1 derived parameter (likelihood)
    settings.nlive = nlive
    settings.num_repeats = max(ndim * 5, 30)
    #settings.file_root = f'polychord_run_simple_test_{ndim}dim'
    settings.file_root = f'polychord_run_{problem}_{ndim}dim'
    settings.base_dir = 'chains'
    
    # Clustering settings
    settings.do_clustering = False
    settings.cluster_posteriors = False
    settings.boost_posterior = 1.0
    settings.read_resume = False
    
    def prior(hypercube):
        """Transform unit hypercube to parameter space"""
        from pypolychord.priors import UniformPrior
        physical_params = []
        for i, (pmin, pmax) in enumerate(param_ranges):
            physical_params.append(UniformPrior(pmin, pmax)(hypercube[i]))
        return physical_params
    
    def wrapped_log_like(theta):
        """Wrapper for likelihood to match PolyChord's interface"""
        like = log_like(theta)
        return like, [like]  # Return likelihood and derived parameters
    
    # Run sampler
    start_time = time.time()
    try:
        output = pypolychord.run_polychord(
            wrapped_log_like,
            ndim, 
            1,  # One derived parameter (likelihood)
            settings,
            prior
        )
        runtime = time.time() - start_time
        paramnames = [(f'p{i}', rf'p_{{{i}}}') for i in range(ndim )]
        paramnames.append(('L', r'L'))  # Add the last parameter as 'L'
        output.make_paramnames_files(paramnames)
        posterior = output.posterior

        samples_full = posterior.samples
        samples = posterior.samples[:, :ndim]  # Exclude derived parameters
        logZ, logZerr = output.logZ, output.logZerr

        log_likelihoods = posterior['L']
        logZ = output.logZ
        logZe = output.logZerr
        weights= log_likelihoods - logZ
        

       # samples_reshaped = samples.reshape(1, -1, ndim)
        nlike = float(output.nlike) if hasattr(output, 'nlike') else None
        npost = float(output.nposterior) if hasattr(output, 'nposterior') else None

        equal_weights_file = f'chains/polychord_run_{problem}_{ndim}dim_equal_weights.txt'
        try:
            equal_weighted_samples = np.loadtxt(equal_weights_file)
        except Exception as e:
            print(f"Warning: Could not read equal weights file: {str(e)}")
            equal_weighted_samples = None
        
        return {
            'samples': samples, #np.expand_dims(samples, 0),  # Add chain dimension
            'raw_samples': samples_full,
            'weights': weights,
            'equal_weighted_samples': equal_weighted_samples,
            'runtime': float(runtime),
            'accept_rate':float(nlike/npost) if (nlike is not None and npost is not None) else None,
            'logz': float(output.logZ) if hasattr(output, 'logZ') else None,
            'logz_err': float(output.logZerr) if hasattr(output, 'logZerr') else None,
            'output': output
        }
    
    except Exception as e:
        print(f"PolyChord error: {str(e)}")
        traceback.print_exc()
        return None


In [None]:
def run_polychord_clust(log_like, ndim, param_ranges, problem="test", nlive=100):
    """Run nested sampling using PolyChord for test problems"""
    import pypolychord
    from pypolychord.settings import PolyChordSettings
    import os
    import numpy as np
    
    # Create chains directory if it doesn't exist
    os.makedirs('chains', exist_ok=True)
    
    # Settings with dimension handling
    settings = PolyChordSettings(ndim, 1)  # ndim parameters, 1 derived parameter (likelihood)
    settings.nlive = nlive
    settings.num_repeats = max(ndim * 5, 30)
    settings.file_root = f'polychord_run_clust_{problem}_{ndim}dim'
    settings.base_dir = 'chains'
    
    # Clustering settings
    settings.do_clustering = True
    settings.cluster_posteriors = True
    settings.boost_posterior = 1.0
    settings.read_resume = False
    
    def prior(hypercube):
        """Transform unit hypercube to parameter space"""
        from pypolychord.priors import UniformPrior
        physical_params = []
        for i, (pmin, pmax) in enumerate(param_ranges):
            physical_params.append(UniformPrior(pmin, pmax)(hypercube[i]))
        return physical_params
    
    def wrapped_log_like(theta):
        """Wrapper for likelihood to match PolyChord's interface"""
        like = log_like(theta)
        return like, [like]  
    
    # Run sampler
    start_time = time.time()
    try:
        output = pypolychord.run_polychord(
            wrapped_log_like,
            ndim, 
            1,  # One derived parameter (likelihood)
            settings,
            prior
        )
        runtime = time.time() - start_time
        paramnames = [(f'p{i}', rf'p_{{{i}}}') for i in range(ndim )]
        paramnames.append(('L', r'L'))  # Add the last parameter as 'L'
        output.make_paramnames_files(paramnames)
        posterior = output.posterior

        samples_full = posterior.samples
        samples = posterior.samples[:, :ndim]  # Exclude derived parameters
        logZ, logZerr = output.logZ, output.logZerr

        log_likelihoods = posterior['L']
        logZ = output.logZ
        logZe = output.logZerr
        weights= log_likelihoods - logZ
        

       # samples_reshaped = samples.reshape(1, -1, ndim)
        nlike = float(output.nlike) if hasattr(output, 'nlike') else None
        npost = float(output.nposterior) if hasattr(output, 'nposterior') else None

        equal_weights_file = f'chains/polychord_run_clust_{problem}_{ndim}dim_equal_weights.txt'
        try:
            equal_weighted_samples = np.loadtxt(equal_weights_file)
        except Exception as e:
            print(f"Warning: Could not read equal weights file: {str(e)}")
            equal_weighted_samples = None
        
        return {
            'samples': samples,
            'raw_samples': samples_full,
            'weights': weights,
            'equal_weighted_samples': equal_weighted_samples,
            'runtime': float(runtime),
            'accept_rate':float(nlike/npost) if (nlike is not None and npost is not None) else None,
            'logz': float(output.logZ) if hasattr(output, 'logZ') else None,
            'logz_err': float(output.logZerr) if hasattr(output, 'logZerr') else None,
            'output': output
        }
    
    except Exception as e:
        print(f"PolyChord error: {str(e)}")
        traceback.print_exc()
        return None


In [None]:
# Define the Samplers class
class Samplers:
    pass

# Add `run_traditional_mcmc` as a static method to `Samplers`
Samplers.run_traditional_mcmc = staticmethod(measure_memory(run_traditional_mcmc))
Samplers.run_emcee = staticmethod(measure_memory(run_emcee))
Samplers.run_hmc = staticmethod(measure_memory(run_hmc))
Samplers.run_nested = staticmethod(measure_memory(run_nested))
Samplers.run_slice = staticmethod(measure_memory(run_slice))
Samplers.run_polychord = staticmethod(measure_memory(run_polychord))
Samplers.run_polychord_clust = staticmethod(measure_memory(run_polychord_clust))


In [None]:
def calculate_diagnostics(samples, method, weights=None):
    """Calculate diagnostics based on sampler type"""
    try:
        if method in ['traditional', 'emcee', 'hmc', 'slice']:
            if samples is not None:
                try:
                    # Convert samples to Arviz format if needed
                    if not isinstance(samples, az.InferenceData):
                        # Reshape samples to ensure minimum shape requirements
                        if len(samples.shape) == 2:  # (samples, params)
                            samples = samples.reshape(2, -1, samples.shape[-1])  # Split into 2 chains
                        samples = az.convert_to_dataset(samples)
                    ess = az.ess(samples)
                    r_hat = az.rhat(samples)
                    
                    # Handle different types of returns
                    if isinstance(ess, dict):
                        ess_values = [val for val in ess.values() if not np.isnan(val).any()]
                    elif hasattr(ess, 'to_dataarray'):
                        ess_values = [float(val) for val in ess.to_dataarray().values.flatten() 
                                    if not np.isnan(val)]
                    else:
                        ess_values = [val for val in np.array(ess).flatten() if not np.isnan(val)]
                    ess_mean = float(np.mean(ess_values)) if ess_values else None
                    
                    # Similar for r_hat
                    if isinstance(r_hat, dict):
                        r_hat_values = [val for val in r_hat.values() if not np.isnan(val).any()]
                    elif hasattr(r_hat, 'to_dataarray'):
                        r_hat_values = [float(val) for val in r_hat.to_dataarray().values.flatten() 
                                      if not np.isnan(val)]
                    else:
                        r_hat_values = [val for val in np.array(r_hat).flatten() if not np.isnan(val)]
                    r_hat_mean = float(np.mean(r_hat_values)) if r_hat_values else None
                    
                    return ess_mean, r_hat_mean
                except Exception as e:
                    print(f"Failed to calculate MCMC diagnostics: {str(e)}")
                    return None, None
        elif method in ['nested', 'polychord', 'polychord_clust']:
            # Special handling for nested samplers
            if weights is not None:
                try:
                    normalized_weights = weights / np.sum(weights)
                    ess = 1.0 / np.sum(normalized_weights ** 2)
                    return float(ess), None
                except Exception as e:
                    print(f"Failed to calculate {method} ESS: {str(e)}")
                    return None, None
            else:
                print(f"No weights provided for {method}")
                return None, None
        return None, None
    except Exception as e:
        print(f"Diagnostic calculation failed: {str(e)}")
        return None, None

def run_benchmark(problems=["gaussian", "rosenbrock", "mixture"], 
                 dims=[2, 5, 10],
                 methods=['traditional', 'emcee', 'hmc', 'nested', 'slice', 'polychord', 'polychord_clust']):
    """
    Run benchmarks on test problems
    """
    results = {}
    
    problem_funcs = {
        "gaussian": TestProblems.correlated_gaussian,
        "rosenbrock": TestProblems.rosenbrock,
        "mixture": TestProblems.gaussian_mixture
    }
    
    default_ranges = {
        "gaussian": (-5, 5),
        "rosenbrock": (-5, 5),
        "mixture": (-5, 5)
    }
    
    for problem_name in problems:
        results[problem_name] = {
            'dims': dims,
            'samples': {method: [] for method in methods},
            'raw_samples': {method: [] for method in methods},
            'metrics': {
                'runtime': {method: [] for method in methods},
                'ess_per_sec': {method: [] for method in methods},
                'accuracy': {method: [] for method in methods},
                'accept_rate': {method: [] for method in methods},
                'r_hat': {method: [] for method in methods},
                'memory_usage': {method: [] for method in methods},
                'init_sensitivity': {method: [] for method in methods},
                'weights': {method: [] for method in methods},
                'raw_weights': {method: [] for method in methods},
                'log_weights': {method: [] for method in methods}
            }
        }
        
        print(f"\nBenchmarking {problem_name} distribution...")
        
        for d in dims:
            print(f"\nDimension: {d}")
            
            log_like, jax_log_like = problem_funcs[problem_name](d)
            range_min, range_max = default_ranges[problem_name]
            param_ranges = [(range_min, range_max) for _ in range(d)]
            
            # Set true parameters for each problem type
            if problem_name == "gaussian":
                true_params = np.zeros(d)
            elif problem_name == "rosenbrock":
                true_params = np.ones(d)
            else:  # mixture
                true_params = np.array([2.0] * d)  # One of the modes
            
            for method in methods:
                print(f"\nRunning {method}...")
                try:
                    init_results = []
                    all_memory_stats = []
                    result = None
                    
                    for init_run in range(3):
                        memory_tracker = MemoryTracker()
                        memory_tracker.start()
                        
                        try:
                            if method == 'traditional':
                                result_i = run_traditional_mcmc(log_like, d, param_ranges)
                            elif method == 'emcee':
                                result_i = run_emcee_robust(log_like, param_ranges)
                            elif method == 'hmc':
                                result_i = run_hmc(jax_log_like, d, param_ranges)
                            elif method == 'slice':
                                result_i = run_slice(log_like, d, param_ranges)
                            elif method == 'nested':
                                result_i = run_nested(log_like, d, param_ranges)
                            elif method == 'polychord':
                                result_i = run_polychord(log_like, d, param_ranges, problem=problem_name)
                            elif method == 'polychord_clust':
                                result_i = run_polychord_clust(log_like, d, param_ranges, problem=problem_name)
                            
                            memory_tracker.stop()
                            memory_stats = memory_tracker.get_memory_stats()
                            all_memory_stats.append(memory_stats)
                            
                            if result_i is None:
                                raise ValueError(f"{method} returned None result")
                            
                            if init_run == 0:
                                result = result_i
                            
                            samples_i = result_i['samples']
                            samples_flat_i = samples_i.reshape(-1, d)
                            init_results.append(np.mean(samples_flat_i, axis=0))
                            
                        except Exception as e:
                            print(f"Error in initialization run {init_run}: {str(e)}")
                            result_i = None
                            memory_tracker.stop()
                            break
                    
                    init_sensitivity = float(np.std(init_results, axis=0).mean()) if init_results else None
                    
                    if result is None:
                        raise ValueError(f"No valid result from {method}")
                    
                    if all_memory_stats:
                        memory = {
                            'peak': {
                                'rss': max(m['peak']['rss'] for m in all_memory_stats),
                                'vms': max(m['peak']['vms'] for m in all_memory_stats),
                                'shared': max(m['peak']['shared'] for m in all_memory_stats)
                            },
                            'change': {
                                'rss_change': np.mean([m['change']['rss_change'] for m in all_memory_stats]),
                                'vms_change': np.mean([m['change']['vms_change'] for m in all_memory_stats]),
                                'shared_change': np.mean([m['change']['shared_change'] for m in all_memory_stats])
                            },
                            'children_peak': {
                                'rss': max(m['children_peak']['rss'] for m in all_memory_stats),
                                'vms': max(m['children_peak']['vms'] for m in all_memory_stats)
                            }
                        }
                    else:
                        memory = {
                            'peak': {'rss': 0, 'vms': 0, 'shared': 0},
                            'change': {'rss_change': 0, 'vms_change': 0, 'shared_change': 0},
                            'children_peak': {'rss': 0, 'vms': 0}
                        }
                    
                    samples = result['samples']
                    samples_flat = samples.reshape(-1, d)
                    sample_means = np.mean(samples_flat, axis=0)
                    sample_std = np.std(samples_flat, axis=0)
                    
                    # Create inference data for diagnostics
                    if method in ['traditional', 'emcee', 'hmc', 'slice']:
                        try:
                            # Ensure proper shape for arviz
                            n_samples = len(samples_flat)
                            chain_samples = samples_flat.reshape(2, n_samples//2, d)  # Split into 2 chains
                            samples_az = az.convert_to_inference_data(
                                {"parameter": chain_samples},
                                group="posterior"
                            )
                        except Exception as e:
                            print(f"Failed to convert samples to InferenceData: {e}")
                            samples_az = None
                    else:
                        samples_az = None
                    
                    weights = result.get('weights', None)
                    ess_mean, r_hat_mean = calculate_diagnostics(samples_az, method, weights)
                    
                    results[problem_name]['raw_samples'][method].append(result.get('raw_samples', None))
                    results[problem_name]['samples'][method].append(samples_flat)
                    
                    metrics = results[problem_name]['metrics']
                    metrics['runtime'][method].append(result['runtime'])
                    metrics['ess_per_sec'][method].append(None if ess_mean is None else ess_mean/result['runtime'])
                    metrics['accept_rate'][method].append(result.get('accept_rate', None))
                    metrics['r_hat'][method].append(r_hat_mean)
                    metrics['memory_usage'][method].append(memory)
                    metrics['init_sensitivity'][method].append(init_sensitivity)
                    metrics['weights'][method].append(weights)
                    metrics['raw_weights'][method].append(result.get('weights', None))
                    metrics['log_weights'][method].append(result.get('equal_weighted_samples', None))

                    # Calculate accuracy using true_params
                    parameter_deviations = np.abs(sample_means - true_params)
                    accuracy_metrics = {
                        'average_deviation': float(np.mean(parameter_deviations)),
                        'parameter_deviations': [float(x) for x in parameter_deviations],
                        'true_values': [float(x) for x in true_params],
                        'sampled_means': [float(x) for x in sample_means],
                        'sampled_std': [float(x) for x in sample_std]
                    }
                    metrics['accuracy'][method].append(accuracy_metrics)

                    print(f"Successfully stored results for {method}")

                except Exception as e:
                    print(f"Error running {method} on {problem_name} in {d} dimensions:")
                    print(f"Error message: {str(e)}")
                    print(f"Error type: {type(e)}")
                    traceback.print_exc()
                    for metric in results[problem_name]['metrics']:
                        results[problem_name]['metrics'][metric][method].append(None)

    return results

In [None]:
# Full run of all the samplers
results_finalC = run_benchmark(
    problems=["gaussian", "rosenbrock", "mixture"],#
    dims=[2, 3, 4, 6, 8, 10], #
    methods=['traditional', 'emcee','hmc' , 'nested', 'slice', 'polychord'] #
)


In [None]:
def plot_summary(results, save_plot=True):
    """Create and save a summary plot of key metrics"""
    problems = list(results.keys())
    methods = list(results[problems[0]]['metrics']['runtime'].keys())
    key_metrics = ['runtime', 'memory_usage', 'ess_per_sec', 'init_sensitivity']  #'accept_rate'
    colors = {'traditional': '#ff6b6b', 'emcee': '#9370db',
              'hmc': 'green', 'nested': '#45b7d1', 'slice': '#ffd93d', 'polychord': 'black', 'polychord_clust': 'gray'}
    line_styles = ['solid', 'dashed', 'dashdot']  # Cycle through these styles for problems

    fig, axes = plt.subplots(2, 2, figsize=(15, 15), dpi=300)
    axes = axes.ravel()

    legend_entries = []
    legend_lines = []
    legend_labels = []


    for idx, metric in enumerate(key_metrics):
        ax = axes[idx]
        for problem_idx, problem in enumerate(problems):
            for method in methods:
                dims = results[problem]['dims']
                values = results[problem]['metrics'][metric][method]

                if metric == 'memory_usage':
                    mem_data = [v if isinstance(v, dict) else None for v in values]
                    rss_values = [abs(m['peak']['rss'])/1000 if m else np.nan for m in mem_data]
                    vms_values = [abs(m['peak']['vms'])/1000 if m else np.nan for m in mem_data]

                    rss_values = np.array(rss_values)
                    vms_values = np.array(vms_values)
                    valid_mask = ~np.isnan(rss_values)

                    valid_dims = np.array(dims)[valid_mask]
                    valid_rss = rss_values[valid_mask]
                    valid_vms = vms_values[valid_mask]

                    sort_idx = np.argsort(valid_dims)
                    valid_dims = valid_dims[sort_idx]
                    valid_rss = valid_rss[sort_idx]
                    valid_vms = valid_vms[sort_idx]

                    if len(valid_rss) > 0:
                       # rss_line = ax.plot(valid_dims, valid_rss, 
                       #                    linestyle=line_styles[problem_idx % len(line_styles)], 
                       #                    color=colors[method],
                       #                    marker='o', label=f'{method} ({problem}, RSS)')[0]
                       # legend_entries.append(rss_line)

                        vms_line = ax.plot(valid_dims, valid_vms, 
                                           linestyle=line_styles[problem_idx % len(line_styles)], 
                                           color=colors[method],
                                           marker='s', label=f'{method} ({problem}, VMS)')[0]
                        legend_entries.append(vms_line)

                else:
                    valid_dims = [d for d, v in zip(dims, values) if v is not None]
                    valid_values = [abs(v) if v is not None else np.nan for v in values]
                    if valid_values:
                        line = ax.plot(valid_dims, valid_values, 
                                       linestyle=line_styles[problem_idx % len(line_styles)],
                                       color=colors[method], marker='o',
                                       label=f'{method} ({problem})')[0]
                        legend_entries.append(line)

                        # Store for legend if this is the first metric
                        if idx == 0:
                            legend_lines.append(line)
                            legend_labels.append(f'{method} ({problem})')

        ax.set_xlabel('Dimensions')
        ax.set_ylabel(metric.replace('_', ' ').title())
        ax.set_title(f'{metric.replace("_", " ").title()}')
        if metric == 'ess_per_sec':
            ax.set_ylabel('ESS per sec')
            ax.set_title('ESS per sec')
        if metric in ['ess_per_sec', 'accuracy', 'memory_usage']:
            ax.set_yscale('log')

    # Add legend for all subplots
    ax_handles, ax_labels = ax.get_legend_handles_labels()
    unique_handles_labels = list(dict(zip(ax_labels, ax_handles)).items())
    handles, labels = zip(*unique_handles_labels)

    axes[0].legend(legend_lines, legend_labels,
                   loc='center left', bbox_to_anchor=(-0.01, 0.665), fontsize="11")

    plt.tight_layout(rect=[0, 0, 1, 0.96])

    if save_plot:
        filename = 'summary_comparison_T1000_clust.pdf'
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"Saved plot: {filename}")

    plt.show()
    plt.close()
plt.rcParams.update({'font.size': 14})

In [None]:
plot_summary(results_finalC) 

In [None]:
import copy
results_original = copy.deepcopy(results_finalC)  

In [None]:
from sklearn.cluster import KMeans
def calculate_distribution_accuracy(samples, problem_type, ndim, weights=None, method=None):
    """Calculate both mean-based and distribution-specific accuracy metrics"""
    if samples is None:
        return None, None
        
    # Make a copy to avoid modifying original data
    samples = samples.copy()
    
    # Check dimensions match expected
    if samples.shape[1] != ndim:
        print(f"Warning: Sample dimension ({samples.shape[1]}) doesn't match expected ndim ({ndim})")
    
    # Current mean-based metric
    if problem_type == 'gaussian':
        true_mode = np.zeros(ndim)
        if weights is not None:
            sample_mean = np.average(samples, weights=weights, axis=0)
        else:
            sample_mean = np.mean(samples, axis=0)
        mean_metric = np.sqrt(np.mean((sample_mean - true_mode)**2))
        print(f"Gaussian metric - mean: {sample_mean}, metric: {mean_metric}")  # Debug
        dist_metric = mean_metric
        
    elif problem_type == 'rosenbrock':
        # Mean-based metric
        true_mode = np.ones(ndim)
        if weights is not None:
            sample_mean = np.average(samples, weights=weights, axis=0)
        else:
            sample_mean = np.mean(samples, axis=0)
        mean_metric = np.sqrt(np.mean((sample_mean - true_mode)**2))
        
        # Distribution-based metric
        errors = []
        for sample in samples:
            error = np.mean(100.0 * (sample[1:] - sample[:-1]**2.0)**2.0 + 
                          (1 - sample[:-1])**2.0)
            errors.append(error)
        if weights is not None:
            dist_metric = np.average(errors, weights=weights)
        else:
            dist_metric = np.mean(errors)
        
    elif problem_type == 'mixture':
        # Mean-based metric using KMeans
        kmeans = KMeans(n_clusters=2, random_state=42)
        if weights is not None:
            kmeans.fit(samples, sample_weight=weights)
        else:
            kmeans.fit(samples)
            
        found_modes = kmeans.cluster_centers_
        true_modes = np.array([2.0 * np.ones(ndim), -2.0 * np.ones(ndim)])
        
        # Calculate distances to true modes, accounting for symmetry
        dist1 = np.min([
            np.sqrt(np.mean((found_modes[0] - true_modes[0])**2)),
            np.sqrt(np.mean((found_modes[0] - true_modes[1])**2))
        ])
        dist2 = np.min([
            np.sqrt(np.mean((found_modes[1] - true_modes[0])**2)),
            np.sqrt(np.mean((found_modes[1] - true_modes[1])**2))
        ])
        
        # Get cluster weights
        if weights is not None:
            labels = kmeans.predict(samples)
            weights_by_cluster = [np.sum(weights[labels == i]) for i in range(2)]
            weights_by_cluster = np.array(weights_by_cluster) / np.sum(weights)
        else:
            weights_by_cluster = np.array([np.mean(kmeans.labels_ == i) for i in range(2)])
            
        weight_balance = abs(weights_by_cluster[0] - 0.5)
        mean_metric = float(dist1 + dist2 + weight_balance)
        
        # Distribution-based metric
        log_probs = []
        for sample in samples:
            log_prob1 = -0.5 * np.sum((sample + 2)**2)
            log_prob2 = -0.5 * np.sum((sample - 2)**2)
            log_probs.append(np.logaddexp(log_prob1, log_prob2))
        
        if weights is not None:
            dist_metric = -np.average(log_probs, weights=weights)
        else:
            dist_metric = -np.mean(log_probs)
    
    return mean_metric, dist_metric
def update_accuracy_metrics_comparison(results):
    """Update results with both types of accuracy metrics"""
    for problem_name in results.keys():
        dims = results[problem_name]['dims']
        
        if 'distribution_accuracy' not in results[problem_name]['metrics']:
            results[problem_name]['metrics']['distribution_accuracy'] = {}
        
        for method in results[problem_name]['samples'].keys():
            mean_accuracy = []
            dist_accuracy = []
            
            for i, d in enumerate(dims):
                # For PolyChord, use equal-weighted samples and correct columns
                if method == 'polychord' or method == 'polychord_clust':
                    full_samples = results[problem_name]['metrics']['log_weights'][method][i]
                    samples = full_samples[:, 2:2+d]  # d is the number of parameters
                    weights = None  # No weights needed for equal-weighted samples
                else:
                    samples = results[problem_name]['samples'][method][i]
                    weights = results[problem_name]['metrics']['weights'][method][i] if method in ['nested'] else None
                    if samples is not None and len(samples.shape) == 3:
                        samples = samples.reshape(-1, samples.shape[-1])
                
                if samples is not None:
                    mean_metric, dist_metric = calculate_distribution_accuracy(
                        samples, problem_name, d, weights, method
                    )
                    mean_accuracy.append({'value': mean_metric})
                    dist_accuracy.append({'value': dist_metric})
                else:
                    mean_accuracy.append({'value': None})
                    dist_accuracy.append({'value': None})
            
            results[problem_name]['metrics']['accuracy'][method] = mean_accuracy
            results[problem_name]['metrics']['distribution_accuracy'][method] = dist_accuracy
    
    return results

In [None]:
updated_resultsX00 = update_accuracy_metrics_comparison(results_finalC)

In [None]:
#final function used for the plots in the paper
def plot_accuracy_comparison_final(results):
    """Plot accuracy metrics using fixed scale normalization"""
    problems = list(results.keys())
    methods = list(results[problems[0]]['metrics']['runtime'].keys())
    
    # Color scheme for methods
    colors = {
        'traditional': '#ff6b6b', 
        'emcee': '#9370db',
        'hmc': 'green', ##4ecdc4', 
        'nested': '#45b7d1', 
        'slice': '#ffd93d', 
        'polychord': 'black',
        'polychord_clust': 'gray'
    }
    
    # Line styles for different problems
    line_styles = {
        problems[0]: 'solid',        # gaussian
        problems[1]: 'dashed',       # rosenbrock
        problems[2]: 'dashdot'       # mixture
    }
    
    # Create figure with mean-based and distribution-based metrics
    fig, axes = plt.subplots(1, 2, figsize=(20, 8), dpi=300)
    
    metrics = ['accuracy', 'distribution_accuracy']
    titles = ['Mean-based Accuracy', 'Distribution-based Accuracy']
    
    max_errors = {
        'gaussian': 0.5,
        'rosenbrock': 2.0,
        'mixture': 5.0  # Adjusted to match the scale of mixture distribution metric
    }
    
    dims = results[problems[0]]['dims']
    
    for idx, (metric, title) in enumerate(zip(metrics, titles)):
        ax = axes[idx]
        
        for method in methods:
            for problem in problems:
                values = results[problem]['metrics'][metric][method]
                
                normalized_values = []
                valid_dims = []
                
                for d_idx, d in enumerate(dims):
                    value = values[d_idx]
                    value = value['value'] if isinstance(value, dict) else value
                    
                    if value is not None:
                        # Apply fixed scale normalization
                        max_error = max_errors[problem]
                        norm_value = 1 - min(value / max_error, 1.0)  # Clip at 0
                        normalized_values.append(norm_value)
                        valid_dims.append(d)
                
                if normalized_values:
                    ax.plot(valid_dims, normalized_values,
                           linestyle=line_styles[problem],
                           color=colors[method],
                           marker='o',
                           label=f'{method} ({problem})')
        
        ax.set_xlabel('Dimensions')
        ax.set_ylabel('Normalized Accuracy (higher is better)')
        ax.set_ylim(-0.01, 1.03)  # Add some padding
        ax.set_title(title)
        
        if idx == 0:  
            ax.set_ylim(-0.01, 1.03) 
            #ax.set_yscale('log')  # Apply log scale to first plot only
            ax.legend(bbox_to_anchor=(1.01, 1), loc='upper left')
        
        ax.grid(True, linestyle='--', alpha=0.7)  # Grid for readability
    
    plt.tight_layout()
    plt.savefig('accuracy_comparison_final.pdf', dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()


In [None]:
plot_accuracy_comparison_final(updated_resultsX00)