# $n \hat R$ convergence

This notebook is intended to present in a reproducible fashion numerical experiments used to evaluate the behavior of $n \hat R$ across a range of models. Each section can be run independently, once the "setup" section has been run.

Copyright 2021 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" }
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Setup

In [None]:
import numpy as np
from matplotlib.pyplot import *
# %config InlineBackend.figure_format = 'retina'
# matplotlib.pyplot.style.use("dark_background")
font = {'family' : 'normal',
        'weight' : 'bold',
        'size'   : 14}

matplotlib.rc('font', **font)

In [None]:
import jax
from jax import random
from jax import numpy as jnp

from colabtools import adhoc_import

# import tensforflow_datasets
from inference_gym import using_jax as gym

from tensorflow_probability.spinoffs.fun_mc import using_jax as fun_mcmc


# import tensorflow as tf
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import unnest


import tensorflow_probability as _tfp
tfp = _tfp.substrates.jax
tfd = tfp.distributions
tfb = tfp.bijectors

tfp_np = _tfp.substrates.numpy
tfd_np = tfp_np.distributions

import arviz as az
from tensorflow_probability.python.internal.unnest import get_innermost

# Theory

### Example A (Incomplete exploration)

In [None]:
a = np.array(range(4, 1024, 4))
d = np.repeat(6., len(a))

# Two optimization solutions, solving quadratic equations (+ / -)
# Remark: + solution gives a negative upper-bound for delta_u
alpha_1 = 2 * a + d / 2 - np.sqrt(np.square(2 * a + d / 2) - 2 * a)
alpha_2 = a - alpha_1
delta_u = (np.square(alpha_1 + d / 2) / (alpha_1 * alpha_2)) / 2

eps = 0.01
delta = np.square(1 + eps) - 1

In [None]:
semilogy(a / d, delta_u)
hlines(delta, (a / d)[0], (a / d)[len(a) - 1], linestyles = '--',
      label =  "delta for 1.01 threshold")
xlabel("a / d")

### Example B (Asymmetric binary initialization)

In [None]:
alpha_l = 0.5  # options: 0.1, 0.5, 1.2
alpha_r = alpha_l

p =  np.arange(0., 1., 0.01)

sigma_pi = 1   # variance at stationarity

# upper bound on initial variance
stationary_bound = False
theta_0_l = 3
theta_0_r = 3 

conservative_bound = False
if conservative_bound:
  sigma_0 = np.power((theta_0_l + theta_0_r), 2) / 4
else:
  delta_L = (alpha_l / theta_0_l) * np.square(theta_0_l - alpha_l) +\
  (theta_0_l - alpha_l) * np.square(alpha_l)
  delta_R = (alpha_r / theta_0_r) * np.square(theta_0_r - alpha_r) +\
  (theta_0_r - alpha_r) * np.square(alpha_r)
  sigma_0 = p * delta_L + (1 - p) * delta_R
  # sigma_u = p * delta_L + (1 - p) * delta_R
  sigma_0 = delta_L  # (for now, assume symmetry)

if stationary_bound:
  sigma_u = sigma_pi
else:
  sigma_u = 2 * max(sigma_pi, sigma_0)


p = np.arange(0, 1.01, 0.01)     # prob of initializing on the left.

# Set bias for chains coming from the left and right
alpha_l = 1.2
alpha_r = alpha_l

delta = 0.02  # Relative tolerance for squared bias

In [None]:
var_mc = p * (1 - p) * np.square(alpha_l + alpha_r)
bias_mc_squared = np.square(-p * alpha_l + (1 - p) * alpha_r)

rel_var = var_mc / sigma_u
rel_err = (var_mc + bias_mc_squared) /sigma_pi

In [None]:
plot(p, var_mc + bias_mc_squared, label = "squared error")
plot(p, var_mc, label = "var_super")
legend(loc = "best")

In [None]:
plot(p, rel_err, label = "rel_err")
plot(p, rel_var, label = "rel_var")
hlines(delta, 0, 1, linestyles = "--", label = "delta threshold")
legend(loc = "best")
xlabel("p")
title("alpha = " + str(alpha_l))

# Application to models

## Setup

### Nested $\hat R$

In [None]:
# Define nested Rhat for one parameter.
# Assume for now the indexed parameter is a scalar.
# TODO: deprecate state_is_list argument
def nested_rhat_1dim(result_state, num_super_chains, index_param, 
                     num_samples, warmup_length = 0, state_is_list = False,
                     vector_index = None):
  if state_is_list:
    if vector_index is not None:
      state_param = result_state[index_param][
                  warmup_length:(warmup_length + num_samples), :, vector_index]
    else:
      state_param = result_state[index_param][
                        warmup_length:(warmup_length + num_samples), :]
  else:
    state_param = result_state[warmup_length:(warmup_length + num_samples),
                               :, index_param]

  num_samples = state_param.shape[0]
  num_chains = state_param.shape[1]
  num_sub_chains = num_chains // num_super_chains

  state_param = state_param.reshape(num_samples, -1, num_sub_chains, 1)

  mean_chain = np.mean(state_param, axis = (0, 3))
  between_chain_var = np.var(mean_chain, axis = 1, ddof = 1)
  within_chain_var = np.var(state_param, axis = (0, 3), ddof = 1)
  total_chain_var = between_chain_var + np.mean(within_chain_var, axis = 1)

  mean_super_chain = np.mean(state_param, axis = (0, 2, 3))
  between_super_chain_var = np.var(mean_super_chain, ddof = 1)

  return np.sqrt(1 + between_super_chain_var / np.mean(total_chain_var)),\
    between_super_chain_var, np.mean(total_chain_var)


def nested_rhat(result_state, num_super_chains, index_param, 
                num_samples, warmup_length = 0, state_is_list = False):
  nRhat = np.array([])
  B = np.array([])
  W = np.array([])
  for i in range(0, index_param.shape[0]):
    if state_is_list:
      shape_state = result_state[index_param[i]].shape
      if (len(shape_state) == 2):  # Listed parameter isn't a vector
        nRhat_local, B_local, W_local = nested_rhat_1dim(result_state, 
                         num_super_chains, index_param[i], num_samples,
                         warmup_length, state_is_list)

        nRhat = np.append(nRhat, nRhat_local)
        B = np.append(B, B_local)
        W = np.append(W, W_local)

      else:  # Listed parameter is a vector
        for j in range(0, shape_state[2]):
          nRhat_local, B_local, W_local = nested_rhat_1dim(result_state,
                           num_super_chains, index_param[i], num_samples,
                           warmup_length, state_is_list, 
                           vector_index = j)

          nRhat = np.append(nRhat, nRhat_local)
          B = np.append(B, B_local)
          W = np.append(W, W_local)

    else:  # Parameters are not stored as a list
      nRhat_local, B_local, W_local = nested_rhat_1dim(result_state, 
                         num_super_chains, index_param[i], num_samples,
                         warmup_length, state_is_list)

      nRhat = np.append(nRhat, nRhat_local)
      B = np.append(B, B_local)
      W = np.append(W, W_local)

  return nRhat, B, W


