Skip to content

Commit

Permalink
trying out using input_signature
Browse files Browse the repository at this point in the history
  • Loading branch information
njtierney committed Jun 24, 2022
1 parent 14d3fac commit 4303bbb
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion R/inference_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,21 @@ sampler <- R6Class(
# define_tf_draws is now used in place of of run_burst
# self$define_tf_draws()
self$tf_evaluate_sample_batch <- tensorflow::tf_function(
self$define_tf_draws
f = self$define_tf_draws,
input_signature = list(
# free state
tf$TensorSpec(shape = list(NULL, self$n_free),
dtype = tf_float()),
# sampler_burst_length
tf$TensorSpec(shape = list(1L),
dtype = tf$int32),
# sampler_thin
tf$TensorSpec(shape = list(1L),
dtype = tf$int32),
# sampler_param_vec
tf$TensorSpec(shape = list(length(unlist(self$sampler_parameter_values()))),
dtype = tf_float())
)
)
},
run_chain = function(n_samples, thin, warmup,
Expand Down

0 comments on commit 4303bbb

Please sign in to comment.