Skip to content

Commit

Permalink
tighten up TF & TFP version dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
goldingn committed Oct 24, 2018
1 parent f85b93b commit be76c70
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 56 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Expand Up @@ -24,8 +24,8 @@ License: Apache License 2.0
URL: https://github.com/greta-dev/greta
BugReports: https://github.com/greta-dev/greta/issues
SystemRequirements: Python (>= 2.7.0) with header files and shared library;
TensorFlow (>= 1.8; https://www.tensorflow.org/);
Tensorflow Probability (>=0.0.1; https://github.com/tensorflow/probability)
TensorFlow (>= 1.10; https://www.tensorflow.org/);
Tensorflow Probability (>=0.3.0; https://www.tensorflow.org/probability/)
Encoding: UTF-8
LazyData: true
Depends:
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Expand Up @@ -217,6 +217,7 @@ importFrom(grDevices,col2rgb)
importFrom(grDevices,colorRampPalette)
importFrom(progress,progress_bar)
importFrom(reticulate,conda_binary)
importFrom(reticulate,py_available)
importFrom(reticulate,py_module_available)
importFrom(reticulate,py_set_attr)
importFrom(stats,na.omit)
Expand Down
1 change: 1 addition & 0 deletions R/dag_class.R
@@ -1,5 +1,6 @@
#' @importFrom reticulate py_set_attr
#' @importFrom tensorflow dict
#' @importFrom R6 R6Class

# create dag class
dag_class <- R6Class(
Expand Down
3 changes: 0 additions & 3 deletions R/package.R
Expand Up @@ -19,9 +19,6 @@
#'
#' @docType package
#' @importFrom tensorflow tf
#' @importFrom reticulate py_module_available
#' @importFrom R6 R6Class
#' @importFrom grDevices colorRampPalette
#' @examples
#' \dontrun{
#' # a simple Bayesian regression model for the iris data
Expand Down
143 changes: 92 additions & 51 deletions R/utils.R
Expand Up @@ -53,6 +53,22 @@ have_virtualenv <- function () {

}

#' @importFrom reticulate py_available
have_python <- function () {
tryCatch(reticulate::py_available(initialize = TRUE),
error = function(e) FALSE)
}

#' @importFrom reticulate py_module_available
have_tfp <- function () {
reticulate::py_module_available("tensorflow_probability")
}

#' @importFrom reticulate py_module_available
have_tf <- function () {
reticulate::py_module_available("tensorflow")
}

# check tensorflow and tensorflow-probability are installed and have valid
# versions. error, warn, or message if not and (if not an error) return an
# invisible logical saying whether it is valid
Expand All @@ -65,93 +81,117 @@ check_tf_version <- function(alert = c("none",
"startup")) {

alert <- match.arg(alert)
text <- NULL

py_available <- TRUE
tf_available <- TRUE
tfp_available <- TRUE

# check TF installation
if (!reticulate::py_module_available("tensorflow")) {
# check python installation
if (!have_python()) {

text <- "TensorFlow isn't installed"
tf_available <- FALSE
text <- paste0("\n\ngreta requires Python and several Python packages ",
"to be installed, but no Python installation was detected.\n",
"You can install Python directly from ",
"https://www.python.org/downloads/ ",
"or with the Anaconda distribution from ",
"https://www.anaconda.com/download/")

} else {
py_available <- tf_available <- tfp_available <- FALSE

tf_version <- tf$`__version__`
tf_version_valid <- utils::compareVersion("1.8", tf_version) != 1
}

if (!tf_version_valid) {
text <- paste0("you have TensorFlow version ", tf_version)
if (py_available) {

text <- NULL

# check TF installation
if (!have_tf()) {

text <- "TensorFlow isn't installed"
tf_available <- FALSE
}

}
} else {

# check TFP installation
if (!reticulate::py_module_available("tensorflow_probability")) {
tf_version <- tf$`__version__`
tf_version_valid <- utils::compareVersion("1.10.0", tf_version) != 1

text <- paste0(text,
ifelse(is.null(text), "", " and "),
"TensorFlow Probability isn't installed")
tfp_available <- FALSE
if (!tf_version_valid) {
text <- paste0("you have TensorFlow version ", tf_version)
tf_available <- FALSE
}

} else {
}

pkg <- reticulate::import("pkg_resources")
tfp_version <- pkg$get_distribution("tensorflow_probability")$version
tfp_version_valid <- utils::compareVersion("0.3.0", tfp_version) != 1
# check TFP installation
if (!have_tfp()) {

if (!tfp_version_valid) {
text <- paste0("you have TensorFlow Probability version ", tfp_version)
text <- paste0(text,
ifelse(is.null(text), "", " and "),
"TensorFlow Probability isn't installed")
tfp_available <- FALSE

} else {

pkg <- reticulate::import("pkg_resources")
tfp_version <- pkg$get_distribution("tensorflow_probability")$version
tfp_version_valid <- utils::compareVersion("0.3.0", tfp_version) != 1

if (!tfp_version_valid) {
text <- paste0("you have TensorFlow Probability version ", tfp_version)
tfp_available <- FALSE
}

}

}
# if there was a problem, append the solution
if (!tf_available | !tfp_available) {

if (!is.null(text)) {
# conda-specific installation instructions, to handle conda not having TFP
if (have_conda() & !have_virtualenv()) {

# conda-specific installation instructions, to handle conda not having TFP
if (have_conda() & !have_virtualenv()) {
tf_install <- tfp_install <- ""

tf_install <- tfp_install <- ""
if (!tf_available | !tfp_available) {
tf_install <- ' install_tensorflow(method = "conda")\n'
}

if (!tf_available | !tfp_available) {
tf_install <- ' install_tensorflow(method = "conda")\n'
}
if (!tfp_available) {
tfp_install <- paste0(' reticulate::conda_install("r-tensorflow", ',
'"tensorflow-probability", pip = TRUE)\n')
}

install <- paste(tf_install, tfp_install, collapse = "\n")

if (!tfp_available) {
tfp_install <- paste0(' reticulate::conda_install("r-tensorflow", ',
'"tensorflow-probability", pip = TRUE)\n')
} else {
# non-conda installation instructions
install <- sprintf("install_tensorflow(%s) ",
ifelse(tfp_available,
"",
"extra_packages = \"tensorflow-probability\""))
}

install <- paste(tf_install, tfp_install, collapse = "\n")
# combine the problem and solution messages
text <- paste0("\n\ngreta requires TensorFlow (>=1.10.0) ",
"and Tensorflow Probability (>=0.3.0), ",
"but ", text, ". Use:\n\n",
install,
"\nto install the latest version.",
"\n\n")

} else {
# non-conda installation instructions
install <- sprintf("install_tensorflow(%s) ",
ifelse(tfp_available,
"",
"extra_packages = \"tensorflow-probability\""))
}

# combine the problem and solution messages
text <- paste0("\n\ngreta requires TensorFlow (>=1.8) or higher ",
"and Tensorflow Probability (>=0.3.0), ",
"but ", text, ". Use:\n\n",
install,
"\nto install the latest version.",
"\n\n")
}

if (!is.null(text)) {
switch(alert,
error = stop(text, call. = FALSE),
warn = warning(text, call. = FALSE),
message = message(text),
startup = packageStartupMessage(text),
none = NULL)

}

invisible(tf_available & tfp_available)
invisible(py_available & tf_available & tfp_available)

}

Expand Down Expand Up @@ -816,6 +856,7 @@ dummy_array_module <- module(flatten_rowwise,
# returning a colour linearly interpolated between black, the colour and white,
# so that values close to 0.5 match the base colour, values close to 0 are
# nearer black, and values close to 1 are nearer white
#' @importFrom grDevices colorRampPalette
palettize <- function(base_colour) {
pal <- colorRampPalette(c("#000000", base_colour, "#ffffff"))
function(val) {
Expand Down

0 comments on commit be76c70

Please sign in to comment.