# Adaptive warmup length

This notebook experiments the use of adaptively terminating the warmup length

Copyright 2021 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" }
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
# Import tf first to enable eager mode.
import tensorflow as tf
tf.executing_eagerly()

import numpy as np
from matplotlib.pyplot import *
%config InlineBackend.figure_format = 'retina'
matplotlib.pyplot.style.use("dark_background")

import jax
from jax import random
from jax import numpy as jnp

from colabtools import adhoc_import

# import tensforflow_datasets
from inference_gym import using_jax as gym

from tensorflow_probability.spinoffs.fun_mc import using_jax as fun_mcmc


# import tensorflow as tf
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import unnest


import tensorflow_probability as _tfp
tfp = _tfp.substrates.jax
tfd = tfp.distributions
tfb = tfp.bijectors

tfp_np = _tfp.substrates.numpy
tfd_np = tfp_np.distributions 


In [None]:
# tested options: Bananas. other: GermanCredit, EightSchools
problem_name = 'EightSchools'

if problem_name == 'Bananas':
  target = gym.targets.VectorModel(gym.targets.Banana(),
                                    flatten_sample_transformations=True)
  num_dimensions = target.event_shape[0]  
  init_step_size = 1.

if problem_name == 'GermanCredit':
  # This problem seems to require that we load TF datasets first.
  import tensorflow_datasets
  target = gym.targets.VectorModel(gym.targets.GermanCreditNumericSparseLogisticRegression(),
                                    flatten_sample_transformations=True)
  num_dimensions = target.event_shape[0]
  init_step_size = 0.02

if problem_name == 'Brownian':
  target = gym.targets.BrownianMotionMissingMiddleObservations()
  target = gym.targets.VectorModel(target,
                                    flatten_sample_transformations = True)
  num_dimensions = target.event_shape[0]
  init_step_size = 0.01

# NOTE: this loads the centered parameterization... (use code below to
# get non-centered parameterization). This code is still useful to get
# the correct parameter values.
if problem_name == 'EightSchools':
  target_raw = gym.targets.EightSchools()  # store raw to examine doc.
  target = gym.targets.VectorModel(target_raw,
                                    flatten_sample_transformations = True)
  num_dimensions = target.event_shape[0]
  init_step_size = 1


def target_log_prob_fn(x):
  """Unnormalized, unconstrained target density.

  This is a thin wrapper that applies the default bijectors so that we can
  ignore any constraints.
  """
  y = target.default_event_space_bijector(x)
  fldj = target.default_event_space_bijector.forward_log_det_jacobian(x)
  return target.unnormalized_log_prob(y) + fldj

if problem_name == 'Bananas':
  offset = 2
  def initialize (shape, key = random.PRNGKey(37272709)):
    return 3 * random.normal(key, shape + (num_dimensions,)) + offset

if problem_name == 'GermanCredit':
  offset = 0.1
  def initialize (shape, key = random.PRNGKey(37272709)):
    return 0.5 * random.normal(key, shape + (num_dimensions,)) + offset

# Using underdispersed initis can show case problems with our diagnostics.
underdispered = False
if problem_name == 'EightSchools':
  if underdispered:
    offset = 0.0
    def initialize (shape, key = random.PRNGKey(37272709)):
      return 1 * random.normal(key, shape + (num_dimensions,)) + offset
      # return 3 * random.normal(key, shape + (num_dimensions,)) + offset
  else:
    def initialize (shape, key = random.PRNGKey(37272709)):
     prior_scale = jnp.append(jnp.array([10., 1.]), jnp.repeat(1., 8))
     prior_offset = jnp.append(jnp.array([0., 5.]), jnp.repeat(0., 8))
     return prior_scale * random.normal(key, shape + (num_dimensions,)) + prior_offset


In [None]:
target_raw = gym.targets.EightSchools()
# target = gym.targets.GermanCreditNumericSparseLogisticRegression()
# print(super(type(target_raw), target_raw).__doc__)
print(type(target_raw).__doc__)

Specify eight school model using a non-centered parameterization.
\begin{eqnarray*}
\mu & \sim & N(0, 10)  \\
\log \tau & \sim & N(5, 1) \\
\eta & \sim & N(0, 1) \\
\theta & = & \mu + \tau \theta \\
y & \sim & N(\theta, \sigma)
\end{eqnarray*}

