Skip to content

Commit

Permalink
feat(spanner/spansql): add support for aggregate functions (#8498)
Browse files Browse the repository at this point in the history
  • Loading branch information
toga4 committed Aug 28, 2023
1 parent 9874485 commit d440d75
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 13 deletions.
44 changes: 32 additions & 12 deletions spanner/spansql/keywords.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,16 @@ var keywords = map[string]bool{
// https://cloud.google.com/spanner/docs/functions-and-operators
var funcs = make(map[string]bool)
var funcArgParsers = make(map[string]func(*parser) (Expr, *parseError))
var aggregateFuncs = make(map[string]bool)

func init() {
for _, f := range allFuncs {
for _, f := range funcNames {
funcs[f] = true
}
for _, f := range aggregateFuncNames {
funcs[f] = true
aggregateFuncs[f] = true
}
// Special case for CAST, SAFE_CAST and EXTRACT
funcArgParsers["CAST"] = typedArgParser
funcArgParsers["SAFE_CAST"] = typedArgParser
Expand All @@ -150,19 +155,9 @@ func init() {
funcArgParsers["GET_INTERNAL_SEQUENCE_STATE"] = sequenceArgParser
}

var allFuncs = []string{
var funcNames = []string{
// TODO: many more

// Aggregate functions.
"ANY_VALUE",
"ARRAY_AGG",
"AVG",
"BIT_XOR",
"COUNT",
"MAX",
"MIN",
"SUM",

// Cast functions.
"CAST",
"SAFE_CAST",
Expand Down Expand Up @@ -295,3 +290,28 @@ var allFuncs = []string{
// Utility functions.
"GENERATE_UUID",
}

var aggregateFuncNames = []string{
// Aggregate functions.
"ANY_VALUE",
"ARRAY_AGG",
"ARRAY_CONCAT_AGG",
"AVG",
"BIT_AND",
"BIT_OR",
"BIT_XOR",
"COUNT",
"COUNTIF",
"LOGICAL_AND",
"LOGICAL_OR",
"MAX",
"MIN",
"STRING_AGG",
"SUM",

// Statistical aggregate functions.
"STDDEV",
"STDDEV_SAMP",
"VAR_SAMP",
"VARIANCE",
}
63 changes: 63 additions & 0 deletions spanner/spansql/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -3566,6 +3566,65 @@ var sequenceArgParser = func(p *parser) (Expr, *parseError) {
return p.parseExpr()
}

func (p *parser) parseAggregateFunc() (Func, *parseError) {
tok := p.next()
if tok.err != nil {
return Func{}, tok.err
}
name := strings.ToUpper(tok.value)
if err := p.expect("("); err != nil {
return Func{}, err
}
var distinct bool
if p.eat("DISTINCT") {
distinct = true
}
args, err := p.parseExprList()
if err != nil {
return Func{}, err
}
var nullsHandling NullsHandling
if p.eat("IGNORE", "NULLS") {
nullsHandling = IgnoreNulls
} else if p.eat("RESPECT", "NULLS") {
nullsHandling = RespectNulls
}
var having *AggregateHaving
if p.eat("HAVING") {
tok := p.next()
if tok.err != nil {
return Func{}, tok.err
}
var cond AggregateHavingCondition
switch tok.value {
case "MAX":
cond = HavingMax
case "MIN":
cond = HavingMin
default:
return Func{}, p.errorf("got %q, want MAX or MIN", tok.value)
}
expr, err := p.parseExpr()
if err != nil {
return Func{}, err
}
having = &AggregateHaving{
Condition: cond,
Expr: expr,
}
}
if err := p.expect(")"); err != nil {
return Func{}, err
}
return Func{
Name: name,
Args: args,
Distinct: distinct,
NullsHandling: nullsHandling,
Having: having,
}, nil
}

/*
Expressions
Expand Down Expand Up @@ -3918,6 +3977,10 @@ func (p *parser) parseLit() (Expr, *parseError) {
// this is a function invocation.
// The `funcs` map is keyed by upper case strings.
if name := strings.ToUpper(tok.value); funcs[name] && p.sniff("(") {
if aggregateFuncs[name] {
p.back()
return p.parseAggregateFunc()
}
var list []Expr
var err *parseError
if f, ok := funcArgParsers[name]; ok {
Expand Down
7 changes: 7 additions & 0 deletions spanner/spansql/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,13 @@ func TestParseExpr(t *testing.T) {
{`GET_NEXT_SEQUENCE_VALUE(SEQUENCE MySequence)`, Func{Name: "GET_NEXT_SEQUENCE_VALUE", Args: []Expr{SequenceExpr{Name: ID("MySequence")}}}},
{`GET_INTERNAL_SEQUENCE_STATE(SEQUENCE MySequence)`, Func{Name: "GET_INTERNAL_SEQUENCE_STATE", Args: []Expr{SequenceExpr{Name: ID("MySequence")}}}},

// Aggregate Functions
{`COUNT(*)`, Func{Name: "COUNT", Args: []Expr{Star}}},
{`COUNTIF(DISTINCT cname)`, Func{Name: "COUNTIF", Args: []Expr{ID("cname")}, Distinct: true}},
{`ARRAY_AGG(Foo IGNORE NULLS)`, Func{Name: "ARRAY_AGG", Args: []Expr{ID("Foo")}, NullsHandling: IgnoreNulls}},
{`ANY_VALUE(Foo HAVING MAX Bar)`, Func{Name: "ANY_VALUE", Args: []Expr{ID("Foo")}, Having: &AggregateHaving{Condition: HavingMax, Expr: ID("Bar")}}},
{`STRING_AGG(DISTINCT Foo, "," IGNORE NULLS HAVING MAX Bar)`, Func{Name: "STRING_AGG", Args: []Expr{ID("Foo"), StringLiteral(",")}, Distinct: true, NullsHandling: IgnoreNulls, Having: &AggregateHaving{Condition: HavingMax, Expr: ID("Bar")}}},

// Conditional expressions
{
`CASE X WHEN 1 THEN "X" WHEN 2 THEN "Y" ELSE NULL END`,
Expand Down
20 changes: 20 additions & 0 deletions spanner/spansql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,27 @@ func (f Func) SQL() string { return buildSQL(f) }
func (f Func) addSQL(sb *strings.Builder) {
sb.WriteString(f.Name)
sb.WriteString("(")
if f.Distinct {
sb.WriteString("DISTINCT ")
}
addExprList(sb, f.Args, ", ")
switch f.NullsHandling {
case RespectNulls:
sb.WriteString(" RESPECT NULLS")
case IgnoreNulls:
sb.WriteString(" IGNORE NULLS")
}
if ah := f.Having; ah != nil {
sb.WriteString(" HAVING")
switch ah.Condition {
case HavingMax:
sb.WriteString(" MAX")
case HavingMin:
sb.WriteString(" MIN")
}
sb.WriteString(" ")
sb.WriteString(ah.Expr.SQL())
}
sb.WriteString(")")
}

Expand Down
25 changes: 25 additions & 0 deletions spanner/spansql/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,31 @@ func TestSQL(t *testing.T) {
`SELECT SAFE_CAST(7 AS DATE)`,
reparseQuery,
},
{
Func{Name: "COUNT", Args: []Expr{Star}},
`COUNT(*)`,
reparseExpr,
},
{
Func{Name: "COUNTIF", Args: []Expr{ID("cname")}, Distinct: true},
`COUNTIF(DISTINCT cname)`,
reparseExpr,
},
{
Func{Name: "ARRAY_AGG", Args: []Expr{ID("Foo")}, NullsHandling: IgnoreNulls},
`ARRAY_AGG(Foo IGNORE NULLS)`,
reparseExpr,
},
{
Func{Name: "ANY_VALUE", Args: []Expr{ID("Foo")}, Having: &AggregateHaving{Condition: HavingMax, Expr: ID("Bar")}},
`ANY_VALUE(Foo HAVING MAX Bar)`,
reparseExpr,
},
{
Func{Name: "STRING_AGG", Args: []Expr{ID("Foo"), StringLiteral(",")}, Distinct: true, NullsHandling: IgnoreNulls, Having: &AggregateHaving{Condition: HavingMax, Expr: ID("Bar")}},
`STRING_AGG(DISTINCT Foo, "," IGNORE NULLS HAVING MAX Bar)`,
reparseExpr,
},
{
ComparisonOp{LHS: ID("X"), Op: NotBetween, RHS: ID("Y"), RHS2: ID("Z")},
`X NOT BETWEEN Y AND Z`,
Expand Down
27 changes: 26 additions & 1 deletion spanner/spansql/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,9 @@ type Func struct {
Name string // not ID
Args []Expr

// TODO: various functions permit as-expressions, which might warrant different types in here.
Distinct bool
NullsHandling NullsHandling
Having *AggregateHaving
}

func (Func) isBoolExpr() {} // possibly bool
Expand Down Expand Up @@ -804,6 +806,29 @@ type SequenceExpr struct {

func (SequenceExpr) isExpr() {}

// NullsHandling represents the method of dealing with NULL values in aggregate functions.
type NullsHandling int

const (
NullsHandlingUnspecified NullsHandling = iota
RespectNulls
IgnoreNulls
)

// AggregateHaving represents the HAVING clause specific to aggregate functions, restricting rows based on a maximal or minimal value.
type AggregateHaving struct {
Condition AggregateHavingCondition
Expr Expr
}

// AggregateHavingCondition represents the condition (MAX or MIN) for the AggregateHaving clause.
type AggregateHavingCondition int

const (
HavingMax AggregateHavingCondition = iota
HavingMin
)

// Paren represents a parenthesised expression.
type Paren struct {
Expr Expr
Expand Down

0 comments on commit d440d75

Please sign in to comment.