# Neural Likelihood Estimation with Nested Sampling

This example demonstrates how to use BlackJAX nested sampling as a posterior sampler for simulation-based inference (SBI) with neural likelihood estimation (NLE).

## Prerequisites

Install the required packages:
```bash
pip install git+https://github.com/handley-lab/blackjax
pip install sbi torch anesthetic numpy tqdm
```

## NSPosterior Implementation

First, we define a custom posterior class that uses BlackJAX nested sampling to sample from the posterior distribution given a trained likelihood estimator and prior.

In [None]:
import torch
import numpy as np
import jax
import jax.numpy as jnp
import blackjax
from blackjax.ns.utils import finalise
import anesthetic
import tqdm


class NSPosterior:
    """Nested Sampling posterior for sbi.
    
    Uses BlackJAX nested sampling to sample from posterior distributions
    when given a likelihood estimator and prior.
    
    Args:
        likelihood: Trained likelihood estimator from NLE
        prior: Prior distribution
        num_live: Number of live points for nested sampling
        num_inner_steps: Number of slice sampling steps
        x_o: Observed data (can be set later with set_default_x)
        num_delete: Number of points to delete per iteration (default: 1)
    """
    
    def __init__(
        self,
        likelihood,
        prior,
        num_live,
        num_inner_steps,
        x_o=None,
        num_delete=1,
    ):
        self.likelihood = likelihood
        self.prior = prior
        self._x = x_o
        self.num_live = num_live
        self.num_delete = num_delete
        self.num_inner_steps = num_inner_steps
    
    @property
    def default_x(self):
        """Return default x used by .sample(), .log_prob() as conditioning context."""
        return self._x
    
    @default_x.setter
    def default_x(self, x):
        """See `set_default_x`."""
        self.set_default_x(x)
    
    def set_default_x(self, x):
        """Set new default x for .sample(), .log_prob() to use as conditioning context.
        
        This convenience is particularly useful when the posterior is focused, i.e.
        has been trained over multiple rounds to be accurate in the vicinity of a
        particular x=x_o.
        
        NOTE: this method is chainable, i.e. will return the NSPosterior object so
        that calls like posterior.set_default_x(my_x).sample(mytheta) are possible.
        
        Args:
            x: The default observation to set for the posterior p(θ|x).
        Returns:
            NSPosterior that will use a default x when not explicitly passed.
        """
        self._x = x
        return self
    
    def _x_else_default_x(self, x):
        if x is not None:
            return x
        elif self.default_x is None:
            raise ValueError(
                "Context `x` needed when a default has not been set."
                "If you'd like to have a default, use the `.set_default_x()` method."
            )
        else:
            return self.default_x
    
    def _loglikelihood_fn(self, theta):
        x_o_batch = self.default_x.unsqueeze(0).expand(1, theta.shape[0], -1)
        return self.likelihood.log_prob(x_o_batch, condition=theta).squeeze(0)
    
    def _logprior_fn(self, theta):
        return self.prior.log_prob(theta)
    
    def nested_samples(self):
        """Run BlackJAX nested sampling."""
        def wrap_fn(fn, vmap_method='legacy_vectorized'):
            def numpy_wrapper(theta):
                x = torch.from_numpy(np.asarray(theta).copy()).float()
                result = fn(x)
                return result.detach().numpy()
            
            def jax_wrapper(x):
                out_shape = jax.ShapeDtypeStruct(x.shape[:-1], x.dtype)
                return jax.pure_callback(numpy_wrapper, out_shape, x, vmap_method=vmap_method)
            
            return jax_wrapper
        
        algo = blackjax.nss(
            logprior_fn=wrap_fn(self._logprior_fn),
            loglikelihood_fn=wrap_fn(self._loglikelihood_fn),
            num_delete=self.num_delete,
            num_inner_steps=self.num_inner_steps,
        )
        prior_samples = self.prior.sample((self.num_live,))
        initial_live = jnp.array(prior_samples.numpy())
        
        rng_key = jax.random.PRNGKey(42)
        live = algo.init(initial_live)
        step = jax.jit(algo.step)
        
        dead_points = []
        
        with tqdm.tqdm(desc="Dead points", unit=" dead points") as pbar:
            while (not live.logZ_live - live.logZ < -3):
                rng_key, subkey = jax.random.split(rng_key)
                live, dead = step(subkey, live)
                dead_points.append(dead)
                pbar.update(len(dead.particles))

        ns_run = finalise(live, dead_points)
        
        return anesthetic.NestedSamples(
            data=ns_run.particles,
            logL=ns_run.loglikelihood,
            logL_birth=ns_run.loglikelihood_birth,
        )
    
    def sample(self, sample_shape=torch.Size(), x=None):
        """Return unweighted posterior samples.
        
        Args:
            sample_shape: Desired shape of samples
            x: Optional observation (uses default if not provided)
            
        Returns:
            PyTorch tensor of samples with shape (*sample_shape, param_dim)
        """
        x_o = self._x_else_default_x(x)
        if x is not None:
            # Temporarily set x for sampling
            old_x = self._x
            self._x = x_o
            ns = self.nested_samples()
            self._x = old_x
        else:
            ns = self.nested_samples()
        samples = ns.sample(torch.Size(sample_shape).numel())
        samples_array = samples.drop(columns=['logL', 'logL_birth', 'nlive']).values
        return torch.from_numpy(samples_array).reshape((*sample_shape, -1))
    
    def log_prob(self, theta, x=None):
        """Evaluate unnormalized log posterior.
        
        Args:
            theta: Parameters to evaluate
            x: Optional observation (uses default if not provided)
            
        Returns:
            Log posterior values (unnormalized)
        """
        x_o = self._x_else_default_x(x)
        if x is not None:
            # Temporarily set x for evaluation
            old_x = self._x
            self._x = x_o
            result = self._loglikelihood_fn(torch.as_tensor(theta)) + self._logprior_fn(torch.as_tensor(theta))
            self._x = old_x
            return result
        else:
            theta = torch.as_tensor(theta)
            return self._loglikelihood_fn(theta) + self._logprior_fn(theta)

