Skip to content

Commit

Permalink
auto_partial improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
mjskay committed Apr 11, 2024
1 parent a8eabf9 commit fb7aa94
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 41 deletions.
4 changes: 2 additions & 2 deletions R/RcppExports.R
Expand Up @@ -21,8 +21,8 @@ promise_env_ <- function(promise) {
.Call(`_ggdist_promise_env_`, promise)
}

is_waived_ <- function(x) {
.Call(`_ggdist_is_waived_`, x)
is_waiver_ <- function(x) {
.Call(`_ggdist_is_waiver_`, x)
}

dots_to_list_ <- function(dots) {
Expand Down
73 changes: 59 additions & 14 deletions R/auto_partial.R
Expand Up @@ -175,15 +175,16 @@ waiver = ggplot2::waiver
else x
}

is_waiver = function(x) {
if (typeof(x) == "promise") {
expr = promise_expr(x)
identical(expr, quote(waiver())) ||
(is.symbol(expr) && is_waiver(get0(as.character(expr), promise_env(x))))
} else {
inherits(x, "waiver")
}
}
is_waiver = is_waiver_
# function(x) {
# if (typeof(x) == "promise") {
# expr = promise_expr(x)
# identical(expr, quote(waiver())) ||
# (is.symbol(expr) && is_waiver(get0(as.character(expr), promise_env(x))))
# } else {
# inherits(x, "waiver")
# }
# }


# promise lists -----------------------------------------------------------------
Expand Down Expand Up @@ -236,6 +237,24 @@ format_promise = function(x) {

# promises ----------------------------------------------------------------

# NOTE: individual promises are always stored in promise lists because
# they are fragile and prone to evaluating themselves if assigned to variables.
new_promise = function(expr, env = parent.frame()) {
do.call(promise_list, list(expr), envir = env)
}

promise = function(x, env = parent.frame()) {
new_promise(substitute(x), env)
}

arg_promise = function(x, env = parent.frame()) {
expr = substitute(x)
stopifnot(is.symbol(expr))
# seem to need to do wrap the call to find_promise() in a list rather than
# returning directly to avoid evaluating the promise...
new_promise_list(list(find_promise(expr, env)))
}

#' Find a promise by name.
#' @param name the name of a promise as a string or symbol
#' @param env the environment to search
Expand All @@ -250,7 +269,7 @@ find_promise = find_promise_

promise_expr = function(x) {
if (typeof(x) == "promise") {
do.call(substitute, list(unwrap_promise_(x)))
promise_expr_(x)
} else {
x
}
Expand Down Expand Up @@ -346,13 +365,24 @@ new_auto_partial = function(
`>args` = update_args(`>args`, new_args)

if (all(`>required_arg_names` %in% names(`>args`))) {
# turn promises that would be evaluated in the calling frame anyway into
# expressions, since this works the same but will look better in
# match.call() in the most typical user-facing cases
caller_env = parent.frame()
# TODO: do this, but also don't unpromise expressions that contain the
# symbol in `>name`
# `>args` = unpromise_in_env(`>args`, caller_env)

# a simpler version of the below would be something like:
# > do.call(`>f`, args, envir = parent.frame())
# however this would lead to the function call appearing as the
# full function body in things like match.call(). So instead
# we construct a calling environment in which the function is
# defined under a readable name.
call_env = new.env(parent = parent.frame())
# defined under a readable name. This does mean that functions
# that inspect the calling context might have issues
# TODO: can we manually construct a call that sets the calling context
# to caller_env but also has nice function calls in match.call()?
call_env = new.env(parent = caller_env)
call_env[[`>name`]] = `>f`
do.call(`>name`, `>args`, envir = call_env)
} else {
Expand Down Expand Up @@ -489,10 +519,12 @@ arg_promise_list = function(which = sys.parent()) {
#' calling function
#' @noRd
named_arg_promise_list = function(which = sys.parent()) {
env = sys.frame(which)
dots_env = do.call(parent.frame, list(), envir = env)

f = sys.function(which)
call = match.call(f, sys.call(which), envir = sys.frame(which - 1L))
call = match.call(f, sys.call(which), envir = dots_env)
arg_names = intersect(names(call[-1]), names(formals(f)))
env = sys.frame(which)
promises = lapply(arg_names, find_promise, env)
names(promises) = arg_names
new_promise_list(promises)
Expand Down Expand Up @@ -524,6 +556,19 @@ remove_waivers = function(args) {
args[!waived]
}

#' turn arguments that are in the given environment into expressions instead
#' of promises
#' @param args a named list of promises
#' @param env an environment
#' @returns a named list of promises and unevaluated expressions, where promises
#' with the environment `env` have been turned into their corresponding
#' expressions
unpromise_in_env = function(args, env) {
lapply(args, function(arg) {
if (identical(promise_env(arg), env)) promise_expr(arg) else arg
})
}

cat0 = function(...) {
cat(..., sep = "")
}
10 changes: 5 additions & 5 deletions src/RcppExports.cpp
Expand Up @@ -67,14 +67,14 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// is_waived_
bool is_waived_(RObject x);
RcppExport SEXP _ggdist_is_waived_(SEXP xSEXP) {
// is_waiver_
bool is_waiver_(RObject x);
RcppExport SEXP _ggdist_is_waiver_(SEXP xSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< RObject >::type x(xSEXP);
rcpp_result_gen = Rcpp::wrap(is_waived_(x));
rcpp_result_gen = Rcpp::wrap(is_waiver_(x));
return rcpp_result_gen;
END_RCPP
}
Expand All @@ -96,7 +96,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_ggdist_find_promise_", (DL_FUNC) &_ggdist_find_promise_, 2},
{"_ggdist_promise_expr_", (DL_FUNC) &_ggdist_promise_expr_, 1},
{"_ggdist_promise_env_", (DL_FUNC) &_ggdist_promise_env_, 1},
{"_ggdist_is_waived_", (DL_FUNC) &_ggdist_is_waived_, 1},
{"_ggdist_is_waiver_", (DL_FUNC) &_ggdist_is_waiver_, 1},
{"_ggdist_dots_to_list_", (DL_FUNC) &_ggdist_dots_to_list_, 1},
{NULL, NULL, 0}
};
Expand Down
40 changes: 20 additions & 20 deletions src/promise.cpp
Expand Up @@ -11,10 +11,10 @@ using namespace Rcpp;
*/
// [[Rcpp::export]]
SEXP unwrap_promise_(SEXP x) {
SEXP expr = x;
RObject expr = x;
while (TYPEOF(expr) == PROMSXP) {
x = expr;
expr = PRCODE(x);
expr = PREXPR(x);
}
return x;
}
Expand All @@ -28,17 +28,20 @@ SEXP unwrap_promise_(SEXP x) {
*/
// [[Rcpp::export]]
SEXP find_promise_(Symbol name, Environment env) {
return unwrap_promise_(Rf_findVar(name, env));
RObject var = Rf_findVar(name, env);
return unwrap_promise_(var);
}

// [[Rcpp::export]]
SEXP promise_expr_(Promise promise) {
return PRCODE(unwrap_promise_(promise));
promise = unwrap_promise_(promise);
return PREXPR(promise);
}

// [[Rcpp::export]]
SEXP promise_env_(Promise promise) {
return PRENV(unwrap_promise_(promise));
promise = unwrap_promise_(promise);
return PRENV(promise);
}

// identical(x, quote(waiver()))
Expand All @@ -55,24 +58,21 @@ bool is_waiver_call(SEXP x) {
}

// [[Rcpp::export]]
bool is_waived_(RObject x) {
if (TYPEOF(x) != PROMSXP) {
return Rf_inherits(x, "waiver");
}

//TODO: fix this so we can use it instead of the R implementation
// the problem is bytecode (I think...); need to fix promise_expr
x = unwrap_promise_(x);
RObject expr = PRCODE(x);

if (is_waiver_call(expr)) return true;
bool is_waiver_(RObject x) {
if (TYPEOF(x) == PROMSXP) {
x = unwrap_promise_(x);
RObject expr = PREXPR(x);

if (TYPEOF(expr) == SYMSXP) {
RObject var = Rf_eval(expr, PRENV(x));
return is_waived_(var);
if (TYPEOF(expr) == SYMSXP) {
// TODO: should this be PRVALUE?
Environment env = PRENV(x);
x = Rcpp_eval(expr, env);
} else {
x = expr;
}
}

return false;
return is_waiver_call(x) || Rf_inherits(x, "waiver");
}

/**
Expand Down
34 changes: 34 additions & 0 deletions tests/testthat/test.auto_partial.R
Expand Up @@ -97,6 +97,25 @@ test_that("original function names are preserved in match.call after multiple pa
expect_equal(foo()()(1), quote(foo(x = 1)))
})


# waivers -----------------------------------------------------------------

test_that("is_waiver works", {
x = waiver()

expect_true(is_waiver(x))
expect_true(is_waiver(waiver()))

expect_true(is_waiver(new_promise(quote(x))[[1]]))
expect_true(is_waiver(new_promise(quote(waiver()))[[1]]))

f = function(x) promise_list(x)
g = function(y) f(y)
h = compiler::cmpfun(function(z) g(z))
expect_true(is_waiver(h(x)[[1]]))
expect_true(is_waiver(h(waiver())[[1]]))
})

test_that("waivers work", {
foo = auto_partial(function(x, a = 2) c(x, a))

Expand All @@ -110,3 +129,18 @@ test_that("waivers work", {

expect_equal(foo(a = waiver(), b = 5)(1)(y = -2, b = waiver()), c(1, -2, 3, 5))
})


# promises ----------------------------------------------------------------

test_that("promise expressions are not retrieved as byte code", {
f = function(...) {
lapply(promise_list(...), promise_expr_)
}
f = auto_partial(f)
g = compiler::cmpfun(function(...) {
gx = 5
f(x = gx, ...)
})
expect_equal(g(), list(x = quote(gx)))
})

0 comments on commit fb7aa94

Please sign in to comment.