### Run fits

In [None]:
def construct_kernel(target_log_prob_fn, init_step_size, num_warmup):
  kernel = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob_fn, init_step_size, 1)
  kernel = tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(kernel, num_warmup)
  kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
      kernel, num_warmup, target_accept_prob = 0.75,
      reduce_fn = tfp.math.reduce_log_harmonic_mean_exp)
  return kernel


In [None]:
def run_fits(num_seed, total_samples, initialize, kernel,
             num_super_chains, index_param, num_samples, num_warmup,
             state_is_list = False):
  # TODO: deprecate state_is_list argument.
  # if state_is_list:
  #   initial_state = initialize((2,), key = jax.random.PRNGKey(1))
  #   num_parameters = 0
  #   for j in range(0, index_param.shape[0]):
  #     if len(initial_state[j].shape) == 1: 
  #       num_parameters += 1
  #     else:
  #       num_parameters += initial_state[j].shape[1]
  # else:
  num_parameters = index_param.shape[0]

  Rhat_list = np.zeros((num_seed, num_parameters))
  nRhat_list = np.zeros((num_seed, num_parameters))
  B_list = np.zeros((num_seed, num_parameters))
  W_list = np.zeros((num_seed, num_parameters))
  mc_mean_list = np.zeros((num_seed, num_parameters))

  i = 0
  for seed in jax.random.split(jax.random.PRNGKey(1), num_seed):
    initial_state = initialize((num_super_chains,), key = seed + 1954)

    if (state_is_list):
      for j in range(0, len(initial_state)):
        initial_state[j] = np.repeat(initial_state[j],
                                     num_chains // num_super_chains, axis = 0)
    else:
      initial_state = np.repeat(initial_state, num_chains // num_super_chains,
                                axis = 0)

    result = tfp.mcmc.sample_chain(
      total_samples, initial_state, kernel = kernel,
      seed = seed)

    # if (state_is_list):
      # result_samples = result.all_states
      # print("Shape:", result_samples[0].shape)
      # for j in range(0, index_param.shape[0]):
      #   result_samples[j] = result_samples[j][num_warmup:]
      
      # Rhat_local_list = tfp.mcmc.potential_scale_reduction(result_samples)

      # param_index = 0
      # for j in range(0, index_param.shape[0]):
      #   if ()
      #   Rhat_list[i, param_index:]

      # Rhat_list[i, :] = tfp.mcmc.potential_scale_reduction(result_samples)
    # else:
    Rhat_list[i, :] = tfp.mcmc.potential_scale_reduction(
                result.all_states[num_warmup:(num_warmup + num_samples), :,
                                  index_param])

    # print(result.all_states[0].shape)

    # if state_is_list:
    #   for j in range(0, index_param.shape[0]):

    #     print(tfp.mcmc.potential_scale_reduction(
    #         result.all_states[j][num_warmup:(num_warmup + num_samples)]))

    #     Rhat_list[i, j] = tfp.mcmc.potential_scale_reduction(
    #         result.all_states[j][num_warmup:(num_warmup + num_samples)]
    #     )
    # else:
    #   Rhat_list[i, :] = tfp.mcmc.potential_scale_reduction(
    #     result.all_states[num_warmup:(num_warmup + num_samples), :,
    #                       index_param])


    nRhat_local, B_local, W_local = nested_rhat(result.all_states,
                                   num_super_chains = num_super_chains,
                                   index_param = index_param,
                                   num_samples = num_samples,
                                   warmup_length = num_warmup,
                                   state_is_list = state_is_list)

    nRhat_list[i, :] = nRhat_local
    B_list[i, :] = B_local
    W_list[i, :] = W_local

    mc_mean_list[i, :] = np.mean(result.all_states[num_warmup + 1, :,
                                                   index_param],
                        axis = 1)
    i += 1

  return Rhat_list, nRhat_list, B_list, W_list, mc_mean_list


### Adaptive warmup (forge chain)

In [None]:
def forge_chain (kernel_cold, kernel_warm, initial_state, num_super_chains,
                 num_warmup, num_samples,
                 target_rhat, max_num_steps, index_param, seed,
                 num_nRhat_comp = 1,
                 state_is_list = False):
  warmup_is_acceptable = False
  window_iteration = 0
  current_state = initial_state
  kernel_args = None

  while (not warmup_is_acceptable and window_iteration < max_num_steps):
    window_iteration += 1

    # 1) Run MCMC on warmup window.
    result_cold, trace, kernel_args = tfp.mcmc.sample_chain(
        num_results = num_warmup,
        current_state = current_state,
        kernel = kernel_cold,
        previous_kernel_results = kernel_args,
        trace_fn = lambda _, pkr: unnest.get_innermost(pkr, 'step_size'),
        return_final_kernel_results = True,
        seed = seed + window_iteration)  # Update seed during while loop

    current_state = result_cold[-1]

    # 2) Generate candidate samples.
    result_warm, trace = tfp.mcmc.sample_chain(
        num_results = num_samples * num_nRhat_comp,
        current_state = current_state,
        kernel = kernel_warm,
        trace_fn = lambda _, pkr: unnest.get_innermost(pkr, 'step_size'),
        previous_kernel_results = kernel_args,
        seed = seed + 999999)

    # 3) Check if candidate samples are acceptable.
    nRhat = np.zeros((index_param.shape[0], num_nRhat_comp))
    for i in range(0, num_nRhat_comp):
      nRhat[:, i], _B, _W = nested_rhat(result_warm[i:((i + 1) * num_samples)],
                                  num_super_chains = num_super_chains,
                                  index_param = index_param,
                                  num_samples = num_samples,
                                  state_is_list = state_is_list)
    
    nRhat_max = max(np.mean(nRhat, axis = 1))
    print(nRhat_max)

    if (nRhat_max < target_rhat): warmup_is_acceptable = True
    # (WHILE loop ends)

  return result_warm, window_iteration


In [None]:
def run_forge_chain (num_seed, kernel_cold, kernel_warm, initialize,
                     num_super_chains, num_warmup, num_samples,
                     target_rhat, max_num_steps, index_param,
                     num_nRhat_comp = 1,
                     state_is_list = False):
  mc_mean_list = np.zeros((num_seed, index_param.shape[0]))
  warmup_length = np.zeros(num_seed)

  i = 0
  for seed in jax.random.split(jax.random.PRNGKey(1), num_seed):
    print("NEW SEED")
    initial_state = initialize((num_super_chains,), key = seed + 1954)
    initial_state = np.repeat(initial_state, num_chains // num_super_chains,
                              axis = 0)

    result, window_iteration = forge_chain(kernel_cold, kernel_warm,
                                           initial_state, num_super_chains,
                                           num_warmup, num_samples,
                                           target_rhat, max_num_steps,
                                           index_param, seed,
                                           num_nRhat_comp, state_is_list)

    warmup_length[i] = window_iteration * num_warmup
    mc_mean_list[i, :] = np.mean(result[0, :, index_param],
                        axis = 1)
    
    i += 1

  return mc_mean_list, warmup_length

For this experiment, we compute $n \hat R$ using `n_samples = 5`, to stabilize estimators. To check if the chain is properly warmed up, we only examine the first sample.

In [None]:
num_chains = 128
num_super_chains = 4
num_samples = 5
num_seed = 30

If the chains converged (i.e. they behave as though indepdent from one another), then we expect the effective sample size "reported" by $n \hat R$ to be lower-bounded by the number of chains in each super-chain. This gives us an upper-bound for $n \hat R$:

\begin{eqnarray*}
  u_{n \hat R}  = \sqrt{1 + \frac{1}{M}}.
\end{eqnarray*}

In [None]:
# Compute a lower bound on nRhat if the chains behave as though independent.
nRhat_upper = np.sqrt(1 + 1 / (num_chains / num_super_chains))
print("Convergence upper bound for nRhat:", nRhat_upper) 

The next line of code is for the adaptive warmup scheme.

## Banana

In [None]:
target = gym.targets.VectorModel(gym.targets.Banana(),
                                  flatten_sample_transformations=True)
num_dimensions = target.event_shape[0]  
init_step_size = 1.

def target_log_prob_fn(x):
  """Unnormalized, unconstrained target density.
  
  This is a thin wrapper that applies the default bijectors so that we can
  ignore any constraints.
  """
  y = target.default_event_space_bijector(x)
  fldj = target.default_event_space_bijector.forward_log_det_jacobian(x)
  return target.unnormalized_log_prob(y) + fldj

# NOTE: Avoid initials centered around the true mean. 
offset = 2
def initialize (shape, key = random.PRNGKey(37272709)):
  return 10 * random.normal(key, shape + (num_dimensions,)) + offset

In [None]:
# Get some estimates of the mean and variance.
try:
  mean_est = target.sample_transformations['identity'].ground_truth_mean
except:
  print('no ground truth mean')
  mean_est = (result.all_states[num_warmup:, :]).mean(0).mean(0)
try:
  var_est = target.sample_transformations['identity'].ground_truth_standard_deviation**2
except:
  print('no ground truth std dev')
  var_est = ((result.all_states[num_warmup:, :]**2).mean(0).mean(0) -
             mean_est**2)

In [None]:
# Define MCMC kernel
num_warmup, num_sampling = 10, 10
total_samples = num_warmup + num_sampling

kernel = construct_kernel(target_log_prob_fn = target_log_prob_fn,
                          init_step_size = init_step_size,
                          num_warmup = num_warmup)

In [None]:
index_param = np.array([0, 1])

### $n \hat R$ diagnostic

In [None]:
Rhat_list, nRhat_list, B_list, W_list, mc_mean_list = run_fits(
           num_seed = num_seed, total_samples = total_samples,
           initialize = initialize, kernel = kernel,
           num_super_chains = num_super_chains, index_param = index_param,
           num_samples = num_samples, num_warmup = num_warmup)

In [None]:
square_error = np.square(mc_mean_list - mean_est[index_param])
expected_error = var_est[index_param] / num_chains

fig = figure(figsize =(6, index_param.shape[0]))
ax = fig.add_axes([0, 0, 1, 1])
plot_data = [square_error[:, 0] / expected_error[0], 
             square_error[:, 1] / expected_error[1]]
ax.boxplot(plot_data)
axhline(y = 1, linestyle = "--")
title("Squared Error over expected squared error (if chains converged)")
show()

In [None]:
fig = figure(figsize =(6, 6))
ax = fig.add_axes([0, 0, 1, 1])
plot_data = [nRhat_list[:, 0], Rhat_list[:, 0],
             nRhat_list[:, 1], Rhat_list[:, 1]]
axvline(x = 1, color = 'y', linestyle = '--', linewidth = 0.5,
        label = "x = 1")
axvline(x = 1.05, color = 'r', linestyle = '--', linewidth = 0.5,
        label = "x = 1.05")
axvline(x = nRhat_upper, color = 'c', linestyle = '--', linewidth = 0.5,
        label = "x = u")
ax.boxplot(plot_data, vert = 0)
ax.set_yticklabels(['nRhat[0]', 'Rhat[0]',
                    'nRhat[1]', 'Rhat[1]'])
title('Warmup = ' + str(num_warmup) + ', samples = ' + str(num_samples))
legend(loc = "best")
show()

In [None]:
# Let's take a closer look at the nRhat's and compare them
# to what we would expect from independent samples.
fig = figure(figsize = (6, 2))
ax = fig.add_axes([0, 0, 1, 1])
plot_data = [nRhat_list[:, 0], nRhat_list[:, 1]]
axvline(x = 1, color = 'y', linestyle = '--', linewidth = 0.5,
        label = "x = 1")
axvline(x = nRhat_upper, color = 'c', linestyle = '--', linewidth = 0.5,
        label = "x = u")
ax.boxplot(plot_data, vert = 0)
ax.set_yticklabels(['nRhat[0]', 'nRhat[1]'])
title('Warmup = ' + str(num_warmup) + ', samples = ' + str(num_samples))
legend(loc = "best")
show()

### Adaptive warmup length

Remark: using a short window gives the algorithm more opportunities to stop, which means start trying our luck. A reasonable compromise seems to use a window of size 100.

In [None]:
warmup_window = 100

kernel_cold = construct_kernel(target_log_prob_fn = target_log_prob_fn,
                               init_step_size = init_step_size,
                               num_warmup = warmup_window)

kernel_warm = construct_kernel(target_log_prob_fn = target_log_prob_fn,
                               init_step_size = init_step_size,
                               num_warmup = 0)

num_samples = 3

In [None]:
mc_mean_list, warmup_length = run_forge_chain(num_seed = num_seed,
                                          kernel_cold = kernel_cold,
                                          kernel_warm = kernel_warm,
                                          initialize = initialize,
                                          num_super_chains = num_super_chains,
                                          num_warmup = warmup_window,
                                          num_samples = num_samples,
                                          target_rhat = nRhat_upper,
                                          max_num_steps = 1000 // warmup_window,
                                          index_param = index_param,
                                          num_nRhat_comp = 3)


In [None]:
square_error = np.square(mc_mean_list - mean_est[index_param])
expected_error = var_est[index_param] / num_chains

plot_data = [square_error[:, 0] / expected_error[0]]
for i in range(1, index_param.shape[0]):
  plot_data.append(square_error[:, i] / expected_error[i])

fig = figure(figsize =(6, 3))
ax = fig.add_axes([0, 0, 1, 1])
ax.boxplot(plot_data)
axhline(y = 1, linestyle = "--")
title("Squared Error over expected squared error (if chains converged)")
show()

In [None]:
scatter(warmup_length, square_error[:, 0] / expected_error[0])
axhline(y = 1, linestyle = "--")
show()

## German credit score

In [None]:
target = gym.targets.VectorModel(gym.targets.GermanCreditNumericLogisticRegression(),
                                  flatten_sample_transformations=True)
num_dimensions = target.event_shape[0]  
init_step_size = 0.02

def target_log_prob_fn(x):
  """Unnormalized, unconstrained target density.
  
  This is a thin wrapper that applies the default bijectors so that we can
  ignore any constraints.
  """
  y = target.default_event_space_bijector(x)
  fldj = target.default_event_space_bijector.forward_log_det_jacobian(x)
  return target.unnormalized_log_prob(y) + fldj

offset = 0.1
def initialize (shape, key = random.PRNGKey(37272709)):
  return random.normal(key, shape + (num_dimensions,)) + offset

In [None]:
# Get some estimates of the mean and variance.
try:
  mean_est = target.sample_transformations['identity'].ground_truth_mean
except:
  print('no ground truth mean')
  mean_est = (result.all_states[num_warmup:, :]).mean(0).mean(0)
try:
  var_est = target.sample_transformations['identity'].ground_truth_standard_deviation**2
except:
  print('no ground truth std dev')
  var_est = ((result.all_states[num_warmup:, :]**2).mean(0).mean(0) -
             mean_est**2)

In [None]:
num_seed = 10
num_samples = 5
num_warmup, num_sampling = 500, 10
total_samples = num_warmup + num_sampling

kernel = construct_kernel(target_log_prob_fn = target_log_prob_fn,
                          init_step_size = init_step_size,
                          num_warmup = num_warmup)

In [None]:
index_param = np.arange(0, 25)

### $n \hat R$ diagnostic

In [None]:
Rhat_list, nRhat_list, B_list, W_list, mc_mean_list = run_fits(
           num_seed = num_seed, total_samples = total_samples,
           initialize = initialize, kernel = kernel,
           num_super_chains = num_super_chains, index_param = index_param,
           num_samples = num_samples, num_warmup = num_warmup)

In [None]:
# OPTIONAL: correct Rhat to avoid having negative values.
if True:
  Rhat_list = np.sqrt(np.square(Rhat_list) - (num_samples - 1) / num_samples + 1)

In [None]:
square_error = np.square(mc_mean_list - mean_est[index_param])
expected_error = var_est[index_param] / num_chains

plot_data = [square_error[:, 0] / expected_error[0]]
for i in range(1, index_param.shape[0]):
  plot_data.append(square_error[:, i] / expected_error[i])

fig = figure(figsize =(6, 3))
ax = fig.add_axes([0, 0, 1, 1])
ax.boxplot(plot_data)
axhline(y = 1, linestyle = "--")
title("Squared Error over expected squared error (if chains converged)")
show()

In [None]:
plot_data = [nRhat_list[:, 0], Rhat_list[:, 0]]
ylabels = ['nRhat[0]', 'Rhat[0]']
plot_data_nRhat = [nRhat_list[:, 0]]
ylabels_nRhat = ['nRhat[0]']
for i in range(1, index_param.shape[0]):
  plot_data.append(nRhat_list[:, i])
  plot_data.append(Rhat_list[:, i])
  ylabels.append(('nRhat[' + str(i) + ']'))
  ylabels.append(('Rhat[' + str(i) + ']'))

  plot_data_nRhat.append(nRhat_list[:, i])
  ylabels_nRhat.append(('nRhat[' + str(i) + ']'))


In [None]:
fig = figure(figsize =(6, 0.5 * index_param.shape[0]))
ax = fig.add_axes([0, 0, 1, 1])
axvline(x = 1, color = 'y', linestyle = '--', linewidth = 0.5,
        label = "x = 1")
axvline(x = 1.05, color = 'r', linestyle = '--', linewidth = 0.5,
        label = "x = 1.05")
axvline(x = nRhat_upper, color = 'c', linestyle = '--', linewidth = 0.5,
        label = "x = u")
ax.boxplot(plot_data, vert = 0)
ax.set_yticklabels(ylabels)
title('Warmup = ' + str(num_warmup) + ', samples = ' + str(num_samples))
legend(loc = "best")
show()

In [None]:
# Let's take a closer look at the nRhat's and compare them
# to what we would expect from independent samples.
fig = figure(figsize =(6, 0.5 * index_param.shape[0]))
ax = fig.add_axes([0, 0, 1, 1])
axvline(x = 1, color = 'y', linestyle = '--', linewidth = 0.5,
        label = "x = 1")
axvline(x = nRhat_upper, color = 'c', linestyle = '--', linewidth = 0.5,
        label = "x = u")
ax.boxplot(plot_data_nRhat, vert = 0)
ax.set_yticklabels(ylabels_nRhat)
title('Warmup = ' + str(num_warmup) + ', samples = ' + str(num_samples))
legend(loc = "best")
show()

### Adapative warmup length

In [None]:
warmup_window = 100
max_warmup = 1000

kernel_cold = construct_kernel(target_log_prob_fn = target_log_prob_fn,
                               init_step_size = init_step_size,
                               num_warmup = max_warmup)

kernel_warm = construct_kernel(target_log_prob_fn = target_log_prob_fn,
                               init_step_size = init_step_size,
                               num_warmup = 0)

num_samples = 5

In [None]:
mc_mean_list, warmup_length = run_forge_chain(num_seed = num_seed,
                                  kernel_cold = kernel_cold,
                                  kernel_warm = kernel_warm,
                                  initialize = initialize,
                                  num_super_chains = num_super_chains,
                                  num_warmup = warmup_window,
                                  num_samples = num_samples,
                                  target_rhat = nRhat_upper,
                                  max_num_steps = max_warmup // warmup_window,
                                  index_param = index_param,
                                  num_nRhat_comp = 1)


In [None]:
square_error = np.square(mc_mean_list - mean_est[index_param])
expected_error = var_est[index_param] / num_chains

plot_data = [square_error[:, 0] / expected_error[0]]
for i in range(1, index_param.shape[0]):
  plot_data.append(square_error[:, i] / expected_error[i])

fig = figure(figsize =(6, 3))
ax = fig.add_axes([0, 0, 1, 1])
ax.boxplot(plot_data)
axhline(y = 1, linestyle = "--")
title("Squared Error over expected squared error (if chains converged)")
show()

In [None]:
scatter(warmup_length, square_error[:, 0] / expected_error[0])
axhline(y = 1, linestyle = "--")
show()

### Draft

In [None]:
seed = jax.random.PRNGKey(1954)
initial_state = initialize((num_super_chains,), key = seed + 1954)
initial_state = np.repeat(initial_state, num_chains // num_super_chains,
                          axis = 0)
current_state = initial_state
kernel_args = None

In [None]:
num_warmup = 250

In [None]:
result_cold, trace0, kernel_args = tfp.mcmc.sample_chain(
        num_results = num_warmup,
        current_state = current_state,
        kernel = kernel_cold,
        previous_kernel_results = kernel_args,
        trace_fn = lambda _, pkr: unnest.get_innermost(pkr, 'step_size'),
        return_final_kernel_results = True,
        seed = seed)  # Update seed during while loop

current_state = result_cold[-1]


In [None]:
plot(trace0)

In [None]:
result_cold, trace1, kernel_args = tfp.mcmc.sample_chain(
        num_results = num_warmup,
        current_state = current_state,
        kernel = kernel_cold,
        previous_kernel_results = kernel_args,
        trace_fn = lambda _, pkr: unnest.get_innermost(pkr, 'step_size'),
        return_final_kernel_results = True,
        seed = seed)  # Update seed during while loop

current_state = result_cold[-1]

In [None]:
plot(trace1)

In [None]:
result_warm, trace = tfp.mcmc.sample_chain(
        num_results = num_samples,
        current_state = current_state,
        kernel = kernel_warm,
        trace_fn = lambda _, pkr: unnest.get_innermost(pkr, 'step_size'),
        previous_kernel_results = kernel_args,
        seed = seed + 999999)

In [None]:
nRhat, _B, _W = nested_rhat(result_warm, num_super_chains, index_param, num_samples)
nRhat

## Eight Schools

In [None]:
# NOTE: inference gym stores the centered parameterization
target_raw = gym.targets.EightSchools()  # store raw to examine doc.
target = gym.targets.VectorModel(target_raw,
                                  flatten_sample_transformations = True)
num_dimensions = target.event_shape[0]
init_step_size = 1

# Using underdispersed initis can show case problems with our diagnostics.
# underdispered = False
# Options: underdispersed, overdispersed, prior
init_type = "prior"
if init_type == "underdispersed":
  offset = 0.0
  def initialize (shape, key = random.PRNGKey(37272709)):
    return 1 * random.normal(key, shape + (num_dimensions,)) + offset
elif init_type == "overdispersed":
  offset = 0.0
  def initialize (shape, key = random.PRNGKey(37272709)):
    return 100 * random.normal(key, shape + (num_dimensions,)) + offset
elif init_type == "prior":
  def initialize (shape, key = random.PRNGKey(37272709)):
    prior_scale = jnp.append(jnp.array([10., 1.]), jnp.repeat(1., 8))
    prior_offset = jnp.append(jnp.array([0., 5.]), jnp.repeat(0., 8))
    return prior_scale * random.normal(key, shape + (num_dimensions,)) + prior_offset


In [None]:
num_schools = 8
y = np.array([28, 8, -3, 7, -1, 1, 18, 12], dtype = np.float32)
sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18], dtype = np.float32)

