# Section 6: Numerical experiments
Numerical experiments for $\mathfrak n \widehat R$ on several test models. Runs and saves results.

## Setup: import libraries

To install `fun_mc`, run the following commands (in virtual environnment).
```
!rm -Rf probability
!rm -Rf fun_mc
!rm -Rf inference_gym
!git clone https://github.com/tensorflow/probability.git
!mv probability/spinoffs/fun_mc/fun_mc .
!mv probability/spinoffs/inference_gym/inference_gym .
!pip install tf-nightly tfp-nightly jax jaxlib
```
and
```
!pip install immutabledict
```

In [1]:
import numpy as np
from matplotlib.pyplot import *
font = {'family' : 'normal',
        'weight' : 'bold',
        'size'   : 14}
matplotlib.rc('font', **font)

# Use this to silence check type warning messages.
import logging
logging.disable(logging.WARNING)


In [3]:
import jax
from jax import random
from jax import numpy as jnp

# import tfp models and datasets
from inference_gym import using_jax as gym

from fun_mc import using_jax as fun_mcmc

from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import unnest
import tensorflow_probability as _tfp
tfp = _tfp.substrates.jax
tfd = tfp.distributions
tfb = tfp.bijectors

tfp_np = _tfp.substrates.numpy
tfd_np = tfp_np.distributions


# import arviz as az
from tensorflow_probability.python.internal.unnest import get_innermost  


In [4]:
# Directory to save results (adjust to your setting!)
deliv_dir = "/mnt/home/cmargossian/Code/nested-rhat/deliv/"
data_dir = "/mnt/home/cmargossian/Code/nested-rhat/data/"

In [45]:
with open("utility.py") as f: exec(f.read())

In [130]:
# Tuning for the numerical experiments
num_seed = 10
adapt_warmup = False
if (not adapt_warmup): nRhat_upper = 1

num_chains = 2048      # total number of chains
num_super_chains = 16  # K, options: 2, 8, 16, 64, 256, 1024
num_sub_chains = num_chains // num_super_chains  # M
num_samples = 1        # length of sampling phase

num_warmup = 10  # 1000
total_samples = num_warmup + num_samples + 1

max_warmup = 1000
warmup_window = 100

