From 22df86d06bb147f5df585cdb67b32c68e877e46f Mon Sep 17 00:00:00 2001 From: David Taylor Date: Fri, 18 Dec 2015 16:36:07 +0000 Subject: [PATCH] GROUP BY implementation Add buckets of values to aggregateFuncs, and capture grouped qvalues. Render buckets into rows of a valuesNode on first call to Next. Note: ORDER BY (with GROUP BY) still unsupported (as is HAVING). --- sql/group.go | 354 ++++++++++++++++++++++++++--------------- sql/group_test.go | 11 +- sql/testdata/aggregate | 104 +++++++++++- 3 files changed, 333 insertions(+), 136 deletions(-) diff --git a/sql/group.go b/sql/group.go index 84eed5855046..5507657d36f2 100644 --- a/sql/group.go +++ b/sql/group.go @@ -35,7 +35,7 @@ var aggregates = map[string]func() aggregateImpl{ func (p *planner) groupBy(n *parser.Select, s *scanNode) (*groupNode, error) { // Start by normalizing the GROUP BY expressions (to match what has been done to - // the SELECT expressions in addRenderer) so that we can compare them later. + // the SELECT expressions in addRender) so that we can compare to them later. // This is done before determining if aggregation is being performed, because // that determination is made during validation, which will require matching // expressions. @@ -51,56 +51,60 @@ func (p *planner) groupBy(n *parser.Select, s *scanNode) (*groupNode, error) { n.GroupBy[i] = norm } - // Determine if aggregation is being performed and, if so, if it is valid. - // Check each render expressions and verify that the only qvalues mentioned - // in it are either aggregated or appear in GROUP BY expressions. - if aggregation, err := checkAggregateExprs(n.GroupBy, s.render); !aggregation { + // Determine if aggregation is being performed and, if so, if it is valid, + // i.e. ever render expression either aggregates all qvalues or appears in + // the GROUP BY expressions. + if aggregation, invalidAggErr := checkAggregateExprs(n.GroupBy, s.render); !aggregation { return nil, nil - } else if err != nil { - return nil, err + } else if invalidAggErr != nil { + return nil, invalidAggErr + } + + group := &groupNode{ + planner: p, + values: valuesNode{columns: s.columns}, + render: s.render, } - // TODO(pmattis): This only handles aggregate functions, not GROUP BY. + // Loop over the render expressions and extract any aggregate functions -- + // qvalues are also replaced (with identAggregates, which just returns the last + // value added to them for a bucket) to provide grouped-by values for each bucket. + // After extraction, group.render will be entirely rendered from aggregateFuncs, + // and group.funcs will contain all the functions which need to be fed values. - // Loop over the render expressions and extract any aggregate functions. - var funcs []*aggregateFunc - for i, r := range s.render { - r, f, err := p.extractAggregateFuncs(r) + for i := range group.render { + expr, err := p.extractAggregateFuncs(group, group.render[i]) if err != nil { return nil, err } - s.render[i] = r - funcs = append(funcs, f...) - } - if len(funcs) == 0 { - return nil, nil + group.render[i] = expr } + // Queries like `SELECT MAX(n) FROM t` expect a row of NULLs if nothing was aggregated. + group.addNullBucketIfEmpty = len(n.GroupBy) == 0 + + group.buckets = make(map[string]struct{}) + if log.V(2) { - strs := make([]string, 0, len(funcs)) - for _, f := range funcs { - strs = append(strs, f.val.String()) + strs := make([]string, 0, len(group.funcs)) + for _, f := range group.funcs { + strs = append(strs, f.String()) } log.Infof("Group: %s", strings.Join(strs, ", ")) } - group := &groupNode{ - planner: p, - columns: s.columns, - render: s.render, - funcs: funcs, - } - // Replace the render expressions in the scanNode with expressions that // compute only the arguments to the aggregate expressions. - s.columns = make([]column, 0, len(funcs)) - s.render = make([]parser.Expr, 0, len(funcs)) - for _, f := range funcs { - if len(f.val.expr.Exprs) != 1 { - panic(fmt.Sprintf("%s has %d arguments (expected 1)", f.val.expr.Name, len(f.val.expr.Exprs))) + s.render = make([]parser.Expr, len(group.funcs)) + for i, f := range group.funcs { + s.render[i] = f.arg + } + + // Add the group-by expressions so they are available for bucketing. + for _, g := range n.GroupBy { + if err := s.addRender(parser.SelectExpr{Expr: g}); err != nil { + return nil, err } - s.columns = append(s.columns, column{name: f.val.String(), typ: f.val.datum}) - s.render = append(s.render, f.val.expr.Exprs[0]) } group.desiredOrdering = desiredAggregateOrdering(group.funcs) @@ -108,68 +112,108 @@ func (p *planner) groupBy(n *parser.Select, s *scanNode) (*groupNode, error) { } type groupNode struct { - planner *planner - plan planNode - columns []column - row parser.DTuple - render []parser.Expr - funcs []*aggregateFunc + planner *planner + plan planNode + + render []parser.Expr + + funcs []*aggregateFunc + // The set of bucket keys. + buckets map[string]struct{} + + addNullBucketIfEmpty bool + + values valuesNode + intialized bool + + // During rendering, aggregateFuncs compute their result for group.currentBucket. + currentBucket string + desiredOrdering []int - needGroup bool err error } func (n *groupNode) Columns() []column { - return n.columns + return n.values.Columns() } func (n *groupNode) Ordering() ([]int, int) { - return n.plan.Ordering() + // TODO(dt): aggregate buckets are returned un-ordered for now. + return nil, 0 } func (n *groupNode) Values() parser.DTuple { - return n.row + return n.values.Values() } func (n *groupNode) Next() bool { - if !n.needGroup || n.err != nil { + if !n.intialized && n.err == nil { + n.computeAggregates() + } + if n.err != nil { return false } - n.needGroup = false + return n.values.Next() +} + +func (n *groupNode) computeAggregates() { + n.intialized = true + var scratch []byte // Loop over the rows passing the values into the corresponding aggregation // functions. for n.plan.Next() { values := n.plan.Values() - for i, f := range n.funcs { - if n.err = f.add(values[i]); n.err != nil { - return false + aggregatedValues, groupedValues := values[:len(n.funcs)], values[len(n.funcs):] + + //TODO(dt): optimization: skip buckets when underlying plan is ordered by grouped values. + + var encoded []byte + encoded, n.err = encodeDTuple(scratch, groupedValues) + if n.err != nil { + return + } + + e := string(encoded) + + n.buckets[e] = struct{}{} + // Feed the aggregateFuncs for this bucket the non-grouped values. + for i, value := range aggregatedValues { + if n.err = n.funcs[i].add(e, value); n.err != nil { + return } } + scratch = encoded[0:0] } n.err = n.plan.Err() if n.err != nil { - return false + return } - // Fill in the aggregate function result value. - for _, f := range n.funcs { - if f.val.datum, n.err = f.impl.result(); n.err != nil { - return false - } + if len(n.buckets) < 1 && n.addNullBucketIfEmpty { + n.buckets[""] = struct{}{} } // Render the results. - n.row = make([]parser.Datum, len(n.render)) - for i, r := range n.render { - n.row[i], n.err = r.Eval(n.planner.evalCtx) - if n.err != nil { - return false + n.values.rows = make([]parser.DTuple, 0, len(n.buckets)) + + for k := range n.buckets { + n.currentBucket = k + + row := make(parser.DTuple, 0, len(n.render)) + + for _, r := range n.render { + res, err := r.Eval(n.planner.evalCtx) + if err != nil { + n.err = err + return + } + row = append(row, res) } - } - return n.err == nil + n.values.rows = append(n.values.rows, row) + } } func (n *groupNode) Err() error { @@ -180,7 +224,7 @@ func (n *groupNode) ExplainPlan() (name, description string, children []planNode name = "group" strs := make([]string, 0, len(n.funcs)) for _, f := range n.funcs { - strs = append(strs, f.val.String()) + strs = append(strs, f.String()) } description = strings.Join(strs, ", ") return name, description, []planNode{n.plan} @@ -192,7 +236,6 @@ func (n *groupNode) wrap(plan planNode) planNode { return plan } n.plan = plan - n.needGroup = true return n } @@ -211,7 +254,7 @@ func (n *groupNode) isNotNullFilter(expr parser.Expr) parser.Expr { f := n.funcs[i-1] isNotNull := &parser.ComparisonExpr{ Operator: parser.IsNot, - Left: f.val.expr.Exprs[0], + Left: f.arg, Right: parser.DNull, } if expr == nil { @@ -231,15 +274,16 @@ func (n *groupNode) isNotNullFilter(expr parser.Expr) parser.Expr { func desiredAggregateOrdering(funcs []*aggregateFunc) []int { var limit int for i, f := range funcs { - switch f.impl.(type) { + impl := f.create() + switch impl.(type) { case *maxAggregate, *minAggregate: - if limit != 0 || len(f.val.expr.Exprs) != 1 { + if limit != 0 || f.arg == nil { return nil } - switch f.val.expr.Exprs[0].(type) { + switch f.arg.(type) { case *qvalue: limit = i + 1 - if _, ok := f.impl.(*maxAggregate); ok { + if _, ok := impl.(*maxAggregate); ok { limit = -limit } default: @@ -257,8 +301,8 @@ func desiredAggregateOrdering(funcs []*aggregateFunc) []int { } type extractAggregatesVisitor struct { - funcs []*aggregateFunc - err error + n *groupNode + err error } var _ parser.Visitor = &extractAggregatesVisitor{} @@ -273,42 +317,60 @@ func (v *extractAggregatesVisitor) Visit(expr parser.Expr, pre bool) (parser.Vis break } if impl, ok := aggregates[strings.ToLower(string(t.Name.Base))]; ok { + if len(t.Exprs) != 1 { + // Type checking has already run on these expressions thus + // if an aggregate function of the wrong arity gets here, + // something has gone really wrong. + panic(fmt.Sprintf("%s has %d arguments (expected 1)", t.Name.Base, len(t.Exprs))) + } + f := &aggregateFunc{ - val: aggregateValue{ - expr: t, - }, - impl: impl(), + expr: t, + arg: t.Exprs[0], + create: impl, + group: v.n, + buckets: make(map[string]aggregateImpl), } if t.Distinct { f.seen = make(map[string]struct{}) } - v.funcs = append(v.funcs, f) - return nil, &f.val + v.n.funcs = append(v.n.funcs, f) + return nil, f + } + case *qvalue: + f := &aggregateFunc{ + expr: t, + arg: t, + create: newIdentAggregate, + group: v.n, + buckets: make(map[string]aggregateImpl), } + v.n.funcs = append(v.n.funcs, f) + return nil, f } return v, expr } -func (v *extractAggregatesVisitor) run(expr parser.Expr) (parser.Expr, []*aggregateFunc, error) { - *v = extractAggregatesVisitor{} +func (v *extractAggregatesVisitor) run(n *groupNode, expr parser.Expr) (parser.Expr, error) { + *v = extractAggregatesVisitor{n: n} expr = parser.WalkExpr(v, expr) - return expr, v.funcs, v.err + return expr, v.err } -func (p *planner) extractAggregateFuncs(expr parser.Expr) (parser.Expr, []*aggregateFunc, error) { - return p.extractAggregatesVisitor.run(expr) +func (p *planner) extractAggregateFuncs(n *groupNode, expr parser.Expr) (parser.Expr, error) { + return p.extractAggregatesVisitor.run(n, expr) } type checkAggregateVisitor struct { groupStrs map[string]struct{} aggregated bool - err error + aggrErr error } var _ parser.Visitor = &checkAggregateVisitor{} func (v *checkAggregateVisitor) Visit(expr parser.Expr, pre bool) (parser.Visitor, parser.Expr) { - if !pre || v.err != nil { + if !pre || v.aggrErr != nil { return nil, expr } @@ -323,7 +385,7 @@ func (v *checkAggregateVisitor) Visit(expr parser.Expr, pre bool) (parser.Visito if _, ok := v.groupStrs[t.String()]; ok { return nil, expr } - v.err = fmt.Errorf("column \"%s\" must appear in the GROUP BY clause or be used in an aggregate function", t.col.Name) + v.aggrErr = fmt.Errorf("column \"%s\" must appear in the GROUP BY clause or be used in an aggregate function", t.col.Name) return v, expr } @@ -336,18 +398,23 @@ func (v *checkAggregateVisitor) Visit(expr parser.Expr, pre bool) (parser.Visito // Check if expressions use aggregation and, if so, if they are valid. // "Valid" expressions must either contain no unaggregated qvalues // or must appear, verbatim, in the group-by clause. Expressions are -// string-compared to the group-by clauses (as an approximation of) a +// string-compared to the group-by clauses as (an approximation of) a // recursive expression-tree equality check. +// Invalid: `SELECT k, SUM(v) FROM kv` +// - `k` is unaggregated and does not appear in the (missing) GROUP BY. +// Valid: `SELECT k, SUM(v) FROM kv GROUP BY k` +// Also valid: `SELECT UPPER(k), SUM(v) FROM kv GROUP BY UPPER(k)` +// - `UPPER(k)` appears in GROUP BY. +// Also valid: `SELECT UPPER(k), SUM(v) FROM kv GROUP BY k` +// - `k` appears in GROUP BY, so `UPPER(k)` is OK, but... +// Invalid: `SELECT k, SUM(v) FROM kv GROUP BY UPPER(k)` +// - qvalue subtrees from the select must appear *verbatim* in the GROUP BY. func checkAggregateExprs(group parser.GroupBy, exprs []parser.Expr) (bool, error) { aggregated := len(group) > 0 - //TODO(davidt): remove when group by is implemented - if aggregated { - return aggregated, fmt.Errorf("GROUP BY not supported yet") - } - v := checkAggregateVisitor{} + // TODO(dt): consider other ways of comparing expression trees. v.groupStrs = make(map[string]struct{}, len(group)) for i := range group { v.groupStrs[group[i].String()] = struct{}{} @@ -358,49 +425,27 @@ func checkAggregateExprs(group parser.GroupBy, exprs []parser.Expr) (bool, error if v.aggregated { aggregated = true } - if v.err != nil { - return aggregated, v.err + if v.aggrErr != nil && aggregated { + return aggregated, v.aggrErr } } return aggregated, nil } -type aggregateValue struct { - datum parser.Datum - expr *parser.FuncExpr -} - -var _ parser.VariableExpr = &aggregateValue{} - -func (*aggregateValue) Variable() {} - -func (av *aggregateValue) String() string { - return av.expr.String() -} - -func (av *aggregateValue) Walk(v parser.Visitor) { - // I expected to implement: - // av.datum = parser.WalkExpr(v, av.datum).(parser.Datum) - // But it seems `av.datum` is sometimes nil. -} - -func (av *aggregateValue) TypeCheck() (parser.Datum, error) { - return av.expr.TypeCheck() -} - -func (av *aggregateValue) Eval(ctx parser.EvalContext) (parser.Datum, error) { - return av.datum.Eval(ctx) -} +var _ parser.VariableExpr = &aggregateFunc{} type aggregateFunc struct { - val aggregateValue - impl aggregateImpl - seen map[string]struct{} + expr parser.Expr + arg parser.Expr + create func() aggregateImpl + group *groupNode + buckets map[string]aggregateImpl + seen map[string]struct{} } -func (a *aggregateFunc) add(d parser.Datum) error { +func (a *aggregateFunc) add(bucket string, d parser.Datum) error { if a.seen != nil { - encoded, err := encodeDatum(nil, d) + encoded, err := encodeDatum([]byte(bucket), d) if err != nil { return err } @@ -411,7 +456,42 @@ func (a *aggregateFunc) add(d parser.Datum) error { } a.seen[e] = struct{}{} } - return a.impl.add(d) + + if _, ok := a.buckets[bucket]; !ok { + a.buckets[bucket] = a.create() + } + + return a.buckets[bucket].add(d) +} + +func (*aggregateFunc) Variable() {} + +func (a *aggregateFunc) String() string { + return a.expr.String() +} + +func (a *aggregateFunc) Walk(v parser.Visitor) { +} + +func (a *aggregateFunc) TypeCheck() (parser.Datum, error) { + return a.expr.TypeCheck() +} + +func (a *aggregateFunc) Eval(ctx parser.EvalContext) (parser.Datum, error) { + found, ok := a.buckets[a.group.currentBucket] + + if !ok { + found = a.create() + } + + datum, err := found.result() + + if err != nil { + return nil, err + } + + // This is almost certainly the identity. Oh well. + return datum.Eval(ctx) } func encodeDatum(b []byte, d parser.Datum) ([]byte, error) { @@ -442,6 +522,30 @@ var _ aggregateImpl = &countAggregate{} var _ aggregateImpl = &maxAggregate{} var _ aggregateImpl = &minAggregate{} var _ aggregateImpl = &sumAggregate{} +var _ aggregateImpl = &identAggregate{} + +// In order to render the unaggregated (ie grouped) fields, during aggregation, +// the values for those fields have to be stored for each bucket. +// The `identAggregate` provides an "aggregate" function that actually +// just returns the last value passed to `add`, unchanged. For accumulating +// and rendering though it behaves like the other aggregate functions, +// allowing both those steps to avoid special-casing grouped vs aggregated fields. +type identAggregate struct { + val parser.Datum +} + +func newIdentAggregate() aggregateImpl { + return &identAggregate{} +} + +func (a *identAggregate) add(datum parser.Datum) error { + a.val = datum + return nil +} + +func (a *identAggregate) result() (parser.Datum, error) { + return a.val, nil +} type avgAggregate struct { sumAggregate diff --git a/sql/group_test.go b/sql/group_test.go index 0207af126501..d4d83af26bfd 100644 --- a/sql/group_test.go +++ b/sql/group_test.go @@ -21,17 +21,13 @@ import ( "reflect" "testing" - "github.com/cockroachdb/cockroach/sql/parser" "github.com/cockroachdb/cockroach/util/leaktest" ) func TestDesiredAggregateOrder(t *testing.T) { defer leaktest.AfterTest(t) - extractAggregateFuncs := func(expr parser.Expr) (parser.Expr, []*aggregateFunc, error) { - var v extractAggregatesVisitor - return v.run(expr) - } + p := planner{} testData := []struct { expr string @@ -53,11 +49,12 @@ func TestDesiredAggregateOrder(t *testing.T) { } for _, d := range testData { expr, _ := parseAndNormalizeExpr(t, d.expr) - _, funcs, err := extractAggregateFuncs(expr) + group := &groupNode{} + _, err := p.extractAggregateFuncs(group, expr) if err != nil { t.Fatal(err) } - ordering := desiredAggregateOrdering(funcs) + ordering := desiredAggregateOrdering(group.funcs) if !reflect.DeepEqual(d.ordering, ordering) { t.Fatalf("%s: expected %d, but found %d", d.expr, d.ordering, ordering) } diff --git a/sql/testdata/aggregate b/sql/testdata/aggregate index f3dd50a25d5e..eaba8951141d 100644 --- a/sql/testdata/aggregate +++ b/sql/testdata/aggregate @@ -1,18 +1,106 @@ statement ok CREATE TABLE kv ( k INT PRIMARY KEY, - v INT + v INT, + w INT, + s STRING ) statement OK -INSERT INTO kv VALUES (1, 2), (3, 4), (5, NULL), (6, 2), (7, 2), (8, 4) +INSERT INTO kv VALUES +(1, 2, 3, 'a'), +(3, 4, 5, 'a'), +(5, NULL, NULL, NULL), +(6, 2, 3, 'b'), +(7, 2, 2, 'b'), +(8, 4, 2, 'A') query error column "k" must appear in the GROUP BY clause or be used in an aggregate function SELECT COUNT(*), k FROM kv -# TODO(pmattis): fix -query error GROUP BY not supported yet +query II rowsort SELECT COUNT(*), k FROM kv GROUP BY k +---- +1 1 +1 3 +1 5 +1 6 +1 7 +1 8 + +query II rowsort +SELECT COUNT(*), k+v FROM kv GROUP BY k+v +---- +1 12 +1 3 +1 7 +1 8 +1 9 +1 NULL + + +query IT rowsort +SELECT COUNT(*), UPPER(s) FROM kv GROUP BY UPPER(s) +---- +1 NULL +2 B +3 A + +query IT rowsort +SELECT COUNT(*), UPPER(s) FROM kv GROUP BY s +---- +1 A +1 NULL +2 A +2 B + +query IT rowsort +SELECT COUNT(*), UPPER(kv.s) FROM kv GROUP BY s +---- +1 A +1 NULL +2 A +2 B + +query IT rowsort +SELECT COUNT(*), UPPER(s) FROM kv GROUP BY kv.s +---- +1 A +1 NULL +2 A +2 B + + +query error column "s" must appear in the GROUP BY clause or be used in an aggregate function +SELECT COUNT(*), s FROM kv GROUP BY UPPER(s) + +query error column "v" must appear in the GROUP BY clause or be used in an aggregate function +SELECT COUNT(*), k+v FROM kv GROUP BY k + +query error column "v" must appear in the GROUP BY clause or be used in an aggregate function +SELECT COUNT(*), v/(k+v) FROM kv GROUP BY k+v + +query IIR rowsort +SELECT COUNT(*), k+v, floor(v/(k+v)) FROM kv WHERE k+v > 8 GROUP BY v, k+v +---- +1 12 0 +1 9 0 + +query TIIIR rowsort +SELECT UPPER(s), COUNT(*), SUM(v), SUM(w), AVG(v+w) as avg FROM kv GROUP BY UPPER(s) +---- +A 3 10 10 6.666666666666667 +B 2 4 5 4.5 +NULL 1 NULL NULL NULL + +query II rowsort +SELECT count(kv.k) AS count_1, kv.v + kv.w AS lx FROM kv GROUP BY kv.v + kv.w +---- +1 4 +1 6 +1 9 +1 NULL +2 5 query error syntax error at or near "," SELECT COUNT(*, 1) FROM kv @@ -34,6 +122,14 @@ SELECT COUNT(DISTINCT k), COUNT(DISTINCT v), COUNT(DISTINCT (v)) FROM kv ---- 6 2 2 +query TIII rowsort +SELECT UPPER(s), COUNT(DISTINCT k), COUNT(DISTINCT v), COUNT(DISTINCT (v)) FROM kv GROUP BY UPPER(s) +---- +A 3 2 2 +B 2 1 1 +NULL 1 0 0 + + query I SELECT COUNT((k, v)) FROM kv ----