# NOTE: the reinterpreted batch dimension specifies the dimension of
# each indepdent variable, here the school.
model = tfd.JointDistributionSequential([
    tfd.Normal(loc = 0., scale = 10., name = "mu"),
    tfd.Normal(loc = 5., scale = 1., name = "log_tau"),
    tfd.Independent(tfd.Normal(loc = jnp.zeros(num_schools),
                               scale = jnp.ones(num_schools),
                               name = "eta"),
                    reinterpreted_batch_ndims = 1),
    lambda eta, log_tau, mu: (
        tfd.Independent(tfd.Normal(loc = (mu[..., jnp.newaxis] +
                                        jnp.exp(log_tau[..., jnp.newaxis]) *
                                        eta),
                                   scale = sigma),
                        name = "y",
                        reinterpreted_batch_ndims = 1))
  ])

def target_log_prob_fn(x):
  mu = x[:, 0]
  log_tau = x[:, 1]
  eta = x[:, 2:10]
  return model.log_prob((mu, log_tau, eta, y))


In [None]:
# Use results from running 128 chains with 1000 + 5000 iterations each,
# for non-centered parameterization.
mean_est = np.array([5.8006573 ,  2.4502006 ,  0.6532423 ,  0.09639207,
             -0.23725411,  0.04723661, -0.33556408, -0.19666635,
              0.5390533 ,  0.14633301])