# Start with a series of small windows for initialized chains
# and then switch to a wider warmup window.
window_array = np.append(np.repeat(10, 10),
                      np.repeat(warmup_window, max_warmup // warmup_window - 1))

naive_super_chains = True

In [131]:
# check tuning parameters of experiment
print("num_samples:", num_samples)
print("num_super_chains:", num_super_chains)
print("(total) num_chains:", num_chains)
print("num_seed:", num_seed)
print("nRhat_upper:", nRhat_upper)
print("naive_super_chains:", naive_super_chains)

num_samples: 1
num_super_chains: 16
(total) num_chains: 2048
num_seed: 10
nRhat_upper: 1
naive_super_chains: True


## Run experiments

### Rosenbrock distribution

In [132]:
target = gym.targets.VectorModel(gym.targets.Banana(),
                                 flatten_sample_transformations=True)
num_dimensions = target.event_shape[0]
init_step_size = 1.

def target_log_prob_fn(x):
  """Unnormalized, unconstrained target density.

  This is a thin wrapper that applies the default bijectors so that we can
  ignore any constraints.
  """
  y = target.default_event_space_bijector(x)
  fldj = target.default_event_space_bijector.forward_log_det_jacobian(x)
  return target.unnormalized_log_prob(y) + fldj

# NOTE: Avoid initials centered around the true mean.
offset = 2
def initialize (shape, key = random.PRNGKey(37272709)):
  return 10 * random.normal(key, shape + (num_dimensions,)) + offset


In [133]:
# Get some estimates of the mean and variance.
try:
  mean_est = target.sample_transformations['identity'].ground_truth_mean
except:
  print('no ground truth mean')
  mean_est = (result.all_states[num_warmup:, :]).mean(0).mean(0)
try:
  var_est = target.sample_transformations['identity'].ground_truth_standard_deviation**2
except:
  print('no ground truth std dev')
  var_est = ((result.all_states[num_warmup:, :]**2).mean(0).mean(0) -
             mean_est**2)

In [134]:
# Define MCMC kernel
# NOTE: to compute classic Rhat, need at least 2 iterations per chain.
# To insure this, total_samples is incremented by 1.


kernel = construct_kernel(target_log_prob_fn = target_log_prob_fn,
                          init_step_size = init_step_size,
                          num_warmup = num_warmup)


In [135]:
# seed = jax.random.split(jax.random.PRNGKey(1), num_seed)
# initial_state = initialize((num_chains,), key = seed[0] + 1954)
# # initial_state = np.repeat(initial_state, num_chains // num_super_chains,
# #                                 axis = 0)

# initial_state
num_seed = 10

In [136]:
kernel_cold, kernel_warm = adaptive_kernels(target_log_prob_fn, init_step_size, num_warmup)
index_param = np.array([0, 1])  # only two dimensions for this distribution

mc_mean_list, warmup_length,\
squared_error_list, nrhat_list = run_forge_chain(num_seed = num_seed,
                                          kernel_cold = kernel_cold,
                                          kernel_warm = kernel_warm,
                                          initialize = initialize,
                                          num_super_chains = num_super_chains,
                                          num_warmup = window_array,
                                          num_samples = num_samples,
                                          target_rhat = nRhat_upper,
                                          max_num_steps = window_array.shape[0],
                                          index_param = index_param,
                                          mean_benchmark = mean_est,
                                          var_benchmark = var_est,
                                          naive_super_chains = naive_super_chains)


SEED : [696157669 674520459]


Users of the modes 'nearest', 'lower', 'higher', or 'midpoint' are encouraged to review the method they used. (Deprecated NumPy 1.22)


SEED : [1771252691  383428323]
SEED : [3312214627   69373237]
SEED : [1094216230 2034336350]
SEED : [1788847948 1554193616]
SEED : [338758650 111605681]
SEED : [4073466969 2556348314]
SEED : [ 167784864 1021574613]
SEED : [3458295564 1820549682]
SEED : [3023805236 1296165206]


In [137]:
# Save output into npy files
model_name = "rosenbrock"
exp_parm = "_K" + str(num_super_chains) + "_M" + str(num_sub_chains) + "_N" + str(num_samples)

if naive_super_chains:
    exp_parm = "_naive" + exp_parm

np.save(deliv_dir + model_name + exp_parm + "_nrhat", nrhat_list)
np.save(deliv_dir + model_name + exp_parm + "_squared_error", squared_error_list)

### German Credit Score

In [138]:
target = gym.targets.VectorModel(gym.targets.GermanCreditNumericLogisticRegression(),
                                  flatten_sample_transformations=True)
num_dimensions = target.event_shape[0]
init_step_size = 0.02

def target_log_prob_fn(x):
  """Unnormalized, unconstrained target density.

  This is a thin wrapper that applies the default bijectors so that we can
  ignore any constraints.
  """
  y = target.default_event_space_bijector(x)
  fldj = target.default_event_space_bijector.forward_log_det_jacobian(x)
  return target.unnormalized_log_prob(y) + fldj

offset = 0.5
def initialize (shape, key = random.PRNGKey(37272709)):
  return 3 * random.normal(key, shape + (num_dimensions,)) + offset

In [139]:
# Get some estimates of the mean and variance.
try:
  mean_est = target.sample_transformations['identity'].ground_truth_mean
except:
  print('no ground truth mean')
  mean_est = (result.all_states[num_warmup:, :]).mean(0).mean(0)
try:
  var_est = target.sample_transformations['identity'].ground_truth_standard_deviation**2
except:
  print('no ground truth std dev')
  var_est = ((result.all_states[num_warmup:, :]**2).mean(0).mean(0) -
             mean_est**2)

In [140]:
kernel_cold, kernel_warm = adaptive_kernels(target_log_prob_fn, init_step_size, num_warmup)
index_param = np.arange(0, 25)

mc_mean_list, warmup_length,\
squared_error_list, nrhat_list = run_forge_chain(num_seed = num_seed,
                                          kernel_cold = kernel_cold,
                                          kernel_warm = kernel_warm,
                                          initialize = initialize,
                                          num_super_chains = num_super_chains,
                                          num_warmup = window_array,
                                          num_samples = num_samples,
                                          target_rhat = nRhat_upper,
                                          max_num_steps = window_array.shape[0],
                                          index_param = index_param,
                                          mean_benchmark = mean_est,
                                          var_benchmark = var_est,
                                          naive_super_chains = naive_super_chains)

SEED : [696157669 674520459]


2024-03-19 11:24:37.695920: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Users of the modes 'nearest', 'lower', 'higher', or 'midpoint' are encouraged to review the method they used. (Deprecated NumPy 1.22)


SEED : [1771252691  383428323]
SEED : [3312214627   69373237]
SEED : [1094216230 2034336350]
SEED : [1788847948 1554193616]
SEED : [338758650 111605681]
SEED : [4073466969 2556348314]
SEED : [ 167784864 1021574613]
SEED : [3458295564 1820549682]
SEED : [3023805236 1296165206]


In [141]:
# Save output into npy files
model_name = "german"
exp_parm = "_K" + str(num_super_chains) + "_M" + str(num_sub_chains) + "_N" + str(num_samples)

if naive_super_chains:
    exp_parm = "_naive" + exp_parm

np.save(deliv_dir + model_name + exp_parm + "_nrhat", nrhat_list)
np.save(deliv_dir + model_name + exp_parm + "_squared_error", squared_error_list)

### Eight Schools

In [142]:
# NOTE: inference gym stores the centered parameterization
target_raw = gym.targets.EightSchools()  # store raw to examine doc.
target = gym.targets.VectorModel(target_raw,
                                  flatten_sample_transformations = True)
num_dimensions = target.event_shape[0]
init_step_size = 1

# Using underdispersed initis can show case problems with our diagnostics.
# underdispered = False
# Options: underdispersed, overdispersed, prior
init_type = "prior"
if init_type == "underdispersed":
  offset = 0.0
  def initialize (shape, key = random.PRNGKey(37272709)):
    return 1 * random.normal(key, shape + (num_dimensions,)) + offset
elif init_type == "overdispersed":
  offset = 0.0
  def initialize (shape, key = random.PRNGKey(37272709)):
    return 100 * random.normal(key, shape + (num_dimensions,)) + offset
elif init_type == "prior":
  def initialize (shape, key = random.PRNGKey(37272709)):
    prior_scale = jnp.append(jnp.array([10., 1.]), jnp.repeat(1., 8))
    prior_offset = jnp.append(jnp.array([0., 5.]), jnp.repeat(0., 8))
    return prior_scale * random.normal(key, shape + (num_dimensions,)) + prior_offset


In [143]:
num_schools = 8
y = np.array([28, 8, -3, 7, -1, 1, 18, 12], dtype = np.float32)
sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18], dtype = np.float32)

