Skip to content

Commit

Permalink
Merge pull request #330 from kitta65/feature/udaf
Browse files Browse the repository at this point in the history
support user defined aggreagate function
  • Loading branch information
kitta65 committed Mar 22, 2024
2 parents 287aa88 + 1b640dd commit e8a3a0d
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 39 deletions.
95 changes: 56 additions & 39 deletions src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ impl Parser {
}
}
"STRUCT" => {
let type_ = self.parse_type(false)?;
let type_ = self.parse_type(false, false)?;
self.next_token()?; // STRUCT -> (, > -> (
let mut struct_literal = self.construct_node(NodeType::StructLiteral)?;
let mut exprs = vec![];
Expand All @@ -378,7 +378,7 @@ impl Parser {
left = struct_literal;
}
"RANGE" => {
let type_ = self.parse_type(false)?;
let type_ = self.parse_type(false, false)?;
self.next_token()?; // > -> '[lower, upper)'
let mut range_literal = self.construct_node(NodeType::RangeLiteral)?;
range_literal.push_node("type", type_);
Expand All @@ -399,7 +399,7 @@ impl Parser {
"ARRAY" => {
// when used as literal
if !self.get_token(1)?.is("(") {
let type_ = self.parse_type(false)?;
let type_ = self.parse_type(false, false)?;
self.next_token()?; // > -> [
let mut arr = self.construct_node(NodeType::ArrayLiteral)?;
self.next_token()?; // [ -> exprs | ]
Expand Down Expand Up @@ -523,7 +523,7 @@ impl Parser {
let mut as_ = self.construct_node(NodeType::CastArgument)?;
as_.push_node("cast_from", cast_from);
self.next_token()?; // -> type
as_.push_node("cast_to", self.parse_type(false)?);
as_.push_node("cast_to", self.parse_type(false, false)?);
if self.get_token(1)?.is("FORMAT") {
self.next_token()?; // -> FORMAT
let mut format =
Expand Down Expand Up @@ -762,6 +762,7 @@ impl Parser {
fn parse_grouped_type_declaration_or_constraints(
&mut self,
schema: bool,
aggregate: bool,
) -> BQ2CSTResult<Node> {
let mut group = self.construct_node(NodeType::GroupedTypeDeclarationOrConstraints)?;
self.next_token()?; // ( -> INOUT | ident | type | PRIMARY | CONSTRAING | FOREIGN
Expand All @@ -784,7 +785,7 @@ impl Parser {
type_declaration = self.construct_node(NodeType::TypeDeclaration)?;
type_declaration.push_node("in_out", in_out);
self.next_token()?; // -> type
type_declaration.push_node("type", self.parse_type(schema)?);
type_declaration.push_node("type", self.parse_type(schema, aggregate)?);
} else if (self.get_token(0)?.is("PRIMARY") && self.get_token(1)?.is("KEY"))
|| (self.get_token(0)?.is("FOREIGN") && self.get_token(1)?.is("KEY"))
|| (self.get_token(0)?.is("CONSTRAINT") && !self.get_token(2)?.in_(&marker_tokens))
Expand All @@ -793,10 +794,10 @@ impl Parser {
} else if !self.get_token(1)?.in_(&marker_tokens) {
type_declaration = self.construct_node(NodeType::TypeDeclaration)?;
self.next_token()?; // -> type
type_declaration.push_node("type", self.parse_type(schema)?);
type_declaration.push_node("type", self.parse_type(schema, aggregate)?);
} else {
type_declaration = Node::empty(NodeType::TypeDeclaration);
type_declaration.push_node("type", self.parse_type(schema)?);
type_declaration.push_node("type", self.parse_type(schema, aggregate)?);
}
self.next_token()?; // -> , | > | )
if self.get_token(0)?.is(",") {
Expand Down Expand Up @@ -1299,15 +1300,15 @@ impl Parser {
}
Ok(left)
}
fn parse_type(&mut self, schema: bool) -> BQ2CSTResult<Node> {
fn parse_type(&mut self, schema: bool, aggregate: bool) -> BQ2CSTResult<Node> {
let mut res = match self.get_token(0)?.literal.to_uppercase().as_str() {
"ARRAY" | "RANGE" => {
let mut res = self.construct_node(NodeType::Type)?;
if self.get_token(1)?.literal.as_str() == "<" {
self.next_token()?; // -> <
let mut type_ = self.construct_node(NodeType::GroupedType)?;
self.next_token()?; // < -> type
type_.push_node("type", self.parse_type(schema)?);
type_.push_node("type", self.parse_type(schema, false)?);
self.next_token()?; // type -> >
type_.push_node("rparen", self.construct_node(NodeType::Symbol)?);
res.push_node("type_declaration", type_);
Expand All @@ -1332,7 +1333,7 @@ impl Parser {
} else {
type_declaration = Node::empty(NodeType::TypeDeclaration);
}
type_declaration.push_node("type", self.parse_type(schema)?);
type_declaration.push_node("type", self.parse_type(schema, false)?);
self.next_token()?; // type -> , or next_declaration
if self.get_token(0)?.is(",") {
type_declaration
Expand Down Expand Up @@ -1405,24 +1406,36 @@ impl Parser {
res.push_node("enforced", self.parse_enforced()?);
}
}
if self.get_token(1)?.is("DEFAULT") && schema {
self.next_token()?; // -> DEFAULT
let mut default = self.construct_node(NodeType::KeywordWithExpr)?;
self.next_token()?; // -> expr
default.push_node("expr", self.parse_expr(usize::MAX, false, false)?);
res.push_node("default", default);
}
if self.get_token(1)?.is("NOT") && schema {
self.next_token()?; // -> NOT
let not_ = self.construct_node(NodeType::Keyword)?;
self.next_token()?; // -> null
let null = self.construct_node(NodeType::Keyword)?;
res.push_node_vec("not_null", vec![not_, null]);
if schema {
if self.get_token(1)?.is("DEFAULT") {
self.next_token()?; // -> DEFAULT
let mut default = self.construct_node(NodeType::KeywordWithExpr)?;
self.next_token()?; // -> expr
default.push_node("expr", self.parse_expr(usize::MAX, false, false)?);
res.push_node("default", default);
}
if self.get_token(1)?.is("NOT") {
self.next_token()?; // -> NOT
let not_ = self.construct_node(NodeType::Keyword)?;
self.next_token()?; // -> null
let null = self.construct_node(NodeType::Keyword)?;
res.push_node_vec("not_null", vec![not_, null]);
}
if self.get_token(1)?.is("OPTIONS") {
self.next_token()?; // -> OPTIONS
let options = self.parse_keyword_with_grouped_exprs(false)?;
res.push_node("options", options);
}
}
if self.get_token(1)?.is("OPTIONS") && schema {
self.next_token()?; // -> OPTIONS
let options = self.parse_keyword_with_grouped_exprs(false)?;
res.push_node("options", options);
if aggregate {
if self.get_token(1)?.is("NOT") {
self.next_token()?; // -> NOT
let mut not_ = self.construct_node(NodeType::KeywordSequence)?;
self.next_token()?; // -> AGGREGATE
let null = self.construct_node(NodeType::Keyword)?;
not_.push_node("next_keyword", null);
res.push_node("aggregate", not_);
}
}
Ok(res)
}
Expand Down Expand Up @@ -2106,7 +2119,7 @@ impl Parser {
self.next_token()?; // -> (
create.push_node(
"column_schema_group",
self.parse_grouped_type_declaration_or_constraints(true)?,
self.parse_grouped_type_declaration_or_constraints(true, false)?,
);
}
if self.get_token(1)?.is("default") {
Expand Down Expand Up @@ -2146,7 +2159,7 @@ impl Parser {
self.next_token()?; // -> (
with.push_node(
"column_schema_group",
self.parse_grouped_type_declaration_or_constraints(false)?,
self.parse_grouped_type_declaration_or_constraints(false, false)?,
);
}
create.push_node("with_partition_columns", with);
Expand Down Expand Up @@ -2278,6 +2291,10 @@ impl Parser {
node.push_node("table", self.construct_node(NodeType::Keyword)?);
is_tvf = true;
}
if self.get_token(1)?.is("AGGREGATE") {
self.next_token()?; // -> AGGREGATE
node.push_node("aggregate", self.construct_node(NodeType::Keyword)?);
}
self.next_token()?; // -> FUNCTION
node.push_node("what", self.construct_node(NodeType::Keyword)?);
if self.get_token(1)?.in_(&vec!["IF"]) {
Expand All @@ -2289,13 +2306,13 @@ impl Parser {
self.next_token()?; // -> (
node.push_node(
"group",
self.parse_grouped_type_declaration_or_constraints(false)?,
self.parse_grouped_type_declaration_or_constraints(false, true)?,
);
if self.get_token(1)?.is("RETURNS") {
self.next_token()?; // -> RETURNS
let mut returns = self.construct_node(NodeType::KeywordWithType)?;
self.next_token()?; // -> type
returns.push_node("type", self.parse_type(false)?);
returns.push_node("type", self.parse_type(false, false)?);
node.push_node("returns", returns);
}
if self.get_token(1)?.is("REMOTE") {
Expand Down Expand Up @@ -2381,7 +2398,7 @@ impl Parser {
self.next_token()?; // -> (
create.push_node(
"group",
self.parse_grouped_type_declaration_or_constraints(true)?,
self.parse_grouped_type_declaration_or_constraints(true, false)?,
);
if self.get_token(1)?.is("EXTERNAL") {
self.next_token()?; // -> EXTERNAL
Expand Down Expand Up @@ -2491,7 +2508,7 @@ impl Parser {
self.next_token()?; // -> (
input.push_node(
"group",
self.parse_grouped_type_declaration_or_constraints(false)?,
self.parse_grouped_type_declaration_or_constraints(false, false)?,
);
create.push_node("input", input);
}
Expand All @@ -2501,7 +2518,7 @@ impl Parser {
self.next_token()?; // -> (
output.push_node(
"group",
self.parse_grouped_type_declaration_or_constraints(false)?,
self.parse_grouped_type_declaration_or_constraints(false, false)?,
);
create.push_node("output", output);
}
Expand Down Expand Up @@ -2632,7 +2649,7 @@ impl Parser {
self.next_token()?; // -> ident
let mut ident = self.construct_node(NodeType::TypeDeclaration)?;
self.next_token()?; // -> type
ident.push_node("type", self.parse_type(true)?);
ident.push_node("type", self.parse_type(true, false)?);
add_column.push_node("type_declaration", ident);
if self.get_token(1)?.is(",") {
self.next_token()?; // -> ,
Expand Down Expand Up @@ -2776,7 +2793,7 @@ impl Parser {
self.next_token()?; // -> DATA
alter.push_node_vec("data_type", self.parse_n_keywords(2)?);
self.next_token()?; // -> type
alter.push_node("type", self.parse_type(false)?);
alter.push_node("type", self.parse_type(false, false)?);
} else {
self.next_token()?; // -> DEFAULT
let mut default = self.construct_node(NodeType::KeywordWithExpr)?;
Expand Down Expand Up @@ -3098,7 +3115,7 @@ impl Parser {
declare.push_node_vec("idents", idents);
if !self.get_token(1)?.is("DEFAULT") {
self.next_token()?; // ident -> variable_type
declare.push_node("variable_type", self.parse_type(false)?);
declare.push_node("variable_type", self.parse_type(false, false)?);
}
if self.get_token(1)?.is("DEFAULT") {
self.next_token()?; // -> DEFAULT
Expand Down Expand Up @@ -3517,7 +3534,7 @@ impl Parser {
self.next_token()?; // -> (
load.push_node(
"column_group",
self.parse_grouped_type_declaration_or_constraints(false)?,
self.parse_grouped_type_declaration_or_constraints(false, false)?,
);
}
if self.get_token(1)?.is("PARTITION") {
Expand Down Expand Up @@ -3547,7 +3564,7 @@ impl Parser {
self.next_token()?; // -> (
with.push_node(
"column_schema_group",
self.parse_grouped_type_declaration_or_constraints(false)?,
self.parse_grouped_type_declaration_or_constraints(false, false)?,
);
}
load.push_node("with_partition_columns", with);
Expand Down
41 changes: 41 additions & 0 deletions src/parser/tests/tests_ddl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1633,6 +1633,47 @@ table:
self: TABLE (Keyword)
what:
self: FUNCTION (Keyword)
",
0,
)),
// AGGREGATE function
Box::new(SuccessTestCase::new(
"\
CREATE AGGREGATE FUNCTION plus_one(n int64 not aggregate)
AS (n + 1)
",
"\
self: CREATE (CreateFunctionStatement)
aggregate:
self: AGGREGATE (Keyword)
as:
self: AS (KeywordWithGroupedXXX)
group:
self: ( (GroupedExpr)
expr:
self: + (BinaryOperator)
left:
self: n (Identifier)
right:
self: 1 (NumericLiteral)
rparen:
self: ) (Symbol)
group:
self: ( (GroupedTypeDeclarationOrConstraints)
declarations:
- self: n (TypeDeclaration)
type:
self: int64 (Type)
aggregate:
self: not (KeywordSequence)
next_keyword:
self: aggregate (Keyword)
rparen:
self: ) (Symbol)
ident:
self: plus_one (Identifier)
what:
self: FUNCTION (Keyword)
",
0,
)),
Expand Down
2 changes: 2 additions & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ export type CreateFunctionStatement = XXXStatement & {
or_replace?: NodeVecChild;
temp?: NodeChild;
table?: NodeChild;
aggregate?: NodeChild;
what: NodeChild;
if_not_exists?: NodeVecChild;
ident: NodeChild;
Expand Down Expand Up @@ -1331,6 +1332,7 @@ export type Type = BaseNode & {
default?: NodeChild;
options?: NodeChild;
collate?: NodeChild
aggregate?: NodeChild;
};
};
Expand Down

0 comments on commit e8a3a0d

Please sign in to comment.