Skip to content

Commit

Permalink
fixes to match new TFP API for mcmc sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
goldingn committed Nov 11, 2018
1 parent a4c63a3 commit 03ec782
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
@@ -1,7 +1,7 @@
Package: greta
Type: Package
Title: Simple and Scalable Statistical Modelling in R
Version: 0.3.0
Version: 0.3.0.9001
Date: 2018-10-30
Authors@R: c(
person("Nick", "Golding", role = c("aut", "cre"),
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
@@ -1,3 +1,8 @@
# greta 0.3.0.9001

* mcmc now works with TensorFlow Probability version 0.5.0 (#248)


# greta 0.3.0

This is a very large update which adds a number of features and major speed improvements. We now depend on the TensorFlow Probability Python package, and use functionality in that package wherever possible. Sampling a simple model now takes ~10s, rather than ~2m (>10x speedup).
Expand Down
7 changes: 4 additions & 3 deletions R/inference_class.R
Expand Up @@ -629,8 +629,9 @@ sampler <- R6Class(
num_results = sampler_burst_length %/% sampler_thin,
current_state = free_state,
kernel = sampler_kernel,
num_burnin_steps = 0L,
num_steps_between_results = sampler_thin)
num_burnin_steps = tf$constant(0L, dtype = tf$int64),
num_steps_between_results = tf$cast(sampler_thin, tf$int64),
parallel_iterations = 1L)
)

},
Expand Down Expand Up @@ -763,7 +764,7 @@ hmc_sampler <- R6Class(

# tensors for sampler parameters
dag$tf_run(hmc_epsilon <- tf$placeholder(dtype = tf_float()))
dag$tf_run(hmc_L <- tf$placeholder(dtype = tf$int32))
dag$tf_run(hmc_L <- tf$placeholder(dtype = tf$int64))

# need to pass in the value for this placeholder as a matrix (shape(n, 1))
dag$tf_run(
Expand Down

0 comments on commit 03ec782

Please sign in to comment.