var_est = np.array([29.60382   ,  0.26338503,  0.6383733 ,  0.4928926 ,
              0.65307987,  0.52441144,  0.46658015,  0.5248887 ,
              0.49544162,  0.690975])

In [None]:
index_param = np.arange(0, 10)

### $n \hat R$ diagnostic

In [None]:
num_seed = 10
num_samples = 5
num_warmup, num_sampling = 5, 10
total_samples = num_warmup + num_sampling

kernel = construct_kernel(target_log_prob_fn = target_log_prob_fn,
                          init_step_size = init_step_size,
                          num_warmup = num_warmup)


In [None]:
Rhat_list, nRhat_list, B_list, W_list, mc_mean_list = run_fits(
           num_seed = num_seed, total_samples = total_samples,
           initialize = initialize, kernel = kernel,
           num_super_chains = num_super_chains, index_param = index_param,
           num_samples = num_samples, num_warmup = num_warmup)

In [None]:
# OPTIONAL: correct Rhat to avoid having negative values.
if True:
  Rhat_list = np.sqrt(np.square(Rhat_list) - (num_samples - 1) / num_samples + 1)

In [None]:
square_error = np.square(mc_mean_list - mean_est[index_param])
expected_error = var_est[index_param] / num_chains

plot_data = [square_error[:, 0] / expected_error[0]]
for i in range(1, index_param.shape[0]):
  plot_data.append(square_error[:, i] / expected_error[i])

