From 03ec78229920a0a699c15bbee2fc67e549bb9951 Mon Sep 17 00:00:00 2001 From: Nick Golding Date: Mon, 12 Nov 2018 10:51:04 +1100 Subject: [PATCH] fixes to match new TFP API for mcmc sampling --- DESCRIPTION | 2 +- NEWS.md | 5 +++++ R/inference_class.R | 7 ++++--- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index bf263488..99d38f33 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,7 +1,7 @@ Package: greta Type: Package Title: Simple and Scalable Statistical Modelling in R -Version: 0.3.0 +Version: 0.3.0.9001 Date: 2018-10-30 Authors@R: c( person("Nick", "Golding", role = c("aut", "cre"), diff --git a/NEWS.md b/NEWS.md index 845d7cbd..ffd56b6e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,8 @@ +# greta 0.3.0.9001 + +* mcmc now works with TensorFlow Probability version 0.5.0 (#248) + + # greta 0.3.0 This is a very large update which adds a number of features and major speed improvements. We now depend on the TensorFlow Probability Python package, and use functionality in that package wherever possible. Sampling a simple model now takes ~10s, rather than ~2m (>10x speedup). diff --git a/R/inference_class.R b/R/inference_class.R index 0b65e3ae..06f9e2c7 100644 --- a/R/inference_class.R +++ b/R/inference_class.R @@ -629,8 +629,9 @@ sampler <- R6Class( num_results = sampler_burst_length %/% sampler_thin, current_state = free_state, kernel = sampler_kernel, - num_burnin_steps = 0L, - num_steps_between_results = sampler_thin) + num_burnin_steps = tf$constant(0L, dtype = tf$int64), + num_steps_between_results = tf$cast(sampler_thin, tf$int64), + parallel_iterations = 1L) ) }, @@ -763,7 +764,7 @@ hmc_sampler <- R6Class( # tensors for sampler parameters dag$tf_run(hmc_epsilon <- tf$placeholder(dtype = tf_float())) - dag$tf_run(hmc_L <- tf$placeholder(dtype = tf$int32)) + dag$tf_run(hmc_L <- tf$placeholder(dtype = tf$int64)) # need to pass in the value for this placeholder as a matrix (shape(n, 1)) dag$tf_run(