In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
import json
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_context('notebook')
from matplotlib import cm
import corner
import datetime

In [3]:
tf.enable_eager_execution()

In [4]:
import tensorflow_probability as tfp
tfd = tfp.distributions

In [5]:
dist = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(
        probs=[0.5, 0.5]),
    components_distribution=tfd.MultivariateNormalDiag(
    loc=[[.8, .5], [.2, .5]],
    scale_diag=[[.2, .4], [.2, .4]])
)

Instructions for updating:
The `logits` property will return `None` when the distribution is parameterized with `logits=None`. Use `logits_parameter()` instead.


In [17]:

log_prob = dist.log_prob
 
n_chains = 128
# initial_state = tf.constant([0.1, 0.15, 0.2])
initial_state = tf.constant(np.random.rand(n_chains, 2), dtype=tf.float32)
num_results = 100000
num_burnin = 10000
step_size = .01

In [18]:
hmc_kernel = tfp.mcmc.HamiltonianMonteCarlo(
    target_log_prob_fn=log_prob,
    num_leapfrog_steps=3,
    step_size=step_size
)

hmc_adaptive_kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
    inner_kernel=hmc_kernel,
    num_adaptation_steps=int(num_burnin * 0.8)
)


In [32]:

@tf.function(experimental_compile=True)
# @tf.function
def run_sampler(initial_state, num_results, num_burnin):
    return tfp.mcmc.sample_chain(
        num_results=num_results,
        num_burnin_steps=num_burnin,
        current_state=initial_state,
        kernel=hmc_adaptive_kernel,
        trace_fn=None,
#         return_final_kernel_results=True
    )


start_time = datetime.datetime.now()

chain_output = run_sampler(initial_state, num_results, num_burnin)
time_elapsed = datetime.datetime.now() - start_time
print(f'HMC sampling complete in {time_elapsed}')

HMC sampling complete in 0:00:14.419344


100,000 results, 10,000 burn-in, 128 chains


2 min 11 seconds with decorator!
8.4 seconds with XLA


1000 results, 100 burnin, 30 chains:

46 seconds with eager

1.15 seconds with graph

3.2 seconds with decorator, no XLA

4.0 seconds with XLA (but probably due to overheads)

In [34]:
chain_output

<tf.Tensor: id=15681, shape=(100000, 128, 2), dtype=float32, numpy=
array([[[ 0.79394245, -0.0798227 ],
        [ 0.14911179,  0.67889696],
        [ 0.75479525, -0.02186025],
        ...,
        [ 0.87894416,  0.83362263],
        [ 0.86337537,  0.54784644],
        [ 0.51741296,  1.4322187 ]],

       [[ 0.79394245, -0.0798227 ],
        [ 0.14911179,  0.67889696],
        [ 0.648572  ,  1.4233072 ],
        ...,
        [ 0.8513606 ,  0.38271025],
        [ 0.86337537,  0.54784644],
        [ 0.9816398 , -0.56908023]],

       [[ 0.05410323,  1.1697688 ],
        [ 0.15783443,  0.16707058],
        [ 0.70495224, -0.2439361 ],
        ...,
        [ 0.61279726,  0.7344026 ],
        [ 0.86337537,  0.54784644],
        [ 0.72114486,  1.3763219 ]],

       ...,

       [[ 0.09594403,  0.24278541],
        [ 0.3862634 ,  0.07676389],
        [ 0.70959586,  1.2336473 ],
        ...,
        [ 0.6963318 ,  0.20357248],
        [ 0.28985682,  0.90122175],
        [ 0.43198845,  0.7305501 

In [33]:
# chain_output.all_states

AttributeError: 'tensorflow.python.framework.ops.EagerTensor' object has no attribute 'all_states'

In [None]:
chain_output.all_states.

In [None]:
chain_output.final_kernel_results

In [None]:
chain_output.final_kernel_results.

In [35]:
second_output = run_sampler(chain_output[-1], num_results, num_burnin)