Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dealing with API changes in TF 2.0 #306

Closed
goldingn opened this issue Aug 8, 2019 · 47 comments
Closed

Dealing with API changes in TF 2.0 #306

goldingn opened this issue Aug 8, 2019 · 47 comments
Milestone

Comments

@goldingn
Copy link
Member

goldingn commented Aug 8, 2019

This issue is about planning ahead to dealing with major TensorFlow API changes when a greta version has to work with TensorFlow 2.0. It doesn't affect anything yet.

optimisers

Tensorflow's various optimiser interfaces are being unified (details here). Some of the old versions will still be available in the compat compatibility module (though there may be speed benefit to changing to more recent versions).
In addition, the TF 1.x contrib module will no longer be available in TF 2.0, and won't be in compat.

There is a tensorflow addons repo with an 'optimizers' module (though with no apparent overlap with existing optimisers) and a TFP 'optimizer' module which has a couple of the methods.

greta optimiser TF 1.x function TF2.0 function
nelder_mead tf$contrib$opt$ScipyOptimizerInterface tfp$optimizer$nelder_mead_minimize
powell tf$contrib$opt$ScipyOptimizerInterface ?
cg tf$contrib$opt$ScipyOptimizerInterface ?
bfgs tf$contrib$opt$ScipyOptimizerInterface tfp$optimizer$bfgs_minimize
newton_cg tf$contrib$opt$ScipyOptimizerInterface ?
l_bfgs_b tf$contrib$opt$ScipyOptimizerInterface ?
tnc tf$contrib$opt$ScipyOptimizerInterface ?
cobyla tf$contrib$opt$ScipyOptimizerInterface ?
slsqp tf$contrib$opt$ScipyOptimizerInterface ?
gradient_descent tf$compat$v1$train$GradientDescentOptimizer tf$keras$optimizers$SGD
adadelta tf$compat$v1$train$AdadeltaOptimizer tf$keras$optimizers$Adadelta
adagrad tf$compat$v1$train$AdagradOptimizer tf$keras$optimizers$Adagrad
adagrad_da tf$compat$v1$train$AdagradDAOptimizer tf$compat$v1$train$AdagradDAOptimizer
momentum tf$compat$v1$train$MomentumOptimizer tf$keras$optimizers$SGD
adam tf$compat$v1$train$AdamOptimizer tf$keras$optimizers$Adam
ftrl tf$compat$v1$train$FtrlOptimizer tf$keras$optimizers$FTRL
proximal_gradient_descent tf$compat$v1$train$ProximalGradientDescentOptimizer tf$compat$v1$train$ProximalGradientDescentOptimizer
proximal_adagrad tf$compat$v1$train$ProximalAdagradOptimizer tf$compat$v1$train$ProximalAdagradOptimizer
rms_prop tf$compat$v1$train$RMSPropOptimizer tf$keras$optimizers$RMSProp (though arguments may have changed, so tf$compat$v1$train$RMSPropOptimizer as a backup)

Unless another interface to the Scipy optimisers is developed, the optimisers without replacements will have to be removed from the API. It'll bee a little hard to deprecate them without knowing whether a replacement will be coming.

ode solvers

The greta.dynamics package's ode_solve() functionality also depends on contrib, using the tf$contrib$integrate$odeint and odeint_fixed methods. tfp$math$ode$Solver$solve may be a viable replacement.

@goldingn goldingn changed the title use of contrib functions in TF 2.0 Dealing with API changes in TF 2.0 Aug 8, 2019
@goldingn
Copy link
Member Author

According to this page TF 2.0 release is planned for Q2 2019, which has passed, and a beta is available. So I'm assuming the release will be happening soon.

@goldingn
Copy link
Member Author

goldingn commented Aug 10, 2019

Added deprecation warnings to the irreplaceable optimisers in 27b1ed3

@goldingn
Copy link
Member Author

goldingn commented Aug 10, 2019

Suppressed TensorFlow warning about contrib being deprecated in 87b29cb

@goldingn
Copy link
Member Author

It would be worthwhile installing the v2.0 beta and attempting to test the current version of the package to look for the biggest pain points.

@goldingn
Copy link
Member Author

There's an experimental branch for running greta on TF 2.0.0. It's currently failing, apparently because of some eager-mode thing.

@cboettig
Copy link

cboettig commented Dec 7, 2020

any updates?

@goldingn
Copy link
Member Author

goldingn commented Dec 9, 2020

This will be a top priority for the new greta dev, who will be starting at the end of March. I haven't had time to look at this this year. It's all doable, the biggest hurdling is switching everything internally from graph mode to eager mode 🙄

@cboettig
Copy link

cboettig commented Dec 9, 2020

thanks, looking forward to it! Yup, having eager execution instead of reserving all the GPU memory will be much more friendly for multi-tasking on the GPU, as will support for python 3.8 (now the default flavor on the current ubuntu LTS)

@goldingn
Copy link
Member Author

goldingn commented Dec 9, 2020

Good to know. I hadn't realised there were performance improvements with eager mode!
We'll also be looking at adding in other backends for greta, including PyTorch. Should be a lot of cool stuff we can do!

@skeydan
Copy link

skeydan commented Dec 9, 2020

@goldingn exciting news :-)

You've seen the new python-less torch for R, right? e.g. this intro series https://blogs.rstudio.com/ai/posts/2020-10-09-torch-optim/

Let us know how we can help :-)

cc @dfalbel

@goldingn
Copy link
Member Author

goldingn commented Dec 9, 2020