# NOTE: the reinterpreted batch dimension specifies the dimension of
# each indepdent variable, here the school.
model = tfd.JointDistributionSequential([
    tfd.Normal(loc = 0., scale = 10., name = "mu"),
    tfd.Normal(loc = 5., scale = 1., name = "log_tau"),
    tfd.Independent(tfd.Normal(loc = jnp.zeros(num_schools),
                               scale = jnp.ones(num_schools),
                               name = "eta"),
                    reinterpreted_batch_ndims = 1),
    lambda eta, log_tau, mu: (
        tfd.Independent(tfd.Normal(loc = (mu[..., jnp.newaxis] +
                                        jnp.exp(log_tau[..., jnp.newaxis]) *
                                        eta),
                                   scale = sigma),
                        name = "y",
                        reinterpreted_batch_ndims = 1))
  ])

def target_log_prob_fn(x):
  mu = x[:, 0]
  log_tau = x[:, 1]
  eta = x[:, 2:10]
  return model.log_prob((mu, log_tau, eta, y))


In [144]:
# Use results from running 128 chains with 1000 + 5000 iterations each,
# for non-centered parameterization.
mean_est = np.array([5.8006573 ,  2.4502006 ,  0.6532423 ,  0.09639207,
             -0.23725411,  0.04723661, -0.33556408, -0.19666635,
              0.5390533 ,  0.14633301])

var_est = np.array([29.60382   ,  0.26338503,  0.6383733 ,  0.4928926 ,
              0.65307987,  0.52441144,  0.46658015,  0.5248887 ,
              0.49544162,  0.690975])

In [145]:
kernel_cold, kernel_warm = adaptive_kernels(target_log_prob_fn, init_step_size, num_warmup)
index_param = np.arange(0, 10)

mc_mean_list, warmup_length,\
squared_error_list, nrhat_list = run_forge_chain(num_seed = num_seed,
                                          kernel_cold = kernel_cold,
                                          kernel_warm = kernel_warm,
                                          initialize = initialize,
                                          num_super_chains = num_super_chains,
                                          num_warmup = window_array,
                                          num_samples = num_samples,
                                          target_rhat = nRhat_upper,
                                          max_num_steps = window_array.shape[0],
                                          index_param = index_param,
                                          mean_benchmark = mean_est,
                                          var_benchmark = var_est,
                                          naive_super_chains = naive_super_chains)