fig = figure(figsize =(6, 3))
ax = fig.add_axes([0, 0, 1, 1])
ax.boxplot(plot_data)
axhline(y = 1, linestyle = "--")
title("Squared Error over expected squared error (if chains converged)")
show()

In [None]:
plot_data = [nRhat_list[:, 0], Rhat_list[:, 0]]
ylabels = ['nRhat[0]', 'Rhat[0]']
plot_data_nRhat = [nRhat_list[:, 0]]
ylabels_nRhat = ['nRhat[0]']
for i in range(1, index_param.shape[0]):
  plot_data.append(nRhat_list[:, i])
  plot_data.append(Rhat_list[:, i])
  ylabels.append(('nRhat[' + str(i) + ']'))
  ylabels.append(('Rhat[' + str(i) + ']'))

  plot_data_nRhat.append(nRhat_list[:, i])
  ylabels_nRhat.append(('nRhat[' + str(i) + ']'))


In [None]:
fig = figure(figsize =(6, 0.5 * index_param.shape[0]))
ax = fig.add_axes([0, 0, 1, 1])
axvline(x = 1, color = 'y', linestyle = '--', linewidth = 0.5,
        label = "x = 1")
axvline(x = 1.05, color = 'r', linestyle = '--', linewidth = 0.5,
        label = "x = 1.05")
