Skip to content

Commit

Permalink
Hand-roll count()
Browse files Browse the repository at this point in the history
  • Loading branch information
krlmlr committed Aug 5, 2023
1 parent 7495ccf commit 9c8affa
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 30 deletions.
39 changes: 29 additions & 10 deletions R/count.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,42 @@ count.duckplyr_df <- function(x, ..., wt = NULL, sort = FALSE, name = NULL, .dro

dplyr_local_error_call()

quos <- enquos(...)
exprs <- unname(map(quos, quo_get_expr))
is_name <- map_lgl(exprs, is_symbol)
by <- dplyr_quosures(...)
by <- fix_auto_name(by)

# FIXME: Use rel_try() for accurate stats
if (!is_grouped_df(x) && all(is_name) && .drop && !sort) {
by_chr <- map_chr(exprs, as_string)
name <- check_n_name(name, by_chr)
by_exprs <- unname(map(by, quo_get_expr))
is_name <- map_lgl(by_exprs, is_symbol)

rel_try(
"count() needs all(is_name)" = !all(is_name),
"count() only implemented for .drop = TRUE" = !.drop,
"count() only implemented for sort = FALSE" = sort,
{
by_chr <- map_chr(by_exprs, as_string)
name <- check_n_name(name, by_chr)

if (name %in% by_chr) {
abort("Name clash in count()")
}

if (!(name %in% by_chr)) {
n <- tally_n(x, {{ wt }})

out <- summarise(x, !!name := !!n, .by = c(!!!exprs))
rel <- duckdb_rel_from_df(x)

groups <- rel_translate_dots(by, x)
aggregates <- list(rel_translate(as_quosure(n, baseenv()), x, alias = name))

out_rel <- rel_aggregate(rel, groups, unname(aggregates))
if (length(groups) > 0) {
out_rel <- rel_order(out_rel, groups)
}

out <- rel_to_df(out_rel)
out <- dplyr_reconstruct(out, x)

return(out)
}
}
)

# FIXME: optimize, no need to forward dots
# out <- count(x_df, !!!quos, wt = {{ wt }}, sort = sort, name = name, .drop = .drop)
Expand Down
53 changes: 35 additions & 18 deletions patch/count.patch
Original file line number Diff line number Diff line change
@@ -1,40 +1,57 @@
diff --git b/R/count.R a/R/count.R
index 9120a9d..7d99ff9 100644
index 9120a9d..438a529 100644
--- b/R/count.R
+++ a/R/count.R
@@ -1,13 +1,30 @@
@@ -1,14 +1,50 @@
# Generated by 02-duckplyr_df-methods.R
#' @export
count.duckplyr_df <- function(x, ..., wt = NULL, sort = FALSE, name = NULL, .drop = group_by_drop_default(x)) {
- # Our implementation
- rel_try(
- "No relational implementation for count()" = TRUE,
- {
+ force(x)
+
+ dplyr_local_error_call()
+
+ quos <- enquos(...)
+ exprs <- unname(map(quos, quo_get_expr))
+ is_name <- map_lgl(exprs, is_symbol)
+ by <- dplyr_quosures(...)
+ by <- fix_auto_name(by)
+
+ by_exprs <- unname(map(by, quo_get_expr))
+ is_name <- map_lgl(by_exprs, is_symbol)
+
rel_try(
- "No relational implementation for count()" = TRUE,
+ "count() needs all(is_name)" = !all(is_name),
+ "count() only implemented for .drop = TRUE" = !.drop,
+ "count() only implemented for sort = FALSE" = sort,
{
+ by_chr <- map_chr(by_exprs, as_string)
+ name <- check_n_name(name, by_chr)
+
+ # FIXME: Use rel_try() for accurate stats
+ if (!is_grouped_df(x) && all(is_name) && .drop && !sort) {
+ by_chr <- map_chr(exprs, as_string)
+ name <- check_n_name(name, by_chr)
+ if (name %in% by_chr) {
+ abort("Name clash in count()")
+ }
+
+ if (!(name %in% by_chr)) {
+ n <- tally_n(x, {{ wt }})
+
+ out <- summarise(x, !!name := !!n, .by = c(!!!exprs))
+ rel <- duckdb_rel_from_df(x)
+
+ groups <- rel_translate_dots(by, x)
+ aggregates <- list(rel_translate(as_quosure(n, baseenv()), x, alias = name))
+
+ out_rel <- rel_aggregate(rel, groups, unname(aggregates))
+ if (length(groups) > 0) {
+ out_rel <- rel_order(out_rel, groups)
+ }
+
+ out <- rel_to_df(out_rel)
+ out <- dplyr_reconstruct(out, x)
+
return(out)
}
- )
+ }
+
)

+ # FIXME: optimize, no need to forward dots
+ # out <- count(x_df, !!!quos, wt = {{ wt }}, sort = sort, name = name, .drop = .drop)
+
# dplyr forward
count <- dplyr:::count.data.frame
out <- count(x, ..., wt = {{ wt }}, sort = sort, name = name, .drop = .drop)
4 changes: 2 additions & 2 deletions tests/testthat/_snaps/count-tally.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
Code
duckplyr_count(df, x, name = 1)
Condition
Error in `count()`:
Error in `tally()`:
! `name` must be a single string, not the number 1.

---

Code
duckplyr_count(df, x, name = letters)
Condition
Error in `count()`:
Error in `tally()`:
! `name` must be a single string, not a character vector.

# can only explicitly chain together multiple tallies
Expand Down

0 comments on commit 9c8affa

Please sign in to comment.