SEED : [696157669 674520459]


Users of the modes 'nearest', 'lower', 'higher', or 'midpoint' are encouraged to review the method they used. (Deprecated NumPy 1.22)


SEED : [1771252691  383428323]
SEED : [3312214627   69373237]
SEED : [1094216230 2034336350]
SEED : [1788847948 1554193616]
SEED : [338758650 111605681]
SEED : [4073466969 2556348314]
SEED : [ 167784864 1021574613]
SEED : [3458295564 1820549682]
SEED : [3023805236 1296165206]


In [146]:
# Save output into npy files
model_name = "schools"
exp_parm = "_K" + str(num_super_chains) + "_M" + str(num_sub_chains) + "_N" + str(num_samples)

if naive_super_chains:
    exp_parm = "_naive" + exp_parm

np.save(deliv_dir + model_name + exp_parm + "_nrhat", nrhat_list)
np.save(deliv_dir + model_name + exp_parm + "_squared_error", squared_error_list)

### Pharmacokinetic

#### load simulated data

In [147]:
with open("pk_model.py") as f: exec(f.read())

In [148]:
n_patients = 20
y_obs = jnp.array(np.load(data_dir + "pk_y_obs.npy"))
mean_est = jnp.array(np.load(data_dir + "pk_npatients_" + str(n_patients) + "_mean_est.npy"))
var_est = jnp.array(np.load(data_dir + "pk_npatients_" + str(n_patients) + "_var_est.npy"))

In [149]:
init_step_size = 0.001
kernel_cold, kernel_warm = adaptive_kernels(pop_target_log_prob_fn_flat,
                                            init_step_size, num_warmup)
index_param = np.arange(0, 45)

mc_mean_list, warmup_length,\
squared_error_list, nrhat_list = run_forge_chain(num_seed = num_seed,
                                          kernel_cold = kernel_cold,
                                          kernel_warm = kernel_warm,
                                          initialize = initialize_flat,
                                          num_super_chains = num_super_chains,
                                          num_warmup = window_array,
                                          num_samples = num_samples,
                                          target_rhat = nRhat_upper,
                                          max_num_steps = window_array.shape[0],
                                          index_param = index_param,
                                          mean_benchmark = mean_est,
                                          var_benchmark = var_est,
                                          naive_super_chains = naive_super_chains)

SEED : [696157669 674520459]


Users of the modes 'nearest', 'lower', 'higher', or 'midpoint' are encouraged to review the method they used. (Deprecated NumPy 1.22)


SEED : [1771252691  383428323]
SEED : [3312214627   69373237]
SEED : [1094216230 2034336350]
SEED : [1788847948 1554193616]
SEED : [338758650 111605681]
SEED : [4073466969 2556348314]
SEED : [ 167784864 1021574613]
SEED : [3458295564 1820549682]
SEED : [3023805236 1296165206]


In [150]:
# Save output into npy files
model_name = "pk"
exp_parm = "_K" + str(num_super_chains) + "_M" + str(num_sub_chains) + "_N" + str(num_samples)

if naive_super_chains:
    exp_parm = "_naive" + exp_parm

np.save(deliv_dir + model_name + exp_parm + "_nrhat", nrhat_list)
np.save(deliv_dir + model_name + exp_parm + "_squared_error", squared_error_list)


### Item Response Theory

In [151]:
target = gym.targets.VectorModel(gym.targets.SyntheticItemResponseTheory(),
                                 flatten_sample_transformations=True)
num_dimensions = target.event_shape[0]
init_step_size = 1.

def target_log_prob_fn(x):
  """Unnormalized, unconstrained target density.

  This is a thin wrapper that applies the default bijectors so that we can
  ignore any constraints.
  """
  y = target.default_event_space_bijector(x)
  fldj = target.default_event_space_bijector.forward_log_det_jacobian(x)
  return target.unnormalized_log_prob(y) + fldj

offset = 0
def initialize (shape, key = random.PRNGKey(37272709)):
  return 10 * random.normal(key, shape + (num_dimensions,)) + offset


In [152]:
# Get some estimates of the mean and variance.
try:
  mean_est = target.sample_transformations['identity'].ground_truth_mean
except:
  print('no ground truth mean')
  mean_est = (result.all_states[num_warmup:, :]).mean(0).mean(0)
