### Testing Hamiltonian Annealed Importance Sampling

In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
rng = random.PRNGKey(0)

import tensorflow_probability as tfp
import tensorflow as tf
seed = 0

## Uncomment these to test the jax backend
# import tenflow_probabilisorflow_probability.substrates.jax as tfp
# from tensorty.python.internal.backend.jax.compat import v2 as tf
# seed = rng

tfd = tfp.distributions

In [2]:
rng = random.PRNGKey(0)

#### LogGamma from the TFP test suite

In [3]:
num_chains = 100
dims = 20
dtype = np.float32

shape_param = 2. # α
rate_param = 3. # β
independent_chain_ndims = 1

def _log_gamma_log_prob(x, event_dims=()):
    """Computes unnormalized log-pdf of a log-gamma random variable.
    Args:
      x: Value of the random variable.
      event_dims: Dimensions not to treat as independent.
    Returns:
      log_prob: The log-pdf up to a normalizing constant.
    """
    return tf.reduce_sum(
        (shape_param - 1) * x - rate_param * tf.math.exp(x),
        axis=event_dims)

proposal = tfd.Normal(loc=0., scale=1.)

target = tfd.TransformedDistribution(
      tfd.Gamma(concentration=dtype(shape_param), rate=dtype(rate_param)),
  bijector=tfp.bijectors.Invert(tfp.bijectors.Exp()))

def proposal_log_prob(x):
  event_dims = tf.range(independent_chain_ndims, tf.rank(x))
  return tf.reduce_sum(
      tfd.Normal(loc=0., scale=1.).log_prob(x), axis=event_dims)

def target_log_prob(x):
  event_dims = tf.range(independent_chain_ndims, tf.rank(x))
  return _log_gamma_log_prob(x, event_dims)

num_steps = 200

def make_kernel(tlp_fn):
        return tfp.mcmc.HamiltonianMonteCarlo(
            target_log_prob_fn=tlp_fn, step_size=0.5, num_leapfrog_steps=2)

init = tfd.Normal(loc=0., scale=1.).sample((100,), seed=seed)

event_shape = tf.shape(init)[independent_chain_ndims:]
event_size = tf.reduce_prod(event_shape)

log_true_normalizer = (-shape_param * tf.math.log(rate_param) +
                        tf.math.lgamma(shape_param))
log_true_normalizer *= tf.cast(event_size, log_true_normalizer.dtype)

chains_state, ais_weights, kernels_results = (
    tfp.mcmc.sample_annealed_importance_chain(
        num_steps=num_steps,
        # proposal_log_prob_fn=proposal_log_prob,
        # target_log_prob_fn=target_log_prob,
        proposal_log_prob_fn=proposal.log_prob,
        target_log_prob_fn=lambda x: target.log_prob(x) + log_true_normalizer,
        current_state=init,
        make_kernel_fn=make_kernel, seed=seed))

ais_weights_size = tf.cast(tf.size(ais_weights), ais_weights.dtype)
log_estimated_normalizer = (
    tf.reduce_logsumexp(ais_weights) -
    tf.math.log(ais_weights_size))

print(log_estimated_normalizer)
print(log_true_normalizer)



tf.Tensor(-2.1896644, shape=(), dtype=float32)
tf.Tensor(-2.1972246, shape=(), dtype=float32)


#### LogGamma from the TFP docs

https://github.com/tensorflow/probability/blob/v0.19.0/tensorflow_probability/python/mcmc/sample_annealed_importance_test.py

In [4]:
num_chains = 100
dims = 20
dtype = np.float32

α = 2.
β = 3.

proposal = tfd.MultivariateNormalDiag(
   loc=tf.zeros([dims], dtype=dtype))

target = tfd.TransformedDistribution(
  distribution=tfd.Sample(
      tfd.Gamma(concentration=dtype(α), rate=dtype(β)),
      sample_shape=[dims]),
  bijector=tfp.bijectors.Invert(tfp.bijectors.Exp()))

log_true_normalizer = tf.math.lgamma(α) - α * tf.math.log(β)

chains_state, ais_weights, kernels_results = (
    tfp.mcmc.sample_annealed_importance_chain(
        num_steps=200,
        proposal_log_prob_fn=proposal.log_prob,
        target_log_prob_fn=lambda x: target.log_prob(x) + log_true_normalizer,
        current_state=proposal.sample(num_chains, seed=seed),
        make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo(
          target_log_prob_fn=tlp_fn,
          step_size=0.2,
          num_leapfrog_steps=2),
          seed=seed))

log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights)
                            - np.log(num_chains))

print(log_estimated_normalizer)
print(log_true_normalizer)


tf.Tensor(-2.0605202, shape=(), dtype=float32)
tf.Tensor(-2.1972246, shape=(), dtype=float32)


In [26]:
target

<tfp.distributions.TransformedDistribution 'invert_expSampleGamma' batch_shape=[] event_shape=[20] dtype=float32>

In [27]:
proposal

<tfp.distributions.MultivariateNormalDiag 'MultivariateNormalDiag' batch_shape=[] event_shape=[1] dtype=float32>

#### Bayesian Linear Regression from the TFP docs

In [45]:
dtype = np.float32
# Run 100 AIS chains in parallel
num_chains = 100
dims = 1

μ_0 = tf.zeros(dims, dtype=dtype)
Σ_0 = tf.eye(dims, dtype=dtype)
σ = .01



def make_prior():
  return tfd.MultivariateNormalDiag(
      loc=μ_0, scale_diag=tf.linalg.tensor_diag_part(Σ_0))