axvline(x = nRhat_upper, color = 'c', linestyle = '--', linewidth = 0.5,
        label = "x = u")
ax.boxplot(plot_data, vert = 0)
ax.set_yticklabels(ylabels)
title('Warmup = ' + str(num_warmup) + ', samples = ' + str(num_samples))
legend(loc = "best")
show()

In [None]:
# Let's take a closer look at the nRhat's and compare them
# to what we would expect from independent samples.
fig = figure(figsize =(6, 0.5 * index_param.shape[0]))
ax = fig.add_axes([0, 0, 1, 1])
axvline(x = 1, color = 'y', linestyle = '--', linewidth = 0.5,
        label = "x = 1")
axvline(x = nRhat_upper, color = 'c', linestyle = '--', linewidth = 0.5,
        label = "x = u")
axvline(x = 1.05, color = 'r', linestyle = '--', linewidth = 0.5,
        label = "x = 1.05")
ax.boxplot(plot_data_nRhat, vert = 0)
ax.set_yticklabels(ylabels_nRhat)
title('Warmup = ' + str(num_warmup) + ', samples = ' + str(num_samples))
legend(loc = "best")
show()

### Adapative warmup length

In [None]:
warmup_window = 100
max_warmup = 1000

kernel_cold = construct_kernel(target_log_prob_fn = target_log_prob_fn,
                               init_step_size = init_step_size,
                               num_warmup = max_warmup)

kernel_warm = construct_kernel(target_log_prob_fn = target_log_prob_fn,
                               init_step_size = init_step_size,
                               num_warmup = 0)

num_samples = 5

In [None]:
mc_mean_list, warmup_length = run_forge_chain(num_seed = num_seed,
                                          kernel_cold = kernel_cold,
                                          kernel_warm = kernel_warm,
                                          initialize = initialize,
                                          num_super_chains = num_super_chains,
                                          num_warmup = warmup_window,
                                          num_samples = num_samples,
                                          target_rhat = nRhat_upper,
                                          max_num_steps = max_warmup // warmup_window,
                                          index_param = index_param)


In [None]:
square_error = np.square(mc_mean_list - mean_est[index_param])
expected_error = var_est[index_param] / num_chains

plot_data = [square_error[:, 0] / expected_error[0]]
for i in range(1, index_param.shape[0]):
  plot_data.append(square_error[:, i] / expected_error[i])

fig = figure(figsize =(6, 3))
ax = fig.add_axes([0, 0, 1, 1])
ax.boxplot(plot_data)
axhline(y = 1, linestyle = "--")
title("Squared Error over expected squared error (if chains converged)")
show()

In [None]:
scatter(warmup_length, square_error[:, 0] / expected_error[0])
axhline(y = 1, linestyle = "--")
show()

## Pharmacokinetic model

### Simulate data

In [None]:
time_after_dose = np.array([0.083, 0.167, 0.25, 0.5, 0.75, 1, 1.5, 2, 3, 4, 6, 8])

t = np.append(
    np.append(np.append(np.append(0., time_after_dose),
                          np.append(12., time_after_dose + 12)),
               np.linspace(start = 24, stop = 156, num = 12)),
               np.append(jnp.append(168., 168. + time_after_dose),
               np.array([180, 192])))

start_event = np.array([], dtype = int)
dosing_time = range(0, 192, 12)

# Use dosing events to determine times of integration between
# exterior interventions on the system.
eps = 1e-4  # hack to deal with some t being slightly offset.
for t_dose in dosing_time:
  start_event = np.append(start_event, np.where(abs(t - t_dose) <= eps))

amt = jnp.array([1000., 0.])
n_dose = start_event.shape[0]

start_event = np.append(start_event, t.shape[0] - 1)

In [None]:
n_patients = 100
pop_location = jnp.log(jnp.array([1.5, 0.25]))
pop_scale = jnp.array([0.15, 0.35])
theta_patient = jnp.exp(pop_scale * random.normal(random.PRNGKey(37272709), 
                          (n_patients, ) + (2,)) + pop_location)

amt = np.array([1000., 0.])
amt_patient = np.append(np.repeat(amt[0], n_patients),
                        np.repeat(amt[1], n_patients))
amt_patient = amt_patient.reshape(2, n_patients)

# redfine variables from previous section (in case we only run population model)
t_jax = jnp.array(t)
amt_vec = np.repeat(0., t.shape[0])
amt_vec[start_event] = 1000
amt_vec[amt_vec.shape[0] - 1] = 0.
amt_vec_jax = jnp.array(amt_vec)

In [None]:
# TODO: remove 'use_second_axis' hack.
def ode_map (theta, dt, current_state, use_second_axis = False):
  if (use_second_axis):
    k1 = theta[0, :]
    k2 = theta[1, :]
  else: 
    k1 = theta[:, 0]
    k2 = theta[:, 1]

  y0_hat = jnp.exp(- k1 * dt) * current_state[0, :]
  y1_hat = jnp.exp(- k2 * dt) / (k1 - k2) * (current_state[0, :] * k1 *\
                (1 - jnp.exp((k2 - k1) * dt)) + (k1 - k2) * current_state[1, :])
  return jnp.array([y0_hat, y1_hat])

# @jax.jit  # Cannot use jit if function has an IF statement.
def ode_map_event(theta, use_second_axis = False):
  def ode_map_step (current_state, event_index):
    dt = t_jax[event_index] - t_jax[event_index - 1]
    y_sln = ode_map(theta, dt, current_state, use_second_axis)
    dose = jnp.repeat(amt_vec_jax[event_index], n_patients)
    y_after_dose = y_sln + jnp.append(jnp.repeat(amt_vec_jax[event_index], n_patients),
                                      jnp.repeat(0., n_patients)).reshape(2, n_patients)
    return (y_after_dose, y_sln[1, ])

  (__, yhat) = jax.lax.scan(ode_map_step, amt_patient, 
                            np.array(range(1, t.shape[0])),
                            unroll = 20)
  return yhat

In [None]:
# Simulate some data
y_hat = ode_map_event(theta_patient)