In [None]:
if problem_name == "EightSchools":
  num_schools = 8
  y = np.array([28, 8, -3, 7, -1, 1, 18, 12], dtype = np.float32)
  sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18], dtype = np.float32)

  # NOTE: the reinterpreted batch dimension specifies the dimension of
  # each indepdent variable, here the school.
  model = tfd.JointDistributionSequential([
    tfd.Normal(loc = 0., scale = 10., name = "mu"),
    tfd.Normal(loc = 5., scale = 1., name = "log_tau"),
    tfd.Independent(tfd.Normal(loc = jnp.zeros(num_schools),
                               scale = jnp.ones(num_schools),
                               name = "eta"),
                    reinterpreted_batch_ndims = 1),
    lambda eta, log_tau, mu: (
        tfd.Independent(tfd.Normal(loc = (mu[..., jnp.newaxis] +
                                        jnp.exp(log_tau[..., jnp.newaxis]) *
                                        eta),
                                   scale = sigma),
                        name = "y",
                        reinterpreted_batch_ndims = 1))
  ])

  def target_log_prob_fn(x):
    mu = x[:, 0]
    log_tau = x[:, 1]
    eta = x[:, 2:10]
    return model.log_prob((mu, log_tau, eta, y))


In [None]:
initial_state = initialize((4,), key = jax.random.PRNGKey(1))
evaluated_density = target_log_prob_fn(initial_state)
evaluated_density

In [None]:
# Get some estimates of the mean and variance.
# WARNING: for EightSchool problem, the correct value is for the centered
# parameterization.
try:
  mean_est = target.sample_transformations['identity'].ground_truth_mean
except:
  print('no ground truth mean')
  mean_est = (result.all_states[num_warmup:, :]).mean(0).mean(0)
try:
  var_est = target.sample_transformations['identity'].ground_truth_standard_deviation**2
except:
  print('no ground truth std dev')
  var_est = ((result.all_states[num_warmup:, :]**2).mean(0).mean(0) -
             mean_est**2)

In [None]:
print(mean_est)
print(var_est)

In [None]:
# Follow procedure described in source code for potential scale reduction.
# NOTE: some of the tf argument need to be adjusted (e.g. keepdims = False,
# instead of True). Not quite sure why.
# QUESTION: can these be accessed as internal functions of tf?
# TODO: following Pavel's example, rewrite this without using tf.
# TODO: add error message when the number of samples is less than 2.

# REMARK: this function doesn't seem to work, returns NaN.
# As a result, can only use _reduce_variance with biased =  False.
def _axis_size(x, axis = None):
  """Get number of elements of `x` in `axis`, as type `x.dtype`."""
  if axis is None:
    return ps.cast(ps.size(x), x.dtype)
  return ps.cast(
      ps.reduce_prod(
          ps.gather(ps.shape(x), axis)), x.dtype)

def _reduce_variance(x, axis=None, biased=True, keepdims=False):
  with tf.name_scope('reduce_variance'):
    x = tf.convert_to_tensor(x, name='x')
    mean = tf.reduce_mean(x, axis=axis, keepdims=True)
    biased_var = tf.reduce_mean(
        tf.math.squared_difference(x, mean), axis=axis, keepdims=keepdims)
    if biased:
      return biased_var
    n = _axis_size(x, axis)
    return (n / (n - 1.)) * biased_var

def nested_rhat(result_state, num_super_chain):
  used_samples = result_state.shape[0]
  num_sub_chains = result_state.shape[1] // num_super_chains
  num_dimensions = result_state.shape[2]

  chain_states = result_state.reshape(used_samples, -1, num_sub_chains,
                                      num_dimensions)

  state = tf.convert_to_tensor(chain_states, name = 'state')
  mean_chain = tf.reduce_mean(state, axis = 0)
  mean_super_chain = tf.reduce_mean(state, axis = [0, 2])
  variance_chain = _reduce_variance(state, axis = 0, biased = False)
  variance_super_chain = _reduce_variance(mean_chain, axis = 1, biased = False) \
     + tf.reduce_mean(variance_chain, axis = 1)

  W = tf.reduce_mean(variance_super_chain, axis = 0)
  B = _reduce_variance(mean_super_chain, axis = 0, biased = False)

  return tf.sqrt((W + B) / W)


