<a href="https://colab.research.google.com/github/careychou/exploration/blob/master/tfp_multi_level_bayesian_examples.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
import warnings

tf.enable_v2_behavior()

plt.style.use("ggplot")
warnings.filterwarnings('ignore')

tfb = tfp.bijectors

In [None]:
def mcmc_run_simple(initial_states, target_log_prob):
  return tfp.mcmc.sample_chain( 
                      num_results=100,
                      num_burnin_steps=10,
                      current_state=init_state_prior,
                      kernel=tfp.mcmc.HamiltonianMonteCarlo(
                          target_log_prob_fn=target_log_prob,
                          step_size=0.4,
                          num_leapfrog_steps=3))

 
def mcmc_run_dual(initial_states, target_log_prob, bijectors):
  return tfp.mcmc.sample_chain(
            num_results=200,
            current_state=initial_states,
            num_burnin_steps=50,
            kernel=tfp.mcmc.DualAveragingStepSizeAdaptation(
                inner_kernel=tfp.mcmc.TransformedTransitionKernel(
                    inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
                        target_log_prob_fn=target_log_prob,
                        step_size=1,
                        num_leapfrog_steps=5,
                        state_gradients_are_stopped=True),
                        bijector=bijectors),               
                    num_adaptation_steps=40))


## Generate sample data - E(D | P) ~ Normal

In [None]:
# seg1
def gen_normal(hypermu, hypertau, sample_size):
  return tfd.Normal(
    loc=tfd.Normal(10., 1.).sample(1) + 
      tfd.Normal(
          loc=tfd.Normal(hypermu[0], hypermu[1]).sample(1),
          scale=tfd.Normal(hypertau[0], hypertau[1]).sample(1)
      ).sample(1), 
    scale=tfd.Normal(5., 1.).sample(1)).sample(sample_size)


sample_data = [*gen_normal((20., 5.), (5., 1.), 10).numpy().squeeze(),  # segment 1
                *gen_normal((-5., 5.), (5., 1.), 10).numpy().squeeze(), # segment 2
                *gen_normal((-10., 5.), (5., 1.), 10).numpy().squeeze() # segment 3
                ]

num_segments = 3
sample_segment = [*np.repeat(0, 10), *np.repeat(1, 10), *np.repeat(2, 10)]

(np.mean(sample_data[:10]), np.mean(sample_data[10:20]), np.mean(sample_data[20:30]))


(28.560055, 9.511309, 3.5284638)

## Centered Multi Level with Pooled Mean: Normal Likelihood

In [None]:
# centered multi-level
model = tfd.JointDistributionSequential([
  tfd.Normal(loc=0., scale=10.),  # mu hyper prior
  tfd.Normal(loc=5., scale=1.),  # tau hyper prior

  # segment as random effects ~ N(0, tau)
  lambda tau, mu: tfd.Independent(
      tfd.Normal(
          loc=tf.ones(num_segments) * mu, 
          scale=tau), 
        reinterpreted_batch_ndims=1),

  tfd.Normal(loc=5., scale=1.),  # pooled sc     
  tfd.Normal(loc=10., scale=1.),  # pooled mean

  lambda mean, sc, seg_eff, _: tfd.Independent(
      tfd.Normal(
          loc=mean + tf.gather(seg_eff, sample_segment), 
          scale=sc), 
        reinterpreted_batch_ndims=1)
])


In [None]:
samples = model.sample(1)
print(samples)
#sample_data = samples[-1]
sample_priors = samples[:-1]
model.log_prob_parts(samples)

[<tf.Tensor: shape=(1,), dtype=float32, numpy=array([19.442276], dtype=float32)>, <tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.416532], dtype=float32)>, <tf.Tensor: shape=(3,), dtype=float32, numpy=array([1.0225685, 5.169071 , 1.648947 ], dtype=float32)>, <tf.Tensor: shape=(1,), dtype=float32, numpy=array([4.7356777], dtype=float32)>, <tf.Tensor: shape=(1,), dtype=float32, numpy=array([9.146925], dtype=float32)>, <tf.Tensor: shape=(6,), dtype=float32, numpy=
array([4.6708636, 3.8781185, 2.1434076, 6.813615 , 2.912087 , 4.423038 ],
      dtype=float32)>]