## Example: Neural Likelihood Estimation

Now let's use the NSPosterior class with neural likelihood estimation (NLE) from the sbi package.

In [None]:
from sbi.inference import NLE
from sbi.utils import BoxUniform

# Define the prior
num_dims = 2
num_sims = 1000
num_rounds = 2
prior = BoxUniform(low=torch.zeros(num_dims), high=torch.ones(num_dims))

# Simple simulator: adds Gaussian noise to parameters
simulator = lambda theta: theta + torch.randn_like(theta) * 0.1

# Observed data
x_o = torch.tensor([0.5, 0.5])

## Sequential Neural Likelihood Estimation

Train the likelihood estimator over multiple rounds, using nested sampling to generate proposal samples.

In [None]:
inference = NLE(prior)
proposal = prior

for round_idx in range(num_rounds):
    print(f"\nRound {round_idx + 1}/{num_rounds}")
    
    # Sample from proposal
    theta = proposal.sample((num_sims,))
    
    # Simulate data
    x = simulator(theta)
    
    # Train likelihood estimator
    likelihood = inference.append_simulations(theta, x).train()
    
    # Create nested sampling posterior
    posterior = NSPosterior(
        likelihood, 
        prior, 
        num_live=500, 
        num_inner_steps=20, 
        num_delete=250
    )
    
    # Set observation and use as proposal for next round
    proposal = posterior.set_default_x(x_o)

## Analyze Results

Get the nested samples and analyze the posterior.

In [None]:
# Get nested samples
nested_samples = posterior.nested_samples()

# Save to CSV for analysis
nested_samples.to_csv("nested_samples.csv")

# Display summary statistics
print("\nNested Sampling Results:")
print(f"Log evidence: {nested_samples.logZ():.2f}")
print(f"Number of samples: {len(nested_samples)}")
print(f"\nPosterior mean:")
print(nested_samples.mean())
print(f"\nPosterior std:")
print(nested_samples.std())