sigma = 0.1
# NOTE: no observation at time t = 0.
log_y = sigma * random.normal(random.PRNGKey(1954), y_hat.shape) \
  + jnp.log(y_hat)
y_obs = jnp.exp(log_y)

figure(figsize = [6, 6])
plot(t[1:], y_hat)
plot(t[1:], y_obs, 'o', markersize = 2)
show()

### Fit model with TFP

As a golden benchmark, we use 1,000 samples, after warming up the chain for 1,000 iterations. With 128 chains, this gives us a total of 128,000 approximate samples.

In [None]:
pop_model = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc = jnp.log(1.), scale = 0.1, name = "log_k1_pop"),
    tfd.Normal(loc = jnp.log(0.3), scale = 0.1, name = "log_k2_pop"),
    tfd.Normal(loc = jnp.log(0.15), scale = 0.1, name = "log_scale_k1"),
    tfd.Normal(loc = jnp.log(0.35), scale = 0.1, name = "log_scale_k2"),
    tfd.Normal(loc = -1., scale = 1., name = "log_sigma"),

    # non-centered parameterization for hierarchy
    tfd.Independent(tfd.Normal(loc = jnp.zeros(n_patients),
                               scale = jnp.ones(n_patients),
                               name = "eta_k1"),
                    reinterpreted_batch_ndims = 1),
    
    tfd.Independent(tfd.Normal(loc = jnp.zeros(n_patients),
                               scale = jnp.ones(n_patients),
                               name = "eta_k2"),
                    reinterpreted_batch_ndims = 1),

    lambda eta_k2, eta_k1, log_sigma, log_scale_k2, log_scale_k1,
           log_k2_pop, log_k1_pop: (
      tfd.Independent(tfd.LogNormal(
          loc = jnp.log(
              ode_map_event(theta = jnp.array([
                  jnp.exp(log_k1_pop[..., jnp.newaxis] + eta_k1 * jnp.exp(log_scale_k1[..., jnp.newaxis])),
                  jnp.exp(log_k2_pop[..., jnp.newaxis] + eta_k2 * jnp.exp(log_scale_k2[..., jnp.newaxis]))]),
                  use_second_axis = True)),
          scale = jnp.exp(log_sigma[..., jnp.newaxis]), name = "y_obs")))
])

def pop_target_log_prob_fn(log_k1_pop, log_k2_pop, log_scale_k1, log_scale_k2,
                           log_sigma, eta_k1, eta_k2):
  return pop_model.log_prob((log_k1_pop, log_k2_pop, log_scale_k1, log_scale_k2,
                            log_sigma, eta_k1, eta_k2, y_obs))
  # CHECK -- do we need to parenthesis?

def pop_target_log_prob_fn_flat(x):
  log_k1_pop = x[:, 0]
  log_k2_pop = x[:, 1]
  log_scale_k1 = x[:, 2]
  log_scale_k2 = x[:, 3]
  log_sigma = x[:, 4]
  eta_k1 = x[:, 5:(5 + n_patients)]
  eta_k2 = x[:, (5 + n_patients):(5 + 2 * n_patients)]

  return pop_model.log_prob((log_k1_pop, log_k2_pop, log_scale_k1, log_scale_k2,
                           log_sigma, eta_k1, eta_k2, y_obs))


In [None]:
def initialize (shape, key = random.PRNGKey(37272709)):
  return pop_model.sample(sample_shape = shape, # (num_super_chains, 1),\
                          seed = key)[:7]


In [None]:
num_dimensions = 5 + 2 * n_patients
def initialize_flat (shape, key = random.PRNGKey(37272709)):
  initial = initialize(shape, key)
  # initial_flat = np.zeros((shape, num_dimensions))
  initial_flat = np.zeros(shape + (num_dimensions,))
  for i in range(0, 5):
    initial_flat[:, i] = initial[i]
  initial_flat[:, 5:(5 + n_patients)] = initial[5]
  initial_flat[:, (5 + n_patients):(5 + 2 * n_patients)] = initial[6]

  return initial_flat


In [None]:
num_seed = 10
num_samples = 5
num_warmup, num_sampling = 1000, 1000
total_samples = num_warmup + num_sampling
init_step_size = 0.001

kernel = construct_kernel(target_log_prob_fn = pop_target_log_prob_fn,
                          init_step_size = init_step_size,
                          num_warmup = num_warmup)

In [None]:
def trace_fn(current_state, pkr):
  return (
    # proxy for divergent transitions
    get_innermost(pkr, 'log_accept_ratio') < -1000,
    get_innermost(pkr, 'step_size'),
    get_innermost(pkr, 'max_trajectory_length')
  )


