Copyright 2021 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" }
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

from colabtools import adhoc_import
from contextlib import ExitStack

ADHOC = True
CLIENT = 'fig-export-fig_tree-change-451-3e0a679e9746'

import tensorflow_probability.substrates.jax as tfp
from fun_mc import using_jax as fun_mc

tfd = tfp.distributions

# Variance analysis

In [None]:
n_chains = 10240
n_super_chains = 4
n_steps = 100
n_sub_chains = n_chains // n_super_chains

# ESS by looking at raw chains
ess_vals = []
# ESS computed by looking at super chains
pooled_ess_vals = []
# Like pooled ESS, but we account for the number of sub chains used
nested_ess_vals = []

for seed in jax.random.split(jax.random.PRNGKey(0), 200):
  chain = jax.random.normal(seed, [n_steps, n_chains])
  pooled_chain = chain.reshape([n_steps, n_sub_chains, n_super_chains])

  between = chain.mean(0).var(0, ddof=1)
  overall = chain.var((0, 1), ddof=1)
  ess_vals.append(overall / between)

  super_chain = pooled_chain.mean(1)
  pooled_between = super_chain.mean(0).var(0, ddof=1)
  pooled_overall = super_chain.var((0, 1), ddof=1)
  pooled_ess_vals.append(pooled_overall / pooled_between)

  if True:
    nested_between = pooled_chain.mean((0, 1)).var(0, ddof=1)
    nested_overall = pooled_chain.var((0, 1, 2), ddof=1)
    nested_ess_vals.append((nested_overall / nested_between))
  else:
    # Calculation from Charles's notebook.
    mean_chain = pooled_chain.mean(0)
    mean_super_chain = pooled_chain.mean((0, 1))
    variance_chain = pooled_chain.var(0, ddof=1)
    variance_nested_chain = mean_chain.var(0, ddof=1) + variance_chain.mean(0)

    within_var = variance_nested_chain.mean(0)
    between_var = mean_super_chain.var(0, ddof=1)

    nested_ess_vals.append((1 + within_var / between_var))

ess_vals = jnp.array(ess_vals)
pooled_ess_vals = jnp.array(pooled_ess_vals)
nested_ess_vals = jnp.array(nested_ess_vals)
# We can also normalize the nested ESS values to take into account that super
# chains are larger than regular chains. This interprets the pooling as an
# ostensibly denoised ESS estimator.
nested_normalized_ess_vals = nested_ess_vals / n_sub_chains

# These are actually rhat - 1
rhat_vals = 1 / ess_vals
pooled_rhat_vals = 1 / pooled_ess_vals
nested_rhat_vals = 1 / nested_ess_vals
# It's not clear what the meaning of this is.
nested_normalized_rhat_vals = 1 / nested_normalized_ess_vals

# Expected per-chain ESS is n_steps
print('ESS mean + std:', ess_vals.mean(), ess_vals.std())
print('pooled ESS mean + std:', pooled_ess_vals.mean(), pooled_ess_vals.std())
print('nested ESS mean + std:', nested_ess_vals.mean(), nested_ess_vals.std())
print('nested normalized ESS mean + std:', nested_normalized_ess_vals.mean(), nested_normalized_ess_vals.std())
print()
print('rhat - 1 mean + std:', rhat_vals.mean(), rhat_vals.std())
print('pooled rhat - 1 mean + std:', pooled_rhat_vals.mean(), pooled_rhat_vals.std())
print('nested rhat - 1 mean + std:', nested_rhat_vals.mean(), nested_rhat_vals.std())
print('nested normalized rhat - 1 mean + std:', nested_normalized_rhat_vals.mean(), nested_normalized_rhat_vals.std())

fig = plt.figure(figsize=(24, 6))
plt.subplot(2, 4, 1)
plt.title('log10 ESS')
plt.hist(jnp.log10(ess_vals), histtype='step', density=True, bins=50)

plt.subplot(2, 4, 2)
plt.title('log10 pooled ESS')
plt.hist(jnp.log10(pooled_ess_vals), histtype='step', density=True, bins=50)

plt.subplot(2, 4, 3)
plt.title('log10 nested ESS')
plt.hist(jnp.log10(nested_ess_vals), histtype='step', density=True, bins=50);

plt.subplot(2, 4, 4)
plt.title('log10 nested normalized ESS')
plt.hist(jnp.log10(nested_normalized_ess_vals), histtype='step', density=True, bins=50);