try:
  var_est = target.sample_transformations['identity'].ground_truth_standard_deviation**2
except:
  print('no ground truth std dev')
  var_est = ((result.all_states[num_warmup:, :]**2).mean(0).mean(0) -
             mean_est**2)

In [153]:
kernel_cold, kernel_warm = adaptive_kernels(target_log_prob_fn, init_step_size, num_warmup)
index_param = np.arange(0, 501)

mc_mean_list, warmup_length,\
squared_error_list, nrhat_list = run_forge_chain(num_seed = num_seed,
                                          kernel_cold = kernel_cold,
                                          kernel_warm = kernel_warm,
                                          initialize = initialize,
                                          num_super_chains = num_super_chains,
                                          num_warmup = window_array,
                                          num_samples = num_samples,
                                          target_rhat = nRhat_upper,
                                          max_num_steps = window_array.shape[0],
                                          index_param = index_param,
                                          mean_benchmark = mean_est,
                                          var_benchmark = var_est,
                                          naive_super_chains = naive_super_chains)

SEED : [696157669 674520459]


Users of the modes 'nearest', 'lower', 'higher', or 'midpoint' are encouraged to review the method they used. (Deprecated NumPy 1.22)


SEED : [1771252691  383428323]
SEED : [3312214627   69373237]
SEED : [1094216230 2034336350]
SEED : [1788847948 1554193616]
SEED : [338758650 111605681]
SEED : [4073466969 2556348314]
SEED : [ 167784864 1021574613]
SEED : [3458295564 1820549682]
SEED : [3023805236 1296165206]


In [154]:
# Save output into npy files
model_name = "itr"
exp_parm = "_K" + str(num_super_chains) + "_M" + str(num_sub_chains) + "_N" + str(num_samples)

if naive_super_chains:
    exp_parm = "_naive" + exp_parm

np.save(deliv_dir + model_name + exp_parm + "_nrhat", nrhat_list)
np.save(deliv_dir + model_name + exp_parm + "_squared_error", squared_error_list)

### Bimodal Gaussian

In [155]:
num_dimensions = 100
dist = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(probs=[0.3, 0.7]), 
    components_distribution=tfd.MultivariateNormalDiag(
      loc=[jnp.repeat(-5., num_dimensions), jnp.repeat(5., num_dimensions)],
      scale_diag=jnp.repeat(1., num_dimensions)))

def target_log_prob_fn(x):
  return dist.log_prob(x)

offset = 0
def initialize (shape, key = random.PRNGKey(37272709)):
  return 10 * random.normal(key, shape + (num_dimensions,)) + offset

mean_est = jnp.repeat(2, num_dimensions)
var_est = jnp.repeat(22, num_dimensions)

init_step_size = 1

In [156]:
kernel_cold, kernel_warm = adaptive_kernels(target_log_prob_fn, init_step_size, num_warmup)
index_param = np.arange(0, num_dimensions)

mc_mean_list, warmup_length,\
squared_error_list, nrhat_list = run_forge_chain(num_seed = num_seed,
                                          kernel_cold = kernel_cold,
                                          kernel_warm = kernel_warm,
                                          initialize = initialize,
                                          num_super_chains = num_super_chains,
                                          num_warmup = window_array,
                                          num_samples = num_samples,
                                          target_rhat = nRhat_upper,
                                          max_num_steps = window_array.shape[0],
                                          index_param = index_param,
                                          mean_benchmark = mean_est,
                                          var_benchmark = var_est,
                                          naive_super_chains = naive_super_chains)


SEED : [696157669 674520459]


Users of the modes 'nearest', 'lower', 'higher', or 'midpoint' are encouraged to review the method they used. (Deprecated NumPy 1.22)


SEED : [1771252691  383428323]
SEED : [3312214627   69373237]
SEED : [1094216230 2034336350]
SEED : [1788847948 1554193616]
SEED : [338758650 111605681]
SEED : [4073466969 2556348314]
SEED : [ 167784864 1021574613]
SEED : [3458295564 1820549682]
SEED : [3023805236 1296165206]


In [157]:
model_name = "bimodal"
exp_parm = "_K" + str(num_super_chains) + "_M" + str(num_sub_chains) + "_N" + str(num_samples)

if naive_super_chains:
    exp_parm = "_naive" + exp_parm

np.save(deliv_dir + model_name + exp_parm + "_nrhat", nrhat_list)
np.save(deliv_dir + model_name + exp_parm + "_squared_error", squared_error_list)

: 