[<tf.Tensor: shape=(1,), dtype=float32, numpy=array([-1.0744666], dtype=float32)>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-1.005688], dtype=float32)>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-7.505594>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-0.95387167], dtype=float32)>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-1.282807], dtype=float32)>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-14.939651>]

In [None]:
sample_priors = model.sample(1)[:-1]
model.log_prob_parts([*sample_priors, sample_data])

[<tf.Tensor: shape=(1,), dtype=float32, numpy=array([-0.97465146], dtype=float32)>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-0.9391118], dtype=float32)>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-8.071594>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-0.9728293], dtype=float32)>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-1.0571938], dtype=float32)>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-15.238572>]

In [None]:
def target_log_prob_fn(*params):
  return model.log_prob(params + (sample_data, ))

target_log_prob_fn(*sample_priors)

<tf.Tensor: shape=(1,), dtype=float32, numpy=array([-27.253952], dtype=float32)>

In [None]:
init_state_prior = model.sample(1)[:-1]
init_state_prior = [tf.squeeze(x) for x in init_state_prior]
init_state_prior

[<tf.Tensor: shape=(), dtype=float32, numpy=8.298271>,
 <tf.Tensor: shape=(), dtype=float32, numpy=4.519633>,
 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([4.903119 , 3.0838227, 6.646962 ], dtype=float32)>,
 <tf.Tensor: shape=(), dtype=float32, numpy=4.669758>,
 <tf.Tensor: shape=(), dtype=float32, numpy=8.922955>]

In [None]:
states, kernels = mcmc_run_dual(init_state_prior, 
                                target_log_prob_fn, 
                                [tfb.Identity(), tfb.Softplus(), tfb.Identity(), tfb.Softplus(), tfb.Identity()])

In [None]:
# 
print('pool mean=', states[4].numpy().mean())
print('pool sc=', states[3].numpy().mean())
print('segment =', states[2].numpy().mean(axis=0))
print('hyper tau =', states[1].numpy().mean())
print('hyper mu =', states[0].numpy().mean())

print('recover mean:')
states[4].numpy().mean() + states[2].numpy().mean(axis=0)


pool mean= 9.877331
pool sc= 5.34104
segment = [17.802338   -0.16821226 -5.4932446 ]
hyper tau = 6.076803
hyper mu = 4.256919
recover mean:


array([27.679668,  9.709119,  4.384086], dtype=float32)

## Bootstrap MCMC and Expectation

In [None]:
# check MCMC one step
hmc = tfp.mcmc.HamiltonianMonteCarlo(
    target_log_prob_fn=target_log_prob_fn,
    step_size=0.015,
    num_leapfrog_steps=3)

# internally HamiltonianMonteCarlo uses this so we can access inner_kernel
hmc = tfp.mcmc.MetropolisHastings(
    tfp.mcmc.UncalibratedHamiltonianMonteCarlo(
        target_log_prob_fn=target_log_prob_fn,
        step_size=0.1,
        num_leapfrog_steps=3))

kernel_results = hmc.bootstrap_results(init_state_prior)
kernel_results.accepted_results.target_log_prob
print(kernel_results.accepted_results)
print(hmc.inner_kernel)

proposed_state, proposed_results = hmc.inner_kernel.one_step(init_state_prior, kernel_results.accepted_results)

