# The 16-dim *Funnel* example in the [GBS paper](http://proceedings.mlr.press/v118/jia20a.html)

*last tested with bayesfast commit 8d6efa1*

In [1]:
import bayesfast as bf
import numpy as np
from threadpoolctl import threadpool_limits
threadpool_limits(1) # TODO: implement a bayesfast global thread controller

<threadpoolctl.threadpool_limits at 0x2aaaaebc75b0>

In [2]:
D = 16 # number of dims
a = 1.
b = 0.5
lower = np.full(D, -30.) # lower bound of the prior for x_1, ...
upper = np.full(D, 30.) # upper bound of the prior for x_1, ...
lower[0] = -4 # lower bound of the prior for x_0
upper[0] = 4 # upper bound of the prior for x_0
bound = np.array((lower, upper)).T
diff = bound[:, 1] - bound[:, 0]
const = np.sum(np.log(diff)) # normalization of the flat prior

def logp(x):
    n = x.shape[-1]
    _a = -0.5 * x[..., 0]**2 / a**2
    _b = -0.5 * np.sum(x[..., 1:]**2, axis=-1) * np.exp(-2 * b * x[..., 0])
    _c = (-0.5 * np.log(2 * np.pi * a**2) - 
          0.5 * (n - 1) * np.log(2 * np.pi) - (n - 1) * b * x[..., 0])
    return _a + _b + _c - const

def grad(x):
    n = x.shape[-1]
    foo = -x / np.insert(np.full((*x.shape[:-1], n - 1), 
                                 np.exp(2 * b * x[..., 0])), 0, a**2, axis=-1)
    foo[0] += b * np.sum(x[..., 1:]**2, axis=-1) * np.exp(-2 * b * x[..., 0])
    foo[0] -= (n - 1) * b
    return foo

In [3]:
bf.utils.random.set_generator(16) # set up the global random number generator
bf.utils.parallel.set_backend(8) # set up the global parallel backend
den = bf.DensityLite(logp=logp, grad=grad, input_size=D, input_scales=bound,
                     hard_bounds=True)
sample_trace = {'n_chain': 8, 'n_iter': 2500, 'n_warmup': 1000,
                'target_accept': 0.95}
# the funnel structure can be pathological for HMC-style samplers
# so here we choose a higher target acceptance rate, see:
# https://mc-stan.org/users/documentation/case-studies/divergences_and_bias.html
rec = bf.Recipe(density=den, sample={'sample_trace': sample_trace},
                post={'evidence_method': 'GBS'})

In [4]:
rec.run()


 *** StaticSample: returning the #0 SampleStep. *** 



  return np.sum(np.log(np.abs(self.to_original_grad(x_trans))),
  _grad += self.to_original_grad2(x) / _tog
  _grad += self.to_original_grad2(x) / _tog
  return np.sum(np.log(np.abs(self.to_original_grad(x_trans))),
  return np.sum(np.log(np.abs(self.to_original_grad(x_trans))),
  _grad += self.to_original_grad2(x) / _tog
  return np.sum(np.log(np.abs(self.to_original_grad(x_trans))),
  _grad += self.to_original_grad2(x) / _tog
  return np.sum(np.log(np.abs(self.to_original_grad(x_trans))),
  _grad += self.to_original_grad2(x) / _tog
  return np.sum(np.log(np.abs(self.to_original_grad(x_trans))),
  return np.sum(np.log(np.abs(self.to_original_grad(x_trans))),
  _grad += self.to_original_grad2(x) / _tog
  _grad += self.to_original_grad2(x) / _tog


 CHAIN #5 : sampling proceeding [ 500 / 2500 ], last 500 samples used 3.43 seconds. (warmup)
 CHAIN #0 : sampling proceeding [ 500 / 2500 ], last 500 samples used 3.48 seconds. (warmup)
 CHAIN #1 : sampling proceeding [ 500 / 2500 ], last 500 samples used 3.70 seconds. (warmup)
 CHAIN #2 : sampling proceeding [ 500 / 2500 ], last 500 samples used 3.79 seconds. (warmup)
 CHAIN #7 : sampling proceeding [ 500 / 2500 ], last 500 samples used 3.77 seconds. (warmup)
 CHAIN #3 : sampling proceeding [ 500 / 2500 ], last 500 samples used 3.86 seconds. (warmup)
 CHAIN #4 : sampling proceeding [ 500 / 2500 ], last 500 samples used 3.96 seconds. (warmup)
 CHAIN #6 : sampling proceeding [ 500 / 2500 ], last 500 samples used 4.39 seconds. (warmup)
 CHAIN #0 : sampling proceeding [ 1000 / 2500 ], last 500 samples used 2.83 seconds. (warmup)
 CHAIN #5 : sampling proceeding [ 1000 / 2500 ], last 500 samples used 3.24 seconds. (warmup)
 CHAIN #2 : sampling proceeding [ 1000 / 2500 ], last 500 samples us




 ***** PostStep finished. ***** 



In [5]:
rec.get()._fields

('samples',
 'weights',
 'weights_trunc',
 'logp',
 'logq',
 'logz',
 'logz_err',
 'x_p',
 'x_q',
 'logp_p',
 'logq_q',
 'trace_p',
 'trace_q',
 'n_call',
 'x_max',
 'f_max')

In [6]:
rec.get().logz, rec.get().logz_err # fiducial value: logz = -63.4988

(-63.47878850208712, 0.017037028847386972)