In [None]:
initial_state = initialize((num_super_chains,))
for i in range(0, len(initial_state)):
  initial_state[i] = np.repeat(initial_state[i],
                               num_chains // num_super_chains, axis = 0)

In [None]:
mcmc_states, diverged = tfp.mcmc.sample_chain(
    num_results = total_samples,
    current_state = initial_state,
    kernel = kernel,
    trace_fn = trace_fn,
    seed = random.PRNGKey(1954))


### Check the inference

Check that the inference is reliable. If it, use it to construct a "golden benchmark".

In [None]:
print("Divergent transitions after warmup:",
      np.sum(diverged[0][num_warmup:(num_warmup + num_samples)]))

In [None]:
# Extract samples after warmup from the list
mcmc_states_sample = mcmc_states
for i in range(0, len(mcmc_states)):
  mcmc_states_sample[i] = mcmc_states[i][num_warmup:]

In [None]:
# NOTE: the last parameter is an 'x': not sure where this comes from...
parameter_names = pop_model._flat_resolve_names()[:-1]

az_states = az.from_dict(
    #prior = {k: v[tf.newaxis, ...] for k, v in zip(parameter_names, prior_samples)},
    posterior={
        k: np.swapaxes(v, 0, 1) for k, v in zip(parameter_names, mcmc_states)
    },
)

fit_summary = az.summary(az_states).filter(items=["mean", "sd", "mcse_sd", "hdi_3%", 
                                       "hdi_97%", "ess_bulk", "ess_tail", 
                                       "r_hat"])

In [None]:
fit_summary

In [None]:
num_dimensions = fit_summary.shape[0]
mean_est = np.zeros(num_dimensions)
var_est = np.zeros(num_dimensions)

for i in range(0, num_dimensions):
  mean_est[i] = fit_summary.iat[i, 0]
  var_est[i] = np.square(fit_summary.iat[i, 1])

### $n \hat R$ diagnostic

In [None]:
num_seed = 10
num_samples = 5
num_warmup, num_sampling = 1000, 10
total_samples = num_warmup + num_sampling
init_step_size = 0.001

kernel = construct_kernel(target_log_prob_fn = pop_target_log_prob_fn_flat,
                          init_step_size = init_step_size,
                          num_warmup = num_warmup)

In [None]:
index_param = np.arange(0, 205)

In [None]:
Rhat_list, nRhat_list, B_list, W_list, mc_mean_list = run_fits(
           num_seed = num_seed, total_samples = total_samples,
           initialize = initialize_flat, kernel = kernel,
           num_super_chains = num_super_chains, index_param = index_param,
           num_samples = num_samples, num_warmup = num_warmup,
           state_is_list = False)

In [None]:
# OPTIONAL: correct Rhat to avoid having negative values.
if True:
  Rhat_list = np.sqrt(np.square(Rhat_list) - (num_samples - 1) / num_samples + 1)

In [None]:
square_error = np.square(mc_mean_list - mean_est[index_param])
expected_error = var_est[index_param] / num_chains

plot_data = [square_error[:, 0] / expected_error[0]]
for i in range(1, index_param.shape[0]):
  plot_data.append(square_error[:, i] / expected_error[i])

fig = figure(figsize =(6, 3))
ax = fig.add_axes([0, 0, 1, 1])
ax.boxplot(plot_data)
axhline(y = 1, linestyle = "--")
title("Squared Error over expected squared error (if chains converged)")
show()

In [None]:
plot_data = [nRhat_list[:, 0], Rhat_list[:, 0]]
ylabels = ['nRhat[0]', 'Rhat[0]']
plot_data_nRhat = [nRhat_list[:, 0]]
ylabels_nRhat = ['nRhat[0]']
for i in range(1, index_param.shape[0]):
  plot_data.append(nRhat_list[:, i])
  plot_data.append(Rhat_list[:, i])
  ylabels.append(('nRhat[' + str(i) + ']'))
  ylabels.append(('Rhat[' + str(i) + ']'))

  plot_data_nRhat.append(nRhat_list[:, i])
  ylabels_nRhat.append(('nRhat[' + str(i) + ']'))


In [None]:
fig = figure(figsize =(6, 0.5 * index_param.shape[0]))
ax = fig.add_axes([0, 0, 1, 1])
axvline(x = 1, color = 'y', linestyle = '--', linewidth = 0.5,
        label = "x = 1")
axvline(x = 1.05, color = 'r', linestyle = '--', linewidth = 0.5,
        label = "x = 1.05")
axvline(x = nRhat_upper, color = 'c', linestyle = '--', linewidth = 0.5,
        label = "x = u")
ax.boxplot(plot_data, vert = 0)
ax.set_yticklabels(ylabels)
title('Warmup = ' + str(num_warmup) + ', samples = ' + str(num_samples))
legend(loc = "best")
show()

In [None]:
# Let's take a closer look at the nRhat's and compare them
# to what we would expect from independent samples.
fig = figure(figsize =(6, 0.5 * index_param.shape[0]))
ax = fig.add_axes([0, 0, 1, 1])
axvline(x = 1, color = 'y', linestyle = '--', linewidth = 0.5,
        label = "x = 1")
axvline(x = nRhat_upper, color = 'c', linestyle = '--', linewidth = 0.5,
        label = "x = u")
axvline(x = 1.05, color = 'r', linestyle = '--', linewidth = 0.5,
        label = "x = 1.05")
ax.boxplot(plot_data_nRhat, vert = 0)
ax.set_yticklabels(ylabels_nRhat)
title('Warmup = ' + str(num_warmup) + ', samples = ' + str(num_samples))
legend(loc = "best")
show()

### Adaptive warmup

In [None]:
warmup_window = 100
max_warmup = 2000

kernel_cold = construct_kernel(target_log_prob_fn = pop_target_log_prob_fn_flat,
                          init_step_size = init_step_size,
                          num_warmup = max_warmup)

kernel_warm = construct_kernel(target_log_prob_fn = pop_target_log_prob_fn_flat,
                          init_step_size = init_step_size,
                          num_warmup = max_warmup)

num_samples = 5

In [None]:
mc_mean_list, warmup_length = run_forge_chain(num_seed = num_seed,
                                          kernel_cold = kernel_cold,
                                          kernel_warm = kernel_warm,
                                          initialize = initialize_flat,
                                          num_super_chains = num_super_chains,
                                          num_warmup = warmup_window,
                                          num_samples = num_samples,
                                          target_rhat = nRhat_upper,
                                          max_num_steps = max_warmup // warmup_window,
                                          index_param = index_param)


# Draft code

In [None]:
nRhat, _B, _W = nested_rhat(mcmc_states,
                    num_super_chains, index_param, num_samples,
                    warmup_length = num_warmup,
                    state_is_list = True)

In [None]:
seed = jax.random.PRNGKey(1954)

initial_state = initialize((num_super_chains,), key = seed)
initial_state = np.repeat(initial_state, num_chains // num_super_chains,
                          axis = 0)

result = tfp.mcmc.sample_chain(
      total_samples, initial_state, kernel = kernel,
      seed = seed)

np.mean(result.all_states[num_warmup:, :, :], axis = (0, 1))
# np.var(result.all_states[num_warmup:, :, :], axis = (0, 1))

In [None]:
# Examine B and W
B_rescale = B_list / (2 * (nRhat_upper - 1))
fig = figure(figsize = (6, 2))
ax = fig.add_axes([0, 0, 1, 1])
plot_data = [B_rescale[:, 0], W_list[:, 0]]
ax.boxplot(plot_data, vert = 1)
ax.set_xticklabels(['B_rescale[0]', 'W[0]'])
axhline(y = var_est[0], linestyle = '--', linewidth = 0.5,
        label = "x = 1")
show()

initial_state = initialize((num_super_chains,), key = jax.random.PRNGKey(1954))
initial_state = np.repeat(initial_state, num_chains // num_super_chains,
                          axis = 0)

result = tfp.mcmc.sample_chain(
      total_samples, initial_state, kernel = kernel,
      seed = jax.random.PRNGKey(1954))

num_sub_chains = num_chains // num_super_chains
state_param = result.all_states[:, :, 0].reshape(total_samples, -1, num_sub_chains, 1)
mean_super_chain = np.mean(state_param, axis = (2, 3))
plot(mean_super_chain)
show()

B = np.var(mean_super_chain, axis = 1)
plot(B)
show()

In [None]:
# Draft
result, num_windows = forge_chain(kernel_cold = kernel_cold,
                                  kernel_warm = kernel_warm,
                                  initial_state = initial_state,
                                  num_super_chains = num_super_chains,
                                  num_warmup = warmup_window,
                                  num_samples = num_samples,
                                  target_rhat = nRhat_upper,
                                  max_num_steps = 1000 // warmup_window,
                                  index_param = index_param,
                                  seed = seed)