In [None]:
def forge_chain (target_rhat, warmup_window_size, kernel_cold, initial_state,
                 max_num_steps, seed, monitor = False,
                 use_nested_rhat = True, use_log_joint = False,
                 num_super_chains = 4):
  # store certain variables
  rhat_forge = np.array([])
  warmup_is_acceptable = False
  store_results = []

  warmup_iteration = 0

  current_state = initial_state
  final_kernel_args = None

  while (not warmup_is_acceptable and warmup_iteration <= max_num_steps):
    warmup_iteration += 1

    # 1) Run MCMC on short warmup window
    result_cold, target_log_prob, final_kernel_args = tfp.mcmc.sample_chain(
        num_results = warmup_window_size,
        current_state = current_state,
        kernel = kernel_cold,
        previous_kernel_results = final_kernel_args,
        seed = seed,
        trace_fn = lambda _, pkr: unnest.get_innermost(pkr, 'target_log_prob'),
        return_final_kernel_results = True)

    if warmup_iteration == 1:
      store_results = result_cold
    else : 
      store_results = np.append(store_results, result_cold, axis = 0)

    current_state = result_cold[-1]

    # 2) Check if warmup is acceptable
    if use_nested_rhat:
      if use_log_joint:
        shape_lp = target_log_prob.shape
        rhat_warmup = nested_rhat(target_log_prob.reshape(shape_lp[0], shape_lp[1], 1),
                                  num_super_chains)
      else:
        rhat_warmup = max(nested_rhat(result_cold, num_super_chains))
    else:
      if use_log_joint:
        rhat_warmup = tfp.mcmc.potential_scale_reduction(target_log_prob)
      else:
        rhat_warmup = max(tfp.mcmc.potential_scale_reduction(result_cold))

    if rhat_warmup < target_rhat: warmup_is_acceptable = True

    save_values = True
    if save_values:
      rhat_forge = np.append(rhat_forge, rhat_warmup)
    # While loop ends

  return store_results, final_kernel_args, rhat_forge

In [None]:
def mc_est_warm(x, axis = 0):
  """ compute running average without discarding half of the samples."""
  return np.cumsum(x, axis) / np.arange(1, x.shape[0] + 1).reshape([-1] + [1] * (len(x.shape) - 1))


In [None]:
# Set up adaptive warmup scheme
init_step_size = 1 # 0.1  # CHECK: how should this be set?
max_warmup_length = 1000  # CHECK: how should this be set?

# define kernel using most recent step size
kernel_cold = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob_fn, init_step_size, 1)
kernel_cold = tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(kernel_cold, max_warmup_length)
kernel_cold = tfp.mcmc.DualAveragingStepSizeAdaptation(
      kernel_cold, max_warmup_length, target_accept_prob = 0.75,
      reduce_fn = tfp.math.reduce_log_harmonic_mean_exp)
# kernel_cold = tfp.mcmc.NoUTurnSampler(target_log_prob_fn, init_step_size)

# kernel for sampling phase
kernel_warm = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob_fn, init_step_size, 1)
kernel_warm = tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(kernel_warm, 0)
kernel_warm = tfp.mcmc.DualAveragingStepSizeAdaptation(
      kernel_warm, 0, target_accept_prob = 0.75,
      reduce_fn = tfp.math.reduce_log_harmonic_mean_exp)


In [None]:
# initial_state = initialize((4,), key = jax.random.PRNGKey(1))
# result_short = tfp.mcmc.sample_chain(
#     num_results = 2000, current_state = initial_state, kernel = kernel_cold,
#     seed = random.PRNGKey(1954))

# states = result_short.all_states[1000:2000, :, :]
# print("rhat:", tfp.mcmc.potential_scale_reduction(states, independent_chain_ndims = 1).T)
# print("mean:", np.mean(states.mean(0), axis = 0))

