From 15d6810a60e4c3e708845ae85ab9fb21f74a4ff2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kirill=20M=C3=BCller?= Date: Thu, 7 Mar 2024 15:43:46 +0100 Subject: [PATCH] fix: Forbid reuse of new columns created in `summarize()` --- R/relational.R | 18 ++++++++++++++++-- R/summarise.R | 2 +- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/R/relational.R b/R/relational.R index ec40a9ed..b146b771 100644 --- a/R/relational.R +++ b/R/relational.R @@ -47,9 +47,19 @@ rel_try <- function(rel, ...) { stop("Must use a return() in rel_try().") } -rel_translate_dots <- function(dots, data) { +rel_translate_dots <- function(dots, data, forbid_new = FALSE) { if (is.null(names(dots))) { map(dots, rel_translate, data) + } else if (forbid_new) { + out <- accumulate(seq_along(dots), .init = NULL, function(.x, .y) { + new <- names(dots)[[.y]] + translation <- rel_translate(dots[[.y]], alias = new, data, names_forbidden = .x$new) + list( + new = c(.x$new, new), + translation = c(.x$translation, list(translation)) + ) + }) + out[[length(out)]]$translation } else { imap(dots, rel_translate, data = data) } @@ -60,7 +70,8 @@ rel_translate <- function( alias = NULL, partition = NULL, need_window = FALSE, - names_data = names(data)) { + names_data = names(data), + names_forbidden = NULL) { if (is_expression(quo)) { expr <- quo env <- baseenv() @@ -84,6 +95,9 @@ rel_translate <- function( double = relexpr_constant(expr), # symbol = { + if (as.character(expr) %in% names_forbidden) { + abort(paste0("Can't reuse summary variable `", as.character(expr), "`.")) + } if (as.character(expr) %in% names_data) { ref <- as.character(expr) if (!(ref %in% used)) { diff --git a/R/summarise.R b/R/summarise.R index 03d3b908..97162069 100644 --- a/R/summarise.R +++ b/R/summarise.R @@ -21,7 +21,7 @@ summarise.duckplyr_df <- function(.data, ..., .by = NULL, .groups = NULL) { } groups <- lapply(by, relexpr_reference) - aggregates <- rel_translate_dots(dots, .data) + aggregates <- rel_translate_dots(dots, .data, forbid_new = TRUE) if (oo) { aggregates <- c(