# Blobs and Metadata

We introduce the blobs interface. An easy way for the user to track arbitrary metadata for every sample of the chain.

## Tracking the value of the log-prior

We can easily use blobs to store the value of the log-prior at each step in the chain by doing something like:

In [1]:
import zeus

import numpy as np

def log_prior(x):
    return -0.5 * np.dot(x,x)

def log_like(x):
    return -0.5 * np.dot(x,x) / 0.1**2.0

def log_prob(x):
    lp = log_prior(x)
    if not np.isfinite(lp):
        return -np.inf, -np.inf
    ll = log_like(x)
    if not np.isfinite(ll):
        return lp, -np.inf
    return lp + ll, lp

nwalkers, ndim = 32, 3
start = np.random.randn(nwalkers, ndim)
sampler = zeus.EnsembleSampler(nwalkers, ndim, log_prob)
sampler.run_mcmc(start, 100)

log_prior_samps = sampler.get_blobs()
flat_log_prior_samps = sampler.get_blobs(flat=True)

print(log_prior_samps.shape)  # (100, 32)
print(flat_log_prior_samps.shape)  # (3200,)

Initialising ensemble of 32 walkers...
Sampling progress : 100%|██████████| 100/100 [00:00<00:00, 160.45it/s](100, 32)
(3200,)



Once this is done running, the “blobs” stored by the sampler will be a ``(nsteps, nwalkers)`` numpy array with the value of the log prior at every sample.

## Tracking multiple species of metadata

When handling multiple species of metadata, it can be useful to name them. This can be done using the ``blobs_dtype`` argument of the ``EnsembleSampler``. For instance, to save the mean of the parameters as well as the log-prior we could do something like:

In [2]:
def log_prob(params):
    lp = log_prior(params)
    if not np.isfinite(lp):
        return -np.inf, -np.inf
    ll = log_like(params)
    if not np.isfinite(ll):
        return lp, -np.inf
    return lp + ll, lp, np.mean(params)

nwalkers, ndim = 32, 3
start = np.random.randn(nwalkers, ndim)

# Here are the important lines
dtype = [("log_prior", float), ("mean", float)]
sampler = zeus.EnsembleSampler(nwalkers, ndim, log_prob, blobs_dtype=dtype)

sampler.run_mcmc(start, 100)

blobs = sampler.get_blobs()
log_prior_samps = blobs["log_prior"]
mean_samps = blobs["mean"]
print(log_prior_samps.shape)
print(mean_samps.shape)

flat_blobs = sampler.get_blobs(flat=True)
flat_log_prior_samps = flat_blobs["log_prior"]
flat_mean_samps = flat_blobs["mean"]
print(flat_log_prior_samps.shape)
print(flat_mean_samps.shape)

Initialising ensemble of 32 walkers...
Sampling progress : 100%|██████████| 100/100 [00:00<00:00, 137.06it/s](100, 32)
(100, 32)
(3200,)
(3200,)

