# C++

This example demonstrates how to use BlackJAX nested sampling with C++ implementations of likelihood and prior functions. The C++ code is compiled using pybind11 to create a Python module, with JAX's `pure_callback` providing the bridge.

## Setup Instructions

### 1. Create the C++ implementation

First, create a file `model.cpp` with your likelihood and prior functions using pybind11:

```cpp
/* model.cpp */
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <cmath>

namespace py = pybind11;

constexpr double LOG_2PI = 1.8378770664093454;

// Scalar likelihood function for a single parameter vector
double loglikelihood_scalar(const double* theta, int d) {
    const double inv_var = 1.0 / 0.01;
    const double log_det = d * std::log(0.01);
    const double mu = 1.0;
    
    double q = 0.0;
    for (int i = 0; i < d; i++) {
        double diff = theta[i] - mu;
        q += diff * diff * inv_var;
    }
    return -0.5 * (d * LOG_2PI + log_det + q);
}

// Scalar prior function for a single parameter vector
double logprior_scalar(const double* theta, int d) {
    double q = 0.0;
    for (int i = 0; i < d; i++) {
        q += theta[i] * theta[i];
    }
    return -0.5 * (d * LOG_2PI + q);
}

// Batched likelihood: processes multiple parameter vectors sequentially
// This reduces Python callbacks by a factor of num_delete (typically 50-100),
// which is the dominant cost for fast likelihoods
py::array_t<double> loglikelihood(py::array_t<double, py::array::c_style | py::array::forcecast> theta) {
    py::gil_scoped_release release;  // Release GIL for better performance
    
    auto theta_buf = theta.request();
    int batch = theta_buf.shape[0];
    int d = theta_buf.shape[1];
    
    auto result = py::array_t<double>(batch);
    auto result_buf = result.request();
    
    const double* theta_ptr = static_cast<const double*>(theta_buf.ptr);
    double* result_ptr = static_cast<double*>(result_buf.ptr);
    
    // Simple sequential loop over batch
    for (int b = 0; b < batch; b++) {
        result_ptr[b] = loglikelihood_scalar(theta_ptr + b * d, d);
    }
    
    return result;
}

// Batched prior: processes multiple parameter vectors sequentially
py::array_t<double> logprior(py::array_t<double, py::array::c_style | py::array::forcecast> theta) {
    py::gil_scoped_release release;  // Release GIL for better performance
    
    auto theta_buf = theta.request();
    int batch = theta_buf.shape[0];
    int d = theta_buf.shape[1];
    
    auto result = py::array_t<double>(batch);
    auto result_buf = result.request();
    
    const double* theta_ptr = static_cast<const double*>(theta_buf.ptr);
    double* result_ptr = static_cast<double*>(result_buf.ptr);
    
    // Simple sequential loop over batch
    for (int b = 0; b < batch; b++) {
        result_ptr[b] = logprior_scalar(theta_ptr + b * d, d);
    }
    
    return result;
}

PYBIND11_MODULE(model, m) {
    m.doc() = "C++ model for BlackJAX nested sampling";
    m.def("loglikelihood", &loglikelihood, "Compute log likelihood");
    m.def("logprior", &logprior, "Compute log prior");
}
```

**Note:** This implementation defines scalar likelihood and prior functions, then provides simple sequential batched versions. While the batching is sequential rather than parallel, this approach is still significantly faster than pure Python (or non-JIT compiled) implementations because it reduces the number of Python callbacks by a factor of `num_delete` (typically 50-100), which is the dominant cost for fast likelihoods.

Save this as `model.cpp` in your working directory.

### 2. Create the setup script

Create a `setup_model_cpp.py` file to compile the C++ module:

```python
# setup_model_cpp.py
from pybind11.setup_helpers import Pybind11Extension, build_ext
from setuptools import setup

ext_modules = [
    Pybind11Extension(
        "model",
        ["model.cpp"],
        cxx_std=11,
    ),
]

setup(
    name="model",
    ext_modules=ext_modules,
    cmdclass={"build_ext": build_ext},
    zip_safe=False,
    python_requires=">=3.7",
)
```

Save this as `setup_model_cpp.py` in your working directory.

### 3. Compile the C++ module

Install pybind11 and compile the module:

```bash
pip install pybind11
python setup_model_cpp.py build_ext --inplace
```

This will create a `model` module that can be imported directly in Python.

### 4. Run nested sampling with C++ functions

In [None]:
import jax
import jax.numpy as jnp
import blackjax
from blackjax.ns.utils import finalise
import tqdm
import numpy as np
import model  # The compiled C++ module

rng_key = jax.random.PRNGKey(0)

loglikelihood_fn = model.loglikelihood
logprior_fn = model.logprior

def wrap_fn(fn, vmap_method='legacy_vectorized'):
    def jax_wrapper(x):
        out_shape = jax.ShapeDtypeStruct(x.shape[:-1], x.dtype)
        return jax.pure_callback(fn, out_shape, x, vmap_method=vmap_method)
    
    return jax_wrapper

loglikelihood_fn = wrap_fn(loglikelihood_fn)
logprior_fn = wrap_fn(logprior_fn)

algo = blackjax.nss(
    logprior_fn=logprior_fn,
    loglikelihood_fn=loglikelihood_fn,
    num_delete=50,
    num_inner_steps=20,
)

rng_key, sampling_key, initialization_key = jax.random.split(rng_key, 3)
live = algo.init(jax.random.normal(initialization_key, (1000, 5)))
step = jax.jit(algo.step)

dead_points = []

with tqdm.tqdm(desc="Dead points", unit=" dead points") as pbar:
    while (not live.logZ_live - live.logZ < -3):
        rng_key, subkey = jax.random.split(rng_key)
        live, dead = step(subkey, live)
        dead_points.append(dead)
        pbar.update(len(dead.particles))

ns_run = finalise(live, dead_points)