We have! Cutting that python installation step out is the main motivation. Though the new reticulate python dependency system (#357) looks like it will be helpful for TF too. Looking forward to having some time and some help to lessen greta's installation pain.

@cboettig
Copy link

cboettig commented Dec 9, 2020

Sweet, I hadn't seen the python-less torch either. Very nice! The new reticulate R packaging stuff is also rather slick.

Antecedently, The (python) libraries we rely on for RL (stable-baselines) recently switched from tensorflow 1.x to pytorch though and it's been a huge improvement on the GPU side. not only do we get the eager execution / shared GPU, but it's also a lot less finicky about being tied to precise version of some cuda lib. No idea how well the torch suite matches up with what you need on the greta end though.

@njtierney njtierney added the 0.5.0 label Jul 2, 2021
@njtierney njtierney modified the milestones: 0.4.0, 0.5.0 Jul 2, 2021
@njtierney njtierney removed the 0.5.0 label Jul 2, 2021
@njtierney
Copy link
Collaborator

One possible path forwards for this:

  • Use TF 2.0.0 - 2.3.0 turn off eager evaluation mode, to get us part way there (Apparently 2.6.0 is out according to this SO thread) TF 2.6.0 doesn't allow you to turn off eager mode)
  • TF 2.0.0 has a native build on ARM (M1) macs, this would help resolve Fix M1 greta installation issues #458 if we can specify exactly TF 2.0.0 for install in the first instance

@njtierney
Copy link
Collaborator

It seems like it might be possible to switch off eager mode (maybe!) in 2.6 or 2.7 - rstudio/tensorflow#498 - however there are still issues with installing all of the python requisites on an M1 mac.

@njtierney
Copy link
Collaborator

OK so we can happily create normal and other distributions, but something funky happens in I believe the log jacobian, where something is coerced to float64 that greta is expecting to be an integer

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply

# initialise greta
normal(0,1)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✓ Initialising python and checking dependencies ... done!
#> 
#> ℹ Disabling TF eager execution with: `tf$compat$v1$disable_eager_execution()`
#> greta array (variable following a normal distribution)
#> 
#>      [,1]
#> [1,]  ?

# simple model
m <- model(normal(0, 1))
#> Loaded Tensorflow version 2.6.0
#> Error in py_call_impl(callable, dots$args, dots$keywords): ValueError: Expected integer, got dtype <class 'float'>.

Created on 2022-01-31 by the reprex package (v2.0.1)

Also there's an annoying "Tensorflow 2.6.0 Loaded" command that appears...I swear that was supposed to happen earlier in the install process

@njtierney
Copy link
Collaborator

Tried to add more of the "disabling" features of V1 into version 2, but this did not seem to stop this overall error:

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply

# initialise greta
normal(0,1)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✓ Initialising python and checking dependencies ... done!
#> 
#> ℹ Disabling TF eager execution with: `tf$compat$v1$disable_eager_execution()`
#> ℹ Disabling TF V2 behaviour: `tf$compat$v1$disable_v2_behavior()`
#> ℹ Disabling TF V2 tensorshape: `tf$compat$v1$disable_v2_tensorshape()`
#> ℹ Disabling TF V2 control flow: `tf$compat$v1$disable_control_flow_v2()`
#> ℹ Disabling TF V2 resource variables:
#>   `tf$compat$v1$disable_resource_variables()`
#> ℹ Disabling TF V2 tensor equality: `tf$compat$v1$disable_tensor_equality()`
#> greta array (variable following a normal distribution)
#> 
#>      [,1]
#> [1,]  ?

# simple model
m <- model(normal(0, 1))
#> Loaded Tensorflow version 2.6.0
#> Error in py_call_impl(callable, dots$args, dots$keywords): ValueError: Expected integer, got dtype <class 'float'>.

Created on 2022-01-31 by the reprex package (v2.0.1)

At least it's a consistent bug?

Note that these changes are being implemented/tested in this branch: #482

@njtierney
Copy link
Collaborator

From memory, the instinct is to debug at the log_jacobian_bijector part, but the right way to debug is to step inside the variable or node creation and compare step by step there.

@njtierney
Copy link
Collaborator

OK turns out that a bunch of the errors were actually because this PR wasn't in sync with master, so we were dealing with the issues from #463

Now...we have a new set of spicier issues.

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply

# initialise greta
normal(0,1)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✓ Initialising python and checking dependencies ... done!
#> 
#> ℹ Disabling TF eager execution with: `tf$compat$v1$disable_eager_execution()`
#> greta array (variable following a normal distribution)
#> 
#>      [,1]
#> [1,]  ?

# simple model
m <- model(normal(0, 1))
#> Loaded Tensorflow version 2.6.0

draws <- mcmc(m) 
#> Error in py_call_impl(callable, dots$args, dots$keywords): TypeError: __init__() got an unexpected keyword argument 'seed'

Created on 2022-02-07 by the reprex package (v2.0.1)

And

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply
x <- rnorm(10)
y <- as_data(rnorm(10))
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✓ Initialising python and checking dependencies ... done!
#> 
#> ℹ Disabling TF eager execution with: `tf$compat$v1$disable_eager_execution()`
alpha <- normal(0,1)
beta <- normal(0,1)
mu <- alpha + beta * x
sigma <- normal(0, 0.1, truncation = c(0, 0.5))
distribution(y) <- normal(mu, sigma)
m <- model(y)
#> Loaded Tensorflow version 2.6.0
#> Error in py_get_attr_impl(x, name, silent): AttributeError: module 'tensorflow_probability.python.bijectors' has no attribute 'AffineScalar'

Created on 2022-02-07 by the reprex package (v2.0.1)

@njtierney
Copy link
Collaborator

More curiousities arise - working on #482

When debugging:

mcmc(model(normal(0,1))

I get:

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply

# initialise greta
mcmc(model(normal(0,1)))
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✓ Initialising python and checking dependencies ... done!
#> 
#> ℹ Disabling TF eager execution with: `tf$compat$v1$disable_eager_execution()`
#> Loaded Tensorflow version 2.6.0
#> Error in py_call_impl(callable, dots$args, dots$keywords): TypeError: __init__() got an unexpected keyword argument 'seed'

Created on 2022-02-07 by the reprex package (v2.0.1)

using options(error = recover) and dropping into frame number 18, I get the following error message:

Selection: 18
Called from: top level 
Error during wrapup: tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.

Error: no more error handlers available (recursive errors?); invoking 'abort' restart

Initially, I thought that this meant we were successfully disabling eager evaluation, with the code:

tf$compat$v1$disable_eager_execution()

I tested this again by removing that part of the startup/check process when greta initialises python and checks if TF is installed etc.

Unfortunately, I got exactly the same error message.

So I'm not sure how to approach things from here, and I'm not sure if

tf$compat$v1$disable_eager_execution()

Is actually disabling eager execution...or it's doing something spooky

Or if we need to work out a new way to pass arguments through to tfp.

@njtierney
Copy link
Collaborator

njtierney commented Mar 17, 2022

So what I notice is that on M1 mac, I get the same error, regardless of turning on eager mode or not. Here I control with an environment variable whether eager mode is turned on,

Below is it being turned off

# install with: 
# remotes::install_github("greta-dev/greta#482)
Sys.setenv("GRETA_ENABLE_EAGER"=FALSE)
library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply

# check various python deps are installed correctly
greta_sitrep()
#> ℹ checking if python available
#> ✓ python (version 3.8) available
#> 
#> ℹ checking if TensorFlow available
#> ✓ TensorFlow (version 2.6.0) available
#> 
#> ℹ checking if TensorFlow Probability available
#> ✓ TensorFlow Probability (version 0.14.1) available
#> 
#> ℹ checking if greta conda environment available
#> ✓ greta conda environment available
#> 
#> Warning: greta does not currently work with Apple Silicon (M1)
#> We are working on getting this resolved ASAP, see
#> <https://github.com/greta-dev/greta/issues/458> for current progress.
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✓ Initialising python and checking dependencies ... done!
#> 
#> eager flag value is FALSE
#> ℹ Disabling TF eager execution with: `tf$compat$v1$disable_eager_execution()`
#> ℹ Show current state of eager execution with:
#>   `tf$compat$v1$executing_eagerly()`
#> Loaded Tensorflow version 2.6.0
# you may need to do
# install_greta_deps()
tensorflow::tf$compat$v1$executing_eagerly()
#> [1] FALSE
m <- model(normal(0,1))
#> Warning: greta does not currently work with Apple Silicon (M1)
#> We are working on getting this resolved ASAP, see
#> <https://github.com/greta-dev/greta/issues/458> for current progress.

#> Warning: greta does not currently work with Apple Silicon (M1)
#> We are working on getting this resolved ASAP, see
#> <https://github.com/greta-dev/greta/issues/458> for current progress.

mcmc(m)
#> Error in py_call_impl(callable, dots$args, dots$keywords): RuntimeError: Evaluation error: ValueError: Tensor("mcmc_sample_chain/trace_scan/while/mcmc_sample_chain_trace_scan_while_chain_of_mcmc_sample_chain_trace_scan_while_reshape_of_mcmc_sample_chain_trace_scan_while_identity/forward_log_det_jacobian/mcmc_sample_chain_trace_scan_while_reshape/forward/Const:0", shape=(2,), dtype=int32) must be from the same graph as Tensor("leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/Shape:0", shape=(2,), dtype=int32) (graphs are <tensorflow.python.framework.ops.Graph object at 0x14e9feee0> and FuncGraph(name=mh_one_step_hmc_kernel_one_step_leapfrog_integrate_while_body_440, id=5902822128)).
#> .

Created on 2022-03-17 by the reprex package (v2.0.1)

Session info
sessioninfo::session_info()
#> ─ Session info  ↘️  👌  🌧️   ───────────────────────────────────────────────────
#>  hash: down-right arrow, OK hand, cloud with rain
#> 
#>  setting  value
#>  version  R version 4.1.2 (2021-11-01)
#>  os       macOS Big Sur 11.2.2
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_AU.UTF-8
#>  ctype    en_AU.UTF-8
#>  tz       Australia/Perth
#>  date     2022-03-17
#>  pandoc   2.16.1 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package     * version    date (UTC) lib source
#>  abind         1.4-5      2016-07-21 [1] CRAN (R 4.1.0)
#>  backports     1.4.1      2021-12-13 [1] CRAN (R 4.1.1)
#>  base64enc     0.1-3      2015-07-28 [1] CRAN (R 4.1.0)
#>  callr         3.7.0      2021-04-20 [1] CRAN (R 4.1.0)
#>  cli           3.2.0      2022-02-14 [1] CRAN (R 4.1.1)
#>  coda          0.19-4     2020-09-30 [1] CRAN (R 4.1.0)
#>  codetools     0.2-18     2020-11-04 [1] CRAN (R 4.1.2)
#>  crayon        1.5.0      2022-02-14 [1] CRAN (R 4.1.1)
#>  digest        0.6.29     2021-12-01 [1] CRAN (R 4.1.1)
#>  ellipsis      0.3.2      2021-04-29 [1] CRAN (R 4.1.0)
#>  evaluate      0.15       2022-02-18 [1] CRAN (R 4.1.1)
#>  fansi         1.0.2      2022-01-14 [1] CRAN (R 4.1.1)
#>  fastmap       1.1.0      2021-01-25 [1] CRAN (R 4.1.0)
#>  fs            1.5.2      2021-12-08 [1] CRAN (R 4.1.1)
#>  future        1.23.0     2021-10-31 [1] CRAN (R 4.1.1)
#>  globals       0.14.0     2020-11-22 [1] CRAN (R 4.1.0)
#>  glue          1.6.2      2022-02-24 [1] CRAN (R 4.1.1)
#>  greta       * 0.4.1.9000 2022-03-17 [1] local
#>  here          1.0.1      2020-12-13 [1] CRAN (R 4.1.0)
#>  highr         0.9        2021-04-16 [1] CRAN (R 4.1.0)
#>  hms           1.1.1      2021-09-26 [1] CRAN (R 4.1.1)
#>  htmltools     0.5.2      2021-08-25 [1] CRAN (R 4.1.1)
#>  jsonlite      1.8.0      2022-02-22 [1] CRAN (R 4.1.1)
#>  knitr         1.37       2021-12-16 [1] CRAN (R 4.1.1)
#>  lattice       0.20-45    2021-09-22 [1] CRAN (R 4.1.2)
#>  lifecycle     1.0.1      2021-09-24 [1] CRAN (R 4.1.1)
#>  listenv       0.8.0      2019-12-05 [1] CRAN (R 4.1.0)
#>  magrittr      2.0.2      2022-01-26 [1] CRAN (R 4.1.1)
#>  Matrix        1.3-4      2021-06-01 [1] CRAN (R 4.1.2)
#>  parallelly    1.30.0     2021-12-17 [1] CRAN (R 4.1.1)
#>  pillar        1.7.0      2022-02-01 [1] CRAN (R 4.1.1)
#>  pkgconfig     2.0.3      2019-09-22 [1] CRAN (R 4.1.0)
#>  png           0.1-7      2013-12-03 [1] CRAN (R 4.1.0)
#>  prettyunits   1.1.1      2020-01-24 [1] CRAN (R 4.1.0)
#>  processx      3.5.2      2021-04-30 [1] CRAN (R 4.1.0)
#>  progress      1.2.2      2019-05-16 [1] CRAN (R 4.1.0)
#>  ps            1.6.0      2021-02-28 [1] CRAN (R 4.1.0)
#>  purrr         0.3.4      2020-04-17 [1] CRAN (R 4.1.0)
#>  R.cache       0.15.0     2021-04-30 [1] CRAN (R 4.1.0)
#>  R.methodsS3   1.8.1      2020-08-26 [1] CRAN (R 4.1.0)
#>  R.oo          1.24.0     2020-08-26 [1] CRAN (R 4.1.0)
#>  R.utils       2.11.0     2021-09-26 [1] CRAN (R 4.1.1)
#>  R6            2.5.1      2021-08-19 [1] CRAN (R 4.1.1)
#>  Rcpp          1.0.8      2022-01-13 [1] CRAN (R 4.1.1)
#>  reprex        2.0.1      2021-08-05 [1] CRAN (R 4.1.1)
#>  reticulate    1.24-9000  2022-01-28 [1] Github (rstudio/reticulate@8347417)
#>  rlang         1.0.2      2022-03-04 [1] CRAN (R 4.1.1)
#>  rmarkdown     2.11       2021-09-14 [1] CRAN (R 4.1.1)
#>  rprojroot     2.0.2      2020-11-15 [1] CRAN (R 4.1.0)
#>  rstudioapi    0.13       2020-11-12 [1] CRAN (R 4.1.0)
#>  sessioninfo   1.2.1      2021-11-02 [1] CRAN (R 4.1.1)
#>  stringi       1.7.6      2021-11-29 [1] CRAN (R 4.1.1)
#>  stringr       1.4.0      2019-02-10 [1] CRAN (R 4.1.1)
#>  styler        1.6.2      2021-09-23 [1] CRAN (R 4.1.1)
#>  tensorflow    2.7.0.9000 2022-01-28 [1] Github (rstudio/tensorflow@ff9eaeb)
#>  tfruns        1.5.0      2021-02-26 [1] CRAN (R 4.1.0)
#>  tibble        3.1.6      2021-11-07 [1] CRAN (R 4.1.1)
#>  utf8          1.2.2      2021-07-24 [1] CRAN (R 4.1.0)
#>  vctrs         0.3.8      2021-04-29 [1] CRAN (R 4.1.0)
#>  whisker       0.4        2019-08-28 [1] CRAN (R 4.1.0)
#>  withr         2.5.0      2022-03-03 [1] CRAN (R 4.1.1)
#>  xfun          0.30       2022-03-02 [1] CRAN (R 4.1.1)
#>  yaml          2.3.5      2022-02-21 [1] CRAN (R 4.1.1)
#> 
#>  [1] /Library/Frameworks/R.framework/Versions/4.1-arm64/Resources/library
#> 
#> ─ Python configuration ───────────────────────────────────────────────────────
#>  python:         /Users/nick/Library/r-miniconda-arm64/envs/greta-env/bin/python
#>  libpython:      /Users/nick/Library/r-miniconda-arm64/envs/greta-env/lib/libpython3.8.dylib
#>  pythonhome:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env:/Users/nick/Library/r-miniconda-arm64/envs/greta-env
#>  version:        3.8.12 | packaged by conda-forge | (default, Jan 30 2022, 23:13:24)  [Clang 11.1.0 ]
#>  numpy:          /Users/nick/Library/r-miniconda-arm64/envs/greta-env/lib/python3.8/site-packages/numpy
#>  numpy_version:  1.19.5
#>  tensorflow:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env/lib/python3.8/site-packages/tensorflow
#>  
#>  NOTE: Python version was forced by use_python function
#> 
#> ──────────────────────────────────────────────────────────────────────────────

And here is it being turned on:

# Now do not turn it on
Sys.setenv("GRETA_ENABLE_EAGER"=TRUE)
library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply

# check various python deps are installed correctly
greta_sitrep()
#> ℹ checking if python available
#> ✓ python (version 3.8) available
#> 
#> ℹ checking if TensorFlow available
#> ✓ TensorFlow (version 2.6.0) available
#> 
#> ℹ checking if TensorFlow Probability available
#> ✓ TensorFlow Probability (version 0.14.1) available
#> 
#> ℹ checking if greta conda environment available
#> ✓ greta conda environment available
#> 
#> Warning: greta does not currently work with Apple Silicon (M1)
#> We are working on getting this resolved ASAP, see
#> <https://github.com/greta-dev/greta/issues/458> for current progress.
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✓ Initialising python and checking dependencies ... done!
#> 
#> eager flag value is TRUE
#> ℹ Disabling TF eager execution with: `tf$compat$v1$disable_eager_execution()`
#> Loaded Tensorflow version 2.6.0
tensorflow::tf$compat$v1$executing_eagerly()
#> [1] TRUE

m <- model(normal(0,1))
#> Warning: greta does not currently work with Apple Silicon (M1)
#> We are working on getting this resolved ASAP, see
#> <https://github.com/greta-dev/greta/issues/458> for current progress.

#> Warning: greta does not currently work with Apple Silicon (M1)
#> We are working on getting this resolved ASAP, see
#> <https://github.com/greta-dev/greta/issues/458> for current progress.

mcmc(m)
#> Error in py_call_impl(callable, dots$args, dots$keywords): RuntimeError: Evaluation error: ValueError: Tensor("mcmc_sample_chain/trace_scan/while/mcmc_sample_chain_trace_scan_while_chain_of_mcmc_sample_chain_trace_scan_while_reshape_of_mcmc_sample_chain_trace_scan_while_identity/forward_log_det_jacobian/mcmc_sample_chain_trace_scan_while_reshape/forward/Const:0", shape=(2,), dtype=int32) must be from the same graph as Tensor("leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/Shape:0", shape=(2,), dtype=int32) (graphs are <tensorflow.python.framework.ops.Graph object at 0x117977ee0> and FuncGraph(name=mh_one_step_hmc_kernel_one_step_leapfrog_integrate_while_body_439, id=5875312768)).
#> .

Created on 2022-03-17 by the reprex package (v2.0.1)

Session info
sessioninfo::session_info()
#> ─ Session info  ↘️  👌  🌧️   ───────────────────────────────────────────────────
#>  hash: down-right arrow, OK hand, cloud with rain
#> 
#>  setting  value
#>  version  R version 4.1.2 (2021-11-01)
#>  os       macOS Big Sur 11.2.2
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_AU.UTF-8
#>  ctype    en_AU.UTF-8
#>  tz       Australia/Perth
#>  date     2022-03-17
#>  pandoc   2.16.1 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package     * version    date (UTC) lib source
#>  abind         1.4-5      2016-07-21 [1] CRAN (R 4.1.0)
#>  backports     1.4.1      2021-12-13 [1] CRAN (R 4.1.1)
#>  base64enc     0.1-3      2015-07-28 [1] CRAN (R 4.1.0)
#>  callr         3.7.0      2021-04-20 [1] CRAN (R 4.1.0)
#>  cli           3.2.0      2022-02-14 [1] CRAN (R 4.1.1)
#>  coda          0.19-4     2020-09-30 [1] CRAN (R 4.1.0)
#>  codetools     0.2-18     2020-11-04 [1] CRAN (R 4.1.2)
#>  crayon        1.5.0      2022-02-14 [1] CRAN (R 4.1.1)
#>  digest        0.6.29     2021-12-01 [1] CRAN (R 4.1.1)
#>  ellipsis      0.3.2      2021-04-29 [1] CRAN (R 4.1.0)
#>  evaluate      0.15       2022-02-18 [1] CRAN (R 4.1.1)
#>  fansi         1.0.2      2022-01-14 [1] CRAN (R 4.1.1)
#>  fastmap       1.1.0      2021-01-25 [1] CRAN (R 4.1.0)
#>  fs            1.5.2      2021-12-08 [1] CRAN (R 4.1.1)
#>  future        1.23.0     2021-10-31 [1] CRAN (R 4.1.1)
#>  globals       0.14.0     2020-11-22 [1] CRAN (R 4.1.0)
#>  glue          1.6.2      2022-02-24 [1] CRAN (R 4.1.1)
#>  greta       * 0.4.1.9000 2022-03-17 [1] local
#>  here          1.0.1      2020-12-13 [1] CRAN (R 4.1.0)
#>  highr         0.9        2021-04-16 [1] CRAN (R 4.1.0)
#>  hms           1.1.1      2021-09-26 [1] CRAN (R 4.1.1)
#>  htmltools     0.5.2      2021-08-25 [1] CRAN (R 4.1.1)
#>  jsonlite      1.8.0      2022-02-22 [1] CRAN (R 4.1.1)
#>  knitr         1.37       2021-12-16 [1] CRAN (R 4.1.1)
#>  lattice       0.20-45    2021-09-22 [1] CRAN (R 4.1.2)
#>  lifecycle     1.0.1      2021-09-24 [1] CRAN (R 4.1.1)
#>  listenv       0.8.0      2019-12-05 [1] CRAN (R 4.1.0)
#>  magrittr      2.0.2      2022-01-26 [1] CRAN (R 4.1.1)
#>  Matrix        1.3-4      2021-06-01 [1] CRAN (R 4.1.2)
#>  parallelly    1.30.0     2021-12-17 [1] CRAN (R 4.1.1)
#>  pillar        1.7.0      2022-02-01 [1] CRAN (R 4.1.1)
#>  pkgconfig     2.0.3      2019-09-22 [1] CRAN (R 4.1.0)
#>  png           0.1-7      2013-12-03 [1] CRAN (R 4.1.0)
#>  prettyunits   1.1.1      2020-01-24 [1] CRAN (R 4.1.0)
#>  processx      3.5.2      2021-04-30 [1] CRAN (R 4.1.0)
#>  progress      1.2.2      2019-05-16 [1] CRAN (R 4.1.0)
#>  ps            1.6.0      2021-02-28 [1] CRAN (R 4.1.0)
#>  purrr         0.3.4      2020-04-17 [1] CRAN (R 4.1.0)
#>  R.cache       0.15.0     2021-04-30 [1] CRAN (R 4.1.0)
#>  R.methodsS3   1.8.1      2020-08-26 [1] CRAN (R 4.1.0)
#>  R.oo          1.24.0     2020-08-26 [1] CRAN (R 4.1.0)
#>  R.utils       2.11.0     2021-09-26 [1] CRAN (R 4.1.1)
#>  R6            2.5.1      2021-08-19 [1] CRAN (R 4.1.1)
#>  Rcpp          1.0.8      2022-01-13 [1] CRAN (R 4.1.1)
#>  reprex        2.0.1      2021-08-05 [1] CRAN (R 4.1.1)
#>  reticulate    1.24-9000  2022-01-28 [1] Github (rstudio/reticulate@8347417)
#>  rlang         1.0.2      2022-03-04 [1] CRAN (R 4.1.1)
#>  rmarkdown     2.11       2021-09-14 [1] CRAN (R 4.1.1)
#>  rprojroot     2.0.2      2020-11-15 [1] CRAN (R 4.1.0)
#>  rstudioapi    0.13       2020-11-12 [1] CRAN (R 4.1.0)
#>  sessioninfo   1.2.1      2021-11-02 [1] CRAN (R 4.1.1)
#>  stringi       1.7.6      2021-11-29 [1] CRAN (R 4.1.1)
#>  stringr       1.4.0      2019-02-10 [1] CRAN (R 4.1.1)
#>  styler        1.6.2      2021-09-23 [1] CRAN (R 4.1.1)
#>  tensorflow    2.7.0.9000 2022-01-28 [1] Github (rstudio/tensorflow@ff9eaeb)
#>  tfruns        1.5.0      2021-02-26 [1] CRAN (R 4.1.0)
#>  tibble        3.1.6      2021-11-07 [1] CRAN (R 4.1.1)
#>  utf8          1.2.2      2021-07-24 [1] CRAN (R 4.1.0)
#>  vctrs         0.3.8      2021-04-29 [1] CRAN (R 4.1.0)
#>  whisker       0.4        2019-08-28 [1] CRAN (R 4.1.0)
#>  withr         2.5.0      2022-03-03 [1] CRAN (R 4.1.1)
#>  xfun          0.30       2022-03-02 [1] CRAN (R 4.1.1)
#>  yaml          2.3.5      2022-02-21 [1] CRAN (R 4.1.1)
#> 
#>  [1] /Library/Frameworks/R.framework/Versions/4.1-arm64/Resources/library
#> 
#> ─ Python configuration ───────────────────────────────────────────────────────
#>  python:         /Users/nick/Library/r-miniconda-arm64/envs/greta-env/bin/python
#>  libpython:      /Users/nick/Library/r-miniconda-arm64/envs/greta-env/lib/libpython3.8.dylib
#>  pythonhome:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env:/Users/nick/Library/r-miniconda-arm64/envs/greta-env
#>  version:        3.8.12 | packaged by conda-forge | (default, Jan 30 2022, 23:13:24)  [Clang 11.1.0 ]
#>  numpy:          /Users/nick/Library/r-miniconda-arm64/envs/greta-env/lib/python3.8/site-packages/numpy
#>  numpy_version:  1.19.5
#>  tensorflow:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env/lib/python3.8/site-packages/tensorflow
#>  
#>  NOTE: Python version was forced by use_python function
#> 
#> ──────────────────────────────────────────────────────────────────────────────

So there is probably some other update here in the API that I need to consider

@njtierney njtierney removed this from the 0.7.0 milestone Mar 28, 2022
@njtierney
Copy link
Collaborator

njtierney commented Jul 14, 2022

For relevance, here's what the input_signature specified in 4303bbb is

       input_signature = list(
          # free state
          tf$TensorSpec(shape = list(NULL, self$n_free),
                        dtype = tf_float()),
          # sampler_burst_length
          tf$TensorSpec(shape = list(1L),
                        dtype = tf$int32),
          # sampler_thin
          tf$TensorSpec(shape = list(1L),
                        dtype = tf$int32),
          # sampler_param_vec
          tf$TensorSpec(shape = list(length(unlist(self$sampler_parameter_values()))),
                        dtype = tf_float())
        )

And after that commit (here: a2d625b)

here's what these different parameter values are

parameter pre_input_signature with_input_signature
hmc_epsilon 0.0 Tensor("strided_slice:0", shape=(), dtype=float64)
hmc_diag_sd 1.0 Tensor("strided_slice_2:0", shape=(0), dtype=float64)
hmc_l 0.1 Tensor("strided_slice_1:0", shape=(), dtype=float64)

@njtierney
Copy link
Collaborator

So my understanding is that we either need to do something to the hmc_* objects, or we need to rewrite this reshape part - at https://github.com/njtierney/greta/blob/tf2-poke-tf-fun/R/inference_class.R#L951-L958

@njtierney
Copy link
Collaborator

OK so after a bit of poking around and fixing parts of input_signature, we've now got it down to a respectable 20 seconds or so (compared to about 3 seconds on a 2016 MBP), but an improvement from previous at the 60 seconds.

library(tictoc)
devtools::load_all(".")
#> ℹ Loading greta
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> 
#> ✔ Initialising python and checking dependencies ... done!
#> Loaded Tensorflow version 2.9.2
x <- normal(0, 1)
m <- model(x)
tic()
draws <- mcmc(m, n_samples = 500, warmup = 500)
#> running 4 chains simultaneously on up to 8 cores
#>     warmup                                            0/500 | eta:  ?s              warmup ====                                      50/500 | eta: 13s              warmup ========                                 100/500 | eta:  9s              warmup ============                             150/500 | eta:  7s              warmup ================                         200/500 | eta:  6s              warmup ====================                     250/500 | eta:  5s              warmup ========================                 300/500 | eta:  4s              warmup ============================             350/500 | eta:  3s              warmup ================================         400/500 | eta:  2s              warmup ====================================     450/500 | eta:  1s              warmup ======================================== 500/500 | eta:  0s          
#>   sampling                                            0/500 | eta:  ?s            sampling ====                                      50/500 | eta:  7s            sampling ========                                 100/500 | eta:  6s            sampling ============                             150/500 | eta:  6s            sampling ================                         200/500 | eta:  5s            sampling ====================                     250/500 | eta:  4s            sampling ========================                 300/500 | eta:  3s            sampling ============================             350/500 | eta:  2s            sampling ================================         400/500 | eta:  2s            sampling ====================================     450/500 | eta:  1s            sampling ======================================== 500/500 | eta:  0s
toc()
#> 17.709 sec elapsed
library(coda)
#> 
#> Attaching package: 'coda'
#> 
#> The following object is masked from 'package:greta':
#> 
#>     mcmc
plot(draws)

Created on 2022-07-26 by the reprex package (v2.0.1)

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.0 (2022-04-22)
#>  os       macOS Monterey 12.3.1
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_AU.UTF-8
#>  ctype    en_AU.UTF-8
#>  tz       Australia/Perth
#>  date     2022-07-26
#>  pandoc   2.17.1.1 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  ! package     * version    date (UTC) lib source
#>    abind         1.4-5      2016-07-21 [1] CRAN (R 4.2.0)
#>    backports     1.4.1      2021-12-13 [1] CRAN (R 4.2.0)
#>    base64enc     0.1-3      2015-07-28 [1] CRAN (R 4.2.0)
#>    brio          1.1.3      2021-11-30 [1] CRAN (R 4.2.0)
#>    cachem        1.0.6      2021-08-19 [1] CRAN (R 4.2.0)
#>    callr         3.7.0      2021-04-20 [1] CRAN (R 4.2.0)
#>    cli           3.3.0.9000 2022-06-15 [1] Github (r-lib/cli@31a5db5)
#>    coda        * 0.19-4     2020-09-30 [1] CRAN (R 4.2.0)
#>    codetools     0.2-18     2020-11-04 [1] CRAN (R 4.2.0)
#>    crayon        1.5.1      2022-03-26 [1] CRAN (R 4.2.0)
#>    curl          4.3.2      2021-06-23 [1] CRAN (R 4.2.0)
#>    desc          1.4.1      2022-03-06 [1] CRAN (R 4.2.0)
#>    devtools      2.4.3      2021-11-30 [1] CRAN (R 4.2.0)
#>    digest        0.6.29     2021-12-01 [1] CRAN (R 4.2.0)
#>    ellipsis      0.3.2      2021-04-29 [1] CRAN (R 4.2.0)
#>    evaluate      0.15       2022-02-18 [1] CRAN (R 4.2.0)
#>    fansi         1.0.3      2022-03-24 [1] CRAN (R 4.2.0)
#>    fastmap       1.1.0      2021-01-25 [1] CRAN (R 4.2.0)
#>    fs            1.5.2      2021-12-08 [1] CRAN (R 4.2.0)
#>    future        1.25.0     2022-04-24 [1] CRAN (R 4.2.0)
#>    globals       0.15.0     2022-05-09 [1] CRAN (R 4.2.0)
#>    glue          1.6.2      2022-02-24 [1] CRAN (R 4.2.0)
#>  P greta       * 0.4.2.9000 2022-05-26 [?] load_all()
#>    here          1.0.1      2020-12-13 [1] CRAN (R 4.2.0)
#>    highr         0.9        2021-04-16 [1] CRAN (R 4.2.0)
#>    hms           1.1.1      2021-09-26 [1] CRAN (R 4.2.0)
#>    htmltools     0.5.2      2021-08-25 [1] CRAN (R 4.2.0)
#>    httr          1.4.3      2022-05-04 [1] CRAN (R 4.2.0)
#>    jsonlite      1.8.0      2022-02-22 [1] CRAN (R 4.2.0)
#>    knitr         1.39       2022-04-26 [1] CRAN (R 4.2.0)
#>    lattice       0.20-45    2021-09-22 [1] CRAN (R 4.2.0)
#>    lifecycle     1.0.1      2021-09-24 [1] CRAN (R 4.2.0)
#>    listenv       0.8.0      2019-12-05 [1] CRAN (R 4.2.0)
#>    magrittr      2.0.3      2022-03-30 [1] CRAN (R 4.2.0)
#>    Matrix        1.4-1      2022-03-23 [1] CRAN (R 4.2.0)
#>    memoise       2.0.1      2021-11-26 [1] CRAN (R 4.2.0)
#>    mime          0.12       2021-09-28 [1] CRAN (R 4.2.0)
#>    parallelly    1.31.1     2022-04-22 [1] CRAN (R 4.2.0)
#>    pillar        1.7.0      2022-02-01 [1] CRAN (R 4.2.0)
#>    pkgbuild      1.3.1      2021-12-20 [1] CRAN (R 4.2.0)
#>    pkgconfig     2.0.3      2019-09-22 [1] CRAN (R 4.2.0)
#>    pkgload       1.3.0      2022-06-27 [1] CRAN (R 4.2.0)
#>    png           0.1-7      2013-12-03 [1] CRAN (R 4.2.0)
#>    prettyunits   1.1.1      2020-01-24 [1] CRAN (R 4.2.0)
#>    processx      3.6.1      2022-06-17 [1] CRAN (R 4.2.0)
#>    progress      1.2.2      2019-05-16 [1] CRAN (R 4.2.0)
#>    ps            1.7.1      2022-06-18 [1] CRAN (R 4.2.0)
#>    purrr         0.3.4      2020-04-17 [1] CRAN (R 4.2.0)
#>    R.cache       0.15.0     2021-04-30 [1] CRAN (R 4.2.0)
#>    R.methodsS3   1.8.1      2020-08-26 [1] CRAN (R 4.2.0)
#>    R.oo          1.24.0     2020-08-26 [1] CRAN (R 4.2.0)
#>    R.utils       2.11.0     2021-09-26 [1] CRAN (R 4.2.0)
#>    R6            2.5.1      2021-08-19 [1] CRAN (R 4.2.0)
#>    Rcpp          1.0.8.3    2022-03-17 [1] CRAN (R 4.2.0)
#>    remotes       2.4.2      2021-11-30 [1] CRAN (R 4.2.0)
#>    reprex        2.0.1      2021-08-05 [1] CRAN (R 4.2.0)
#>    reticulate    1.24-9000  2022-05-11 [1] Github (rstudio/reticulate@451fbff)
#>    rlang         1.0.3      2022-06-27 [1] CRAN (R 4.2.0)
#>    rmarkdown     2.14       2022-04-25 [1] CRAN (R 4.2.0)
#>    rprojroot     2.0.3      2022-04-02 [1] CRAN (R 4.2.0)
#>    rstudioapi    0.13       2020-11-12 [1] CRAN (R 4.2.0)
#>    sessioninfo   1.2.2      2021-12-06 [1] CRAN (R 4.2.0)
#>    stringi       1.7.6      2021-11-29 [1] CRAN (R 4.2.0)
#>    stringr       1.4.0      2019-02-10 [1] CRAN (R 4.2.0)
#>    styler        1.7.0      2022-03-13 [1] CRAN (R 4.2.0)
#>    tensorflow    2.9.0      2022-05-21 [1] CRAN (R 4.2.0)
#>    testthat    * 3.1.4      2022-04-26 [1] CRAN (R 4.2.0)
#>    tfautograph   0.3.2      2021-09-17 [1] CRAN (R 4.2.0)
#>    tfruns        1.5.0      2021-02-26 [1] CRAN (R 4.2.0)
#>    tibble        3.1.7      2022-05-03 [1] CRAN (R 4.2.0)
#>    tictoc      * 1.0.1      2021-04-19 [1] CRAN (R 4.2.0)
#>    usethis       2.1.5      2021-12-09 [1] CRAN (R 4.2.0)
#>    utf8          1.2.2      2021-07-24 [1] CRAN (R 4.2.0)
#>    vctrs         0.4.1      2022-04-13 [1] CRAN (R 4.2.0)
#>    whisker       0.4        2019-08-28 [1] CRAN (R 4.2.0)
#>    withr         2.5.0      2022-03-03 [1] CRAN (R 4.2.0)
#>    xfun          0.31       2022-05-10 [1] CRAN (R 4.2.0)
#>    xml2          1.3.3      2021-11-30 [1] CRAN (R 4.2.0)
#>    yaml          2.3.5      2022-02-21 [1] CRAN (R 4.2.0)
#>    yesno         0.1.2      2020-07-10 [1] CRAN (R 4.2.0)
#> 
#>  [1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library
#> 
#>  P ── Loaded and on-disk path mismatch.
#> 
#> ─ Python configuration ───────────────────────────────────────────────────────
#>  python:         /Users/nick/Library/r-miniconda-arm64/envs/greta-env/bin/python
#>  libpython:      /Users/nick/Library/r-miniconda-arm64/envs/greta-env/lib/libpython3.8.dylib
#>  pythonhome:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env:/Users/nick/Library/r-miniconda-arm64/envs/greta-env
#>  version:        3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:05:16)  [Clang 12.0.1 ]
#>  numpy:          /Users/nick/Library/r-miniconda-arm64/envs/greta-env/lib/python3.8/site-packages/numpy
#>  numpy_version:  1.22.4
#>  tensorflow:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env/lib/python3.8/site-packages/tensorflow
#>  
#>  NOTE: Python version was forced by use_python function
#> 
#> ──────────────────────────────────────────────────────────────────────────────

Not super happy with how well it is sampling though...

@njtierney
Copy link
Collaborator

The next step is to replace the tuning algorithm that we use internally from within R, to instead use the TFP version:

https://www.tensorflow.org/probability/api_docs/python/tfp/mcmc/DualAveragingStepSizeAdaptation

@njtierney
Copy link
Collaborator

OK so it turns out that we had hmc_l and hmc_epsilon indexed the wrong way 😅

So the upshot of this is that we now get proper sampling

library(tictoc)
devtools::load_all(".")
#> ℹ Loading greta
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> 
#> ✔ Initialising python and checking dependencies ... done!
#> Loaded Tensorflow version 2.9.2
x <- normal(0, 1)
m <- model(x)
tic()
draws <- mcmc(m, n_samples = 500, warmup = 500)
#> running 4 chains simultaneously on up to 8 cores
#>     warmup                                            0/500 | eta:  ?s              warmup ====                                      50/500 | eta: 27s              warmup ========                                 100/500 | eta: 22s              warmup ============                             150/500 | eta: 18s              warmup ================                         200/500 | eta: 15s              warmup ====================                     250/500 | eta: 12s              warmup ========================                 300/500 | eta: 10s              warmup ============================             350/500 | eta:  7s              warmup ================================         400/500 | eta:  5s              warmup ====================================     450/500 | eta:  2s              warmup ======================================== 500/500 | eta:  0s          
#>   sampling                                            0/500 | eta:  ?s            sampling ====                                      50/500 | eta: 30s            sampling ========                                 100/500 | eta: 27s            sampling ============                             150/500 | eta: 23s            sampling ================                         200/500 | eta: 18s            sampling ====================                     250/500 | eta: 14s            sampling ========================                 300/500 | eta: 11s            sampling ============================             350/500 | eta:  8s            sampling ================================         400/500 | eta:  5s            sampling ====================================     450/500 | eta:  3s            sampling ======================================== 500/500 | eta:  0s
toc()
#> 53.136 sec elapsed
library(coda)
#> 
#> Attaching package: 'coda'
#> 
#> The following object is masked from 'package:greta':
#> 
#>     mcmc
plot(draws)

Created on 2022-07-26 by the reprex package (v2.0.1)

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.0 (2022-04-22)
#>  os       macOS Monterey 12.3.1
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_AU.UTF-8
#>  ctype    en_AU.UTF-8
#>  tz       Australia/Perth
#>  date     2022-07-26
#>  pandoc   2.17.1.1 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  ! package     * version    date (UTC) lib source
#>    abind         1.4-5      2016-07-21 [1] CRAN (R 4.2.0)
#>    backports     1.4.1      2021-12-13 [1] CRAN (R 4.2.0)
#>    base64enc     0.1-3      2015-07-28 [1] CRAN (R 4.2.0)
#>    brio          1.1.3      2021-11-30 [1] CRAN (R 4.2.0)
#>    cachem        1.0.6      2021-08-19 [1] CRAN (R 4.2.0)
#>    callr         3.7.0      2021-04-20 [1] CRAN (R 4.2.0)
#>    cli           3.3.0.9000 2022-06-15 [1] Github (r-lib/cli@31a5db5)
#>    coda        * 0.19-4     2020-09-30 [1] CRAN (R 4.2.0)
#>    codetools     0.2-18     2020-11-04 [1] CRAN (R 4.2.0)
#>    crayon        1.5.1      2022-03-26 [1] CRAN (R 4.2.0)
#>    curl          4.3.2      2021-06-23 [1] CRAN (R 4.2.0)
#>    desc          1.4.1      2022-03-06 [1] CRAN (R 4.2.0)
#>    devtools      2.4.3      2021-11-30 [1] CRAN (R 4.2.0)
#>    digest        0.6.29     2021-12-01 [1] CRAN (R 4.2.0)
#>    ellipsis      0.3.2      2021-04-29 [1] CRAN (R 4.2.0)
#>    evaluate      0.15       2022-02-18 [1] CRAN (R 4.2.0)
#>    fansi         1.0.3      2022-03-24 [1] CRAN (R 4.2.0)
#>    fastmap       1.1.0      2021-01-25 [1] CRAN (R 4.2.0)
#>    fs            1.5.2      2021-12-08 [1] CRAN (R 4.2.0)
#>    future        1.25.0     2022-04-24 [1] CRAN (R 4.2.0)
#>    globals       0.15.0     2022-05-09 [1] CRAN (R 4.2.0)
#>    glue          1.6.2      2022-02-24 [1] CRAN (R 4.2.0)
#>  P greta       * 0.4.2.9000 2022-05-26 [?] load_all()
#>    here          1.0.1      2020-12-13 [1] CRAN (R 4.2.0)
#>    highr         0.9        2021-04-16 [1] CRAN (R 4.2.0)
#>    hms           1.1.1      2021-09-26 [1] CRAN (R 4.2.0)
#>    htmltools     0.5.2      2021-08-25 [1] CRAN (R 4.2.0)
#>    httr          1.4.3      2022-05-04 [1] CRAN (R 4.2.0)
#>    jsonlite      1.8.0      2022-02-22 [1] CRAN (R 4.2.0)
#>    knitr         1.39       2022-04-26 [1] CRAN (R 4.2.0)
#>    lattice       0.20-45    2021-09-22 [1] CRAN (R 4.2.0)
#>    lifecycle     1.0.1      2021-09-24 [1] CRAN (R 4.2.0)
#>    listenv       0.8.0      2019-12-05 [1] CRAN (R 4.2.0)
#>    magrittr      2.0.3      2022-03-30 [1] CRAN (R 4.2.0)
#>    Matrix        1.4-1      2022-03-23 [1] CRAN (R 4.2.0)
#>    memoise       2.0.1      2021-11-26 [1] CRAN (R 4.2.0)
#>    mime          0.12       2021-09-28 [1] CRAN (R 4.2.0)
#>    parallelly    1.31.1     2022-04-22 [1] CRAN (R 4.2.0)
#>    pillar        1.7.0      2022-02-01 [1] CRAN (R 4.2.0)
#>    pkgbuild      1.3.1      2021-12-20 [1] CRAN (R 4.2.0)
#>    pkgconfig     2.0.3      2019-09-22 [1] CRAN (R 4.2.0)
#>    pkgload       1.3.0      2022-06-27 [1] CRAN (R 4.2.0)
#>    png           0.1-7      2013-12-03 [1] CRAN (R 4.2.0)
#>    prettyunits   1.1.1      2020-01-24 [1] CRAN (R 4.2.0)
#>    processx      3.6.1      2022-06-17 [1] CRAN (R 4.2.0)
#>    progress      1.2.2      2019-05-16 [1] CRAN (R 4.2.0)
#>    ps            1.7.1      2022-06-18 [1] CRAN (R 4.2.0)
#>    purrr         0.3.4      2020-04-17 [1] CRAN (R 4.2.0)
#>    R.cache       0.15.0     2021-04-30 [1] CRAN (R 4.2.0)
#>    R.methodsS3   1.8.1      2020-08-26 [1] CRAN (R 4.2.0)
#>    R.oo          1.24.0     2020-08-26 [1] CRAN (R 4.2.0)
#>    R.utils       2.11.0     2021-09-26 [1] CRAN (R 4.2.0)
#>    R6            2.5.1      2021-08-19 [1] CRAN (R 4.2.0)
#>    Rcpp          1.0.8.3    2022-03-17 [1] CRAN (R 4.2.0)
#>    remotes       2.4.2      2021-11-30 [1] CRAN (R 4.2.0)
#>    reprex        2.0.1      2021-08-05 [1] CRAN (R 4.2.0)
#>    reticulate    1.24-9000  2022-05-11 [1] Github (rstudio/reticulate@451fbff)
#>    rlang         1.0.3      2022-06-27 [1] CRAN (R 4.2.0)
#>    rmarkdown     2.14       2022-04-25 [1] CRAN (R 4.2.0)
#>    rprojroot     2.0.3      2022-04-02 [1] CRAN (R 4.2.0)
#>    rstudioapi    0.13       2020-11-12 [1] CRAN (R 4.2.0)
#>    sessioninfo   1.2.2      2021-12-06 [1] CRAN (R 4.2.0)
#>    stringi       1.7.6      2021-11-29 [1] CRAN (R 4.2.0)
#>    stringr       1.4.0      2019-02-10 [1] CRAN (R 4.2.0)
#>    styler        1.7.0      2022-03-13 [1] CRAN (R 4.2.0)
#>    tensorflow    2.9.0      2022-05-21 [1] CRAN (R 4.2.0)
#>    testthat    * 3.1.4      2022-04-26 [1] CRAN (R 4.2.0)
#>    tfautograph   0.3.2      2021-09-17 [1] CRAN (R 4.2.0)
#>    tfruns        1.5.0      2021-02-26 [1] CRAN (R 4.2.0)
#>    tibble        3.1.7      2022-05-03 [1] CRAN (R 4.2.0)
#>    tictoc      * 1.0.1      2021-04-19 [1] CRAN (R 4.2.0)
#>    usethis       2.1.5      2021-12-09 [1] CRAN (R 4.2.0)
#>    utf8          1.2.2      2021-07-24 [1] CRAN (R 4.2.0)
#>    vctrs         0.4.1      2022-04-13 [1] CRAN (R 4.2.0)
#>    whisker       0.4        2019-08-28 [1] CRAN (R 4.2.0)
#>    withr         2.5.0      2022-03-03 [1] CRAN (R 4.2.0)
#>    xfun          0.31       2022-05-10 [1] CRAN (R 4.2.0)
#>    xml2          1.3.3      2021-11-30 [1] CRAN (R 4.2.0)
#>    yaml          2.3.5      2022-02-21 [1] CRAN (R 4.2.0)
#>    yesno         0.1.2      2020-07-10 [1] CRAN (R 4.2.0)
#> 
#>  [1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library
#> 
#>  P ── Loaded and on-disk path mismatch.
#> 
#> ─ Python configuration ───────────────────────────────────────────────────────
#>  python:         /Users/nick/Library/r-miniconda-arm64/envs/greta-env/bin/python
#>  libpython:      /Users/nick/Library/r-miniconda-arm64/envs/greta-env/lib/libpython3.8.dylib
#>  pythonhome:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env:/Users/nick/Library/r-miniconda-arm64/envs/greta-env
#>  version:        3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:05:16)  [Clang 12.0.1 ]
#>  numpy:          /Users/nick/Library/r-miniconda-arm64/envs/greta-env/lib/python3.8/site-packages/numpy
#>  numpy_version:  1.22.4
#>  tensorflow:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env/lib/python3.8/site-packages/tensorflow
#>  
#>  NOTE: Python version was forced by use_python function
#> 
#> ──────────────────────────────────────────────────────────────────────────────

So overall we get correct sampling now, but it still takes about 50-60 seconds, for something that normally runs in about 3 seconds.

We're going to try exploring some options for where/why this is a problem. Some of the ways we'll explore this:

  • Evaluating the log probability in TF1 vs TF2 on the same machine. If there are substantial differences between TF versions, then we know that there is some issue with how we are interfacing with TF2
# TF1 code to try out and profile
x <- normal(0, 1)
m <- model(x)
tic()
draws <- mcmc(m, n_samples = 500, warmup = 500)
glp_fun <- m$dag$generate_log_prob_function()
glp_fun(array(1, c(1, 1)))

# we can profile the code on this - TF2 code
build_array <- function(n){
  array(n, c(n,1))
}

bm <- bench::mark(
  x_0 = glp_fun(build_array(1)),
  x_3 = glp_fun(build_array(1e3)),
  x_6 = glp_fun(build_array(1e6)),
  x_7 = glp_fun(build_array(1e7)),
  check = FALSE
)

plot(bm)
  • Evaluate different TF code on TF 1 and TF2
library(greta)
library(tensorflow)
library(tfprobability)
tfp <- greta:::tfp


fun <- function(free_state) {
  norm <- tfp$distributions$Normal(0, 1)
  norm$log_prob(free_state)
}

free_state_value <- array(1, c(1, 1))

# TF 1
res <- fun(free_state_value)
sess <- tf$compat$v1$Session()
system.time(
  sess$run(res)
)

# TF 2
system.time(
  fun(free_state_value)
)

@njtierney
Copy link
Collaborator

OK so the mystery continues...

Looks like currently running the same code on a 2016 Macbook Pro, using TF2 and on the tf-poke branch...takes 5 times less time to run MCMC than on a 2021 Macbook pro with an M1 chip.

devtools::load_all(".")
#> ℹ Loading greta
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> 
#> ✔ Initialising python and checking dependencies ... done!
#> Loaded Tensorflow version 2.9.1
greta_sitrep()
#> ℹ checking if python available
#> ✔ python (v3.8) available
#> 
#> ℹ checking if TensorFlow available
#> ✔ TensorFlow (v2.9.1) available
#> 
#> ℹ checking if TensorFlow Probability available
#> ✔ TensorFlow Probability (v0.17.0) available
#> 
#> ℹ checking if greta conda environment available
#> ✔ greta conda environment available
#> 
#> ℹ greta is ready to use!
library(tictoc)
x <- normal(0,1)
m <- model(x)
tic()
draws <- mcmc(m, n_samples = 500, warmup = 500)
#> running 4 chains simultaneously on up to 8 cores
#>     warmup                                            0/500 | eta:  ?s              warmup ====                                      50/500 | eta: 23s              warmup ========                                 100/500 | eta: 12s              warmup ============                             150/500 | eta:  9s              warmup ================                         200/500 | eta:  6s              warmup ====================                     250/500 | eta:  5s              warmup ========================                 300/500 | eta:  4s              warmup ============================             350/500 | eta:  3s              warmup ================================         400/500 | eta:  2s              warmup ====================================     450/500 | eta:  1s              warmup ======================================== 500/500 | eta:  0s          
#>   sampling                                            0/500 | eta:  ?s            sampling ====                                      50/500 | eta:  2s            sampling ========                                 100/500 | eta:  1s            sampling ============                             150/500 | eta:  1s            sampling ================                         200/500 | eta:  1s            sampling ====================                     250/500 | eta:  1s            sampling ========================                 300/500 | eta:  1s            sampling ============================             350/500 | eta:  0s            sampling ================================         400/500 | eta:  0s            sampling ====================================     450/500 | eta:  0s            sampling ======================================== 500/500 | eta:  0s
toc()
#> 9.699 sec elapsed
library(coda)
#> 
#> Attaching package: 'coda'
#> 
#> The following object is masked from 'package:greta':
#> 
#>     mcmc
plot(draws)

Created on 2022-07-27 by the reprex package (v2.0.1)

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.0 (2022-04-22)
#>  os       macOS Big Sur/Monterey 10.16
#>  system   x86_64, darwin17.0
#>  ui       X11
#>  language (EN)
#>  collate  en_AU.UTF-8
#>  ctype    en_AU.UTF-8
#>  tz       Australia/Perth
#>  date     2022-07-27
#>  pandoc   2.17.1.1 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  ! package     * version    date (UTC) lib source
#>    abind         1.4-5      2016-07-21 [1] CRAN (R 4.2.0)
#>    backports     1.4.1      2021-12-13 [1] CRAN (R 4.2.0)
#>    base64enc     0.1-3      2015-07-28 [1] CRAN (R 4.2.0)
#>    brio          1.1.3      2021-11-30 [1] CRAN (R 4.2.0)
#>    cachem        1.0.6      2021-08-19 [1] CRAN (R 4.2.0)
#>    callr         3.7.1      2022-07-13 [1] CRAN (R 4.2.0)
#>    cli           3.3.0      2022-04-25 [1] CRAN (R 4.2.0)
#>    coda        * 0.19-4     2020-09-30 [1] CRAN (R 4.2.0)
#>    codetools     0.2-18     2020-11-04 [1] CRAN (R 4.2.0)
#>    crayon        1.5.1      2022-03-26 [1] CRAN (R 4.2.0)
#>    curl          4.3.2      2021-06-23 [1] CRAN (R 4.2.0)
#>    desc          1.4.1      2022-03-06 [1] CRAN (R 4.2.0)
#>    devtools      2.4.3      2021-11-30 [1] CRAN (R 4.2.0)
#>    digest        0.6.29     2021-12-01 [1] CRAN (R 4.2.0)
#>    ellipsis      0.3.2      2021-04-29 [1] CRAN (R 4.2.0)
#>    evaluate      0.15       2022-02-18 [1] CRAN (R 4.2.0)
#>    fastmap       1.1.0      2021-01-25 [1] CRAN (R 4.2.0)
#>    fs            1.5.2      2021-12-08 [1] CRAN (R 4.2.0)
#>    future        1.26.1     2022-05-27 [1] CRAN (R 4.2.0)
#>    globals       0.15.1     2022-06-24 [1] CRAN (R 4.2.0)
#>    glue          1.6.2      2022-02-24 [1] CRAN (R 4.2.0)
#>  P greta       * 0.4.2.9000 2022-07-27 [?] load_all()
#>    here          1.0.1      2020-12-13 [1] CRAN (R 4.2.0)
#>    highr         0.9        2021-04-16 [1] CRAN (R 4.2.0)
#>    hms           1.1.1      2021-09-26 [1] CRAN (R 4.2.0)
#>    htmltools     0.5.3      2022-07-18 [1] CRAN (R 4.2.0)
#>    httr          1.4.3      2022-05-04 [1] CRAN (R 4.2.0)
#>    jsonlite      1.8.0      2022-02-22 [1] CRAN (R 4.2.0)
#>    knitr         1.39       2022-04-26 [1] CRAN (R 4.2.0)
#>    lattice       0.20-45    2021-09-22 [1] CRAN (R 4.2.0)
#>    lifecycle     1.0.1      2021-09-24 [1] CRAN (R 4.2.0)
#>    listenv       0.8.0      2019-12-05 [1] CRAN (R 4.2.0)
#>    magrittr      2.0.3      2022-03-30 [1] CRAN (R 4.2.0)
#>    Matrix        1.4-1      2022-03-23 [1] CRAN (R 4.2.0)
#>    memoise       2.0.1      2021-11-26 [1] CRAN (R 4.2.0)
#>    mime          0.12       2021-09-28 [1] CRAN (R 4.2.0)
#>    parallelly    1.32.0     2022-06-07 [1] CRAN (R 4.2.0)
#>    pkgbuild      1.3.1      2021-12-20 [1] CRAN (R 4.2.0)
#>    pkgconfig     2.0.3      2019-09-22 [1] CRAN (R 4.2.0)
#>    pkgload       1.3.0      2022-06-27 [1] CRAN (R 4.2.0)
#>    png           0.1-7      2013-12-03 [1] CRAN (R 4.2.0)
#>    prettyunits   1.1.1      2020-01-24 [1] CRAN (R 4.2.0)
#>    processx      3.7.0      2022-07-07 [1] CRAN (R 4.2.0)
#>    progress      1.2.2      2019-05-16 [1] CRAN (R 4.2.0)
#>    ps            1.7.1      2022-06-18 [1] CRAN (R 4.2.0)
#>    purrr         0.3.4      2020-04-17 [1] CRAN (R 4.2.0)
#>    R6            2.5.1      2021-08-19 [1] CRAN (R 4.2.0)
#>    Rcpp          1.0.9      2022-07-08 [1] CRAN (R 4.2.0)
#>    remotes       2.4.2      2021-11-30 [1] CRAN (R 4.2.0)
#>    reprex        2.0.1      2021-08-05 [1] CRAN (R 4.2.0)
#>    reticulate    1.25       2022-05-11 [1] CRAN (R 4.2.0)
#>    rlang         1.0.4      2022-07-12 [1] CRAN (R 4.2.0)
#>    rmarkdown     2.14       2022-04-25 [1] CRAN (R 4.2.0)
#>    rprojroot     2.0.3      2022-04-02 [1] CRAN (R 4.2.0)
#>    rstudioapi    0.13       2020-11-12 [1] CRAN (R 4.2.0)
#>    sessioninfo   1.2.2      2021-12-06 [1] CRAN (R 4.2.0)
#>    stringi       1.7.8      2022-07-11 [1] CRAN (R 4.2.0)
#>    stringr       1.4.0      2019-02-10 [1] CRAN (R 4.2.0)
#>    tensorflow    2.9.0      2022-05-21 [1] CRAN (R 4.2.0)
#>    testthat    * 3.1.4      2022-04-26 [1] CRAN (R 4.2.0)
#>    tfautograph   0.3.2      2021-09-17 [1] CRAN (R 4.2.0)
#>    tfruns        1.5.0      2021-02-26 [1] CRAN (R 4.2.0)
#>    tictoc      * 1.0.1      2021-04-19 [1] CRAN (R 4.2.0)
#>    usethis       2.1.6      2022-05-25 [1] CRAN (R 4.2.0)
#>    vctrs         0.4.1      2022-04-13 [1] CRAN (R 4.2.0)
#>    whisker       0.4        2019-08-28 [1] CRAN (R 4.2.0)
#>    withr         2.5.0      2022-03-03 [1] CRAN (R 4.2.0)
#>    xfun          0.31       2022-05-10 [1] CRAN (R 4.2.0)
#>    xml2          1.3.3      2021-11-30 [1] CRAN (R 4.2.0)
#>    yaml          2.3.5      2022-02-21 [1] CRAN (R 4.2.0)
#>    yesno         0.1.2      2020-07-10 [1] CRAN (R 4.2.0)
#> 
#>  [1] /Library/Frameworks/R.framework/Versions/4.2/Resources/library
#> 
#>  P ── Loaded and on-disk path mismatch.
#> 
#> ─ Python configuration ───────────────────────────────────────────────────────
#>  python:         /Users/njtierney/Library/r-miniconda/envs/greta-env/bin/python
#>  libpython:      /Users/njtierney/Library/r-miniconda/envs/greta-env/lib/libpython3.8.dylib
#>  pythonhome:     /Users/njtierney/Library/r-miniconda/envs/greta-env:/Users/njtierney/Library/r-miniconda/envs/greta-env
#>  version:        3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:05:47)  [Clang 12.0.1 ]
#>  numpy:          /Users/njtierney/Library/r-miniconda/envs/greta-env/lib/python3.8/site-packages/numpy
#>  numpy_version:  1.23.1
#>  tensorflow:     /Users/njtierney/Library/r-miniconda/envs/greta-env/lib/python3.8/site-packages/tensorflow
#>  
#>  NOTE: Python version was forced by use_python function
#> 
#> ──────────────────────────────────────────────────────────────────────────────

weird.

Still about twice as slow as on TF1...but still...so strange.

@njtierney
Copy link
Collaborator

An on the same intel mac, running 500 samples for TF1

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply
library(tictoc)
x <- normal(0,1)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 
m <- model(x)
tic()
draws <- mcmc(m, n_samples = 500, warmup = 500)
#> running 4 chains simultaneously on up to 8 cores
#> 
#>     warmup                                            0/500 | eta:  ?s              warmup ====                                      50/500 | eta:  6s              warmup ========                                 100/500 | eta:  4s              warmup ============                             150/500 | eta:  3s              warmup ================                         200/500 | eta:  3s              warmup ====================                     250/500 | eta:  2s              warmup ========================                 300/500 | eta:  2s              warmup ============================             350/500 | eta:  1s              warmup ================================         400/500 | eta:  1s              warmup ====================================     450/500 | eta:  0s              warmup ======================================== 500/500 | eta:  0s          
#>   sampling                                            0/500 | eta:  ?s            sampling ====                                      50/500 | eta:  1s            sampling ========                                 100/500 | eta:  1s            sampling ============                             150/500 | eta:  1s            sampling ================                         200/500 | eta:  1s            sampling ====================                     250/500 | eta:  1s            sampling ========================                 300/500 | eta:  0s            sampling ============================             350/500 | eta:  0s            sampling ================================         400/500 | eta:  0s            sampling ====================================     450/500 | eta:  0s            sampling ======================================== 500/500 | eta:  0s
toc()
#> 6.293 sec elapsed

library(coda)
#> 
#> Attaching package: 'coda'
#> 
#> The following object is masked from 'package:greta':
#> 
#>     mcmc
plot(draws)

Created on 2022-08-05 by the reprex package (v2.0.1)

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.0 (2022-04-22)
#>  os       macOS Big Sur/Monterey 10.16
#>  system   x86_64, darwin17.0
#>  ui       X11
#>  language (EN)
#>  collate  en_AU.UTF-8
#>  ctype    en_AU.UTF-8
#>  tz       Australia/Perth
#>  date     2022-08-05
#>  pandoc   2.17.1.1 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package     * version    date (UTC) lib source
#>  abind         1.4-5      2016-07-21 [1] CRAN (R 4.2.0)
#>  base64enc     0.1-3      2015-07-28 [1] CRAN (R 4.2.0)
#>  callr         3.7.1      2022-07-13 [1] CRAN (R 4.2.0)
#>  cli           3.3.0      2022-04-25 [1] CRAN (R 4.2.0)
#>  coda        * 0.19-4     2020-09-30 [1] CRAN (R 4.2.0)
#>  codetools     0.2-18     2020-11-04 [1] CRAN (R 4.2.0)
#>  crayon        1.5.1      2022-03-26 [1] CRAN (R 4.2.0)
#>  curl          4.3.2      2021-06-23 [1] CRAN (R 4.2.0)
#>  digest        0.6.29     2021-12-01 [1] CRAN (R 4.2.0)
#>  ellipsis      0.3.2      2021-04-29 [1] CRAN (R 4.2.0)
#>  evaluate      0.15       2022-02-18 [1] CRAN (R 4.2.0)
#>  fansi         1.0.3      2022-03-24 [1] CRAN (R 4.2.0)
#>  fastmap       1.1.0      2021-01-25 [1] CRAN (R 4.2.0)
#>  fs            1.5.2      2021-12-08 [1] CRAN (R 4.2.0)
#>  future        1.26.1     2022-05-27 [1] CRAN (R 4.2.0)
#>  globals       0.15.1     2022-06-24 [1] CRAN (R 4.2.0)
#>  glue          1.6.2      2022-02-24 [1] CRAN (R 4.2.0)
#>  greta       * 0.4.2.9000 2022-08-05 [1] local
#>  here          1.0.1      2020-12-13 [1] CRAN (R 4.2.0)
#>  highr         0.9        2021-04-16 [1] CRAN (R 4.2.0)
#>  hms           1.1.1      2021-09-26 [1] CRAN (R 4.2.0)
#>  htmltools     0.5.3      2022-07-18 [1] CRAN (R 4.2.0)
#>  httr          1.4.3      2022-05-04 [1] CRAN (R 4.2.0)
#>  jsonlite      1.8.0      2022-02-22 [1] CRAN (R 4.2.0)
#>  knitr         1.39       2022-04-26 [1] CRAN (R 4.2.0)
#>  lattice       0.20-45    2021-09-22 [1] CRAN (R 4.2.0)
#>  lifecycle     1.0.1      2021-09-24 [1] CRAN (R 4.2.0)
#>  listenv       0.8.0      2019-12-05 [1] CRAN (R 4.2.0)
#>  magrittr      2.0.3      2022-03-30 [1] CRAN (R 4.2.0)
#>  Matrix        1.4-1      2022-03-23 [1] CRAN (R 4.2.0)
#>  mime          0.12       2021-09-28 [1] CRAN (R 4.2.0)
#>  parallelly    1.32.0     2022-06-07 [1] CRAN (R 4.2.0)
#>  pillar        1.8.0      2022-07-18 [1] CRAN (R 4.2.0)
#>  pkgconfig     2.0.3      2019-09-22 [1] CRAN (R 4.2.0)
#>  png           0.1-7      2013-12-03 [1] CRAN (R 4.2.0)
#>  prettyunits   1.1.1      2020-01-24 [1] CRAN (R 4.2.0)
#>  processx      3.7.0      2022-07-07 [1] CRAN (R 4.2.0)
#>  progress      1.2.2      2019-05-16 [1] CRAN (R 4.2.0)
#>  ps            1.7.1      2022-06-18 [1] CRAN (R 4.2.0)
#>  purrr         0.3.4      2020-04-17 [1] CRAN (R 4.2.0)
#>  R.cache       0.16.0     2022-07-21 [1] CRAN (R 4.2.0)
#>  R.methodsS3   1.8.2      2022-06-13 [1] CRAN (R 4.2.0)
#>  R.oo          1.25.0     2022-06-12 [1] CRAN (R 4.2.0)
#>  R.utils       2.12.0     2022-06-28 [1] CRAN (R 4.2.0)
#>  R6            2.5.1      2021-08-19 [1] CRAN (R 4.2.0)
#>  Rcpp          1.0.9      2022-07-08 [1] CRAN (R 4.2.0)
#>  reprex        2.0.1      2021-08-05 [1] CRAN (R 4.2.0)
#>  reticulate    1.25       2022-05-11 [1] CRAN (R 4.2.0)
#>  rlang         1.0.4      2022-07-12 [1] CRAN (R 4.2.0)
#>  rmarkdown     2.14       2022-04-25 [1] CRAN (R 4.2.0)
#>  rprojroot     2.0.3      2022-04-02 [1] CRAN (R 4.2.0)
#>  rstudioapi    0.13       2020-11-12 [1] CRAN (R 4.2.0)
#>  sessioninfo   1.2.2      2021-12-06 [1] CRAN (R 4.2.0)
#>  stringi       1.7.8      2022-07-11 [1] CRAN (R 4.2.0)
#>  stringr       1.4.0      2019-02-10 [1] CRAN (R 4.2.0)
#>  styler        1.7.0      2022-03-13 [1] CRAN (R 4.2.0)
#>  tensorflow    2.9.0      2022-05-21 [1] CRAN (R 4.2.0)
#>  tfruns        1.5.0      2021-02-26 [1] CRAN (R 4.2.0)
#>  tibble        3.1.7      2022-05-03 [1] CRAN (R 4.2.0)
#>  tictoc      * 1.0.1      2021-04-19 [1] CRAN (R 4.2.0)
#>  utf8          1.2.2      2021-07-24 [1] CRAN (R 4.2.0)
#>  vctrs         0.4.1      2022-04-13 [1] CRAN (R 4.2.0)
#>  whisker       0.4        2019-08-28 [1] CRAN (R 4.2.0)
#>  withr         2.5.0      2022-03-03 [1] CRAN (R 4.2.0)
#>  xfun          0.31       2022-05-10 [1] CRAN (R 4.2.0)
#>  xml2          1.3.3      2021-11-30 [1] CRAN (R 4.2.0)
#>  yaml          2.3.5      2022-02-21 [1] CRAN (R 4.2.0)
#> 
#>  [1] /Library/Frameworks/R.framework/Versions/4.2/Resources/library
#> 
#> ─ Python configuration ───────────────────────────────────────────────────────
#>  python:         /Users/njtierney/Library/r-miniconda/envs/greta-env/bin/python
#>  libpython:      /Users/njtierney/Library/r-miniconda/envs/greta-env/lib/libpython3.7m.dylib
#>  pythonhome:     /Users/njtierney/Library/r-miniconda/envs/greta-env:/Users/njtierney/Library/r-miniconda/envs/greta-env
#>  version:        3.7.12 | packaged by conda-forge | (default, Oct 26 2021, 05:59:23)  [Clang 11.1.0 ]
#>  numpy:          /Users/njtierney/Library/r-miniconda/envs/greta-env/lib/python3.7/site-packages/numpy
#>  numpy_version:  1.16.4
#>  tensorflow:     /Users/njtierney/Library/r-miniconda/envs/greta-env/lib/python3.7/site-packages/tensorflow
#>  
#>  NOTE: Python version was forced by use_python function
#> 
#> ──────────────────────────────────────────────────────────────────────────────

@njtierney
Copy link
Collaborator

So, after some exploration of the TF profiler on greta, we need to do the following

  • Run the profiler on TF1 on the same intel mac
  • Try installing M1 specific instructions for TF2 -
  • Work out why there is a retracing warning
  • Explore tensorflow tuning instead of coming back to R repeatedly during the tuning phase in warmup

@njtierney
Copy link
Collaborator

njtierney commented Aug 8, 2022

OK so there's something very interesting.

Turning off the GPU makes the M1 mac...about 19 times faster? 53. seconds compared to 2.7 seconds

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply
library(tictoc)
greta_sitrep()
#> ℹ checking if python available
#> ✔ python (v3.8) available
#> 
#> ℹ checking if TensorFlow available
#> ✔ TensorFlow (v2.9.2) available
#> 
#> ℹ checking if TensorFlow Probability available
#> ✔ TensorFlow Probability (v0.17.0) available
#> 
#> ℹ checking if greta conda environment available
#> ✔ greta conda environment available
#> 
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 
#> ℹ greta is ready to use!

tensorflow::tf$config$get_visible_devices("CPU") |> 
  tensorflow::tf$config$set_visible_devices()
#> Loaded Tensorflow version 2.9.2

x <- normal(0,1)
m <- model(x)
tic()
draws <- mcmc(m, n_samples = 500, warmup = 500)
#> running 4 chains simultaneously on up to 8 cores
#> 
#>     warmup                                            0/500 | eta:  ?s              warmup ====                                      50/500 | eta:  6s              warmup ========                                 100/500 | eta:  3s              warmup ============                             150/500 | eta:  2s              warmup ================                         200/500 | eta:  2s              warmup ====================                     250/500 | eta:  1s              warmup ========================                 300/500 | eta:  1s              warmup ============================             350/500 | eta:  1s              warmup ================================         400/500 | eta:  0s              warmup ====================================     450/500 | eta:  0s              warmup ======================================== 500/500 | eta:  0s          
#>   sampling                                            0/500 | eta:  ?s            sampling ====                                      50/500 | eta:  0s            sampling ========                                 100/500 | eta:  0s            sampling ============                             150/500 | eta:  0s            sampling ================                         200/500 | eta:  0s            sampling ====================                     250/500 | eta:  0s            sampling ========================                 300/500 | eta:  0s            sampling ============================             350/500 | eta:  0s            sampling ================================         400/500 | eta:  0s            sampling ====================================     450/500 | eta:  0s            sampling ======================================== 500/500 | eta:  0s
toc()
#> 2.732 sec elapsed
plot(draws)

Created on 2022-08-08 by the reprex package (v2.0.1)

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.0 (2022-04-22)
#>  os       macOS Monterey 12.3.1
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_AU.UTF-8
#>  ctype    en_AU.UTF-8
#>  tz       Australia/Perth
#>  date     2022-08-08
#>  pandoc   2.18 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package     * version    date (UTC) lib source
#>  abind         1.4-5      2016-07-21 [1] CRAN (R 4.2.0)
#>  backports     1.4.1      2021-12-13 [1] CRAN (R 4.2.0)
#>  base64enc     0.1-3      2015-07-28 [1] CRAN (R 4.2.0)
#>  callr         3.7.0      2021-04-20 [1] CRAN (R 4.2.0)
#>  cli           3.3.0.9000 2022-06-15 [1] Github (r-lib/cli@31a5db5)
#>  coda          0.19-4     2020-09-30 [1] CRAN (R 4.2.0)
#>  codetools     0.2-18     2020-11-04 [1] CRAN (R 4.2.0)
#>  crayon        1.5.1      2022-03-26 [1] CRAN (R 4.2.0)
#>  curl          4.3.2      2021-06-23 [1] CRAN (R 4.2.0)
#>  digest        0.6.29     2021-12-01 [1] CRAN (R 4.2.0)
#>  ellipsis      0.3.2      2021-04-29 [1] CRAN (R 4.2.0)
#>  evaluate      0.15       2022-02-18 [1] CRAN (R 4.2.0)
#>  fansi         1.0.3      2022-03-24 [1] CRAN (R 4.2.0)
#>  fastmap       1.1.0      2021-01-25 [1] CRAN (R 4.2.0)
#>  fs            1.5.2      2021-12-08 [1] CRAN (R 4.2.0)
#>  future        1.25.0     2022-04-24 [1] CRAN (R 4.2.0)
#>  globals       0.15.0     2022-05-09 [1] CRAN (R 4.2.0)
#>  glue          1.6.2      2022-02-24 [1] CRAN (R 4.2.0)
#>  greta       * 0.4.2.9000 2022-08-08 [1] local
#>  here          1.0.1      2020-12-13 [1] CRAN (R 4.2.0)
#>  highr         0.9        2021-04-16 [1] CRAN (R 4.2.0)
#>  hms           1.1.1      2021-09-26 [1] CRAN (R 4.2.0)
#>  htmltools     0.5.2      2021-08-25 [1] CRAN (R 4.2.0)
#>  httr          1.4.3      2022-05-04 [1] CRAN (R 4.2.0)
#>  jsonlite      1.8.0      2022-02-22 [1] CRAN (R 4.2.0)
#>  knitr         1.39       2022-04-26 [1] CRAN (R 4.2.0)
#>  lattice       0.20-45    2021-09-22 [1] CRAN (R 4.2.0)
#>  lifecycle     1.0.1      2021-09-24 [1] CRAN (R 4.2.0)
#>  listenv       0.8.0      2019-12-05 [1] CRAN (R 4.2.0)
#>  magrittr      2.0.3      2022-03-30 [1] CRAN (R 4.2.0)
#>  Matrix        1.4-1      2022-03-23 [1] CRAN (R 4.2.0)
#>  mime          0.12       2021-09-28 [1] CRAN (R 4.2.0)
#>  parallelly    1.31.1     2022-04-22 [1] CRAN (R 4.2.0)
#>  pillar        1.7.0      2022-02-01 [1] CRAN (R 4.2.0)
#>  pkgconfig     2.0.3      2019-09-22 [1] CRAN (R 4.2.0)
#>  png           0.1-7      2013-12-03 [1] CRAN (R 4.2.0)
#>  prettyunits   1.1.1      2020-01-24 [1] CRAN (R 4.2.0)
#>  processx      3.6.1      2022-06-17 [1] CRAN (R 4.2.0)
#>  progress      1.2.2      2019-05-16 [1] CRAN (R 4.2.0)
#>  ps            1.7.1      2022-06-18 [1] CRAN (R 4.2.0)
#>  purrr         0.3.4      2020-04-17 [1] CRAN (R 4.2.0)
#>  R.cache       0.15.0     2021-04-30 [1] CRAN (R 4.2.0)
#>  R.methodsS3   1.8.1      2020-08-26 [1] CRAN (R 4.2.0)
#>  R.oo          1.24.0     2020-08-26 [1] CRAN (R 4.2.0)
#>  R.utils       2.11.0     2021-09-26 [1] CRAN (R 4.2.0)
#>  R6            2.5.1      2021-08-19 [1] CRAN (R 4.2.0)
#>  Rcpp          1.0.9      2022-07-08 [1] CRAN (R 4.2.0)
#>  reprex        2.0.1      2021-08-05 [1] CRAN (R 4.2.0)
#>  reticulate    1.24-9000  2022-05-11 [1] Github (rstudio/reticulate@451fbff)
#>  rlang         1.0.4      2022-07-12 [1] CRAN (R 4.2.0)
#>  rmarkdown     2.14       2022-04-25 [1] CRAN (R 4.2.0)
#>  rprojroot     2.0.3      2022-04-02 [1] CRAN (R 4.2.0)
#>  rstudioapi    0.13       2020-11-12 [1] CRAN (R 4.2.0)
#>  sessioninfo   1.2.2      2021-12-06 [1] CRAN (R 4.2.0)
#>  stringi       1.7.6      2021-11-29 [1] CRAN (R 4.2.0)
#>  stringr       1.4.0      2019-02-10 [1] CRAN (R 4.2.0)
#>  styler        1.7.0      2022-03-13 [1] CRAN (R 4.2.0)
#>  tensorflow    2.9.0      2022-05-21 [1] CRAN (R 4.2.0)
#>  tfautograph   0.3.2      2021-09-17 [1] CRAN (R 4.2.0)
#>  tfruns        1.5.0      2021-02-26 [1] CRAN (R 4.2.0)
#>  tibble        3.1.7      2022-05-03 [1] CRAN (R 4.2.0)
#>  tictoc      * 1.0.1      2021-04-19 [1] CRAN (R 4.2.0)
#>  utf8          1.2.2      2021-07-24 [1] CRAN (R 4.2.0)
#>  vctrs         0.4.1      2022-04-13 [1] CRAN (R 4.2.0)
#>  whisker       0.4        2019-08-28 [1] CRAN (R 4.2.0)
#>  withr         2.5.0      2022-03-03 [1] CRAN (R 4.2.0)
#>  xfun          0.31       2022-05-10 [1] CRAN (R 4.2.0)
#>  xml2          1.3.3      2021-11-30 [1] CRAN (R 4.2.0)
#>  yaml          2.3.5      2022-02-21 [1] CRAN (R 4.2.0)
#> 
#>  [1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library
#> 
#> ─ Python configuration ───────────────────────────────────────────────────────
#>  python:         /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python
#>  libpython:      /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.8.dylib
#>  pythonhome:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2
#>  version:        3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:05:16)  [Clang 12.0.1 ]
#>  numpy:          /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/numpy
#>  numpy_version:  1.22.4
#>  tensorflow:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/tensorflow
#>  
#>  NOTE: Python version was forced by use_python function
#> 
#> ──────────────────────────────────────────────────────────────────────────────

Thanks to rstudio/tensorflow#541

and to @t-kalinowski for that!

@goldingn
Copy link
Member Author

goldingn commented Aug 8, 2022

Aha! So the GPU was being turned on by default then? If so, the fix should be to turn off the GPU by default (but make it easy to run with the GPU if wanted)?

The fact that the GPU is slower (for this model) is expected, as the overhead in passing things to the GPU is higher, and there's next to no benefit of parallelising a vector length 4 on a GPU. In a big model with large matrix multiplies, the GPU will likely be faster. Looking forward to finding out how much faster!

@goldingn
Copy link
Member Author

goldingn commented Aug 8, 2022

We should think about the best interface for turning the GPU on/off.

Assuming we can use the same interface for all platforms, then just a GPU = FALSE (the default) or GPU = TRUE in the call to MCMC would be simple.

If they set GPU = TRUE but don't have a GPU listed, we'd want to let them know at that point.

@njtierney
Copy link
Collaborator

OK that sounds good to me!

I imagine some kind of on.exit pattern is what we'll need to use here, as this code:

tensorflow::tf$config$get_visible_devices("CPU") |> 
  tensorflow::tf$config$set_visible_devices()

Changes state permanently I think.

Another thought I had was, are there other operations, like calculate or other TF operations where we'd want to give users the option to use GPU or CPU?

@goldingn
Copy link
Member Author

goldingn commented Aug 9, 2022

Yep, that makes sense.

calculate() and opt() are probably the only two other places to do this.

It's there a good way of checking the available devices?

@cboettig
Copy link

cboettig commented Aug 9, 2022

Just lurking here, but what I've seen in the python modules I've worked with is there's usually an option to set GPU use (and resources generally) as part of a configuration that can be set when initializing intensive operations. Just about every library respects the env vars for CUDA_VISIBLE_DEVICES, but it's usually nicer to configure resources with more granularity, since as you've noted it can be pretty variable which is fastest in which circumstances. (not sure if tensorflow libs here can make use of multiple GPUs, but typically interfaces let you set number of GPUs as well). For example:

  • darts gpu config, a user-friendly forecasting module
  • ray, a popular but more complex system for distributed ML, particularly RL. Some methods can set fractional GPU usage as well as n-cpu use as well, probably less relevant though.

@goldingn
Copy link
Member Author

goldingn commented Aug 9, 2022

Thanks @cboettig! Yeah, that's a whole other level of interface detail we should think about.

It does suggest that a different interface might be more future proof. E.g.:

use_gpu()
  CPU execution disabled and GPU execution enabled for all greta operations
use_cpu()
 GPU execution disabled and CPU execution enabled for all greta operations

Which could then be changed in the future to target specific operations, like:

beta <- normal(0, 1, dim = 10)
eta <- X %*% beta

use_gpu(eta)
 GPU execution enabled for the operation greta array 'eta' (matrix multiply)
devices_in_use()
GPU execution enabled for the operation greta array 'eta' (matrix multiply)
CPU execution enabled for all other greta operations

what do you think to that?

@cboettig
Copy link

cboettig commented Aug 9, 2022

Neat! I think you have a stronger intuition than I do here, and I'm not entirely sure when evaluation takes place in each case. I was thinking of mimicking the {darts} interface via an argument to either the model() or mcmc() call. e.g. in {darts} it this looks like:

my_model = RNNModel(
    model="RNN",
    ...
    force_reset=True,
    pl_trainer_kwargs={
      "accelerator": "gpu",
      "gpus": [0]
    },
)

my_model.train(training_data)

Maybe less elegant and not sure how it generalized to other places where you might want a GPU, but passing the GPU setting as a function argument seems easier to me. I'm always a bit nervous about the 'side-effect' mechanism, it's harder for me as a user to reason about, since there's so many places an R package can 'stash' these additional states about execution (especially if a user is attempting to wrap greta functions in some derivative package)?

@goldingn
Copy link
Member Author

goldingn commented Aug 9, 2022

Yeah, good point about the side-effects. greta already has a few spooky things that can happen because of the unorthodox interface to constructing a model.

The fact that we can create the greta arrays before having to worry about the TF implementation of them means we have a lot of flexibility in the interface. And the fact that the user can refer to specific operation greta arrays when specifying devices should make it fairly intuitive. So a simpler approach that also enables using specific devices for specific operations might be something like:

draws <- mcmc(...,
              compute_options = compute_setup(gpu_operations = list(eta),
                                              gpus_to_use = c(0, 1))
running 4 chains simultaneously on up to 8 CPU cores and 2 GPU devices
GPU execution enabled for the operation greta array 'eta' (matrix multiply)
CPU execution enabled for all other greta operations
GPU devices in use: 0 and 1

Though we'd want to think carefully about how to make it so beginner users (or people with small models who don't need GPU acceleration) don't have to confront this stuff. So we could just do a default of:

draws <- mcmc(...,
              compute_options = cpu_only())
running 4 chains simultaneously on up to 8 CPU cores

Would that feel more natural @cboettig?

@goldingn
Copy link
Member Author

goldingn commented Aug 9, 2022

We could also change the interface so it's more like:

compute_setup(gpu_operations = list(with_gpu(0, eta),
                                    with_gpu(1, some_other_op)))

which is maybe more like how people will want to configure these things? We should do an explore of the TF functionality and idioms a bit before deciding on that level of detail, but worth planning the general structure first, so we can at least roll-out a CPU-only (or GPU-only) version of the TF2 greta soon without getting too bogged down in this extra functionality.

@cboettig
Copy link

cboettig commented Aug 9, 2022

Nice, yeah I like that. Another simple advantage of this construction is it is a bit easier to document and for users to discover by browsing the docs for the function, though maybe that's just a crutch for projects that lack greta's attention to good user documentation.

I do think having something like compute_options argument is probably the most intuitive and future-proof construction to build from though.

Definitely agree about sensible defaults. Defaulting all evaluation to CPU-only is I think a pretty common strategy. (just like most cpu parallelization -- simple cases are better with such defaults).

I'm not sure on the precise syntax. My intuition is that compute_setup() object wouldn't reference specific ops, but rather, the user attaches the desired config to the operation, but I realize that only works for syntax that involves a function call and not the generic matrix math example. But yeah, I think there's space to define the precise syntax down the road. Like you say, it's nice to have a simple syntax option since for most users today it will just be gpu on or off, but large models elsewhere might benefit from configuration for multi-gpus, TPUs, etc.

@njtierney
Copy link
Collaborator

I've moved some of the discussion of these features into #545 - just to try and help isolate some of the tasks and discussion around this. 😄

@goldingn
Copy link
Member Author

goldingn commented Aug 9, 2022

Thanks Nick!

Just to wrap this discussion up, I wasn't quite sure what you meant by this @cboettig: "My intuition is that compute_setup() object wouldn't reference specific ops, but rather, the user attaches the desired config to the operation, but I realize that only works for syntax that involves a function call and not the generic matrix math example."

Do you mean that you were expecting an interface a bit like this?:

eta <- matrix_multiply(X, beta, compute_on = gpu(0))

That's a more TF-like interface, so will definitely be ore familiar to people coming from that background.

It's a bit different to the current greta design of defining the mathematical model in greta arrays, and then only getting to specific aspects of the computation (like floating point precision, inference algorithm) later. I quite like that distinction, because it focuses the user on the model rather than the implementation, and because beginners (or just users without very heavy models) don't need to worry about what the additional argument does. But on the other hand, many users will be more used to specifying all the details when defining the operation, so it might be simpler.

Definitely an option to consider for a more detailed interface later, since this is I think completely compatible with the planned first steps.

@cboettig
Copy link

cboettig commented Aug 9, 2022

Right, I generally like the feel of greta interface and wasn't advocating for any change to the existing feel. My comment was merely that mcmc(..., compute_setup = ) feels like the natural way to tell mcmc to run on GPU, but that I didn't have any equivalent suggestion for eta <- X %*% beta.

Like you say, this could be turned into a function call but that feels awkward. Is the above matrix operation a "promise" or an execution anyway? e.g. I could almost imagine the above eta being lazy and remaining abstract, such that I could later decide to evaluate it either on GPU or CPU (model(eta, compute=setup = ....) or compute(eta, compute_setup = ...))

@goldingn
Copy link
Member Author

That makes sense, thanks!

It's a promise, not an execution. The greta array objects are all constructed and linked together (to create an abstract syntax tree made up of R6 objects) when writing the greta code, but they don't do any computation. Then when mcmc(), opt(), or calculate() are called, the model written in greta arrays is translated into the TF code appropriate for that task, and then executed. For MCMC, a log_prob TF function is created that takes in the proposed parameters and compute the log unnormalised posterior density at those parameter values. Thats called repeatedly by MCMC code (also TF code).

So when telling mcmc to run on GPU, it's actually just being lazy and saying that all the greta ops should run on GPU when evaluating the log_prob.

But because it's a promise, we could also link each operation to a different compute setting at the point of running the MCMC, (like in this interface), so that in evaluating the log_prob, there is a mix of GPU and CPU ops, rather than all one or the other. That will probably be optimal in many cases.

This also means the user can develop the model, and then trial different compute options, like in Nick's example here

@njtierney
Copy link
Collaborator

For those following along, TF2 greta is approaching an alpha release, with just a couple of small issues to be ironed out, which are related to how we are representing cholesky, and trying to suppress some python warnings regarding retracing.

If you would like to install this version of greta, you can do the following:

remotes::install_github("greta-dev/greta#534")

As we get closer to a release, I'm cleaning up issues, and I think this issue has served its purpose of helping implement most of the TF2 changes - the remaining issues can be seen in milestone 0.5: https://github.com/greta-dev/greta/milestone/2, and other TF2 related issues are marked with the TF2 label: https://github.com/greta-dev/greta/issues?q=is%3Aopen+is%3Aissue+label%3ATF2

Thank you for the discussion, everyone!

This issue was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging a pull request may close this issue.

4 participants