Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 39 additions & 11 deletions src/query/ast/src/ast/format/ast_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2074,19 +2074,47 @@ impl<'ast> Visitor<'ast> for AstFormatVisitor {
FormatTreeNode::with_children(selection_format_ctx, vec![selection_child]);
children.push(selection_node);
}
if !stmt.group_by.is_empty() {
let mut group_by_list_children = Vec::with_capacity(stmt.group_by.len());
for group_by in stmt.group_by.iter() {
self.visit_expr(group_by);
group_by_list_children.push(self.children.pop().unwrap());
match &stmt.group_by {
Some(GroupBy::Normal(exprs)) => {
let mut group_by_list_children = Vec::with_capacity(exprs.len());
for group_by in exprs.iter() {
self.visit_expr(group_by);
group_by_list_children.push(self.children.pop().unwrap());
}
let group_by_list_name = "GroupByList".to_string();
let group_by_list_format_ctx = AstFormatContext::with_children(
group_by_list_name,
group_by_list_children.len(),
);
let group_by_list_node =
FormatTreeNode::with_children(group_by_list_format_ctx, group_by_list_children);
children.push(group_by_list_node);
}
let group_by_list_name = "GroupByList".to_string();
let group_by_list_format_ctx =
AstFormatContext::with_children(group_by_list_name, group_by_list_children.len());
let group_by_list_node =
FormatTreeNode::with_children(group_by_list_format_ctx, group_by_list_children);
children.push(group_by_list_node);
Some(GroupBy::GroupingSets(sets)) => {
let mut grouping_sets = Vec::with_capacity(sets.len());
for set in sets.iter() {
let mut grouping_set = Vec::with_capacity(set.len());
for expr in set.iter() {
self.visit_expr(expr);
grouping_set.push(self.children.pop().unwrap());
}
let name = "GroupingSet".to_string();
let grouping_set_format_ctx =
AstFormatContext::with_children(name, grouping_set.len());
let grouping_set_node =
FormatTreeNode::with_children(grouping_set_format_ctx, grouping_set);
grouping_sets.push(grouping_set_node);
}
let group_by_list_name = "GroupByList".to_string();
let group_by_list_format_ctx =
AstFormatContext::with_children(group_by_list_name, grouping_sets.len());
let group_by_list_node =
FormatTreeNode::with_children(group_by_list_format_ctx, grouping_sets);
children.push(group_by_list_node);
}
_ => {}
}

if let Some(having) = &stmt.having {
self.visit_expr(having);
let having_child = self.children.pop().unwrap();
Expand Down
30 changes: 23 additions & 7 deletions src/query/ast/src/ast/format/syntax/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::ast::format::syntax::interweave_comma;
use crate::ast::format::syntax::parenthenized;
use crate::ast::format::syntax::NEST_FACTOR;
use crate::ast::Expr;
use crate::ast::GroupBy;
use crate::ast::JoinCondition;
use crate::ast::JoinOperator;
use crate::ast::OrderByExpr;
Expand Down Expand Up @@ -194,12 +195,19 @@ fn pretty_selection(selection: Option<Expr>) -> RcDoc<'static> {
}
}

fn pretty_group_by(group_by: Vec<Expr>) -> RcDoc<'static> {
if !group_by.is_empty() {
RcDoc::line()
fn pretty_group_set(set: Vec<Expr>) -> RcDoc<'static> {
RcDoc::nil()
.append(RcDoc::text("("))
.append(inline_comma(set.into_iter().map(pretty_expr)))
.append(RcDoc::text(")"))
}

fn pretty_group_by(group_by: Option<GroupBy>) -> RcDoc<'static> {
match group_by {
Some(GroupBy::Normal(exprs)) => RcDoc::line()
.append(
RcDoc::text("GROUP BY").append(
if group_by.len() > 1 {
if exprs.len() > 1 {
RcDoc::line()
} else {
RcDoc::space()
Expand All @@ -208,12 +216,20 @@ fn pretty_group_by(group_by: Vec<Expr>) -> RcDoc<'static> {
),
)
.append(
interweave_comma(group_by.into_iter().map(pretty_expr))
interweave_comma(exprs.into_iter().map(pretty_expr))
.nest(NEST_FACTOR)
.group(),
),
Some(GroupBy::GroupingSets(sets)) => RcDoc::line()
.append(RcDoc::text("GROUP BY GROUPING SETS (").append(RcDoc::line().nest(NEST_FACTOR)))
.append(
interweave_comma(sets.into_iter().map(pretty_group_set))
.nest(NEST_FACTOR)
.group(),
)
} else {
RcDoc::nil()
.append(RcDoc::line())
.append(RcDoc::text(")")),
_ => RcDoc::nil(),
}
}

Expand Down
33 changes: 30 additions & 3 deletions src/query/ast/src/ast/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,22 @@ pub struct SelectStmt {
// `WHERE` clause
pub selection: Option<Expr>,
// `GROUP BY` clause
pub group_by: Vec<Expr>,
pub group_by: Option<GroupBy>,
// `HAVING` clause
pub having: Option<Expr>,
}

/// Group by Clause.
#[derive(Debug, Clone, PartialEq)]
pub enum GroupBy {
/// GROUP BY expr [, expr]*
Normal(Vec<Expr>),
/// GROUP BY GROUPING SETS ( GroupSet [, GroupSet]* )
///
/// GroupSet := (expr [, expr]*) | expr
GroupingSets(Vec<Vec<Expr>>),
}

/// A relational set expression, like `SELECT ... FROM ... {UNION|EXCEPT|INTERSECT} SELECT ... FROM ...`
#[derive(Debug, Clone, PartialEq)]
pub enum SetExpr {
Expand Down Expand Up @@ -442,9 +453,25 @@ impl Display for SelectStmt {
}

// GROUP BY clause
if !self.group_by.is_empty() {
if self.group_by.is_some() {
write!(f, " GROUP BY ")?;
write_comma_separated_list(f, &self.group_by)?;
match self.group_by.as_ref().unwrap() {
GroupBy::Normal(exprs) => {
write_comma_separated_list(f, exprs)?;
}
GroupBy::GroupingSets(sets) => {
write!(f, "GROUPING SETS (")?;
for (i, set) in sets.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "(")?;
write_comma_separated_list(f, set)?;
write!(f, ")")?;
}
write!(f, ")")?;
}
}
}

// HAVING clause
Expand Down
31 changes: 23 additions & 8 deletions src/query/ast/src/parser/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ pub enum SetOperationElement {
select_list: Box<Vec<SelectTarget>>,
from: Box<Vec<TableReference>>,
selection: Box<Option<Expr>>,
group_by: Box<Vec<Expr>>,
group_by: Option<GroupBy>,
having: Box<Option<Expr>>,
},
SetOperation {
Expand All @@ -565,6 +565,25 @@ pub enum SetOperationElement {
Group(SetExpr),
}

pub fn group_by_items(i: Input) -> IResult<GroupBy> {
let normal = map(rule! { ^#comma_separated_list1(expr) }, |groups| {
GroupBy::Normal(groups)
});
let group_set = alt((
map(rule! {"(" ~ ")"}, |(_, _)| vec![]), // empty grouping set
map(
rule! {"(" ~ #comma_separated_list1(expr) ~ ")"},
|(_, sets, _)| sets,
),
map(rule! { #expr }, |e| vec![e]),
));
let group_sets = map(
rule! { GROUPING ~ SETS ~ "(" ~ ^#comma_separated_list1(group_set) ~ ")" },
|(_, _, _, sets, _)| GroupBy::GroupingSets(sets),
);
rule!(#group_sets | #normal)(i)
}

pub fn set_operation_element(i: Input) -> IResult<WithSpan<SetOperationElement>> {
let set_operator = map(
rule! {
Expand All @@ -588,7 +607,7 @@ pub fn set_operation_element(i: Input) -> IResult<WithSpan<SetOperationElement>>
SELECT ~ DISTINCT? ~ ^#comma_separated_list1(select_target)
~ ( FROM ~ ^#comma_separated_list1(table_reference) )?
~ ( WHERE ~ ^#expr )?
~ ( GROUP ~ ^BY ~ ^#comma_separated_list1(expr) )?
~ ( GROUP ~ ^BY ~ ^#group_by_items )?
~ ( HAVING ~ ^#expr )?
},
|(
Expand All @@ -609,11 +628,7 @@ pub fn set_operation_element(i: Input) -> IResult<WithSpan<SetOperationElement>>
.unwrap_or_default(),
),
selection: Box::new(opt_where_block.map(|(_, selection)| selection)),
group_by: Box::new(
opt_group_by_block
.map(|(_, _, group_by)| group_by)
.unwrap_or_default(),
),
group_by: opt_group_by_block.map(|(_, _, group_by)| group_by),
having: Box::new(opt_having_block.map(|(_, having)| having)),
}
},
Expand Down Expand Up @@ -667,7 +682,7 @@ impl<'a, I: Iterator<Item = WithSpan<'a, SetOperationElement>>> PrattParser<I>
select_list: *select_list,
from: *from,
selection: *selection,
group_by: *group_by,
group_by,
having: *having,
})),
_ => unreachable!(),
Expand Down
4 changes: 4 additions & 0 deletions src/query/ast/src/parser/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,10 @@ pub enum TokenKind {
LAST,
#[token("IGNORE_RESULT", ignore(ascii_case))]
IGNORE_RESULT,
#[token("GROUPING", ignore(ascii_case))]
GROUPING,
#[token("SETS", ignore(ascii_case))]
SETS,
}

// Reference: https://www.postgresql.org/docs/current/sql-keywords-appendix.html
Expand Down
16 changes: 14 additions & 2 deletions src/query/ast/src/visitors/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -568,8 +568,20 @@ pub trait Visitor<'ast>: Sized {
walk_expr(self, selection);
}

for expr in group_by.iter() {
walk_expr(self, expr);
match group_by {
Some(GroupBy::Normal(exprs)) => {
for expr in exprs {
walk_expr(self, expr);
}
}
Some(GroupBy::GroupingSets(sets)) => {
for set in sets {
for expr in set {
walk_expr(self, expr);
}
}
}
_ => {}
}

if let Some(having) = having {
Expand Down
16 changes: 14 additions & 2 deletions src/query/ast/src/visitors/visitor_mut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,8 +574,20 @@ pub trait VisitorMut: Sized {
walk_expr_mut(self, selection);
}

for expr in group_by.iter_mut() {
walk_expr_mut(self, expr);
match group_by {
Some(GroupBy::Normal(exprs)) => {
for expr in exprs {
walk_expr_mut(self, expr);
}
}
Some(GroupBy::GroupingSets(sets)) => {
for set in sets {
for expr in set {
walk_expr_mut(self, expr);
}
}
}
_ => {}
}

if let Some(having) = having {
Expand Down
6 changes: 6 additions & 0 deletions src/query/ast/tests/it/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,10 @@ fn test_statement() {
type = CSV field_delimiter = ',' record_delimiter = '\n' skip_header = 1;"#,
r#"SHOW FILE FORMATS"#,
r#"DROP FILE FORMAT my_csv"#,
r#"SELECT * FROM t GROUP BY GROUPING SETS (a, b, c, d)"#,
r#"SELECT * FROM t GROUP BY GROUPING SETS (a, b, (c, d))"#,
r#"SELECT * FROM t GROUP BY GROUPING SETS ((a, b), (c), (d, e))"#,
r#"SELECT * FROM t GROUP BY GROUPING SETS ((a, b), (), (d, e))"#,
];

for case in cases {
Expand Down Expand Up @@ -424,6 +428,8 @@ fn test_statement_error() {
r#"CALL system$test(a"#,
r#"show settings ilike 'enable%'"#,
r#"PRESIGN INVALID @my_stage/path/to/file"#,
r#"SELECT * FROM t GROUP BY GROUPING SETS a, b"#,
r#"SELECT * FROM t GROUP BY GROUPING SETS ()"#,
];

for case in cases {
Expand Down
Loading