UncalibratedHamiltonianMonteCarloKernelResults(
  log_acceptance_correction=<tf.Tensor: shape=(), dtype=float32, numpy=0.0>,
  target_log_prob=<tf.Tensor: shape=(), dtype=float32, numpy=-501.19904>,
  grads_target_log_prob=[<tf.Tensor: shape=(), dtype=float32, numpy=-0.038234256>, <tf.Tensor: shape=(), dtype=float32, numpy=-2.2814758>, <tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 17.365917 ,  -6.9865375, -16.19313  ], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=224.9885>, <tf.Tensor: shape=(), dtype=float32, numpy=-5.3730164>],
  initial_momentum=[<tf.Tensor: shape=(), dtype=float32, numpy=0.0>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>, <tf.Tensor: shape=(3,), dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>],
  final_momentum=[<tf.Tensor: shape=(), dtype=float32, numpy=0.0>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>, <tf.Tensor: shape=(3,), dty

## Non-Centered Multi-Level with Pooled Mean: Normal Likelihood

In [None]:
# centered multi-level
model = tfd.JointDistributionSequential([
  tfd.Normal(loc=0., scale=10.),  # pooled mean
  tfd.Normal(loc=5., scale=1.),  # tau hyper prior

  # segment ~ N(0, 1)
  lambda _: tfd.Independent(
      tfd.Normal(
          loc=tf.zeros(num_segments),
          scale=tf.ones(num_segments)), 
        reinterpreted_batch_ndims=1),

  tfd.Normal(loc=5., scale=1.),  # pooled sc     

  # non-centered multi-level
  # tau will determine the degree of pooling level
  lambda sc, seg_eff, tau, mean: 
    tfd.Independent(
        tfd.Normal(
            loc=mean + tf.gather(seg_eff, sample_segment) * tau, 
            scale=sc), 
          reinterpreted_batch_ndims=1) 
])


model.sample(1)

[<tf.Tensor: shape=(1,), dtype=float32, numpy=array([-5.298491], dtype=float32)>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.993776], dtype=float32)>,
 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([2.3179882, 1.9950273, 1.065626 ], dtype=float32)>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.052293], dtype=float32)>,
 <tf.Tensor: shape=(30,), dtype=float32, numpy=
 array([10.127832 , 12.451334 ,  3.2078671, 10.156356 , 12.192538 ,
         9.801848 ,  8.25924  ,  8.719268 ,  9.96419  ,  7.3505316,
        13.731631 ,  6.212725 ,  1.3092108,  6.95438  , -2.9004512,
        11.709936 ,  3.838882 ,  3.4632978, -1.1662292,  3.649712 ,
        -4.195501 ,  5.607013 , -1.0003219, -5.7458754,  5.1385775,
        11.245331 ,  3.484091 , 13.492812 , -6.1119547, -5.8696256],
       dtype=float32)>]

In [None]:
init_state_prior = model.sample(1)[:-1]
init_state_prior = [tf.squeeze(x) for x in init_state_prior]
init_state_prior

[<tf.Tensor: shape=(), dtype=float32, numpy=1.3581139>,
 <tf.Tensor: shape=(), dtype=float32, numpy=3.4316115>,
 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([-0.01797873, -1.2139146 ,  0.2831507 ], dtype=float32)>,
 <tf.Tensor: shape=(), dtype=float32, numpy=6.054669>]

In [None]:
#states, kernels = mcmc_run_simple(init_state_prior, target_log_prob_fn)

states, kernels = mcmc_run_dual(init_state_prior, 
                                target_log_prob_fn, 
                                [tfb.Identity(), tfb.Softplus(), tfb.Identity(), tfb.Softplus()])

In [None]:

# segment random coefficients
# 
print('pool sc=', states[3].numpy().mean())
print('segment =', states[2].numpy().mean(axis=0))
print('hyper tau =', states[1].numpy().mean())
print('pool mean =', states[0].numpy().mean())

print('recover mean:')
states[0].numpy().mean() + states[2].numpy().mean(axis=0)

pool sc= 5.256248
segment = [ 3.0016313   0.06285957 -0.8820417 ]
hyper tau = 6.08828
pool mean = 9.052127
recover mean:


array([12.053759,  9.114986,  8.170085], dtype=float32)

## Inverse Gaussian Likelihood - E(D) ~ InvGaussian

