In [None]:
""" 
Implementation of https://arxiv.org/abs/1903.09556 with JAX
"""

#%%
import sys
sys.path.append("..")
from opt_einsum import contract
import jax.numpy as jnp
import numpy as np
from functools import partial
import jax
from jax.config import config
config.update("jax_enable_x64", True)
import numpy as np

class hybrid_rosenbrock:
    def __init__(self, n2, n1, mu, b, seed=1):
        """Hybrid Rosenbrock class

        Args:
            mu (float): mean
            b  (array): d-dimensional array. b[0] = a from the paper
            n2 (int):   Number of blocks
            n1 (int):   Block size
        """
        self.n2 = n2
        self.n1 = n1
        self.mu = mu
        self.b = b
        self.DoF = self.n2 * (self.n1 - 1) + 1

        self.B = self._getDependencyStructure(self.b)

        self.Z = self.getPartitionFunction() # Inverse of normalization constant
        self.priorDict = None
        self.id = 'jax_hybrid_rosenbrock'

        np.random.seed(seed)
        self.lower_bound = jnp.array(np.random.uniform(0.2, 1, self.DoF))
        self.upper_bound = jnp.array(np.random.uniform(2, 4.5, self.DoF))


    def _getDependencyStructure(self, x):
        """Get the matrix representation of the dependency structure denoted in Figure 7 - https://arxiv.org/abs/1903.09556
        Remark: This method will be used to express $x_{j,i}$, and $b_{j,i}$ in tensor form. This simplifies the implementation!

        Args:
            x (array): (DoF,) shaped array

        Returns:
            array: (n2, n1) shaped array
        """

        structure = jnp.zeros((self.n2, self.n1))
        structure = structure.at[:, 0].set(x[0])
        structure = structure.at[:, 1:].set(x[1:].reshape(self.n2, self.n1-1))
        return structure

    def _getResiduals(self, x):
        """Get residuals so that Hybrid Rosenbrock may be expressed in "least squares" form. That is,
        $$-\ln \pi(x) = \frac{1}{2} \sum_{d=1}^{DoF} r(x;d)^2$$,

        Args:
            x (array): (DoF,) shaped array representing the point at which we wish to evaluate the residual

        Returns:
            array: (DoF,) shaped array representing the residuals evaluated at 'x'.
        """
        X = self._getDependencyStructure(x)
        res = jnp.zeros(self.DoF)
        res = res.at[0].set(jnp.sqrt(self.b[0]) * (x[0] - self.mu))
        res = res.at[1:].set((jnp.sqrt(self.B[:,1:]) * (X[:, 1:] - X[:,:-1] ** 2)).flatten())
        return res * jnp.sqrt(2)

    def getMinusLogPosterior(self, x):
        """Get the minus log likelihood, also known as the "potential"

        Args:
            x (array): (DoF,) shaped array representing the point at which we wish to evaluate the potential

        Returns:
            float: Potential evaluated at 'x'.
        """
        res = self._getResiduals(x)
        return jnp.sum(res ** 2) / 2 + jnp.log(self.Z)

    def getPartitionFunction(self):
        """Get the partition function for the hybrid Rosenbrock
        Remark: The partition function is the inverse of the so called "normalization constant"

        Returns:
            float: The partition function.
        """
        return (jnp.pi ** (self.DoF / 2)) / (jnp.prod(jnp.sqrt(self.b)))
    
    def newDrawFromPosterior(self, nSamples):
        """Get i.i.d samples from hybrid Rosenbrock

        Args:
            nSamples (int): The number of samples to draw

        Returns:
            array: (nSamples, DoF) shaped array, where each row corresponds to a sample.
        """
        samples = np.zeros((nSamples, self.DoF))
        index_structure = self._getDependencyStructure(np.arange(self.DoF))
        for d in range(self.DoF):
            standard_deviation = 1 / np.sqrt(2 * self.b[d])
            if d == 0:
                samples[:, d] = np.random.normal(self.mu, standard_deviation, nSamples)
            elif d in index_structure[:, 1]:
                samples[:, d] = samples[:, 0] ** 2 + np.random.normal(0, 1, nSamples) * standard_deviation
            else:
                samples[:, d] = samples[:, d - 1] ** 2 + np.random.normal(0, 1, nSamples) * standard_deviation
        return samples

In [None]:
# Import Hybrid Rosenbrock
import sys, os
sys.path.append("..")
import numpy as np

# Define Hybrid Rosenbrock model
n2 = 3
n1 = 4
DoF = n2 * (n1 - 1) + 1
B = np.zeros(DoF)
B[0] = 30
B[1:] = 20
mu=1
model = hybrid_rosenbrock(n2, n1, mu, B, seed=35)

In [3]:
import os

from jax.config import config

config.update("jax_enable_x64", True)

import pylab as plt
import tensorflow_probability.substrates.jax as tfp
from jax import random, numpy as jnp
from jax import vmap

from jaxns import DefaultNestedSampler
from jaxns import Model
from jaxns import Prior
from jaxns import bruteforce_evidence
from jaxns import TerminationCondition

tfpd = tfp.distributions

def prior_model():
    x = yield Prior(tfpd.Uniform(low=model.lower_bound, high=model.upper_bound, name='x'))
    return x

log_like = lambda x: -1 * model.getMinusLogPosterior(x)

jaxns_model = Model(prior_model=prior_model,
              log_likelihood=log_like)

# Create the nested sampler class. In this case without any tuning.
ns = DefaultNestedSampler(model=jaxns_model, max_samples=1e5)

termination_reason, state = ns(random.PRNGKey(420))
results = ns.to_results(termination_reason=termination_reason, state=state)

  from jax.config import config


In [4]:
results.samples

{}