From f36d0ed5b3a802891b376f913206c42c6ce4b272 Mon Sep 17 00:00:00 2001 From: RinChanNOWWW Date: Mon, 13 Mar 2023 19:00:35 +0800 Subject: [PATCH 1/7] Parser for GROUPIUNG SETS. --- src/query/ast/src/ast/format/ast_format.rs | 50 +- src/query/ast/src/ast/format/syntax/query.rs | 30 +- src/query/ast/src/ast/query.rs | 33 +- src/query/ast/src/parser/query.rs | 31 +- src/query/ast/src/parser/token.rs | 4 + src/query/ast/src/visitors/visitor.rs | 16 +- src/query/ast/src/visitors/visitor_mut.rs | 16 +- src/query/ast/tests/it/parser.rs | 6 + src/query/ast/tests/it/testdata/query.txt | 146 ++-- .../ast/tests/it/testdata/statement-error.txt | 22 + src/query/ast/tests/it/testdata/statement.txt | 623 ++++++++++++++++-- 11 files changed, 819 insertions(+), 158 deletions(-) diff --git a/src/query/ast/src/ast/format/ast_format.rs b/src/query/ast/src/ast/format/ast_format.rs index 4fff8bc958be4..f945502679714 100644 --- a/src/query/ast/src/ast/format/ast_format.rs +++ b/src/query/ast/src/ast/format/ast_format.rs @@ -2074,19 +2074,47 @@ impl<'ast> Visitor<'ast> for AstFormatVisitor { FormatTreeNode::with_children(selection_format_ctx, vec![selection_child]); children.push(selection_node); } - if !stmt.group_by.is_empty() { - let mut group_by_list_children = Vec::with_capacity(stmt.group_by.len()); - for group_by in stmt.group_by.iter() { - self.visit_expr(group_by); - group_by_list_children.push(self.children.pop().unwrap()); + match &stmt.group_by { + Some(GroupBy::Normal(exprs)) => { + let mut group_by_list_children = Vec::with_capacity(exprs.len()); + for group_by in exprs.iter() { + self.visit_expr(group_by); + group_by_list_children.push(self.children.pop().unwrap()); + } + let group_by_list_name = "GroupByList".to_string(); + let group_by_list_format_ctx = AstFormatContext::with_children( + group_by_list_name, + group_by_list_children.len(), + ); + let group_by_list_node = + FormatTreeNode::with_children(group_by_list_format_ctx, group_by_list_children); + children.push(group_by_list_node); } - let group_by_list_name = "GroupByList".to_string(); - let group_by_list_format_ctx = - AstFormatContext::with_children(group_by_list_name, group_by_list_children.len()); - let group_by_list_node = - FormatTreeNode::with_children(group_by_list_format_ctx, group_by_list_children); - children.push(group_by_list_node); + Some(GroupBy::GroupingSets(sets)) => { + let mut grouping_sets = Vec::with_capacity(sets.len()); + for set in sets.iter() { + let mut grouping_set = Vec::with_capacity(set.len()); + for expr in set.iter() { + self.visit_expr(expr); + grouping_set.push(self.children.pop().unwrap()); + } + let name = "GroupingSet".to_string(); + let grouping_set_format_ctx = + AstFormatContext::with_children(name, grouping_set.len()); + let grouping_set_node = + FormatTreeNode::with_children(grouping_set_format_ctx, grouping_set); + grouping_sets.push(grouping_set_node); + } + let group_by_list_name = "GroupByList".to_string(); + let group_by_list_format_ctx = + AstFormatContext::with_children(group_by_list_name, grouping_sets.len()); + let group_by_list_node = + FormatTreeNode::with_children(group_by_list_format_ctx, grouping_sets); + children.push(group_by_list_node); + } + _ => {} } + if let Some(having) = &stmt.having { self.visit_expr(having); let having_child = self.children.pop().unwrap(); diff --git a/src/query/ast/src/ast/format/syntax/query.rs b/src/query/ast/src/ast/format/syntax/query.rs index e4bbec89bd2ce..27bacd3213777 100644 --- a/src/query/ast/src/ast/format/syntax/query.rs +++ b/src/query/ast/src/ast/format/syntax/query.rs @@ -21,6 +21,7 @@ use crate::ast::format::syntax::interweave_comma; use crate::ast::format::syntax::parenthenized; use crate::ast::format::syntax::NEST_FACTOR; use crate::ast::Expr; +use crate::ast::GroupBy; use crate::ast::JoinCondition; use crate::ast::JoinOperator; use crate::ast::OrderByExpr; @@ -194,12 +195,19 @@ fn pretty_selection(selection: Option) -> RcDoc<'static> { } } -fn pretty_group_by(group_by: Vec) -> RcDoc<'static> { - if !group_by.is_empty() { - RcDoc::line() +fn pretty_group_set(set: Vec) -> RcDoc<'static> { + RcDoc::nil() + .append(RcDoc::text("(")) + .append(inline_comma(set.into_iter().map(pretty_expr))) + .append(RcDoc::text(")")) +} + +fn pretty_group_by(group_by: Option) -> RcDoc<'static> { + match group_by { + Some(GroupBy::Normal(exprs)) => RcDoc::line() .append( RcDoc::text("GROUP BY").append( - if group_by.len() > 1 { + if exprs.len() > 1 { RcDoc::line() } else { RcDoc::space() @@ -208,12 +216,20 @@ fn pretty_group_by(group_by: Vec) -> RcDoc<'static> { ), ) .append( - interweave_comma(group_by.into_iter().map(pretty_expr)) + interweave_comma(exprs.into_iter().map(pretty_expr)) + .nest(NEST_FACTOR) + .group(), + ), + Some(GroupBy::GroupingSets(sets)) => RcDoc::line() + .append(RcDoc::text("GROUP BY GROUPING SETS (").append(RcDoc::line().nest(NEST_FACTOR))) + .append( + interweave_comma(sets.into_iter().map(pretty_group_set)) .nest(NEST_FACTOR) .group(), ) - } else { - RcDoc::nil() + .append(RcDoc::line()) + .append(RcDoc::text(")")), + _ => RcDoc::nil(), } } diff --git a/src/query/ast/src/ast/query.rs b/src/query/ast/src/ast/query.rs index f5c56fdae7a30..7b1d58e04a89b 100644 --- a/src/query/ast/src/ast/query.rs +++ b/src/query/ast/src/ast/query.rs @@ -82,11 +82,22 @@ pub struct SelectStmt { // `WHERE` clause pub selection: Option, // `GROUP BY` clause - pub group_by: Vec, + pub group_by: Option, // `HAVING` clause pub having: Option, } +/// Group by Clause. +#[derive(Debug, Clone, PartialEq)] +pub enum GroupBy { + /// GROUP BY expr [, expr]* + Normal(Vec), + /// GROUP BY GROUPING SETS ( GroupSet [, GroupSet]* ) + /// + /// GroupSet := (expr [, expr]*) | expr + GroupingSets(Vec>), +} + /// A relational set expression, like `SELECT ... FROM ... {UNION|EXCEPT|INTERSECT} SELECT ... FROM ...` #[derive(Debug, Clone, PartialEq)] pub enum SetExpr { @@ -442,9 +453,25 @@ impl Display for SelectStmt { } // GROUP BY clause - if !self.group_by.is_empty() { + if self.group_by.is_some() { write!(f, " GROUP BY ")?; - write_comma_separated_list(f, &self.group_by)?; + match self.group_by.as_ref().unwrap() { + GroupBy::Normal(exprs) => { + write_comma_separated_list(f, exprs)?; + } + GroupBy::GroupingSets(sets) => { + write!(f, "GROUPING SETS (")?; + for (i, set) in sets.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "(")?; + write_comma_separated_list(f, set)?; + write!(f, ")")?; + } + write!(f, ")")?; + } + } } // HAVING clause diff --git a/src/query/ast/src/parser/query.rs b/src/query/ast/src/parser/query.rs index 263edbaa6f7b5..1513ac426e1dc 100644 --- a/src/query/ast/src/parser/query.rs +++ b/src/query/ast/src/parser/query.rs @@ -555,7 +555,7 @@ pub enum SetOperationElement { select_list: Box>, from: Box>, selection: Box>, - group_by: Box>, + group_by: Option, having: Box>, }, SetOperation { @@ -565,6 +565,25 @@ pub enum SetOperationElement { Group(SetExpr), } +pub fn group_by_items(i: Input) -> IResult { + let normal = map(rule! { ^#comma_separated_list1(expr) }, |groups| { + GroupBy::Normal(groups) + }); + let group_set = alt(( + map(rule! {"(" ~ ")"}, |(_, _)| vec![]), // empty grouping set + map( + rule! {"(" ~ #comma_separated_list1(expr) ~ ")"}, + |(_, sets, _)| sets, + ), + map(rule! { #expr }, |e| vec![e]), + )); + let group_sets = map( + rule! { GROUPING ~ SETS ~ "(" ~ ^#comma_separated_list1(group_set) ~ ")" }, + |(_, _, _, sets, _)| GroupBy::GroupingSets(sets), + ); + rule!(#group_sets | #normal)(i) +} + pub fn set_operation_element(i: Input) -> IResult> { let set_operator = map( rule! { @@ -588,7 +607,7 @@ pub fn set_operation_element(i: Input) -> IResult> SELECT ~ DISTINCT? ~ ^#comma_separated_list1(select_target) ~ ( FROM ~ ^#comma_separated_list1(table_reference) )? ~ ( WHERE ~ ^#expr )? - ~ ( GROUP ~ ^BY ~ ^#comma_separated_list1(expr) )? + ~ ( GROUP ~ ^BY ~ ^#group_by_items )? ~ ( HAVING ~ ^#expr )? }, |( @@ -609,11 +628,7 @@ pub fn set_operation_element(i: Input) -> IResult> .unwrap_or_default(), ), selection: Box::new(opt_where_block.map(|(_, selection)| selection)), - group_by: Box::new( - opt_group_by_block - .map(|(_, _, group_by)| group_by) - .unwrap_or_default(), - ), + group_by: opt_group_by_block.map(|(_, _, group_by)| group_by), having: Box::new(opt_having_block.map(|(_, having)| having)), } }, @@ -667,7 +682,7 @@ impl<'a, I: Iterator>> PrattParser select_list: *select_list, from: *from, selection: *selection, - group_by: *group_by, + group_by, having: *having, })), _ => unreachable!(), diff --git a/src/query/ast/src/parser/token.rs b/src/query/ast/src/parser/token.rs index 18321a9cdaf4e..b995424b504aa 100644 --- a/src/query/ast/src/parser/token.rs +++ b/src/query/ast/src/parser/token.rs @@ -875,6 +875,10 @@ pub enum TokenKind { LAST, #[token("IGNORE_RESULT", ignore(ascii_case))] IGNORE_RESULT, + #[token("GROUPING", ignore(ascii_case))] + GROUPING, + #[token("SETS", ignore(ascii_case))] + SETS, } // Reference: https://www.postgresql.org/docs/current/sql-keywords-appendix.html diff --git a/src/query/ast/src/visitors/visitor.rs b/src/query/ast/src/visitors/visitor.rs index f0a18b5379b48..f27c17bd3c185 100644 --- a/src/query/ast/src/visitors/visitor.rs +++ b/src/query/ast/src/visitors/visitor.rs @@ -568,8 +568,20 @@ pub trait Visitor<'ast>: Sized { walk_expr(self, selection); } - for expr in group_by.iter() { - walk_expr(self, expr); + match group_by { + Some(GroupBy::Normal(exprs)) => { + for expr in exprs { + walk_expr(self, expr); + } + } + Some(GroupBy::GroupingSets(sets)) => { + for set in sets { + for expr in set { + walk_expr(self, expr); + } + } + } + _ => {} } if let Some(having) = having { diff --git a/src/query/ast/src/visitors/visitor_mut.rs b/src/query/ast/src/visitors/visitor_mut.rs index df2310a01d326..99489982c3011 100644 --- a/src/query/ast/src/visitors/visitor_mut.rs +++ b/src/query/ast/src/visitors/visitor_mut.rs @@ -574,8 +574,20 @@ pub trait VisitorMut: Sized { walk_expr_mut(self, selection); } - for expr in group_by.iter_mut() { - walk_expr_mut(self, expr); + match group_by { + Some(GroupBy::Normal(exprs)) => { + for expr in exprs { + walk_expr_mut(self, expr); + } + } + Some(GroupBy::GroupingSets(sets)) => { + for set in sets { + for expr in set { + walk_expr_mut(self, expr); + } + } + } + _ => {} } if let Some(having) = having { diff --git a/src/query/ast/tests/it/parser.rs b/src/query/ast/tests/it/parser.rs index 503ba95104393..cbdea6ea92941 100644 --- a/src/query/ast/tests/it/parser.rs +++ b/src/query/ast/tests/it/parser.rs @@ -368,6 +368,10 @@ fn test_statement() { type = CSV field_delimiter = ',' record_delimiter = '\n' skip_header = 1;"#, r#"SHOW FILE FORMATS"#, r#"DROP FILE FORMAT my_csv"#, + r#"SELECT * FROM t GROUP BY GROUPING SETS (a, b, c, d)"#, + r#"SELECT * FROM t GROUP BY GROUPING SETS (a, b, (c, d))"#, + r#"SELECT * FROM t GROUP BY GROUPING SETS ((a, b), (c), (d, e))"#, + r#"SELECT * FROM t GROUP BY GROUPING SETS ((a, b), (), (d, e))"#, ]; for case in cases { @@ -424,6 +428,8 @@ fn test_statement_error() { r#"CALL system$test(a"#, r#"show settings ilike 'enable%'"#, r#"PRESIGN INVALID @my_stage/path/to/file"#, + r#"SELECT * FROM t GROUP BY GROUPING SETS a, b"#, + r#"SELECT * FROM t GROUP BY GROUPING SETS ()"#, ]; for case in cases { diff --git a/src/query/ast/tests/it/testdata/query.txt b/src/query/ast/tests/it/testdata/query.txt index a5ef21e8584ea..72b687c0753e4 100644 --- a/src/query/ast/tests/it/testdata/query.txt +++ b/src/query/ast/tests/it/testdata/query.txt @@ -150,7 +150,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -238,7 +238,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -317,7 +317,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -431,7 +431,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -554,7 +554,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -651,7 +651,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -765,7 +765,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -863,7 +863,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -965,7 +965,7 @@ Query { }, }, ), - group_by: [], + group_by: None, having: None, }, ), @@ -1055,7 +1055,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -1157,7 +1157,7 @@ Query { }, }, ), - group_by: [], + group_by: None, having: None, }, ), @@ -1255,7 +1255,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -1317,7 +1317,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -1419,7 +1419,7 @@ Query { }, }, ), - group_by: [], + group_by: None, having: None, }, ), @@ -1603,7 +1603,7 @@ Query { }, }, ), - group_by: [], + group_by: None, having: None, }, ), @@ -1708,7 +1708,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -1756,7 +1756,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -1823,7 +1823,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -2175,22 +2175,26 @@ Query { }, ], selection: None, - group_by: [ - ColumnRef { - span: Some( - 479..488, - ), - database: None, - table: None, - column: Identifier { - name: "c_custkey", - quote: None, - span: Some( - 479..488, - ), - }, - }, - ], + group_by: Some( + Normal( + [ + ColumnRef { + span: Some( + 479..488, + ), + database: None, + table: None, + column: Identifier { + name: "c_custkey", + quote: None, + span: Some( + 479..488, + ), + }, + }, + ], + ), + ), having: None, }, ), @@ -2214,22 +2218,26 @@ Query { }, ], selection: None, - group_by: [ - ColumnRef { - span: Some( - 540..547, - ), - database: None, - table: None, - column: Identifier { - name: "c_count", - quote: None, - span: Some( - 540..547, - ), - }, - }, - ], + group_by: Some( + Normal( + [ + ColumnRef { + span: Some( + 540..547, + ), + database: None, + table: None, + column: Identifier { + name: "c_count", + quote: None, + span: Some( + 540..547, + ), + }, + }, + ], + ), + ), having: None, }, ), @@ -2376,7 +2384,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -2413,7 +2421,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -2476,7 +2484,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -2513,7 +2521,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -2583,7 +2591,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -2620,7 +2628,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -2659,7 +2667,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -2729,7 +2737,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -2766,7 +2774,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -2805,7 +2813,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -2868,7 +2876,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -2912,7 +2920,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -2949,7 +2957,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -3021,7 +3029,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -3058,7 +3066,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -3097,7 +3105,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -3160,7 +3168,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -3204,7 +3212,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -3241,7 +3249,7 @@ Query { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), diff --git a/src/query/ast/tests/it/testdata/statement-error.txt b/src/query/ast/tests/it/testdata/statement-error.txt index ebde07c72e3f4..76ba129a6a040 100644 --- a/src/query/ast/tests/it/testdata/statement-error.txt +++ b/src/query/ast/tests/it/testdata/statement-error.txt @@ -352,3 +352,25 @@ error: | while parsing `PRESIGN [{DOWNLOAD | UPLOAD}] [EXPIRE = 3600]` +---------- Input ---------- +SELECT * FROM t GROUP BY GROUPING SETS a, b +---------- Output --------- +error: + --> SQL:1:35 + | +1 | SELECT * FROM t GROUP BY GROUPING SETS a, b + | ^^^^ expected `,`, `HAVING`, `(`, `UNION`, `EXCEPT`, `INTERSECT`, or 7 more ... + + +---------- Input ---------- +SELECT * FROM t GROUP BY GROUPING SETS () +---------- Output --------- +error: + --> SQL:1:41 + | +1 | SELECT * FROM t GROUP BY GROUPING SETS () + | ------ ^ expected `(`, `IS`, `IN`, `EXISTS`, `BETWEEN`, `+`, or 58 more ... + | | + | while parsing `SELECT ...` + + diff --git a/src/query/ast/tests/it/testdata/statement.txt b/src/query/ast/tests/it/testdata/statement.txt index 271e780968bd1..806f8cb29a842 100644 --- a/src/query/ast/tests/it/testdata/statement.txt +++ b/src/query/ast/tests/it/testdata/statement.txt @@ -263,7 +263,7 @@ Explain { }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -549,7 +549,7 @@ CreateTable( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -1337,7 +1337,7 @@ CreateView( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -1446,7 +1446,7 @@ AlterView( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -1585,7 +1585,7 @@ CreateView( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -1702,7 +1702,7 @@ AlterView( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -2324,22 +2324,26 @@ Query( }, }, ), - group_by: [ - ColumnRef { - span: Some( - 70..71, - ), - database: None, - table: None, - column: Identifier { - name: "a", - quote: None, - span: Some( - 70..71, - ), - }, - }, - ], + group_by: Some( + Normal( + [ + ColumnRef { + span: Some( + 70..71, + ), + database: None, + table: None, + column: Identifier { + name: "a", + quote: None, + span: Some( + 70..71, + ), + }, + }, + ], + ), + ), having: Some( BinaryOp { span: Some( @@ -2424,7 +2428,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -2488,7 +2492,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -2576,7 +2580,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -2664,7 +2668,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -2837,7 +2841,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -2969,7 +2973,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -3101,7 +3105,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -3233,7 +3237,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -3365,7 +3369,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -3497,7 +3501,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -3629,7 +3633,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -3761,7 +3765,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -3893,7 +3897,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -4025,7 +4029,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -4157,7 +4161,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -4248,7 +4252,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -4339,7 +4343,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -4430,7 +4434,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -4521,7 +4525,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -4668,7 +4672,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -4680,7 +4684,7 @@ Query( }, }, ), - group_by: [], + group_by: None, having: None, }, ), @@ -4827,7 +4831,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -4839,7 +4843,7 @@ Query( }, }, ), - group_by: [], + group_by: None, having: None, }, ), @@ -4986,7 +4990,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -4998,7 +5002,7 @@ Query( }, }, ), - group_by: [], + group_by: None, having: None, }, ), @@ -5143,7 +5147,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -5155,7 +5159,7 @@ Query( }, }, ), - group_by: [], + group_by: None, having: None, }, ), @@ -5261,7 +5265,7 @@ Query( }, }, ), - group_by: [], + group_by: None, having: None, }, ), @@ -5364,7 +5368,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -5499,7 +5503,7 @@ Insert( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -5592,7 +5596,7 @@ Query( ], from: [], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -8412,7 +8416,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -8559,7 +8563,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -8640,7 +8644,7 @@ Query( }, ], selection: None, - group_by: [], + group_by: None, having: None, }, ), @@ -8695,3 +8699,510 @@ DropFileFormat { } +---------- Input ---------- +SELECT * FROM t GROUP BY GROUPING SETS (a, b, c, d) +---------- Output --------- +SELECT * FROM t GROUP BY GROUPING SETS ((a), (b), (c), (d)) +---------- AST ------------ +Query( + Query { + span: Some( + 0..51, + ), + with: None, + body: Select( + SelectStmt { + span: Some( + 0..51, + ), + distinct: false, + select_list: [ + QualifiedName { + qualified: [ + Star, + ], + exclude: None, + }, + ], + from: [ + Table { + span: Some( + 14..15, + ), + catalog: None, + database: None, + table: Identifier { + name: "t", + quote: None, + span: Some( + 14..15, + ), + }, + alias: None, + travel_point: None, + }, + ], + selection: None, + group_by: Some( + GroupingSets( + [ + [ + ColumnRef { + span: Some( + 40..41, + ), + database: None, + table: None, + column: Identifier { + name: "a", + quote: None, + span: Some( + 40..41, + ), + }, + }, + ], + [ + ColumnRef { + span: Some( + 43..44, + ), + database: None, + table: None, + column: Identifier { + name: "b", + quote: None, + span: Some( + 43..44, + ), + }, + }, + ], + [ + ColumnRef { + span: Some( + 46..47, + ), + database: None, + table: None, + column: Identifier { + name: "c", + quote: None, + span: Some( + 46..47, + ), + }, + }, + ], + [ + ColumnRef { + span: Some( + 49..50, + ), + database: None, + table: None, + column: Identifier { + name: "d", + quote: None, + span: Some( + 49..50, + ), + }, + }, + ], + ], + ), + ), + having: None, + }, + ), + order_by: [], + limit: [], + offset: None, + ignore_result: false, + }, +) + + +---------- Input ---------- +SELECT * FROM t GROUP BY GROUPING SETS (a, b, (c, d)) +---------- Output --------- +SELECT * FROM t GROUP BY GROUPING SETS ((a), (b), (c, d)) +---------- AST ------------ +Query( + Query { + span: Some( + 0..53, + ), + with: None, + body: Select( + SelectStmt { + span: Some( + 0..53, + ), + distinct: false, + select_list: [ + QualifiedName { + qualified: [ + Star, + ], + exclude: None, + }, + ], + from: [ + Table { + span: Some( + 14..15, + ), + catalog: None, + database: None, + table: Identifier { + name: "t", + quote: None, + span: Some( + 14..15, + ), + }, + alias: None, + travel_point: None, + }, + ], + selection: None, + group_by: Some( + GroupingSets( + [ + [ + ColumnRef { + span: Some( + 40..41, + ), + database: None, + table: None, + column: Identifier { + name: "a", + quote: None, + span: Some( + 40..41, + ), + }, + }, + ], + [ + ColumnRef { + span: Some( + 43..44, + ), + database: None, + table: None, + column: Identifier { + name: "b", + quote: None, + span: Some( + 43..44, + ), + }, + }, + ], + [ + ColumnRef { + span: Some( + 47..48, + ), + database: None, + table: None, + column: Identifier { + name: "c", + quote: None, + span: Some( + 47..48, + ), + }, + }, + ColumnRef { + span: Some( + 50..51, + ), + database: None, + table: None, + column: Identifier { + name: "d", + quote: None, + span: Some( + 50..51, + ), + }, + }, + ], + ], + ), + ), + having: None, + }, + ), + order_by: [], + limit: [], + offset: None, + ignore_result: false, + }, +) + + +---------- Input ---------- +SELECT * FROM t GROUP BY GROUPING SETS ((a, b), (c), (d, e)) +---------- Output --------- +SELECT * FROM t GROUP BY GROUPING SETS ((a, b), (c), (d, e)) +---------- AST ------------ +Query( + Query { + span: Some( + 0..60, + ), + with: None, + body: Select( + SelectStmt { + span: Some( + 0..60, + ), + distinct: false, + select_list: [ + QualifiedName { + qualified: [ + Star, + ], + exclude: None, + }, + ], + from: [ + Table { + span: Some( + 14..15, + ), + catalog: None, + database: None, + table: Identifier { + name: "t", + quote: None, + span: Some( + 14..15, + ), + }, + alias: None, + travel_point: None, + }, + ], + selection: None, + group_by: Some( + GroupingSets( + [ + [ + ColumnRef { + span: Some( + 41..42, + ), + database: None, + table: None, + column: Identifier { + name: "a", + quote: None, + span: Some( + 41..42, + ), + }, + }, + ColumnRef { + span: Some( + 44..45, + ), + database: None, + table: None, + column: Identifier { + name: "b", + quote: None, + span: Some( + 44..45, + ), + }, + }, + ], + [ + ColumnRef { + span: Some( + 49..50, + ), + database: None, + table: None, + column: Identifier { + name: "c", + quote: None, + span: Some( + 49..50, + ), + }, + }, + ], + [ + ColumnRef { + span: Some( + 54..55, + ), + database: None, + table: None, + column: Identifier { + name: "d", + quote: None, + span: Some( + 54..55, + ), + }, + }, + ColumnRef { + span: Some( + 57..58, + ), + database: None, + table: None, + column: Identifier { + name: "e", + quote: None, + span: Some( + 57..58, + ), + }, + }, + ], + ], + ), + ), + having: None, + }, + ), + order_by: [], + limit: [], + offset: None, + ignore_result: false, + }, +) + + +---------- Input ---------- +SELECT * FROM t GROUP BY GROUPING SETS ((a, b), (), (d, e)) +---------- Output --------- +SELECT * FROM t GROUP BY GROUPING SETS ((a, b), (), (d, e)) +---------- AST ------------ +Query( + Query { + span: Some( + 0..59, + ), + with: None, + body: Select( + SelectStmt { + span: Some( + 0..59, + ), + distinct: false, + select_list: [ + QualifiedName { + qualified: [ + Star, + ], + exclude: None, + }, + ], + from: [ + Table { + span: Some( + 14..15, + ), + catalog: None, + database: None, + table: Identifier { + name: "t", + quote: None, + span: Some( + 14..15, + ), + }, + alias: None, + travel_point: None, + }, + ], + selection: None, + group_by: Some( + GroupingSets( + [ + [ + ColumnRef { + span: Some( + 41..42, + ), + database: None, + table: None, + column: Identifier { + name: "a", + quote: None, + span: Some( + 41..42, + ), + }, + }, + ColumnRef { + span: Some( + 44..45, + ), + database: None, + table: None, + column: Identifier { + name: "b", + quote: None, + span: Some( + 44..45, + ), + }, + }, + ], + [], + [ + ColumnRef { + span: Some( + 53..54, + ), + database: None, + table: None, + column: Identifier { + name: "d", + quote: None, + span: Some( + 53..54, + ), + }, + }, + ColumnRef { + span: Some( + 56..57, + ), + database: None, + table: None, + column: Identifier { + name: "e", + quote: None, + span: Some( + 56..57, + ), + }, + }, + ], + ], + ), + ), + having: None, + }, + ), + order_by: [], + limit: [], + offset: None, + ignore_result: false, + }, +) + + From fb67b0482712bc6bf680053d3933460b313f2bc9 Mon Sep 17 00:00:00 2001 From: RinChanNOWWW Date: Mon, 13 Mar 2023 21:04:22 +0800 Subject: [PATCH 2/7] Bind GROUPING SETS. --- src/query/sql/src/planner/binder/aggregate.rs | 97 ++++++++++++++++++- src/query/sql/src/planner/binder/distinct.rs | 1 + src/query/sql/src/planner/binder/select.rs | 6 +- src/query/sql/src/planner/metadata.rs | 6 -- .../optimizer/heuristic/decorrelate.rs | 1 + .../heuristic/prune_unused_columns.rs | 1 + .../optimizer/heuristic/subquery_rewriter.rs | 1 + src/query/sql/src/planner/plans/aggregate.rs | 2 + .../planner/semantic/distinct_to_groupby.rs | 7 +- 9 files changed, 106 insertions(+), 16 deletions(-) diff --git a/src/query/sql/src/planner/binder/aggregate.rs b/src/query/sql/src/planner/binder/aggregate.rs index b0fb4073bff4c..6eccf5f7de5d3 100644 --- a/src/query/sql/src/planner/binder/aggregate.rs +++ b/src/query/sql/src/planner/binder/aggregate.rs @@ -17,11 +17,13 @@ use std::collections::HashMap; use std::collections::HashSet; use common_ast::ast::Expr; +use common_ast::ast::GroupBy; use common_ast::ast::Literal; use common_ast::ast::SelectTarget; use common_exception::ErrorCode; use common_exception::Result; use common_expression::types::DataType; +use common_expression::types::NumberDataType; use super::prune_by_children; use crate::binder::scalar::ScalarBinder; @@ -45,6 +47,7 @@ use crate::plans::ScalarExpr; use crate::plans::ScalarItem; use crate::plans::Unnest; use crate::BindContext; +use crate::IndexType; use crate::MetadataRef; #[derive(Default, Clone, PartialEq, Eq, Debug)] @@ -74,6 +77,11 @@ pub struct AggregateInfo { /// TODO(leiysky): so far we are using `Debug` string of `Scalar` as identifier, /// maybe a more reasonable way is needed pub group_items_map: HashMap, + + /// Index for virtual column `grouping_id`. It's valid only if `grouping_sets` is not empty. + pub grouping_id_index: IndexType, + /// Each grouping set is a list of column indices in `group_items`. + pub grouping_sets: Vec>, } pub(super) struct AggregateRewriter<'a> { @@ -246,7 +254,7 @@ impl Binder { &mut self, bind_context: &mut BindContext, select_list: &SelectList<'a>, - group_by: &[Expr], + group_by: &GroupBy, ) -> Result<()> { let mut available_aliases = vec![]; @@ -269,8 +277,23 @@ impl Binder { } } - self.resolve_group_items(bind_context, select_list, group_by, &available_aliases) - .await + match group_by { + GroupBy::Normal(exprs) => { + self.resolve_group_items( + bind_context, + select_list, + exprs, + &available_aliases, + false, + &mut vec![], + ) + .await + } + GroupBy::GroupingSets(sets) => { + self.resolve_grouping_sets(bind_context, select_list, sets, &available_aliases) + .await + } + } } pub(super) async fn bind_aggregate( @@ -306,19 +329,78 @@ impl Binder { aggregate_functions: bind_context.aggregate_info.aggregate_functions.clone(), from_distinct: false, limit: None, + grouping_sets: agg_info.grouping_sets.clone(), }; new_expr = SExpr::create_unary(aggregate_plan.into(), new_expr); Ok(new_expr) } + async fn resolve_grouping_sets( + &mut self, + bind_context: &mut BindContext, + select_list: &SelectList<'_>, + sets: &[Vec], + available_aliases: &[(ColumnBinding, ScalarExpr)], + ) -> Result<()> { + let mut grouping_sets = Vec::with_capacity(sets.len()); + for set in sets { + self.resolve_group_items( + bind_context, + select_list, + set, + available_aliases, + true, + &mut grouping_sets, + ) + .await?; + } + // `grouping_sets` stores formatted `ScalarExpr` for each grouping set. + let grouping_sets = grouping_sets + .into_iter() + .map(|set| { + let mut set = set + .into_iter() + .map(|s| { + let offset = *bind_context.aggregate_info.group_items_map.get(&s).unwrap(); + bind_context.aggregate_info.group_items[offset].index + }) + .collect::>(); + // Grouping sets with the same items should be treated as the same. + set.sort(); + set + }) + .collect(); + bind_context.aggregate_info.grouping_sets = grouping_sets; + // Add a virtual column `_grouping_id` to group items. + let grouping_id_column = self.create_column_binding( + None, + None, + "_grouping_id".to_string(), + DataType::Number(NumberDataType::UInt32), + ); + let index = grouping_id_column.index; + bind_context.aggregate_info.group_items.push(ScalarItem { + index, + scalar: ScalarExpr::BoundColumnRef(BoundColumnRef { + column: grouping_id_column, + }), + }); + Ok(()) + } + async fn resolve_group_items( &mut self, bind_context: &mut BindContext, select_list: &SelectList<'_>, group_by: &[Expr], available_aliases: &[(ColumnBinding, ScalarExpr)], + collect_grouping_sets: bool, + grouping_sets: &mut Vec>, ) -> Result<()> { + if collect_grouping_sets { + grouping_sets.push(Vec::with_capacity(group_by.len())); + } // Resolve group items with `FROM` context. Since the alias item can not be resolved // from the context, we can detect the failure and fallback to resolving with `available_aliases`. for expr in group_by.iter() { @@ -361,10 +443,15 @@ impl Binder { .await .or_else(|e| Self::resolve_alias_item(bind_context, expr, available_aliases, e))?; + let scalar_str = format!("{:?}", scalar_expr); + if collect_grouping_sets && !grouping_sets.last().unwrap().contains(&scalar_str) { + grouping_sets.last_mut().unwrap().push(scalar_str.clone()); + } + if bind_context .aggregate_info .group_items_map - .get(&format!("{:?}", &scalar_expr)) + .get(&scalar_str) .is_some() { // The group key is duplicated @@ -388,7 +475,7 @@ impl Binder { index, }); bind_context.aggregate_info.group_items_map.insert( - format!("{:?}", &scalar_expr), + scalar_str, bind_context.aggregate_info.group_items.len() - 1, ); } diff --git a/src/query/sql/src/planner/binder/distinct.rs b/src/query/sql/src/planner/binder/distinct.rs index 315be7a465549..e902e5066b5c9 100644 --- a/src/query/sql/src/planner/binder/distinct.rs +++ b/src/query/sql/src/planner/binder/distinct.rs @@ -76,6 +76,7 @@ impl Binder { aggregate_functions: vec![], from_distinct: true, limit: None, + grouping_sets: vec![], }; Ok(SExpr::create_unary(distinct_plan.into(), new_expr)) diff --git a/src/query/sql/src/planner/binder/select.rs b/src/query/sql/src/planner/binder/select.rs index b4c5780b2d5e3..f46ddaf06309f 100644 --- a/src/query/sql/src/planner/binder/select.rs +++ b/src/query/sql/src/planner/binder/select.rs @@ -104,8 +104,10 @@ impl Binder { let (mut scalar_items, projections) = self.analyze_projection(&select_list)?; // This will potentially add some alias group items to `from_context` if find some. - self.analyze_group_items(&mut from_context, &select_list, &stmt.group_by) - .await?; + if let Some(group_by) = stmt.group_by.as_ref() { + self.analyze_group_items(&mut from_context, &select_list, group_by) + .await?; + } self.analyze_aggregate_select(&mut from_context, &mut select_list)?; diff --git a/src/query/sql/src/planner/metadata.rs b/src/query/sql/src/planner/metadata.rs index 65041817d89d2..e1f37a2e7a412 100644 --- a/src/query/sql/src/planner/metadata.rs +++ b/src/query/sql/src/planner/metadata.rs @@ -35,12 +35,6 @@ pub type IndexType = usize; pub static DUMMY_TABLE_INDEX: IndexType = IndexType::MAX; pub static DUMMY_COLUMN_INDEX: IndexType = IndexType::MAX; -/// A special index value to represent the internal column `_group_by_key`, which is -/// used to store the group by key in the final aggregation stage. -/// -/// TODO(leiysky): remove this after we have a better way to represent the internal column. -pub static GROUP_BY_KEY_COLUMN_INDEX: IndexType = IndexType::MAX - 1; - /// ColumnSet represents a set of columns identified by its IndexType. pub type ColumnSet = HashSet; diff --git a/src/query/sql/src/planner/optimizer/heuristic/decorrelate.rs b/src/query/sql/src/planner/optimizer/heuristic/decorrelate.rs index 80e48f816c256..fa643a34be22c 100644 --- a/src/query/sql/src/planner/optimizer/heuristic/decorrelate.rs +++ b/src/query/sql/src/planner/optimizer/heuristic/decorrelate.rs @@ -624,6 +624,7 @@ impl SubqueryRewriter { aggregate_functions: agg_items, from_distinct: aggregate.from_distinct, limit: aggregate.limit, + grouping_sets: aggregate.grouping_sets.clone(), } .into(), flatten_plan, diff --git a/src/query/sql/src/planner/optimizer/heuristic/prune_unused_columns.rs b/src/query/sql/src/planner/optimizer/heuristic/prune_unused_columns.rs index 0a4304da0a04c..01007c84c11b4 100644 --- a/src/query/sql/src/planner/optimizer/heuristic/prune_unused_columns.rs +++ b/src/query/sql/src/planner/optimizer/heuristic/prune_unused_columns.rs @@ -150,6 +150,7 @@ impl UnusedColumnPruner { from_distinct: p.from_distinct, mode: p.mode, limit: p.limit, + grouping_sets: p.grouping_sets.clone(), }), Self::keep_required_columns(expr.child(0)?, required)?, )) diff --git a/src/query/sql/src/planner/optimizer/heuristic/subquery_rewriter.rs b/src/query/sql/src/planner/optimizer/heuristic/subquery_rewriter.rs index 03a3f94561ff0..a2e765e355ce5 100644 --- a/src/query/sql/src/planner/optimizer/heuristic/subquery_rewriter.rs +++ b/src/query/sql/src/planner/optimizer/heuristic/subquery_rewriter.rs @@ -420,6 +420,7 @@ impl SubqueryRewriter { from_distinct: false, mode: AggregateMode::Initial, limit: None, + grouping_sets: vec![], }; let compare = ComparisonExpr { diff --git a/src/query/sql/src/planner/plans/aggregate.rs b/src/query/sql/src/planner/plans/aggregate.rs index ed155784bd454..67c4b23ee8135 100644 --- a/src/query/sql/src/planner/plans/aggregate.rs +++ b/src/query/sql/src/planner/plans/aggregate.rs @@ -48,6 +48,8 @@ pub struct Aggregate { // True if the plan is generated from distinct, else the plan is a normal aggregate; pub from_distinct: bool, pub limit: Option, + /// The grouping sets, each grouping set is a list of `group_items` indices. + pub grouping_sets: Vec>, } impl Aggregate { diff --git a/src/query/sql/src/planner/semantic/distinct_to_groupby.rs b/src/query/sql/src/planner/semantic/distinct_to_groupby.rs index f01000de44cf2..0a458a0662772 100644 --- a/src/query/sql/src/planner/semantic/distinct_to_groupby.rs +++ b/src/query/sql/src/planner/semantic/distinct_to_groupby.rs @@ -13,6 +13,7 @@ // limitations under the License. use common_ast::ast::Expr; +use common_ast::ast::GroupBy; use common_ast::ast::Identifier; use common_ast::ast::Query; use common_ast::ast::SelectStmt; @@ -34,7 +35,7 @@ impl VisitorMut for DistinctToGroupBy { .. } = stmt; - if group_by.is_empty() && select_list.len() == 1 && from.len() == 1 { + if group_by.is_none() && select_list.len() == 1 && from.len() == 1 { if let common_ast::ast::SelectTarget::AliasedExpr { expr: box Expr::FunctionCall { @@ -60,7 +61,7 @@ impl VisitorMut for DistinctToGroupBy { select_list: vec![], from: from.clone(), selection: selection.clone(), - group_by: args.clone(), + group_by: Some(GroupBy::Normal(args.clone())), having: None, })), order_by: vec![], @@ -93,7 +94,7 @@ impl VisitorMut for DistinctToGroupBy { alias: None, }], selection: None, - group_by: vec![], + group_by: None, having: having.clone(), }; From f154f15dc1f64a14894b3515c69c1220914abebe Mon Sep 17 00:00:00 2001 From: RinChanNOWWW Date: Tue, 14 Mar 2023 16:06:23 +0800 Subject: [PATCH 3/7] Build plan for GROUPING SETS. --- .../service/src/pipelines/pipeline_builder.rs | 6 ++ src/query/sql/src/executor/format.rs | 52 +++++++++++++++ src/query/sql/src/executor/physical_plan.rs | 37 +++++++++++ .../sql/src/executor/physical_plan_builder.rs | 64 +++++++++++++++---- .../sql/src/executor/physical_plan_display.rs | 20 ++++++ .../sql/src/executor/physical_plan_visitor.rs | 17 +++++ src/query/sql/src/planner/binder/aggregate.rs | 2 + src/query/sql/src/planner/binder/distinct.rs | 1 + .../optimizer/heuristic/decorrelate.rs | 1 + .../heuristic/prune_unused_columns.rs | 1 + .../optimizer/heuristic/subquery_rewriter.rs | 1 + src/query/sql/src/planner/plans/aggregate.rs | 5 +- 12 files changed, 192 insertions(+), 15 deletions(-) diff --git a/src/query/service/src/pipelines/pipeline_builder.rs b/src/query/service/src/pipelines/pipeline_builder.rs index 8c0a070ed2c0a..678f6ad2f3891 100644 --- a/src/query/service/src/pipelines/pipeline_builder.rs +++ b/src/query/service/src/pipelines/pipeline_builder.rs @@ -42,6 +42,7 @@ use common_pipeline_transforms::processors::transforms::try_create_transform_sor use common_profile::ProfSpanSetRef; use common_sql::evaluator::BlockOperator; use common_sql::evaluator::CompoundBlockOperator; +use common_sql::executor::AggregateExpand; use common_sql::executor::AggregateFinal; use common_sql::executor::AggregateFunctionDesc; use common_sql::executor::AggregatePartial; @@ -160,6 +161,7 @@ impl PipelineBuilder { PhysicalPlan::Filter(filter) => self.build_filter(filter), PhysicalPlan::Project(project) => self.build_project(project), PhysicalPlan::EvalScalar(eval_scalar) => self.build_eval_scalar(eval_scalar), + PhysicalPlan::AggregateExpand(aggregate) => self.build_aggregate_expand(aggregate), PhysicalPlan::AggregatePartial(aggregate) => self.build_aggregate_partial(aggregate), PhysicalPlan::AggregateFinal(aggregate) => self.build_aggregate_final(aggregate), PhysicalPlan::Sort(sort) => self.build_sort(sort), @@ -422,6 +424,10 @@ impl PipelineBuilder { }) } + fn build_aggregate_expand(&mut self, _aggregate: &AggregateExpand) -> Result<()> { + todo!() + } + fn build_aggregate_partial(&mut self, aggregate: &AggregatePartial) -> Result<()> { self.build_pipeline(&aggregate.input)?; diff --git a/src/query/sql/src/executor/format.rs b/src/query/sql/src/executor/format.rs index 7dd0091ed3367..2d520d9dad156 100644 --- a/src/query/sql/src/executor/format.rs +++ b/src/query/sql/src/executor/format.rs @@ -19,6 +19,7 @@ use common_functions::scalars::BUILTIN_FUNCTIONS; use common_profile::ProfSpanSetRef; use itertools::Itertools; +use super::AggregateExpand; use super::AggregateFinal; use super::AggregateFunctionDesc; use super::AggregatePartial; @@ -65,6 +66,9 @@ fn to_format_tree( PhysicalPlan::Filter(plan) => filter_to_format_tree(plan, metadata, prof_span_set), PhysicalPlan::Project(plan) => project_to_format_tree(plan, metadata, prof_span_set), PhysicalPlan::EvalScalar(plan) => eval_scalar_to_format_tree(plan, metadata, prof_span_set), + PhysicalPlan::AggregateExpand(plan) => { + aggregate_expand_to_format_tree(plan, metadata, prof_span_set) + } PhysicalPlan::AggregatePartial(plan) => { aggregate_partial_to_format_tree(plan, metadata, prof_span_set) } @@ -281,6 +285,54 @@ pub fn pretty_display_agg_desc(desc: &AggregateFunctionDesc, metadata: &Metadata ) } +fn aggregate_expand_to_format_tree( + plan: &AggregateExpand, + metadata: &MetadataRef, + prof_span_set: &ProfSpanSetRef, +) -> Result> { + let sets = plan + .grouping_sets + .iter() + .map(|set| { + set.iter() + .map(|column| { + let column = metadata.read().column(*column).clone(); + match column { + ColumnEntry::BaseTableColumn(BaseTableColumn { column_name, .. }) => { + column_name + } + ColumnEntry::DerivedColumn(DerivedColumn { alias, .. }) => alias, + } + }) + .collect::>() + .join(", ") + }) + .map(|s| format!("({})", s)) + .collect::>() + .join(", "); + + let mut children = vec![FormatTreeNode::new(format!("grouping sets: [{sets}]"))]; + + if let Some(info) = &plan.stat_info { + let items = plan_stats_info_to_format_tree(info); + children.extend(items); + } + + if let Some(prof_span) = prof_span_set.lock().unwrap().get(&plan.plan_id) { + let process_time = prof_span.process_time / 1000 / 1000; // milliseconds + children.push(FormatTreeNode::new(format!( + "total process time: {process_time}ms" + ))); + } + + children.push(to_format_tree(&plan.input, metadata, prof_span_set)?); + + Ok(FormatTreeNode::with_children( + "AggregateExpand".to_string(), + children, + )) +} + fn aggregate_partial_to_format_tree( plan: &AggregatePartial, metadata: &MetadataRef, diff --git a/src/query/sql/src/executor/physical_plan.rs b/src/query/sql/src/executor/physical_plan.rs index f17f3ae46479e..514c611d288b0 100644 --- a/src/query/sql/src/executor/physical_plan.rs +++ b/src/query/sql/src/executor/physical_plan.rs @@ -17,6 +17,7 @@ use std::collections::BTreeMap; use common_catalog::plan::DataSourcePlan; use common_exception::Result; use common_expression::types::DataType; +use common_expression::types::NumberDataType; use common_expression::DataBlock; use common_expression::DataField; use common_expression::DataSchemaRef; @@ -162,6 +163,38 @@ impl Unnest { } } +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct AggregateExpand { + /// A unique id of operator in a `PhysicalPlan` tree. + /// Only used for display. + pub plan_id: u32, + + pub input: Box, + pub grouping_id_index: IndexType, + pub grouping_sets: Vec>, + /// Only used for explain + pub stat_info: Option, +} + +impl AggregateExpand { + pub fn output_schema(&self) -> Result { + let input_schema = self.input.output_schema()?; + let input_fields = input_schema.fields(); + let mut output_fields = Vec::with_capacity(input_fields.len() + 1); + for field in input_fields { + output_fields.push(DataField::new( + field.name(), + field.data_type().wrap_nullable(), + )); + } + output_fields.push(DataField::new( + &self.grouping_id_index.to_string(), + DataType::Number(NumberDataType::UInt32), + )); + Ok(DataSchemaRefExt::create(output_fields)) + } +} + #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct AggregatePartial { /// A unique id of operator in a `PhysicalPlan` tree. @@ -513,6 +546,7 @@ pub enum PhysicalPlan { Project(Project), EvalScalar(EvalScalar), Unnest(Unnest), + AggregateExpand(AggregateExpand), AggregatePartial(AggregatePartial), AggregateFinal(AggregateFinal), Sort(Sort), @@ -545,6 +579,7 @@ impl PhysicalPlan { PhysicalPlan::Filter(plan) => plan.output_schema(), PhysicalPlan::Project(plan) => plan.output_schema(), PhysicalPlan::EvalScalar(plan) => plan.output_schema(), + PhysicalPlan::AggregateExpand(plan) => plan.output_schema(), PhysicalPlan::AggregatePartial(plan) => plan.output_schema(), PhysicalPlan::AggregateFinal(plan) => plan.output_schema(), PhysicalPlan::Sort(plan) => plan.output_schema(), @@ -566,6 +601,7 @@ impl PhysicalPlan { PhysicalPlan::Filter(_) => "Filter".to_string(), PhysicalPlan::Project(_) => "Project".to_string(), PhysicalPlan::EvalScalar(_) => "EvalScalar".to_string(), + PhysicalPlan::AggregateExpand(_) => "AggregateExpand".to_string(), PhysicalPlan::AggregatePartial(_) => "AggregatePartial".to_string(), PhysicalPlan::AggregateFinal(_) => "AggregateFinal".to_string(), PhysicalPlan::Sort(_) => "Sort".to_string(), @@ -587,6 +623,7 @@ impl PhysicalPlan { PhysicalPlan::Filter(plan) => Box::new(std::iter::once(plan.input.as_ref())), PhysicalPlan::Project(plan) => Box::new(std::iter::once(plan.input.as_ref())), PhysicalPlan::EvalScalar(plan) => Box::new(std::iter::once(plan.input.as_ref())), + PhysicalPlan::AggregateExpand(plan) => Box::new(std::iter::once(plan.input.as_ref())), PhysicalPlan::AggregatePartial(plan) => Box::new(std::iter::once(plan.input.as_ref())), PhysicalPlan::AggregateFinal(plan) => Box::new(std::iter::once(plan.input.as_ref())), PhysicalPlan::Sort(plan) => Box::new(std::iter::once(plan.input.as_ref())), diff --git a/src/query/sql/src/executor/physical_plan_builder.rs b/src/query/sql/src/executor/physical_plan_builder.rs index 77929df0316b3..0909f50ed166a 100644 --- a/src/query/sql/src/executor/physical_plan_builder.rs +++ b/src/query/sql/src/executor/physical_plan_builder.rs @@ -36,6 +36,7 @@ use common_functions::scalars::BUILTIN_FUNCTIONS; use itertools::Itertools; use super::cast_expr_to_non_null_boolean; +use super::AggregateExpand; use super::AggregateFinal; use super::AggregateFunctionDesc; use super::AggregateFunctionSignature; @@ -473,12 +474,29 @@ impl PhysicalPlanBuilder { match input { PhysicalPlan::Exchange(PhysicalExchange { input, kind, .. }) => { - let aggregate_partial = AggregatePartial { - plan_id: self.next_plan_id(), - input, - agg_funcs, - group_by: group_items, - stat_info: Some(stat_info), + let aggregate_partial = if !agg.grouping_sets.is_empty() { + let expand = AggregateExpand { + plan_id: self.next_plan_id(), + input, + grouping_id_index: agg.grouping_id_index, + grouping_sets: agg.grouping_sets.clone(), + stat_info: Some(stat_info.clone()), + }; + AggregatePartial { + plan_id: self.next_plan_id(), + input: Box::new(PhysicalPlan::AggregateExpand(expand)), + agg_funcs, + group_by: group_items, + stat_info: Some(stat_info), + } + } else { + AggregatePartial { + plan_id: self.next_plan_id(), + input, + agg_funcs, + group_by: group_items, + stat_info: Some(stat_info), + } }; let group_by_key_index = @@ -505,14 +523,32 @@ impl PhysicalPlanBuilder { }], }) } - _ => PhysicalPlan::AggregatePartial(AggregatePartial { - plan_id: self.next_plan_id(), - agg_funcs, - group_by: group_items, - input: Box::new(input), - - stat_info: Some(stat_info), - }), + _ => { + if !agg.grouping_sets.is_empty() { + let expand = AggregateExpand { + plan_id: self.next_plan_id(), + input: Box::new(input), + grouping_id_index: agg.grouping_id_index, + grouping_sets: agg.grouping_sets.clone(), + stat_info: Some(stat_info.clone()), + }; + PhysicalPlan::AggregatePartial(AggregatePartial { + plan_id: self.next_plan_id(), + agg_funcs, + group_by: group_items, + input: Box::new(PhysicalPlan::AggregateExpand(expand)), + stat_info: Some(stat_info), + }) + } else { + PhysicalPlan::AggregatePartial(AggregatePartial { + plan_id: self.next_plan_id(), + agg_funcs, + group_by: group_items, + input: Box::new(input), + stat_info: Some(stat_info), + }) + } + } } } diff --git a/src/query/sql/src/executor/physical_plan_display.rs b/src/query/sql/src/executor/physical_plan_display.rs index 38dd97fc8561b..e935239e4f446 100644 --- a/src/query/sql/src/executor/physical_plan_display.rs +++ b/src/query/sql/src/executor/physical_plan_display.rs @@ -18,6 +18,7 @@ use std::fmt::Formatter; use common_functions::scalars::BUILTIN_FUNCTIONS; use itertools::Itertools; +use super::AggregateExpand; use super::DistributedInsertSelect; use super::Unnest; use crate::executor::AggregateFinal; @@ -57,6 +58,7 @@ impl<'a> Display for PhysicalPlanIndentFormatDisplay<'a> { PhysicalPlan::Filter(filter) => write!(f, "{}", filter)?, PhysicalPlan::Project(project) => write!(f, "{}", project)?, PhysicalPlan::EvalScalar(eval_scalar) => write!(f, "{}", eval_scalar)?, + PhysicalPlan::AggregateExpand(aggregate) => write!(f, "{}", aggregate)?, PhysicalPlan::AggregatePartial(aggregate) => write!(f, "{}", aggregate)?, PhysicalPlan::AggregateFinal(aggregate) => write!(f, "{}", aggregate)?, PhysicalPlan::Sort(sort) => write!(f, "{}", sort)?, @@ -146,6 +148,24 @@ impl Display for EvalScalar { } } +impl Display for AggregateExpand { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let sets = self + .grouping_sets + .iter() + .map(|set| { + set.iter() + .map(|index| index.to_string()) + .collect::>() + .join(", ") + }) + .map(|s| format!("[{}]", s)) + .collect::>() + .join(", "); + write!(f, "Aggregate(Expand): grouping sets: [{}]", sets) + } +} + impl Display for AggregateFinal { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let group_items = self diff --git a/src/query/sql/src/executor/physical_plan_visitor.rs b/src/query/sql/src/executor/physical_plan_visitor.rs index 46f406d31d4df..335ddd97b4b97 100644 --- a/src/query/sql/src/executor/physical_plan_visitor.rs +++ b/src/query/sql/src/executor/physical_plan_visitor.rs @@ -14,6 +14,7 @@ use common_exception::Result; +use super::AggregateExpand; use super::AggregateFinal; use super::AggregatePartial; use super::DistributedInsertSelect; @@ -39,6 +40,7 @@ pub trait PhysicalPlanReplacer { PhysicalPlan::Filter(plan) => self.replace_filter(plan), PhysicalPlan::Project(plan) => self.replace_project(plan), PhysicalPlan::EvalScalar(plan) => self.replace_eval_scalar(plan), + PhysicalPlan::AggregateExpand(plan) => self.replace_aggregate_expand(plan), PhysicalPlan::AggregatePartial(plan) => self.replace_aggregate_partial(plan), PhysicalPlan::AggregateFinal(plan) => self.replace_aggregate_final(plan), PhysicalPlan::Sort(plan) => self.replace_sort(plan), @@ -92,6 +94,18 @@ pub trait PhysicalPlanReplacer { })) } + fn replace_aggregate_expand(&mut self, plan: &AggregateExpand) -> Result { + let input = self.replace(&plan.input)?; + + Ok(PhysicalPlan::AggregateExpand(AggregateExpand { + plan_id: plan.plan_id, + input: Box::new(input), + grouping_id_index: plan.grouping_id_index, + grouping_sets: plan.grouping_sets.clone(), + stat_info: plan.stat_info.clone(), + })) + } + fn replace_aggregate_partial(&mut self, plan: &AggregatePartial) -> Result { let input = self.replace(&plan.input)?; @@ -264,6 +278,9 @@ impl PhysicalPlan { PhysicalPlan::EvalScalar(plan) => { Self::traverse(&plan.input, pre_visit, visit, post_visit); } + PhysicalPlan::AggregateExpand(plan) => { + Self::traverse(&plan.input, pre_visit, visit, post_visit); + } PhysicalPlan::AggregatePartial(plan) => { Self::traverse(&plan.input, pre_visit, visit, post_visit); } diff --git a/src/query/sql/src/planner/binder/aggregate.rs b/src/query/sql/src/planner/binder/aggregate.rs index 6eccf5f7de5d3..2d56855960ce5 100644 --- a/src/query/sql/src/planner/binder/aggregate.rs +++ b/src/query/sql/src/planner/binder/aggregate.rs @@ -329,6 +329,7 @@ impl Binder { aggregate_functions: bind_context.aggregate_info.aggregate_functions.clone(), from_distinct: false, limit: None, + grouping_id_index: agg_info.grouping_id_index, grouping_sets: agg_info.grouping_sets.clone(), }; new_expr = SExpr::create_unary(aggregate_plan.into(), new_expr); @@ -380,6 +381,7 @@ impl Binder { DataType::Number(NumberDataType::UInt32), ); let index = grouping_id_column.index; + bind_context.aggregate_info.grouping_id_index = index; bind_context.aggregate_info.group_items.push(ScalarItem { index, scalar: ScalarExpr::BoundColumnRef(BoundColumnRef { diff --git a/src/query/sql/src/planner/binder/distinct.rs b/src/query/sql/src/planner/binder/distinct.rs index e902e5066b5c9..51b10eee7a2de 100644 --- a/src/query/sql/src/planner/binder/distinct.rs +++ b/src/query/sql/src/planner/binder/distinct.rs @@ -76,6 +76,7 @@ impl Binder { aggregate_functions: vec![], from_distinct: true, limit: None, + grouping_id_index: 0, grouping_sets: vec![], }; diff --git a/src/query/sql/src/planner/optimizer/heuristic/decorrelate.rs b/src/query/sql/src/planner/optimizer/heuristic/decorrelate.rs index fa643a34be22c..3ca120fb7a2e3 100644 --- a/src/query/sql/src/planner/optimizer/heuristic/decorrelate.rs +++ b/src/query/sql/src/planner/optimizer/heuristic/decorrelate.rs @@ -624,6 +624,7 @@ impl SubqueryRewriter { aggregate_functions: agg_items, from_distinct: aggregate.from_distinct, limit: aggregate.limit, + grouping_id_index: aggregate.grouping_id_index, grouping_sets: aggregate.grouping_sets.clone(), } .into(), diff --git a/src/query/sql/src/planner/optimizer/heuristic/prune_unused_columns.rs b/src/query/sql/src/planner/optimizer/heuristic/prune_unused_columns.rs index 01007c84c11b4..ae89cb0efe20a 100644 --- a/src/query/sql/src/planner/optimizer/heuristic/prune_unused_columns.rs +++ b/src/query/sql/src/planner/optimizer/heuristic/prune_unused_columns.rs @@ -150,6 +150,7 @@ impl UnusedColumnPruner { from_distinct: p.from_distinct, mode: p.mode, limit: p.limit, + grouping_id_index: p.grouping_id_index, grouping_sets: p.grouping_sets.clone(), }), Self::keep_required_columns(expr.child(0)?, required)?, diff --git a/src/query/sql/src/planner/optimizer/heuristic/subquery_rewriter.rs b/src/query/sql/src/planner/optimizer/heuristic/subquery_rewriter.rs index a2e765e355ce5..bed4b08e0c263 100644 --- a/src/query/sql/src/planner/optimizer/heuristic/subquery_rewriter.rs +++ b/src/query/sql/src/planner/optimizer/heuristic/subquery_rewriter.rs @@ -420,6 +420,7 @@ impl SubqueryRewriter { from_distinct: false, mode: AggregateMode::Initial, limit: None, + grouping_id_index: 0, grouping_sets: vec![], }; diff --git a/src/query/sql/src/planner/plans/aggregate.rs b/src/query/sql/src/planner/plans/aggregate.rs index 67c4b23ee8135..46278d948de23 100644 --- a/src/query/sql/src/planner/plans/aggregate.rs +++ b/src/query/sql/src/planner/plans/aggregate.rs @@ -27,6 +27,7 @@ use crate::optimizer::Statistics; use crate::plans::Operator; use crate::plans::RelOp; use crate::plans::ScalarItem; +use crate::IndexType; #[derive(Clone, Debug, PartialEq, Eq, Hash, Copy)] pub enum AggregateMode { @@ -48,8 +49,10 @@ pub struct Aggregate { // True if the plan is generated from distinct, else the plan is a normal aggregate; pub from_distinct: bool, pub limit: Option, + /// The index of the virtual column `_grouping_id`. It's valid only if `grouping_sets` is not empty. + pub grouping_id_index: IndexType, /// The grouping sets, each grouping set is a list of `group_items` indices. - pub grouping_sets: Vec>, + pub grouping_sets: Vec>, } impl Aggregate { From fea23fc5fe65bf61ee66a4e9bf361b0a8f5cf3b4 Mon Sep 17 00:00:00 2001 From: RinChanNOWWW Date: Tue, 14 Mar 2023 20:45:32 +0800 Subject: [PATCH 4/7] New Transform for GROUPING SETS. --- src/query/expression/src/values.rs | 13 ++- .../service/src/pipelines/pipeline_builder.rs | 60 +++++++++++- .../service/src/pipelines/processors/mod.rs | 1 + .../processors/transforms/aggregator/mod.rs | 2 + .../aggregator/transform_aggregate_expand.rs | 98 +++++++++++++++++++ .../pipelines/processors/transforms/mod.rs | 1 + src/query/sql/src/executor/physical_plan.rs | 1 + .../sql/src/executor/physical_plan_builder.rs | 2 + .../sql/src/executor/physical_plan_visitor.rs | 1 + src/query/sql/src/planner/binder/aggregate.rs | 4 +- .../group/group_by_grouping_sets.test | 21 ++++ 11 files changed, 198 insertions(+), 6 deletions(-) create mode 100644 src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_expand.rs create mode 100644 tests/sqllogictests/suites/duckdb/sql/aggregate/group/group_by_grouping_sets.test diff --git a/src/query/expression/src/values.rs b/src/query/expression/src/values.rs index 59dc139f9c639..3a6f970c33c79 100755 --- a/src/query/expression/src/values.rs +++ b/src/query/expression/src/values.rs @@ -238,6 +238,13 @@ impl Value { pub fn try_downcast(&self) -> Option> { Some(self.as_ref().try_downcast::()?.to_owned()) } + + pub fn wrap_nullable(&self) -> Self { + match self { + Value::Column(c) => Value::Column(c.wrap_nullable()), + scalar => scalar.clone(), + } + } } impl<'a> ValueRef<'a, AnyType> { @@ -1479,14 +1486,14 @@ impl Column { } } - pub fn wrap_nullable(self) -> Self { + pub fn wrap_nullable(&self) -> Self { match self { - col @ Column::Nullable(_) => col, + col @ Column::Nullable(_) => col.clone(), col => { let mut validity = MutableBitmap::with_capacity(col.len()); validity.extend_constant(col.len(), true); Column::Nullable(Box::new(NullableColumn { - column: col, + column: col.clone(), validity: validity.into(), })) } diff --git a/src/query/service/src/pipelines/pipeline_builder.rs b/src/query/service/src/pipelines/pipeline_builder.rs index 678f6ad2f3891..99947c3a2af25 100644 --- a/src/query/service/src/pipelines/pipeline_builder.rs +++ b/src/query/service/src/pipelines/pipeline_builder.rs @@ -66,6 +66,7 @@ use common_sql::IndexType; use common_storage::DataOperator; use super::processors::ProfileWrapper; +use super::processors::TransformExpandGroupingSets; use crate::api::DefaultExchangeInjector; use crate::api::ExchangeInjector; use crate::pipelines::processors::transforms::build_partition_bucket; @@ -424,8 +425,63 @@ impl PipelineBuilder { }) } - fn build_aggregate_expand(&mut self, _aggregate: &AggregateExpand) -> Result<()> { - todo!() + fn build_aggregate_expand(&mut self, expand: &AggregateExpand) -> Result<()> { + self.build_pipeline(&expand.input)?; + let input_schema = expand.input.output_schema()?; + let group_bys = expand + .group_bys + .iter() + .filter_map(|i| { + // Do not collect virtual column "_grouping_id". + if *i != expand.grouping_id_index { + match input_schema.index_of(&i.to_string()) { + Ok(index) => { + let ty = input_schema.field(index).data_type().clone(); + Some(Ok((index, ty))) + } + Err(e) => Some(Err(e)), + } + } else { + None + } + }) + .collect::>>()?; + let grouping_sets = expand + .grouping_sets + .iter() + .map(|sets| { + sets.iter() + .map(|i| { + let i = input_schema.index_of(&i.to_string())?; + let offset = group_bys.iter().position(|(j, _)| *j == i).unwrap(); + Ok(offset) + }) + .collect::>>() + }) + .collect::>>()?; + let mut grouping_ids = Vec::with_capacity(grouping_sets.len()); + for set in grouping_sets { + let mut id = 0; + for i in set { + id |= 1 << i; + } + // For element in `group_bys`, + // if it is in current grouping set: set 0, else: set 1. (1 represents it will be NULL in grouping) + // Example: GROUP BY GROUPING SETS ((a, b), (a), (b), ()) + // group_bys: [a, b] + // grouping_sets: [[0, 1], [0], [1], []] + // grouping_ids: 00, 01, 10, 11 + grouping_ids.push(!id); + } + + self.main_pipeline.add_transform(|input, output| { + Ok(TransformExpandGroupingSets::create( + input, + output, + group_bys.clone(), + grouping_ids.clone(), + )) + }) } fn build_aggregate_partial(&mut self, aggregate: &AggregatePartial) -> Result<()> { diff --git a/src/query/service/src/pipelines/processors/mod.rs b/src/query/service/src/pipelines/processors/mod.rs index 255f83a89fe59..8e49c02914d79 100644 --- a/src/query/service/src/pipelines/processors/mod.rs +++ b/src/query/service/src/pipelines/processors/mod.rs @@ -33,6 +33,7 @@ pub use transforms::TransformBlockCompact; pub use transforms::TransformCastSchema; pub use transforms::TransformCompact; pub use transforms::TransformCreateSets; +pub use transforms::TransformExpandGroupingSets; pub use transforms::TransformHashJoinProbe; pub use transforms::TransformLimit; pub use transforms::TransformResortAddOn; diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/mod.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/mod.rs index 0297ced4ef9ef..41aa7b1055b86 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/mod.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/mod.rs @@ -17,6 +17,7 @@ mod aggregate_exchange_injector; mod aggregate_meta; mod aggregator_params; mod serde; +mod transform_aggregate_expand; mod transform_aggregate_final; mod transform_aggregate_partial; mod transform_group_by_final; @@ -29,6 +30,7 @@ pub use aggregate_cell::HashTableCell; pub use aggregate_cell::PartitionedHashTableDropper; pub use aggregate_exchange_injector::AggregateInjector; pub use aggregator_params::AggregatorParams; +pub use transform_aggregate_expand::TransformExpandGroupingSets; pub use transform_aggregate_final::TransformFinalAggregate; pub use transform_aggregate_partial::TransformPartialAggregate; pub use transform_group_by_final::TransformFinalGroupBy; diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_expand.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_expand.rs new file mode 100644 index 0000000000000..a21f749844cb7 --- /dev/null +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_expand.rs @@ -0,0 +1,98 @@ +// Copyright 2023 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use common_exception::Result; +use common_expression::types::DataType; +use common_expression::types::NumberDataType; +use common_expression::types::NumberScalar; +use common_expression::BlockEntry; +use common_expression::DataBlock; +use common_expression::Scalar; +use common_expression::Value; +use common_pipeline_core::processors::port::InputPort; +use common_pipeline_core::processors::port::OutputPort; +use common_pipeline_core::processors::processor::ProcessorPtr; +use common_pipeline_transforms::processors::transforms::Transform; +use common_pipeline_transforms::processors::transforms::Transformer; + +pub struct TransformExpandGroupingSets { + group_bys: Vec<(usize, DataType)>, + grouping_ids: Vec, +} + +impl TransformExpandGroupingSets { + pub fn create( + input: Arc, + output: Arc, + group_bys: Vec<(usize, DataType)>, + grouping_ids: Vec, + ) -> ProcessorPtr { + ProcessorPtr::create(Transformer::create( + input, + output, + TransformExpandGroupingSets { + grouping_ids, + group_bys, + }, + )) + } +} + +impl Transform for TransformExpandGroupingSets { + const NAME: &'static str = "TransformExpandGroupingSets"; + + fn transform(&mut self, data: DataBlock) -> Result { + let num_rows = data.num_rows(); + let num_group_bys = self.group_bys.len(); + let mut output_blocks = Vec::with_capacity(self.grouping_ids.len()); + + for &id in &self.grouping_ids { + // Repeat data for each grouping set. + let grouping_column = BlockEntry { + data_type: DataType::Number(NumberDataType::UInt32), + value: Value::Scalar(Scalar::Number(NumberScalar::UInt32(id as u32))), + }; + let mut columns = data + .columns() + .iter() + .cloned() + .chain(vec![grouping_column]) + .collect::>(); + let bits = !id; + for i in 0..num_group_bys { + let entry = unsafe { + let offset = self.group_bys.get_unchecked(i).0; + columns.get_unchecked_mut(offset) + }; + if bits & (1 << i) == 0 { + // This column should be set to NULLs. + *entry = BlockEntry { + data_type: entry.data_type.wrap_nullable(), + value: Value::Scalar(Scalar::Null), + } + } else { + *entry = BlockEntry { + data_type: entry.data_type.wrap_nullable(), + value: entry.value.wrap_nullable(), + } + } + } + output_blocks.push(DataBlock::new(columns, num_rows)); + } + + DataBlock::concat(&output_blocks) + } +} diff --git a/src/query/service/src/pipelines/processors/transforms/mod.rs b/src/query/service/src/pipelines/processors/transforms/mod.rs index 0a8571593c869..b097ac41fdb03 100644 --- a/src/query/service/src/pipelines/processors/transforms/mod.rs +++ b/src/query/service/src/pipelines/processors/transforms/mod.rs @@ -42,6 +42,7 @@ pub use aggregator::TransformAggregateDeserializer; pub use aggregator::TransformAggregateSerializer; pub use aggregator::TransformAggregateSpillReader; pub use aggregator::TransformAggregateSpillWriter; +pub use aggregator::TransformExpandGroupingSets; pub use aggregator::TransformFinalAggregate; pub use aggregator::TransformGroupByDeserializer; pub use aggregator::TransformGroupBySerializer; diff --git a/src/query/sql/src/executor/physical_plan.rs b/src/query/sql/src/executor/physical_plan.rs index 514c611d288b0..1ba12d343747a 100644 --- a/src/query/sql/src/executor/physical_plan.rs +++ b/src/query/sql/src/executor/physical_plan.rs @@ -170,6 +170,7 @@ pub struct AggregateExpand { pub plan_id: u32, pub input: Box, + pub group_bys: Vec, pub grouping_id_index: IndexType, pub grouping_sets: Vec>, /// Only used for explain diff --git a/src/query/sql/src/executor/physical_plan_builder.rs b/src/query/sql/src/executor/physical_plan_builder.rs index 0909f50ed166a..369574d5c4a51 100644 --- a/src/query/sql/src/executor/physical_plan_builder.rs +++ b/src/query/sql/src/executor/physical_plan_builder.rs @@ -478,6 +478,7 @@ impl PhysicalPlanBuilder { let expand = AggregateExpand { plan_id: self.next_plan_id(), input, + group_bys: group_items.clone(), grouping_id_index: agg.grouping_id_index, grouping_sets: agg.grouping_sets.clone(), stat_info: Some(stat_info.clone()), @@ -528,6 +529,7 @@ impl PhysicalPlanBuilder { let expand = AggregateExpand { plan_id: self.next_plan_id(), input: Box::new(input), + group_bys: group_items.clone(), grouping_id_index: agg.grouping_id_index, grouping_sets: agg.grouping_sets.clone(), stat_info: Some(stat_info.clone()), diff --git a/src/query/sql/src/executor/physical_plan_visitor.rs b/src/query/sql/src/executor/physical_plan_visitor.rs index 335ddd97b4b97..c8e33bed27e79 100644 --- a/src/query/sql/src/executor/physical_plan_visitor.rs +++ b/src/query/sql/src/executor/physical_plan_visitor.rs @@ -100,6 +100,7 @@ pub trait PhysicalPlanReplacer { Ok(PhysicalPlan::AggregateExpand(AggregateExpand { plan_id: plan.plan_id, input: Box::new(input), + group_bys: plan.group_bys.clone(), grouping_id_index: plan.grouping_id_index, grouping_sets: plan.grouping_sets.clone(), stat_info: plan.stat_info.clone(), diff --git a/src/query/sql/src/planner/binder/aggregate.rs b/src/query/sql/src/planner/binder/aggregate.rs index 2d56855960ce5..642fbc0982557 100644 --- a/src/query/sql/src/planner/binder/aggregate.rs +++ b/src/query/sql/src/planner/binder/aggregate.rs @@ -24,6 +24,7 @@ use common_exception::ErrorCode; use common_exception::Result; use common_expression::types::DataType; use common_expression::types::NumberDataType; +use itertools::Itertools; use super::prune_by_children; use crate::binder::scalar::ScalarBinder; @@ -371,7 +372,8 @@ impl Binder { set.sort(); set }) - .collect(); + .collect::>(); + let grouping_sets = grouping_sets.into_iter().unique().collect(); bind_context.aggregate_info.grouping_sets = grouping_sets; // Add a virtual column `_grouping_id` to group items. let grouping_id_column = self.create_column_binding( diff --git a/tests/sqllogictests/suites/duckdb/sql/aggregate/group/group_by_grouping_sets.test b/tests/sqllogictests/suites/duckdb/sql/aggregate/group/group_by_grouping_sets.test new file mode 100644 index 0000000000000..502d44b467b90 --- /dev/null +++ b/tests/sqllogictests/suites/duckdb/sql/aggregate/group/group_by_grouping_sets.test @@ -0,0 +1,21 @@ +statement ok +drop table if exists t; + +statement ok +create table t (a string, b string, c int); + +statement ok +insert into t values ('a','A',1),('a','A',2),('a','B',1),('a','B',3),('b','A',1),('b','A',4),('b','B',1),('b','B',5); + +query TTI +select a, b, sum(c) as sc from t group by grouping sets ((a,b),(),(b),(a)) order by sc; +---- +a A 3 +a B 4 +b A 5 +b B 6 +a NULL 7 +NULL A 8 +NULL B 10 +b NULL 11 +NULL NULL 18 From f17d3736ff74ecbc09f6d59d14afbf5f539196d7 Mon Sep 17 00:00:00 2001 From: RinChanNOWWW Date: Tue, 14 Mar 2023 22:07:06 +0800 Subject: [PATCH 5/7] Fix explain results. --- .../mode/cluster/04_0002_explain_v2.test | 2 +- .../suites/mode/cluster/exchange.test | 8 ++++---- .../mode/standalone/explain/prune_column.test | 18 +++++++++--------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/sqllogictests/suites/mode/cluster/04_0002_explain_v2.test b/tests/sqllogictests/suites/mode/cluster/04_0002_explain_v2.test index c6834ea7da96b..e6790e06e458b 100644 --- a/tests/sqllogictests/suites/mode/cluster/04_0002_explain_v2.test +++ b/tests/sqllogictests/suites/mode/cluster/04_0002_explain_v2.test @@ -157,7 +157,7 @@ Limit ├── sort keys: [c ASC NULLS LAST, e ASC NULLS LAST, d ASC NULLS LAST] ├── estimated rows: 1.00 └── EvalScalar - ├── expressions: [count(1) (#8), max(a) (#10), count(b) (#9)] + ├── expressions: [count(1) (#5), max(a) (#7), count(b) (#6)] ├── estimated rows: 1.00 └── AggregateFinal ├── group by: [] diff --git a/tests/sqllogictests/suites/mode/cluster/exchange.test b/tests/sqllogictests/suites/mode/cluster/exchange.test index 267dbce03486f..8df0811e47759 100644 --- a/tests/sqllogictests/suites/mode/cluster/exchange.test +++ b/tests/sqllogictests/suites/mode/cluster/exchange.test @@ -90,12 +90,12 @@ Exchange ├── exchange type: Merge └── HashJoin ├── join type: INNER - ├── build keys: [t1.number (#3)] + ├── build keys: [t1.number (#2)] ├── probe keys: [t.a (#0)] ├── filters: [] ├── estimated rows: 6.00 ├── Exchange(Build) - │ ├── exchange type: Hash(t1.number (#3)) + │ ├── exchange type: Hash(t1.number (#2)) │ └── TableScan │ ├── table: default.system.numbers │ ├── read rows: 2 @@ -109,7 +109,7 @@ Exchange └── HashJoin ├── join type: INNER ├── build keys: [t.b (#1)] - ├── probe keys: [t2.number (#4)] + ├── probe keys: [t2.number (#3)] ├── filters: [] ├── estimated rows: 3.00 ├── Exchange(Build) @@ -126,7 +126,7 @@ Exchange │ ├── push downs: [filters: [], limit: NONE] │ └── estimated rows: 1.00 └── Exchange(Probe) - ├── exchange type: Hash(t2.number (#4)) + ├── exchange type: Hash(t2.number (#3)) └── TableScan ├── table: default.system.numbers ├── read rows: 3 diff --git a/tests/sqllogictests/suites/mode/standalone/explain/prune_column.test b/tests/sqllogictests/suites/mode/standalone/explain/prune_column.test index fc81893973a3d..b6f932804d78c 100644 --- a/tests/sqllogictests/suites/mode/standalone/explain/prune_column.test +++ b/tests/sqllogictests/suites/mode/standalone/explain/prune_column.test @@ -66,7 +66,7 @@ explain select * from (select t1.a from (select number + 1 as a, number + 1 as b HashJoin ├── join type: INNER ├── build keys: [t1.b (#2)] -├── probe keys: [t2.b (#11)] +├── probe keys: [t2.b (#7)] ├── filters: [] ├── estimated rows: 0.33 ├── Filter(Build) @@ -84,7 +84,7 @@ HashJoin │ ├── push downs: [filters: [], limit: NONE] │ └── estimated rows: 1.00 └── EvalScalar(Probe) - ├── expressions: [numbers.number (#9) + 1] + ├── expressions: [numbers.number (#5) + 1] ├── estimated rows: 1.00 └── TableScan ├── table: default.system.numbers @@ -99,7 +99,7 @@ query T explain select t1.a from (select number + 1 as a, number + 1 as b from numbers(1)) as t1 where t1.a = (select count(*) from (select t2.a, t3.a from (select number + 1 as a, number + 1 as b, number + 1 as c, number + 1 as d from numbers(1)) as t2, (select number + 1 as a, number + 1 as b, number + 1 as c from numbers(1)) as t3 where t2.b = t3.b and t2.c = 1)) ---- Filter -├── filters: [is_true(CAST(t1.a (#1) AS UInt64 NULL) = scalar_subquery_21 (#21))] +├── filters: [is_true(CAST(t1.a (#1) AS UInt64 NULL) = scalar_subquery_12 (#12))] ├── estimated rows: 0.33 └── HashJoin ├── join type: SINGLE @@ -108,7 +108,7 @@ Filter ├── filters: [] ├── estimated rows: 1.00 ├── EvalScalar(Build) - │ ├── expressions: [COUNT(*) (#22)] + │ ├── expressions: [COUNT(*) (#13)] │ ├── estimated rows: 1.00 │ └── AggregateFinal │ ├── group by: [] @@ -120,15 +120,15 @@ Filter │ ├── estimated rows: 1.00 │ └── HashJoin │ ├── join type: INNER - │ ├── build keys: [t2.b (#7)] - │ ├── probe keys: [t3.b (#16)] + │ ├── build keys: [t2.b (#5)] + │ ├── probe keys: [t3.b (#10)] │ ├── filters: [] │ ├── estimated rows: 0.33 │ ├── Filter(Build) - │ │ ├── filters: [t2.c (#8) = 1] + │ │ ├── filters: [t2.c (#6) = 1] │ │ ├── estimated rows: 0.33 │ │ └── EvalScalar - │ │ ├── expressions: [numbers.number (#5) + 1, numbers.number (#5) + 1] + │ │ ├── expressions: [numbers.number (#3) + 1, numbers.number (#3) + 1] │ │ ├── estimated rows: 1.00 │ │ └── TableScan │ │ ├── table: default.system.numbers @@ -139,7 +139,7 @@ Filter │ │ ├── push downs: [filters: [], limit: NONE] │ │ └── estimated rows: 1.00 │ └── EvalScalar(Probe) - │ ├── expressions: [numbers.number (#14) + 1] + │ ├── expressions: [numbers.number (#8) + 1] │ ├── estimated rows: 1.00 │ └── TableScan │ ├── table: default.system.numbers From 3b356b6ebbfcf4e87af60a6598d5355502b32a4e Mon Sep 17 00:00:00 2001 From: RinChanNOWWW Date: Wed, 15 Mar 2023 10:37:37 +0800 Subject: [PATCH 6/7] Make wrap_nullable get the ownership. --- src/query/expression/src/values.rs | 16 ++++++++-------- .../aggregator/transform_aggregate_expand.rs | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/query/expression/src/values.rs b/src/query/expression/src/values.rs index 3a6f970c33c79..2f43a2d137a14 100755 --- a/src/query/expression/src/values.rs +++ b/src/query/expression/src/values.rs @@ -239,10 +239,10 @@ impl Value { Some(self.as_ref().try_downcast::()?.to_owned()) } - pub fn wrap_nullable(&self) -> Self { + pub fn wrap_nullable(self) -> Self { match self { Value::Column(c) => Value::Column(c.wrap_nullable()), - scalar => scalar.clone(), + scalar => scalar, } } } @@ -1486,14 +1486,14 @@ impl Column { } } - pub fn wrap_nullable(&self) -> Self { + pub fn wrap_nullable(self) -> Self { match self { - col @ Column::Nullable(_) => col.clone(), - col => { - let mut validity = MutableBitmap::with_capacity(col.len()); - validity.extend_constant(col.len(), true); + column @ Column::Nullable(_) => column, + column => { + let mut validity = MutableBitmap::with_capacity(column.len()); + validity.extend_constant(column.len(), true); Column::Nullable(Box::new(NullableColumn { - column: col.clone(), + column, validity: validity.into(), })) } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_expand.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_expand.rs index a21f749844cb7..0ee0b1bb28f77 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_expand.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_expand.rs @@ -86,7 +86,7 @@ impl Transform for TransformExpandGroupingSets { } else { *entry = BlockEntry { data_type: entry.data_type.wrap_nullable(), - value: entry.value.wrap_nullable(), + value: entry.value.clone().wrap_nullable(), } } } From d5f0ebc34f78549ee7900e895008c18b1a4e9684 Mon Sep 17 00:00:00 2001 From: RinChanNOWWW Date: Wed, 15 Mar 2023 10:56:41 +0800 Subject: [PATCH 7/7] Fix after rebasing main. --- src/query/ast/tests/it/testdata/statement-error.txt | 2 +- src/query/sql/src/planner/binder/copy.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/query/ast/tests/it/testdata/statement-error.txt b/src/query/ast/tests/it/testdata/statement-error.txt index 76ba129a6a040..36094d63b9218 100644 --- a/src/query/ast/tests/it/testdata/statement-error.txt +++ b/src/query/ast/tests/it/testdata/statement-error.txt @@ -369,7 +369,7 @@ error: --> SQL:1:41 | 1 | SELECT * FROM t GROUP BY GROUPING SETS () - | ------ ^ expected `(`, `IS`, `IN`, `EXISTS`, `BETWEEN`, `+`, or 58 more ... + | ------ ^ expected `(`, `IS`, `IN`, `EXISTS`, `BETWEEN`, `+`, or 65 more ... | | | while parsing `SELECT ...` diff --git a/src/query/sql/src/planner/binder/copy.rs b/src/query/sql/src/planner/binder/copy.rs index 80804a6f9e402..3ddc31e20b4a3 100644 --- a/src/query/sql/src/planner/binder/copy.rs +++ b/src/query/sql/src/planner/binder/copy.rs @@ -675,7 +675,7 @@ fn check_transform_query( && query.with.is_none() { if let SetExpr::Select(select) = &query.body { - if select.group_by.is_empty() + if select.group_by.is_none() && !select.distinct && select.having.is_none() && select.from.len() == 1