In [None]:
# seg1
def gen_invgauss(hypermu, hypertau, sample_size):
  return tfd.InverseGaussian(
    loc=tfd.Normal(10., 1.).sample(1) + 
      tfd.InverseGaussian(
          loc=tfd.Normal(hypermu[0], hypermu[1]).sample(1),
          concentration=tfd.Normal(hypertau[0], hypertau[1]).sample(1)
      ).sample(1), 
    concentration=tfd.Normal(5., 1.).sample(1)).sample(sample_size)


sample_data = [*gen_invgauss((20., 1.), (5., 1.), 10).numpy().squeeze(),  # segment 1
                *gen_invgauss((15., 1.), (5., 1.), 10).numpy().squeeze(), # segment 2
                *gen_invgauss((15., 1.), (5., 1.), 10).numpy().squeeze() # segment 3
                ]

num_segments = 3
sample_segment = [*np.repeat(0, 10), *np.repeat(1, 10), *np.repeat(2, 10)]

(np.mean(sample_data[:10]), np.mean(sample_data[10:20]), np.mean(sample_data[20:30]))

(46.24634, 26.008976, 21.900906)

In [None]:
# centered multi-level
model = tfd.JointDistributionSequential([
  tfd.Normal(loc=20., scale=1.),  # mu hyper prior
  tfd.Normal(loc=5., scale=1.),  # tau hyper prior

  # segment effects
  lambda tau, mu: tfd.Independent(
      tfd.InverseGaussian(
          loc=tf.ones(num_segments) * mu,
          concentration=tau), 
        reinterpreted_batch_ndims=1),

  tfd.Normal(loc=5., scale=1.),  # pooled con     
  tfd.Normal(loc=10., scale=1.),  # pooled mean     

  lambda mean, con, seg_eff, tau, mu: tfd.Independent(
      tfd.InverseGaussian(
          loc=mean + tf.gather(seg_eff, sample_segment), 
          concentration=con), 
        reinterpreted_batch_ndims=1)
])


model.sample(1)



[<tf.Tensor: shape=(1,), dtype=float32, numpy=array([20.855316], dtype=float32)>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([4.580132], dtype=float32)>,
 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([34.3951   , 19.64031  ,  3.3982825], dtype=float32)>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([4.2118573], dtype=float32)>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([8.726359], dtype=float32)>,
 <tf.Tensor: shape=(30,), dtype=float32, numpy=
 array([  7.700204  ,   1.3017842 , 103.44781   ,   1.6039978 ,
          2.0240557 ,  10.114746  ,   0.6932087 ,  15.266415  ,
          4.83981   ,   2.492891  ,  34.25634   ,   1.4207306 ,
         29.137371  ,   4.397866  ,   3.5568922 ,  33.987255  ,
          3.2681527 ,  10.790403  ,   2.6242285 ,   4.026872  ,
          9.565543  ,   1.3787609 ,   0.93280584,  10.192049  ,
          3.2980587 ,   7.5055037 ,   4.8420825 ,   3.5081747 ,
         12.46052   ,  14.481946  ], dtype=float32)>]

In [None]:
samples = model.sample(1)
print(samples)
#sample_data = samples[-1]
sample_priors = samples[:-1]
model.log_prob_parts(samples)

[<tf.Tensor: shape=(1,), dtype=float32, numpy=array([20.40563], dtype=float32)>, <tf.Tensor: shape=(1,), dtype=float32, numpy=array([4.476732], dtype=float32)>, <tf.Tensor: shape=(3,), dtype=float32, numpy=array([13.847706 ,  3.9496722,  2.1079385], dtype=float32)>, <tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.599988], dtype=float32)>, <tf.Tensor: shape=(1,), dtype=float32, numpy=array([10.752052], dtype=float32)>, <tf.Tensor: shape=(30,), dtype=float32, numpy=
array([ 6.56479  ,  1.6680896,  1.1255519,  2.356922 , 16.02023  ,
        2.7315688,  9.514123 ,  1.890434 , 33.071724 , 93.20746  ,
       12.143097 ,  6.0066495, 18.136318 ,  3.365411 , 10.415132 ,
       11.090991 ,  1.5251104,  5.0145874, 15.211153 , 11.743227 ,
       10.726705 , 11.186402 ,  6.7708626,  2.9365609,  6.5919466,
        4.288242 ,  6.5893703,  4.4199357, 12.470299 ,  4.8206515],
      dtype=float32)>]