def make_likelihood(weights, x):
  # return tfd.MultivariateNormalDiag(
  #     loc=tf.einsum("nm,nm->nm", x, weights), scale_diag=σ * tf.ones((num_chains, dims), dtype=dtype))
  return tfd.Normal(
      loc=tf.einsum("nm,nm->n", x, weights),
      scale=σ * tf.ones((num_chains,), dtype=dtype))


# Make training data.
x_rng, w_rng, y_rng = random.split(rng, 3)
x = random.normal(x_rng, (num_chains, dims), dtype)
true_weights = random.normal(w_rng, (dims,), dtype)
y = np.dot(x, true_weights) + σ * random.normal(y_rng, (num_chains,), dtype)

# Setup model.
prior = make_prior()
def target_log_prob_fn(weights):
  return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y)

proposal = tfd.MultivariateNormalDiag(
    loc=tf.zeros(dims, dtype))

weight_samples, ais_weights, kernel_results = (
    tfp.mcmc.sample_annealed_importance_chain(
      num_steps=200,
      proposal_log_prob_fn=proposal.log_prob,
      target_log_prob_fn=target_log_prob_fn,
      current_state=prior.sample(num_chains, seed=seed), #tf.zeros([num_chains, dims], dtype),
      make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=tlp_fn,
        step_size=0.1,
        num_leapfrog_steps=2), 
      # make_kernel_fn=lambda tlp_fn: tfp.mcmc.NoUTurnSampler(
      #   target_log_prob_fn=tlp_fn,
      #   step_size=0.1),
        seed=seed))
log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights)
                           - np.log(num_chains))

print(log_normalizer_estimate)


tf.Tensor(-0.18235588, shape=(), dtype=float32)


In [46]:
def make_evidence(X):
    mean = tf.einsum("nm,m->n", X, μ_0)
    var = X @ Σ_0 @ tf.transpose(X) + σ**2 * tf.eye(num_chains)

    return tfd.MultivariateNormalFullCovariance(
        loc=mean,
        covariance_matrix=var)

print(make_evidence(x).log_prob(y))

tf.Tensor(304.22253, shape=(), dtype=float32)


In [44]:
np.dot(x, true_weights)

array([-0.555008  ,  0.64811164,  0.36106637, -0.62956095,  0.6839977 ,
        1.1524882 ,  0.4709547 , -0.6581082 ,  0.24107417, -0.02283283,
        0.19492728, -0.9294237 , -1.0637267 ,  0.448683  , -0.17867398,
        0.0054085 , -0.76171285, -0.33724046,  0.21614343, -0.04126224,
        0.2985184 , -0.95977587,  0.48037645,  0.08838923, -0.43512648,
       -1.1337018 ,  0.20574404, -0.02180363, -0.5019763 ,  0.04223837,
       -0.32269415,  0.3342762 , -0.7076715 , -0.5723325 ,  0.6807139 ,
        0.8538231 ,  0.5951085 , -0.7882813 ,  0.35678887,  0.07028113,
        0.11483616, -0.15576555, -0.7715699 , -0.72422713,  0.37517452,
       -0.20552973,  0.5832004 ,  0.27629483, -0.28311706, -0.94119465,
       -0.18647191, -1.0897077 ,  0.80816424,  0.4845704 ,  0.02891979,
        0.38963977, -0.12723228, -0.19480483, -0.21068431, -0.5622996 ,
        0.57014126,  0.5398224 ,  0.69954157, -0.5477924 ,  0.6984031 ,
        0.15822472, -0.6824107 ,  0.32123387,  0.5557636 ,  0.91

In [30]:
make_evidence(x)

<tfp.distributions.MultivariateNormalFullCovariance 'MultivariateNormalFullCovariance' batch_shape=[] event_shape=[200] dtype=float32>

In [31]:
prior

<tfp.distributions.MultivariateNormalDiag 'MultivariateNormalDiag' batch_shape=[] event_shape=[2] dtype=float32>

In [36]:
make_likelihood(weight_samples, x)

<tfp.distributions.Normal 'Normal' batch_shape=[200] event_shape=[] dtype=float32>

In [37]:
prior.log_prob(weight_samples) + make_likelihood(weight_samples, x).log_prob(y)

<tf.Tensor: shape=(200,), dtype=float32, numpy=
array([-2.04182816e+00, -4.01725494e+02, -3.13177252e+00, -7.44948669e+02,
       -3.51725098e+02, -2.33914137e+00, -2.11889410e+00, -1.64997101e+00,
       -8.52647171e+01, -4.06426764e+00, -2.78958845e+00, -4.45868683e+02,
       -5.83225203e+00, -1.79815888e+00, -4.45596924e+01, -2.29333925e+00,
       -2.35432100e+00, -2.03694687e+02, -8.50046814e+02, -1.48989685e+03,
       -3.25154572e+02, -8.26057911e-01, -2.61331696e+02, -2.10980368e+00,
       -2.55607056e+02, -2.23822260e+00, -7.40023017e-01, -1.59524930e+00,
       -3.49467712e+02, -8.16077948e-01, -9.21989868e+02, -1.58126402e+00,
       -1.31552029e+00, -2.59153628e+00, -2.36172342e+00, -2.29546677e+02,
       -3.33462739e+00, -4.44852783e+02, -2.12858891e+00, -5.40152588e+02,
       -5.00225878e+00, -1.87100816e+00, -1.90876591e+00, -5.47283268e+00,
       -2.35615158e+02, -1.01408434e+00, -7.35563517e-01, -1.70783126e+00,
       -4.72567439e-01, -4.81881653e+02, -4.69158077