Intuitively, if we warm up our chains well and they forget their initial point, we expect each chain to generate independent samples. Hence, if our goal is to generate an effective sample size of 100 (and we're committed to properly warming up our chains), then we should be able to reach our goal using 100 chains, each with one sampling iteration. One way to check this is to do the adaptive warmup, generate one sample, and then check if the squared error of the Monte Carlo estimate is about $\mathrm{Var}(\theta) / n_\mathrm{chains}$.

One issue is that our estimated squared error can be quite noisy. Suppose we run a model fit $M$ times, each generating a Monte Carlo estimator $\hat \theta^{(m)}$. Leveraring the precise estimator from the inference gym, we then compute the squared error as
$$
  (\hat \theta^{(m)} - \theta^*)^2.
$$
Let's investigate the property of this estimator.

Under a CLT, meaning $m$ is relatively large, we get 
$$
  \hat \theta^{(m)} \overset{\mathrm{approx}}{\sim} \mathrm{Normal} \left(\theta^*, \frac{\sigma^2}{N} \right),
$$
where $N$ is the effective sample size. Then
$$
  (\hat \theta^{(m)} - \theta^*)^2 \overset{\mathrm{approx}}{\sim} \frac{\sigma^2}{N} \chi^2_M,
$$
where we have $M$ degrees of freedom, since $\theta^*$ is not estimated using our sample. The estimator
$$
\hat \tau^2 = \frac{1}{M} \sum_i (\hat \theta^{(m)} - \theta^*)^2
$$
has the following properties:
$$
\mathbb E \hat \tau^2 \approx \sigma^2 / N
$$
and
$$
\mathrm{Var} \hat \tau^2 \approx \frac{2 \sigma^2}{MN}.
$$

In [None]:
# NOTE: need to control the experiment to avoid exhausting the memory
num_chains_array = [16, 32, 64, 128, 256]  # (recommended for Bananas)
# num_chains_array = [256, 512]  # [16, 32, 64, 128, 256]
# num_chains_array = [16, 20, 24, 28, 32, 64, 128]
# num_chains_array = [512]
# num_chains_array = [2048]  # [32, 64]

In [None]:

num_super_chains = 4  # 4
target_rhat = 1.01
warmup_window_size = 50  # 25  # 300
max_num_steps = 1000 // warmup_window_size  # 1000  # 1
mc_err_mean = np.array([])
mc_err_median = np.array([])
mc_err_sd = np.array([])
length_warmup_mean = np.array([])
length_warmup_sd = np.array([])
nested_rhat_store = np.array([])

parameter_index = 0  # parameter of interest

# TODO: parallelize this while still allowing each run to have a different
# warmup length.
for num_chains_short in num_chains_array:

  length_warmup = np.array([])
  mc_err = np.array([])

  for seed in jax.random.split(jax.random.PRNGKey(1), 10):  # 10  # 30
    initial_state = initialize((num_super_chains,), key = seed)
    initial_state = np.repeat(initial_state, num_chains_short // num_super_chains,
                              axis = 0)

    # 1) Warmup phase
    result_cold, final_kernel_args, rhat_forge = \
      forge_chain(target_rhat = target_rhat,
                  warmup_window_size = warmup_window_size,
                  kernel_cold = kernel_cold,
                  initial_state = initial_state,
                  max_num_steps = max_num_steps,
                  seed = seed + 1, monitor = False,
                  use_nested_rhat = True,
                  use_log_joint = False)

    length_warmup = np.append(length_warmup,
                              len(rhat_forge) * warmup_window_size)
  
   # 2) Sampling phase
    current_state = result_cold[-1]

    result_warm = tfp.mcmc.sample_chain(
        num_results = 50,
        current_state = current_state,
        kernel = kernel_warm,
        previous_kernel_results = final_kernel_args,
        seed = seed + 2,
        return_final_kernel_results = None, trace_fn = None)

    # Compute error based on first iteration of the sampling phase.
    mc_err = np.append(mc_err,
                       np.square(result_warm[0, :, parameter_index].mean()
                                 - mean_est[parameter_index]))
    
    # Store nested Rhat computed using first 5 sampling iterations
    nested_rhat_store = np.append(nested_rhat_store,
    nested_rhat(result_warm[0:5, :, :], num_super_chain = num_super_chains)[0])

    # END seed for loop
  mc_err_mean = np.append(mc_err_mean, mc_err.mean())
  mc_err_median = np.append(mc_err_median, np.median(mc_err))
  mc_err_sd = np.append(mc_err_sd, mc_err.std())
  length_warmup_mean = np.append(length_warmup_mean, length_warmup.mean())
  length_warmup_sd = np.append(length_warmup_sd, length_warmup.std())


In [None]:
print(result_warm[0, :, parameter_index].mean())
print(mean_est)

In [None]:
# # Compute school effect
# theta = (result_warm[:, :, 0, jnp.newaxis] +
#   jnp.exp(result_warm[:, :, 1, jnp.newaxis]) * result_warm[:, :, 2:10])

# # print("Our estimates:", mc_est_warm(result_warm.mean(0))[0:2])
# print("Our estimates:", np.mean(result_warm.mean(0), axis = 0)[0:2])
# print("Correct values:", mean_est[0:2])

# print("Our school estimates:", theta.mean(axis = [0, 1]))
# print("True estimates:", mean_est[2:10])

num_chains_short
print("length warmup: ", length_warmup)
print("mc_err", mc_err)
print("nested-Rhat", nested_rhat_store)
# print("mc_mean (excluding outlier):", mc_err[1:10].mean())

In [None]:
print("mc_err_mean:", mc_err_mean, "+/-", mc_err_sd)
print("mc_err_median:", mc_err_median)
print("warmup length:", length_warmup_mean, "+/-", length_warmup_sd)
print(var_est[parameter_index] / np.array(num_chains_array))
result_warm.shape

In [None]:
# For experiments that exhaust memory, enter results from previous kernels
enter_manually = False  # True
if enter_manually:
  num_chains_array = np.array([16, 32, 64, 128, 256, 512])
  mc_err_mean = np.array([1.517, 0.877, 0.474, 0.388, 0.139, 0.576])
  mc_err_sd = np.array([2.45, 0.693, 0.629, 0.791, 0.14, 1.304])
  mc_err_median = np.array([0.476, 0.693, 0.173, 0.0460, 0.075, 0.054])
  length_warmup_mean = np.array([1025., 1025., 722, 460, 437, 400])
  length_warmup_sd = np.array([0., 0., 160, 124, 127, 115])

In [None]:
print(mc_err_mean)
print(num_chains_array)
print(3 * mc_err_sd / np.sqrt(10))

In [None]:
figure(figsize = [6, 6])
errorbar(x = num_chains_array, y = mc_err_mean, yerr = 3 * mc_err_sd / np.sqrt(10), label = 'Observed mean')
# plot(num_chains_array, mc_err_median, label = 'Observed median')
plot(num_chains_array, var_est[parameter_index] / np.array(num_chains_array), linestyle = '--', label = 'Expected')
legend(loc = 'best')
xlabel("Number of chains")
ylabel("Squared error")
show()

In [None]:
figure(figsize = [6, 6])
errorbar(x = num_chains_array, y = length_warmup_mean, yerr = length_warmup_sd)
plot(num_chains_array, length_warmup_mean, 'o')
ylabel("Warmup length")
xlabel("Number of chains")
show()

### Further analysis for last run

In [None]:
# Takes at least 10 iterations to estimate ESS?
ess_warm = np.sum(tfp.mcmc.effective_sample_size(result_warm[1:10, :, :]))
result_warm.shape

In [None]:
print("nested_rhat:", nested_rhat(result_warm[0:5, :, :], num_super_chain = 4))
print("Parameter index: ", parameter_index)
print("Number of iterations: ", result_warm.shape)
print("Ess: ", ess_warm)

In [None]:
print(mean_est[0])
print(np.mean(result_warm[0, :, 0]))

In [None]:
print(np.mean(result_warm[0:5, :, 0])

In [None]:
num_samples_plot = 4  # target_iter_mean
plot(result_warm[1:5, 256:512, 1])
show()

In [None]:
# compute mean of super chains and between super chain variance (using first 5 iterations)
result_state = result_warm[0, :, :]

used_samples = 1 # result_state.shape[0]
num_sub_chains = result_state.shape[0] // num_super_chains
num_dimensions = result_state.shape[1]

chain_states = result_state.reshape(used_samples, -1, num_sub_chains,
                                    num_dimensions)

mean_superchain = np.mean(chain_states, axis = [0, 2])

B = mean_superchain.var(0, ddof = 1)
V = result_state.var(axis = [0, 1], ddof = 1)

print(used_samples)
print(result_state.shape)
print(chain_states.shape)
print(mean_superchain.shape, mean_superchain[:, 0])
print(B)
print(V)

In [None]:
np.unique(initial_state[:, 0])

In [None]:
m = np.mean([2.26, 1.78, 2.12, 2.01])
S = (2.26 - m)**2 + (1.78 - m)**2 + (2.12 - m)**2 + (2.01 - m)**2
S / 3

# Draft Code

In [None]:
mc_err_mean_adaptive = mc_err_mean
mc_err_sd_adaptive = mc_err_sd

In [None]:
print("Expected squared error:", var_est / num_chains_short)
print("Mean warmup length:", length_warmup.mean(), "+/-", length_warmup.std())
print("Mean Monte Carlo squared error:", mc_err.mean(), "+/-", mc_err.std())
print("Median Monte Carlo squared error:", np.median(mc_err))

# hist(result.all_states[:, :, 0].flatten(), 30, log=True)
hist(mc_err, 30)
show()

In [None]:


current_state = result_cold[-1]

result_warm = tfp.mcmc.sample_chain(
      num_results = 5,
      current_state = current_state,
      kernel = kernel_warm,
      previous_kernel_results = final_kernel_args,
      seed = random.PRNGKey(100001),
      return_final_kernel_results = None, trace_fn = None)


In [None]:

target_precision = var_est / 100
mc_est_warm(result_warm.mean(1))[0, 0]