[<tf.Tensor: shape=(1,), dtype=float32, numpy=array([-1.0012064], dtype=float32)>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-1.0558434], dtype=float32)>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-8.868751>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-2.1989193], dtype=float32)>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-1.2017299], dtype=float32)>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-98.32765>]

In [None]:
sample_priors = model.sample(1)[:-1]
model.log_prob_parts([*sample_priors, sample_data])

[<tf.Tensor: shape=(1,), dtype=float32, numpy=array([-2.3039825], dtype=float32)>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-1.9715383], dtype=float32)>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-9.87714>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-1.08834], dtype=float32)>,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-2.0942154], dtype=float32)>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-111.52952>]

In [None]:
def target_log_prob_fn(*params):
  return model.log_prob(params + (sample_data, ))

target_log_prob_fn(*sample_priors)

<tf.Tensor: shape=(1,), dtype=float32, numpy=array([-128.86473], dtype=float32)>

In [None]:
init_state_prior = model.sample(1)[:-1]
init_state_prior = [tf.squeeze(x) for x in init_state_prior]
init_state_prior

[<tf.Tensor: shape=(), dtype=float32, numpy=19.535088>,
 <tf.Tensor: shape=(), dtype=float32, numpy=3.535799>,
 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([3.817508 , 4.6954713, 0.6031735], dtype=float32)>,
 <tf.Tensor: shape=(), dtype=float32, numpy=5.2614326>,
 <tf.Tensor: shape=(), dtype=float32, numpy=13.092587>]

In [None]:
states, kernels = mcmc_run_dual(init_state_prior, 
                                target_log_prob_fn, 
                                [tfb.Identity(), tfb.Softplus(), tfb.Identity(), tfb.Softplus(), tfb.Identity()])


In [None]:
# spot check
states[4][:50]

<tf.Tensor: shape=(50,), dtype=float32, numpy=
array([10.894014 , 10.894014 , 10.921763 , 10.599963 ,  9.915589 ,
       10.763045 ,  9.8496   , 10.61969  ,  9.903763 , 10.628729 ,
        9.831019 , 10.544575 ,  9.728827 , 10.3594885,  9.7000265,
       10.475151 ,  9.686661 , 10.475242 ,  9.736767 , 10.4396515,
        9.778593 , 10.435515 ,  9.667163 , 10.419586 ,  9.764664 ,
       10.324897 ,  9.761891 , 10.282962 ,  9.689324 , 10.369272 ,
        9.651732 , 10.411853 ,  9.795025 , 10.468335 ,  9.919574 ,
       10.785205 ,  9.855764 , 10.843621 ,  9.967306 , 10.812167 ,
        9.78414  , 10.678479 , 10.678479 , 10.678479 ,  9.691377 ,
       10.6689415, 10.6689415,  9.7830515, 10.610923 ,  9.808554 ],
      dtype=float32)>

In [None]:
# 
print('pool mean=', states[4].numpy().mean())
print('pool sc=', states[3].numpy().mean())
print('segment =', states[2].numpy().mean(axis=0))
print('hyper tau =', states[1].numpy().mean())
print('hyper mu =', states[0].numpy().mean())

print('recover mean:')
states[4].numpy().mean() + states[2].numpy().mean(axis=0)

pool mean= 10.329773
pool sc= 5.2572002
segment = [13.436426 11.54477   8.80277 ]
hyper tau = 5.17677
hyper mu = 19.989313
recover mean:


array([23.7662  , 21.874542, 19.132542], dtype=float32)