In [2]:
import os
import sys
import numpy as np

import matplotlib as mpl
mpl.use("Agg")

import tensorflow_probability as tfp
from tensorflow_probability import edward2 as ed
import tensorflow as tf


In [23]:
m = 3.5
c = 2

M = 100
xmin = 0.
xmax = 10.
stepsize = (xmax-xmin)/M
x = np.arange(xmin, xmax, stepsize)

def straight_line(x, m, c):
    return m*x + c

sigma = 0.5
data = straight_line(x, m, c) + sigma * np.random.randn(M)

In [34]:
cmin, cmax = -10., 10.
mmu = 0.
msigma = 10.

# create a log-likelihood function
def log_likelihood(x, cmin, mmu, msigma, sigma):
    m = ed.Normal(loc=mmu, scale=msigma, name="m")
    c = ed.Poisson(cmin, name="c")
    y = ed.Normal(loc=(m*x + c), scale=sigma, name="y")
    return y

 
qm = tf.random_normal(shape=[], mean=mmu, stddev=msigma, dtype=tf.float32)
qc = tf.random_poisson(shape=[], lam=10, dtype=tf.int32)

In [35]:
x = tf.convert_to_tensor(x, dtype=tf.float32)
data = tf.convert_to_tensor(data, dtype=tf.float32)

In [36]:
log_joint = ed.make_log_joint_fn(log_likelihood)

In [39]:
def target_log_prob_fn(m, c):
    """Target log-probability as a function of states."""
    return log_joint(x, cmin, mmu, msigma, sigma, m=m, c=c,
                     y=data)

Nsamples = 2000
Nburn = 2000

# set up Hamiltonian MC
hmc_kernel = tfp.mcmc.MetropolisHastings(
    target_log_prob_fn=target_log_prob_fn,
    step_size=0.01,
    num_leapfrog_steps=5)

states, kernel_results = tfp.mcmc.sample_chain(
    num_results=Nsamples,
    current_state=[qm, qc],
    kernel=hmc_kernel,
    num_burnin_steps=Nburn)

TypeError: __init__() got an unexpected keyword argument 'target_log_prob_fn'

In [22]:
with tf.Session() as sess:
    states, is_accepted_ = sess.run([states, kernel_results.is_accepted])
    accepted = np.sum(is_accepted_)
    print("Acceptance rate: {}".format(accepted / Nsamples))

results = dict(zip(['m', 'c'], states))

postsamples = np.vstack((results['m'], results['c'])).T
print(postsamples)

Acceptance rate: 0.0
[[ 3.2864695 -5.240934 ]
 [ 3.2864695 -5.240934 ]
 [ 3.2864695 -5.240934 ]
 ...
 [ 3.2864695 -5.240934 ]
 [ 3.2864695 -5.240934 ]
 [ 3.2864695 -5.240934 ]]
