In [1]:
import tensorflow as tf
import numpy as np
import tensorflow_probability as tfp
import time 
tfd = tfp.distributions
tfp.__version__

'0.8.0-rc0'

In [13]:
FROM_LOGITS = True

model_path   = '../data/saved_models/vae_mnist_generative_net.h5'
test_path    = '../data/datasets/test/mnist_single.npy'
mcmc_kernel  = tfp.mcmc.HamiltonianMonteCarlo
mcmc_kwargs  = {}
n_chains     = 100


# Import models 
generator  = tf.keras.models.load_model(model_path)
latent_dim = generator.input_shape[-1]
z_shape    = [n_chains, latent_dim]
shape_single_draw = [latent_dim]

# Import test data
test = np.load(test_path, allow_pickle=True)
x_true = test.data.repeat(n_chains,axis=0)
mask = (1 - test.mask).repeat(n_chains,axis=0)
#x_true = tf.cast(test.data,'float32')
#mask   = (1-test.mask)

# Prepare log prior function
prior_sd          = 1

ind = tfd.Independent(
    distribution=tfd.MultivariateNormalDiag(
        loc=tf.zeros(z_shape),
        scale_identity_multiplier=prior_sd))

log_prior = ind.log_prob


# Prepare log likelihood function
def log_like(z):
    x_pred = generator(z)
        
    raw_loss = tf.keras.losses.binary_crossentropy(x_true,
                                               x_pred,
                                               from_logits=FROM_LOGITS) 
    loss = tf.expand_dims(raw_loss,-1) * mask
    ll = tf.reduce_sum(loss,axis=[1,2,3])
    return ll
    
# Create target log density
def log_prob(z):
    return log_prior(z) + log_like(z)




In [14]:
inverse_temperatures = np.exp(np.linspace(1,-2,100)).astype('float32')

In [15]:
adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(
    tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=log_prob,
        num_leapfrog_steps=8,
        step_size=1.),
    num_adaptation_steps=int(num_burnin_steps * 0.8))



In [5]:
# Run the chain (with burn-in).
@tf.function
def run_chain(kernel,n_samples,n_burnin):
  # Run the chain (with burn-in).
    samples, is_accepted = tfp.mcmc.sample_chain(
      num_results=n_samples,
      num_burnin_steps=n_burnin,
      current_state = np.random.randn(*z_shape).astype('float32'),
      kernel=kernel,
      parallel_iterations = n_chains,
    )

    return samples

### Short run

In [None]:
start = time.time()
num_results = 100
num_burnin_steps = 100
total_iter = num_results + num_burnin_steps
samples = run_chain(remc,num_results,num_burnin_steps)
end = time.time()
total = end - start
rate = total_iter/total
print(f'{rate} samples drawn per second.')



Instructions for updating:
This op will be removed after the deprecation date. Please switch to tf.sets.difference().


In [9]:
start = time.time()
num_results = 100
num_burnin_steps = 100
total_iter = num_results + num_burnin_steps
samples = run_chain(adaptive_hmc,num_results,num_burnin_steps)
end = time.time()
total = end - start
rate = total_iter/total
print(f'{rate} samples drawn per second.')



9.448367300124692 samples drawn per second.


In [10]:
sampled_x = np.zeros([samples.shape[0],samples.shape[1],28,28,1])

In [11]:
for i in range(sampled_x.shape[0]):
    sampled_x[i] = tf.math.sigmoid(generator(samples[i])).numpy()

In [None]:
from visualize import samples2gif

output_path = '../data/visualizations/hmc_chain.gif'
samples2gif(sampled_x[::500], output_path, 10, 10)

import utils
import matplotlib.pyplot as plt
for i in np.linspace(0,num_results-1,10):
    idx = int(i)
    plt.figure()
    x = tf.math.sigmoid(generator(samples[idx])).numpy().squeeze()
    plt.imshow(utils.flatten_image_batch(x,3,3))

In [12]:
from IPython.display import HTML
HTML('<img src="../data/visualizations/hmc_chain.gif">')