From ebd894782faf128f28778c495a8e338f603732da Mon Sep 17 00:00:00 2001 From: David Taylor Date: Fri, 18 Dec 2015 16:36:07 +0000 Subject: [PATCH] GROUP BY implementation Add buckets of values to aggregateFuncs, and capture grouped qvalues. Render buckets into rows of a valuesNode on first call to Next. Note: ORDER BY (with GROUP BY) still unsupported (as is HAVING). --- sql/group.go | 345 ++++++++++++++++++++++++++--------------- sql/group_test.go | 11 +- sql/testdata/aggregate | 67 +++++++- 3 files changed, 291 insertions(+), 132 deletions(-) diff --git a/sql/group.go b/sql/group.go index 84eed5855046..cbfd59053b96 100644 --- a/sql/group.go +++ b/sql/group.go @@ -35,7 +35,7 @@ var aggregates = map[string]func() aggregateImpl{ func (p *planner) groupBy(n *parser.Select, s *scanNode) (*groupNode, error) { // Start by normalizing the GROUP BY expressions (to match what has been done to - // the SELECT expressions in addRenderer) so that we can compare them later. + // the SELECT expressions in addRender) so that we can compare to them later. // This is done before determining if aggregation is being performed, because // that determination is made during validation, which will require matching // expressions. @@ -51,56 +51,60 @@ func (p *planner) groupBy(n *parser.Select, s *scanNode) (*groupNode, error) { n.GroupBy[i] = norm } - // Determine if aggregation is being performed and, if so, if it is valid. - // Check each render expressions and verify that the only qvalues mentioned - // in it are either aggregated or appear in GROUP BY expressions. - if aggregation, err := checkAggregateExprs(n.GroupBy, s.render); !aggregation { + // Determine if aggregation is being performed and, if so, if it is valid, + // i.e. ever render expression either aggregates all qvalues or appears in + // the GROUP BY expressions. + if aggregation, invalidAggErr := checkAggregateExprs(n.GroupBy, s.render); !aggregation { return nil, nil - } else if err != nil { - return nil, err + } else if invalidAggErr != nil { + return nil, invalidAggErr + } + + group := &groupNode{ + planner: p, + columns: s.columns, + render: s.render, } - // TODO(pmattis): This only handles aggregate functions, not GROUP BY. + // Loop over the render expressions and extract any aggregate functions -- + // qvalues are also replaced (with identAggregates, which just returns the last + // value added to them for a bucket) to provide grouped-by values for each bucket. + // After extraction, group.render will be entirely rendered from aggregateFuncs, + // and group.funcs will contain all the functions which need to be fed values. - // Loop over the render expressions and extract any aggregate functions. - var funcs []*aggregateFunc - for i, r := range s.render { - r, f, err := p.extractAggregateFuncs(r) + for i := range group.render { + expr, err := p.extractAggregateFuncs(group, group.render[i]) if err != nil { return nil, err } - s.render[i] = r - funcs = append(funcs, f...) - } - if len(funcs) == 0 { - return nil, nil + group.render[i] = expr } + // Queries like `SELECT MAX(n) FROM t` expect a row of NULLs if nothing was aggregated. + group.addNullBucketIfEmpty = len(n.GroupBy) == 0 + + group.buckets = make(map[string]struct{}) + if log.V(2) { - strs := make([]string, 0, len(funcs)) - for _, f := range funcs { - strs = append(strs, f.val.String()) + strs := make([]string, 0, len(group.funcs)) + for _, f := range group.funcs { + strs = append(strs, f.String()) } log.Infof("Group: %s", strings.Join(strs, ", ")) } - group := &groupNode{ - planner: p, - columns: s.columns, - render: s.render, - funcs: funcs, - } - // Replace the render expressions in the scanNode with expressions that // compute only the arguments to the aggregate expressions. - s.columns = make([]column, 0, len(funcs)) - s.render = make([]parser.Expr, 0, len(funcs)) - for _, f := range funcs { - if len(f.val.expr.Exprs) != 1 { - panic(fmt.Sprintf("%s has %d arguments (expected 1)", f.val.expr.Name, len(f.val.expr.Exprs))) + s.render = make([]parser.Expr, len(group.funcs)) + for i, f := range group.funcs { + s.render[i] = f.arg + } + + // Add the group-by expressions so they are available for bucketing. + for _, g := range n.GroupBy { + if err := s.addRender(parser.SelectExpr{Expr: g}); err != nil { + return nil, err } - s.columns = append(s.columns, column{name: f.val.String(), typ: f.val.datum}) - s.render = append(s.render, f.val.expr.Exprs[0]) } group.desiredOrdering = desiredAggregateOrdering(group.funcs) @@ -108,14 +112,24 @@ func (p *planner) groupBy(n *parser.Select, s *scanNode) (*groupNode, error) { } type groupNode struct { - planner *planner - plan planNode - columns []column - row parser.DTuple - render []parser.Expr - funcs []*aggregateFunc + planner *planner + plan planNode + + columns []column + render []parser.Expr + + funcs []*aggregateFunc + // The set of bucket keys. + buckets map[string]struct{} + + addNullBucketIfEmpty bool + + // This will be nil until the first call to Next calls computeAggregates. + values *valuesNode + // During rendering, aggregateFuncs compute their result for group.currentBucket. + currentBucket string + desiredOrdering []int - needGroup bool err error } @@ -124,52 +138,86 @@ func (n *groupNode) Columns() []column { } func (n *groupNode) Ordering() ([]int, int) { - return n.plan.Ordering() + // TODO(dt): aggregate buckets are returned un-ordered for now. + return nil, 0 } func (n *groupNode) Values() parser.DTuple { - return n.row + return n.values.Values() } func (n *groupNode) Next() bool { - if !n.needGroup || n.err != nil { + if n.values == nil && n.err == nil { + n.computeAggregates() + } + if n.err != nil { return false } - n.needGroup = false + return n.values.Next() +} + +func (n *groupNode) computeAggregates() { + values := &valuesNode{ + columns: n.columns, + } + n.values = values + + var scratch []byte // Loop over the rows passing the values into the corresponding aggregation // functions. for n.plan.Next() { values := n.plan.Values() - for i, f := range n.funcs { - if n.err = f.add(values[i]); n.err != nil { - return false + aggregatedValues, groupedValues := values[:len(n.funcs)], values[len(n.funcs):] + + //TODO(dt): optimization: skip buckets when underlying plan is ordered by grouped values. + + var encoded []byte + encoded, n.err = encodeDTuple(scratch, groupedValues) + if n.err != nil { + return + } + + e := string(encoded) + + n.buckets[e] = struct{}{} + // Feed the aggregateFuncs for this bucket the non-grouped values. + for i, value := range aggregatedValues { + if n.err = n.funcs[i].add(e, value); n.err != nil { + return } } + scratch = encoded[0:0] } n.err = n.plan.Err() if n.err != nil { - return false + return } - // Fill in the aggregate function result value. - for _, f := range n.funcs { - if f.val.datum, n.err = f.impl.result(); n.err != nil { - return false - } + if len(n.buckets) < 1 && n.addNullBucketIfEmpty { + n.buckets[""] = struct{}{} } // Render the results. - n.row = make([]parser.Datum, len(n.render)) - for i, r := range n.render { - n.row[i], n.err = r.Eval(n.planner.evalCtx) - if n.err != nil { - return false + n.values.rows = make([]parser.DTuple, 0, len(n.buckets)) + + for k := range n.buckets { + n.currentBucket = k + + row := make(parser.DTuple, 0, len(n.render)) + + for _, r := range n.render { + res, err := r.Eval(n.planner.evalCtx) + if err != nil { + n.err = err + return + } + row = append(row, res) } - } - return n.err == nil + n.values.rows = append(n.values.rows, row) + } } func (n *groupNode) Err() error { @@ -180,7 +228,7 @@ func (n *groupNode) ExplainPlan() (name, description string, children []planNode name = "group" strs := make([]string, 0, len(n.funcs)) for _, f := range n.funcs { - strs = append(strs, f.val.String()) + strs = append(strs, f.String()) } description = strings.Join(strs, ", ") return name, description, []planNode{n.plan} @@ -192,7 +240,6 @@ func (n *groupNode) wrap(plan planNode) planNode { return plan } n.plan = plan - n.needGroup = true return n } @@ -211,7 +258,7 @@ func (n *groupNode) isNotNullFilter(expr parser.Expr) parser.Expr { f := n.funcs[i-1] isNotNull := &parser.ComparisonExpr{ Operator: parser.IsNot, - Left: f.val.expr.Exprs[0], + Left: f.arg, Right: parser.DNull, } if expr == nil { @@ -231,15 +278,16 @@ func (n *groupNode) isNotNullFilter(expr parser.Expr) parser.Expr { func desiredAggregateOrdering(funcs []*aggregateFunc) []int { var limit int for i, f := range funcs { - switch f.impl.(type) { + impl := f.create() + switch impl.(type) { case *maxAggregate, *minAggregate: - if limit != 0 || len(f.val.expr.Exprs) != 1 { + if limit != 0 || f.arg == nil { return nil } - switch f.val.expr.Exprs[0].(type) { + switch f.arg.(type) { case *qvalue: limit = i + 1 - if _, ok := f.impl.(*maxAggregate); ok { + if _, ok := impl.(*maxAggregate); ok { limit = -limit } default: @@ -257,8 +305,8 @@ func desiredAggregateOrdering(funcs []*aggregateFunc) []int { } type extractAggregatesVisitor struct { - funcs []*aggregateFunc - err error + n *groupNode + err error } var _ parser.Visitor = &extractAggregatesVisitor{} @@ -273,42 +321,60 @@ func (v *extractAggregatesVisitor) Visit(expr parser.Expr, pre bool) (parser.Vis break } if impl, ok := aggregates[strings.ToLower(string(t.Name.Base))]; ok { + if len(t.Exprs) != 1 { + // Type checking has already run on these expressions thus + // if an aggregate function of the wrong arity gets here, + // something has gone really wrong. + panic(fmt.Sprintf("%s has %d arguments (expected 1)", t.Name.Base, len(t.Exprs))) + } + f := &aggregateFunc{ - val: aggregateValue{ - expr: t, - }, - impl: impl(), + expr: t, + arg: t.Exprs[0], + create: impl, + group: v.n, + buckets: make(map[string]aggregateImpl), } if t.Distinct { f.seen = make(map[string]struct{}) } - v.funcs = append(v.funcs, f) - return nil, &f.val + v.n.funcs = append(v.n.funcs, f) + return nil, f + } + case *qvalue: + f := &aggregateFunc{ + expr: t, + arg: t, + create: newIdentAggregate, + group: v.n, + buckets: make(map[string]aggregateImpl), } + v.n.funcs = append(v.n.funcs, f) + return nil, f } return v, expr } -func (v *extractAggregatesVisitor) run(expr parser.Expr) (parser.Expr, []*aggregateFunc, error) { - *v = extractAggregatesVisitor{} +func (v *extractAggregatesVisitor) run(n *groupNode, expr parser.Expr) (parser.Expr, error) { + *v = extractAggregatesVisitor{n: n} expr = parser.WalkExpr(v, expr) - return expr, v.funcs, v.err + return expr, v.err } -func (p *planner) extractAggregateFuncs(expr parser.Expr) (parser.Expr, []*aggregateFunc, error) { - return p.extractAggregatesVisitor.run(expr) +func (p *planner) extractAggregateFuncs(n *groupNode, expr parser.Expr) (parser.Expr, error) { + return p.extractAggregatesVisitor.run(n, expr) } type checkAggregateVisitor struct { groupStrs map[string]struct{} aggregated bool - err error + invalid error } var _ parser.Visitor = &checkAggregateVisitor{} func (v *checkAggregateVisitor) Visit(expr parser.Expr, pre bool) (parser.Visitor, parser.Expr) { - if !pre || v.err != nil { + if !pre || v.invalid != nil { return nil, expr } @@ -323,7 +389,7 @@ func (v *checkAggregateVisitor) Visit(expr parser.Expr, pre bool) (parser.Visito if _, ok := v.groupStrs[t.String()]; ok { return nil, expr } - v.err = fmt.Errorf("column \"%s\" must appear in the GROUP BY clause or be used in an aggregate function", t.col.Name) + v.invalid = fmt.Errorf("column \"%s\" must appear in the GROUP BY clause or be used in an aggregate function", t.col.Name) return v, expr } @@ -336,18 +402,14 @@ func (v *checkAggregateVisitor) Visit(expr parser.Expr, pre bool) (parser.Visito // Check if expressions use aggregation and, if so, if they are valid. // "Valid" expressions must either contain no unaggregated qvalues // or must appear, verbatim, in the group-by clause. Expressions are -// string-compared to the group-by clauses (as an approximation of) a +// string-compared to the group-by clauses as (an approximation of) a // recursive expression-tree equality check. func checkAggregateExprs(group parser.GroupBy, exprs []parser.Expr) (bool, error) { aggregated := len(group) > 0 - //TODO(davidt): remove when group by is implemented - if aggregated { - return aggregated, fmt.Errorf("GROUP BY not supported yet") - } - v := checkAggregateVisitor{} + // TODO(dt): consider other ways of comparing expression trees. v.groupStrs = make(map[string]struct{}, len(group)) for i := range group { v.groupStrs[group[i].String()] = struct{}{} @@ -358,47 +420,25 @@ func checkAggregateExprs(group parser.GroupBy, exprs []parser.Expr) (bool, error if v.aggregated { aggregated = true } - if v.err != nil { - return aggregated, v.err + if v.invalid != nil && aggregated { + return aggregated, v.invalid } } return aggregated, nil } -type aggregateValue struct { - datum parser.Datum - expr *parser.FuncExpr -} - -var _ parser.VariableExpr = &aggregateValue{} - -func (*aggregateValue) Variable() {} - -func (av *aggregateValue) String() string { - return av.expr.String() -} - -func (av *aggregateValue) Walk(v parser.Visitor) { - // I expected to implement: - // av.datum = parser.WalkExpr(v, av.datum).(parser.Datum) - // But it seems `av.datum` is sometimes nil. -} - -func (av *aggregateValue) TypeCheck() (parser.Datum, error) { - return av.expr.TypeCheck() -} - -func (av *aggregateValue) Eval(ctx parser.EvalContext) (parser.Datum, error) { - return av.datum.Eval(ctx) -} +var _ parser.VariableExpr = &aggregateFunc{} type aggregateFunc struct { - val aggregateValue - impl aggregateImpl - seen map[string]struct{} + expr parser.Expr + arg parser.Expr + create func() aggregateImpl + group *groupNode + buckets map[string]aggregateImpl + seen map[string]struct{} } -func (a *aggregateFunc) add(d parser.Datum) error { +func (a *aggregateFunc) add(bucket string, d parser.Datum) error { if a.seen != nil { encoded, err := encodeDatum(nil, d) if err != nil { @@ -411,7 +451,42 @@ func (a *aggregateFunc) add(d parser.Datum) error { } a.seen[e] = struct{}{} } - return a.impl.add(d) + + if _, ok := a.buckets[bucket]; !ok { + a.buckets[bucket] = a.create() + } + + return a.buckets[bucket].add(d) +} + +func (*aggregateFunc) Variable() {} + +func (a *aggregateFunc) String() string { + return a.expr.String() +} + +func (a *aggregateFunc) Walk(v parser.Visitor) { +} + +func (a *aggregateFunc) TypeCheck() (parser.Datum, error) { + return a.expr.TypeCheck() +} + +func (a *aggregateFunc) Eval(ctx parser.EvalContext) (parser.Datum, error) { + found, ok := a.buckets[a.group.currentBucket] + + if !ok { + found = a.create() + } + + datum, err := found.result() + + if err != nil { + return nil, err + } + + // This is almost certainly the identity. Oh well. + return datum.Eval(ctx) } func encodeDatum(b []byte, d parser.Datum) ([]byte, error) { @@ -442,6 +517,30 @@ var _ aggregateImpl = &countAggregate{} var _ aggregateImpl = &maxAggregate{} var _ aggregateImpl = &minAggregate{} var _ aggregateImpl = &sumAggregate{} +var _ aggregateImpl = &identAggregate{} + +// In order to render the unaggregated (ie grouped) fields, during aggregation, +// the values for those fields have to be stored for each bucket. +// The `identAggregate` provides an "aggregate" function that actually +// just returns the last value passed to `add`, unchanged. For accumulating +// and rendering though it behaves like the other aggregate functions, +// allowing both those steps to avoid special-casing grouped vs aggregated fields. +type identAggregate struct { + val parser.Datum +} + +func newIdentAggregate() aggregateImpl { + return &identAggregate{} +} + +func (a *identAggregate) add(datum parser.Datum) error { + a.val = datum + return nil +} + +func (a *identAggregate) result() (parser.Datum, error) { + return a.val, nil +} type avgAggregate struct { sumAggregate diff --git a/sql/group_test.go b/sql/group_test.go index 0207af126501..d4d83af26bfd 100644 --- a/sql/group_test.go +++ b/sql/group_test.go @@ -21,17 +21,13 @@ import ( "reflect" "testing" - "github.com/cockroachdb/cockroach/sql/parser" "github.com/cockroachdb/cockroach/util/leaktest" ) func TestDesiredAggregateOrder(t *testing.T) { defer leaktest.AfterTest(t) - extractAggregateFuncs := func(expr parser.Expr) (parser.Expr, []*aggregateFunc, error) { - var v extractAggregatesVisitor - return v.run(expr) - } + p := planner{} testData := []struct { expr string @@ -53,11 +49,12 @@ func TestDesiredAggregateOrder(t *testing.T) { } for _, d := range testData { expr, _ := parseAndNormalizeExpr(t, d.expr) - _, funcs, err := extractAggregateFuncs(expr) + group := &groupNode{} + _, err := p.extractAggregateFuncs(group, expr) if err != nil { t.Fatal(err) } - ordering := desiredAggregateOrdering(funcs) + ordering := desiredAggregateOrdering(group.funcs) if !reflect.DeepEqual(d.ordering, ordering) { t.Fatalf("%s: expected %d, but found %d", d.expr, d.ordering, ordering) } diff --git a/sql/testdata/aggregate b/sql/testdata/aggregate index f3dd50a25d5e..88b8ac044bec 100644 --- a/sql/testdata/aggregate +++ b/sql/testdata/aggregate @@ -10,9 +10,15 @@ INSERT INTO kv VALUES (1, 2), (3, 4), (5, NULL), (6, 2), (7, 2), (8, 4) query error column "k" must appear in the GROUP BY clause or be used in an aggregate function SELECT COUNT(*), k FROM kv -# TODO(pmattis): fix -query error GROUP BY not supported yet +query II rowsort SELECT COUNT(*), k FROM kv GROUP BY k +---- +1 1 +1 3 +1 5 +1 6 +1 7 +1 8 query error syntax error at or near "," SELECT COUNT(*, 1) FROM kv @@ -236,3 +242,60 @@ EXPLAIN SELECT MAX(x) FROM xyz WHERE (z, y) = (3, 2) ---- 0 group MAX(x) 1 revscan xyz@zyx 1:/3/2/#-/3/3 + +statement ok +CREATE TABLE links ( + id INT PRIMARY KEY, + up INT, + down INT, + hide BOOLEAN, + tag STRING +) + +statement OK +INSERT INTO links VALUES +(1, 2, 3, true, 'a'), +(2, 3, 1, false, 'b'), +(3, 1, 4, true, 'a'), +(4, 10, NULL, NULL, 'c'), +(5, 1, 1, false, NULL), +(6, 1, 1, false, 'C'); + +query I +SELECT COUNT(*) FROM links +---- +6 + +query TI rowsort +SELECT tag, COUNT(*) FROM links GROUP BY tag +---- +C 1 +NULL 1 +a 2 +b 1 +c 1 + +query TI rowsort +SELECT UPPER(tag), COUNT(*) FROM links GROUP BY tag +---- +A 2 +B 1 +C 1 +C 1 +NULL 1 + +query TIIIR rowsort +SELECT UPPER(tag), COUNT(*), SUM(up), SUM(down), AVG(up+down) as avg_votes FROM links GROUP BY UPPER(tag) +---- +A 2 3 7 5 +B 1 3 1 4 +C 2 11 1 2 +NULL 1 1 1 2 + +query II rowsort +SELECT count(links.id) AS count_1, links.up + links.down AS lx FROM links GROUP BY links.up + links.down +---- +1 4 +1 NULL +2 2 +2 5