Skip to content

Commit

Permalink
allow optional arguments to curried functions to be waived
Browse files Browse the repository at this point in the history
  • Loading branch information
mjskay committed Mar 3, 2024
1 parent 6b55b03 commit 0376023
Show file tree
Hide file tree
Showing 20 changed files with 223 additions and 85 deletions.
4 changes: 3 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ export(thickness)
export(to_broom_names)
export(to_ggmcmc_names)
export(ul)
export(waiver)
export(weighted_ecdf)
export(weighted_quantile)
export(weighted_quantile_fun)
Expand All @@ -299,6 +300,7 @@ importFrom(ggplot2,.stroke)
importFrom(ggplot2,GeomPolygon)
importFrom(ggplot2,GeomSegment)
importFrom(ggplot2,has_flipped_aes)
importFrom(ggplot2,waiver)
importFrom(glue,glue)
importFrom(glue,glue_collapse)
importFrom(grDevices,nclass.FD)
Expand All @@ -315,7 +317,7 @@ importFrom(rlang,caller_env)
importFrom(rlang,enexpr)
importFrom(rlang,enexprs)
importFrom(rlang,enquo)
importFrom(rlang,enquos)
importFrom(rlang,enquos0)
importFrom(rlang,eval_tidy)
importFrom(rlang,expr)
importFrom(rlang,get_expr)
Expand Down
6 changes: 5 additions & 1 deletion R/abstract_geom.R
Original file line number Diff line number Diff line change
Expand Up @@ -235,5 +235,9 @@ make_geom = function(geom,
#' (which will be expressions) match the formals of the generated code.
#' @noRd
to_expression = function(x) {
parse(text = deparse(x), keep.source = FALSE)[[1L]]
if (inherits(x, "waiver")) {
quote(waiver())
} else {
parse(text = deparse(x), keep.source = FALSE)[[1L]]
}
}
80 changes: 63 additions & 17 deletions R/auto_partial.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@
#' Many other common arguments for \pkg{ggdist} functions work similarly; e.g.
#' `density`, `align`, `breaks`, `bandwidth`, and `point_interval` arguments.
#'
#' These function families (except [point_interval()]) also support passing
#' [waiver]s to their optional arguments: if [waiver()] is passed to any
#' of these arguments, their default value (or the most
#' recently-partially-applied non-`waiver` value) is used instead.
#'
#' Use the [auto_partial()] function to create new functions that support
#' automatic partial application.
#'
Expand Down Expand Up @@ -83,21 +88,35 @@ NULL
#' function with the arguments that were supplied replacing the defaults.
#' Can be called multiple times
#' @noRd
#' @importFrom rlang as_quosure enquos eval_tidy expr get_expr
partial_self = function(name = NULL) {
#' @importFrom rlang as_quosure enquos0 eval_tidy expr get_expr
partial_self = function(name = NULL, waivable = TRUE) {
f = sys.function(-1L)
call_expr = match.call(f, sys.call(-1L), TRUE, parent.frame(2L))
f_quo = as_quosure(call_expr[[1]], parent.frame(2L))
default_args = lapply(call_expr[-1], as_quosure, env = parent.frame(2L))
provided_args = lapply(call_expr[-1], as_quosure, env = parent.frame(2L))
name = name %||% deparse0(get_expr(call_expr[[1]]))

waivable_arg_names = if (waivable) {
f_args = formals(f)
is_required_arg = vapply(f_args, rlang::is_missing, FUN.VALUE = logical(1))
names(f_args)[!is_required_arg]
}

partial_f = function(...) {
new_args = enquos(...)
all_args = defaults(new_args, default_args)
new_args = enquos0(...)
if (waivable) {
is_waivable = rlang::names2(new_args) %in% waivable_arg_names
is_waived = is_waivable
is_waived[is_waivable] = map_lgl_(new_args[is_waivable], function(arg_expr) {
inherits(eval_tidy(arg_expr), "waiver")
})
new_args = new_args[!is_waived]
}
all_args = defaults(new_args, provided_args)
eval_tidy(expr((!!f_quo)(!!!all_args)))
}

attr(partial_f, "default_args") = default_args
attr(partial_f, "provided_args") = provided_args
attr(partial_f, "name") = name
class(partial_f) = c("ggdist_partial_function", "function")
partial_f
Expand All @@ -107,10 +126,12 @@ partial_self = function(name = NULL) {
#' @param f A function
#' @param name A character string giving the name of the function, to be used
#' when printing.
#' @param waivable logical: if `TRUE`, optional arguments that get
#' passed a [waiver()] will keep their default value (or whatever
#' non-`waiver` value has been most recently partially applied for that
#' argument).
#' @returns A modified version of `f` that will automatically be partially
#' applied if all of its required arguments are not given.
#' @export
#' @importFrom rlang new_function expr
#' @examples
#' # create a custom automatically partially applied function
#' f = auto_partial(function(x, y, z = 3) (x + y) * z)
Expand All @@ -119,7 +140,13 @@ partial_self = function(name = NULL) {
#' g = f(y = 2)(z = 4)
#' g
#' g(1)
auto_partial = function(f, name = NULL) {
#'
#' # pass waiver() to optional arguments to use existing values
#' f(z = waiver())(1, 2) # uses default z = 3
#' f(z = 4)(z = waiver())(1, 2) # uses z = 4
#' @export
#' @importFrom rlang new_function expr
auto_partial = function(f, name = NULL, waivable = TRUE) {
f_body = body(f)
# must ensure the function body is a { ... } block, not a single expression,
# so we can splice it in later with !!!f_body
Expand All @@ -131,17 +158,25 @@ auto_partial = function(f, name = NULL) {
f_args = formals(f)

# find the required arguments
required_args = f_args[vapply(f_args, rlang::is_missing, FUN.VALUE = logical(1))]
required_arg_names = names(required_args)
is_required_arg = vapply(f_args, rlang::is_missing, FUN.VALUE = logical(1))
required_arg_names = names(f_args)[is_required_arg]
required_arg_names = required_arg_names[required_arg_names != "..."]

# no required arguments => function will always fully evaluate when called
if (length(required_arg_names) == 0) return(f)
# build an expression to apply waivers to optional args
if (waivable) {
optional_args = f_args[!is_required_arg]
process_waivers = map2_(optional_args, names(optional_args), function(arg_expr, arg_name) {
arg_sym = as.symbol(arg_name)
expr(if (inherits(!!arg_sym, "waiver")) !!arg_sym = !!arg_expr)
})
} else {
process_waivers = list()
}

# build a logical expression testing to see if any required args are missing
any_required_args_missing = Reduce(
function(x, y) expr(!!x || !!y),
lapply(required_arg_names, function(arg_name) expr(missing(!!as.name(arg_name))))
lapply(required_arg_names, function(arg_name) expr(missing(!!as.symbol(arg_name))))
)

partial_self_f = if (identical(environment(f), environment(partial_self))) {
Expand All @@ -154,11 +189,19 @@ auto_partial = function(f, name = NULL) {
partial_self
}

if (length(required_arg_names) == 0) {
partial_self_if_missing_args = list()
} else {
partial_self_if_missing_args = list(expr(
if (!!any_required_args_missing) return((!!partial_self_f)(!!name, waivable = !!waivable))
))
}

new_f = new_function(
f_args,
expr({
if (!!any_required_args_missing) return((!!partial_self_f)(!!name))

!!!process_waivers
!!!partial_self_if_missing_args
!!!f_body
}),
env = environment(f)
Expand All @@ -172,7 +215,7 @@ auto_partial = function(f, name = NULL) {
#' @export
print.ggdist_partial_function = function(x, ...) {
f_sym = as.name(attr(x, "name"))
f_args = lapply(attr(x, "default_args"), get_expr)
f_args = lapply(attr(x, "provided_args"), get_expr)

cat(sep = "\n",
"<partial_function>: ",
Expand All @@ -184,3 +227,6 @@ print.ggdist_partial_function = function(x, ...) {

invisible(x)
}

#' @export
ggplot2::waiver
8 changes: 5 additions & 3 deletions R/density.R
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,9 @@ density_bounded = auto_partial(name = "density_bounded", function(
#'
#' @param x numeric vector containing a sample to compute a density estimate for.
#' @param weights optional numeric vector of weights to apply to `x`.
#' @param breaks Determines the breakpoints defining bins. Similar to (but not
#' exactly the same as) the `breaks` argument to [graphics::hist()]. One of:
#' @param breaks Determines the breakpoints defining bins. Defaults to `"Scott"`.
#' Similar to (but not exactly the same as) the `breaks` argument to [graphics::hist()].
#' One of:
#' - A scalar (length-1) numeric giving the number of bins
#' - A vector numeric giving the breakpoints between histogram bins
#' - A function taking `x` and `weights` and returning either the
Expand All @@ -276,7 +277,8 @@ density_bounded = auto_partial(name = "density_bounded", function(
#' For example, `breaks = "Sturges"` will use the [breaks_Sturges()] algorithm,
#' `breaks = 9` will create 9 bins, and `breaks = breaks_fixed(width = 1)` will
#' set the bin width to `1`.
#' @param align Determines how to align the breakpoints defining bins. One of:
#' @param align Determines how to align the breakpoints defining bins. Default
#' (`"none"`) performs no alignment. One of:
#' - A scalar (length-1) numeric giving an offset that is subtracted from the breaks.
#' The offset must be between `0` and the bin width.
#' - A function taking a sorted vector of `breaks` (bin edges) and returning
Expand Down
2 changes: 1 addition & 1 deletion R/point_interval.R
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ point_interval = function(
.data, ..., .width = 0.95, .point = median, .interval = qi, .simple_names = TRUE,
na.rm = FALSE, .exclude = c(".chain", ".iteration", ".draw", ".row"), .prob
) {
if (missing(.data)) return(partial_self("point_interval"))
if (missing(.data)) return(partial_self("point_interval", waivable = FALSE))

UseMethod("point_interval")
}
Expand Down
5 changes: 3 additions & 2 deletions R/stat_slabinterval.R
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ compute_interval_slabinterval = function(
#' the data and then uses a bounded density estimator based on the reflection method.
#' @param adjust Passed to `density`: the bandwidth for the density estimator for sample data
#' is adjusted by multiplying it by this value. See e.g. [density_bounded()] for more information.
#' Default (`waiver()`) defers to the default of the density estimator, which is usually `1`.
#' @param trim For sample data, should the density estimate be trimmed to the range of the
#' data? Passed on to the density estimator; see the `density` parameter. Default `TRUE`.
#' @param expand For sample data, should the slab be expanded to the limits of the scale? Default `FALSE`.
Expand Down Expand Up @@ -492,10 +493,10 @@ StatSlabinterval = ggproto("StatSlabinterval", AbstractStatSlabinterval,
default_params = defaults(list(
p_limits = c(NA, NA),
density = "bounded",
adjust = 1,
adjust = waiver(),
trim = TRUE,
expand = FALSE,
breaks = "Scott",
breaks = waiver(),
align = "none",
outline_bars = FALSE,

Expand Down
6 changes: 5 additions & 1 deletion R/util.R
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,12 @@ map_lgl_ = function(.x, .f, ...) {
vapply(.x, .f, FUN.VALUE = logical(1), ...)
}

map2_ = function(.x, .y, .f) {
.mapply(.f, list(.x, .y), NULL)
}

map2_chr_ = function(.x, .y, .f) {
vctrs::list_unchop(.mapply(.f, list(.x, .y), NULL), ptype = character())
vctrs::list_unchop(map2_(.x, .y, .f), ptype = character())
}

fct_rev_ = function(x) {
Expand Down
16 changes: 15 additions & 1 deletion man/auto_partial.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 5 additions & 3 deletions man/density_histogram.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions man/reexports.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 9 additions & 6 deletions man/stat_ccdfinterval.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 0376023

Please sign in to comment.