plt.subplot(2, 4, 5)
plt.title('log10 rhat - 1')
plt.hist(jnp.log10(rhat_vals), histtype='step', density=True, bins=50)

plt.subplot(2, 4, 6)
plt.title('log10 pooled rhat - 1')
plt.hist(jnp.log10(pooled_rhat_vals), histtype='step', density=True, bins=50)

plt.subplot(2, 4, 7)
plt.title('log10 nested rhat - 1')
plt.hist(jnp.log10(nested_rhat_vals), histtype='step', density=True, bins=50);

plt.subplot(2, 4, 8)
plt.title('log10 nested normalized rhat - 1')
plt.hist(jnp.log10(nested_normalized_rhat_vals), histtype='step', density=True, bins=50);

fig.tight_layout()

# MCMC test

In [None]:
dist = tfd.Normal(0., 1.)
n_chains = 10240
n_super_chains = 8
n_steps = 100
n_sub_chains = n_chains // n_super_chains

def target_log_prob_fn(x):
  return dist.log_prob(x), ()


def kernel(hmc_state, seed):
  hmc_seed, seed = jax.random.split(seed)
  hmc_state, hmc_extra = fun_mc.hamiltonian_monte_carlo_step(
      hmc_state,
      target_log_prob_fn=target_log_prob_fn,
      step_size=0.5,
      num_integrator_steps=1,
      seed=hmc_seed)
  return (hmc_state, seed), (hmc_state.state, hmc_extra.is_accepted)



init_x = dist.sample([n_chains], seed=jax.random.PRNGKey(0))

_, (chain, is_accepted) = fun_mc.trace((fun_mc.hamiltonian_monte_carlo_init(init_x,
    target_log_prob_fn), jax.random.PRNGKey(0)), kernel, 10000)

init_x2 = dist.sample([n_super_chains], seed=jax.random.PRNGKey(3))
init_x2 = jnp.repeat(init_x2, n_sub_chains)
#init_x2 = dist.sample([num_chains], seed=jax.random.PRNGKey(3))
init_x2 = init_x2.reshape([n_super_chains, n_sub_chains])

_, (chain2, is_accepted2) = fun_mc.trace((fun_mc.hamiltonian_monte_carlo_init(init_x2,
    target_log_prob_fn), jax.random.PRNGKey(3)), kernel, 10000)

chain = jnp.concatenate([init_x[jnp.newaxis], chain], 0)
chain2 = jnp.concatenate([init_x2[jnp.newaxis], chain2], 0)

In [None]:
plt.plot(chain[:, :4])

In [None]:
plt.plot(chain2[:, 0, :4])

In [None]:
chain2[0].mean(-1).var(0), chain[0].var(0)

In [None]:
between_reg = (jnp.cumsum(chain, 0) / jnp.arange(1, chain.shape[0] + 1)[:, jnp.newaxis]).var(1)
#between_reg = (jnp.cumsum(chain2, 0) / jnp.arange(1, chain.shape[0] + 1)[:, jnp.newaxis, jnp.newaxis]).var((1, 2))
super_chain = chain2.mean(-1)
between_nested = (jnp.cumsum(super_chain, 0) / jnp.arange(1, super_chain.shape[0] + 1)[:, jnp.newaxis]).var(1)

In [None]:
plt.title('between chain variance')
plt.plot(between_reg, label='regular chain')
plt.plot(between_nested, label='super chain')
plt.plot(between_reg / n_sub_chains, label='regular chain / n_sub_chains')
plt.axhline(1e-2, ls='--', color='black', lw=2)
plt.xscale('log')
plt.yscale('log')
plt.xlabel('chain length')
plt.legend()

In [None]:
between_reg2 = between_reg
between_nested2 = between_nested
n_sub_chains2 = n_sub_chains

In [None]:
print(n_sub_chains)
print(n_sub_chains2)
plt.figure(figsize=(12, 8))
plt.title('between chain variance')

plt.plot(between_reg, label='regular chain')
plt.plot(between_nested, label='super chain')
plt.plot(between_reg / n_sub_chains, label='regular chain / n_sub_chains')

plt.plot(between_reg2, label='regular chain 2')
plt.plot(between_nested2, label='super chain 2')
plt.plot(between_reg2 / n_sub_chains2, label='regular chain 2 / n_sub_chains 2')

plt.axhline(1e-2, ls='--', color='black', lw=2)
plt.xscale('log')
plt.yscale('log')
plt.xlabel('chain length')
plt.legend()