Skip to content

Commit

Permalink
support 'GROUP BY' column number
Browse files Browse the repository at this point in the history
  • Loading branch information
dt committed Jan 5, 2016
1 parent 6a80a4f commit 7d40505
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 7 deletions.
31 changes: 26 additions & 5 deletions sql/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,18 @@ func (p *planner) groupBy(n *parser.Select, s *scanNode) (*groupNode, error) {
if err != nil {
return nil, err
}
n.GroupBy[i] = norm
// If a col index is specified, replace it with that expression first.
// NB: This is not a deep copy, and thus when extractAggregateFuncs runs
// on s.render, the GroupBy expressions can contain wrapped qvalues.
// aggregateFunc's Eval() method handles being called during grouping.
if col, err := s.colIndex(norm); err != nil {
return nil, err
} else if col >= 0 {
n.GroupBy[i] = s.render[col]
} else {
n.GroupBy[i] = norm
}

}

if err := checkAggregateExprs(n.GroupBy, s.render); err != nil {
Expand Down Expand Up @@ -158,8 +169,8 @@ type groupNode struct {

addNullBucketIfEmpty bool

values valuesNode
intialized bool
values valuesNode
initialized bool

// During rendering, aggregateFuncs compute their result for group.currentBucket.
currentBucket string
Expand All @@ -182,7 +193,7 @@ func (n *groupNode) Values() parser.DTuple {
}

func (n *groupNode) Next() bool {
if !n.intialized && n.err == nil {
if !n.initialized && n.err == nil {
n.computeAggregates()
}
if n.err != nil {
Expand All @@ -192,7 +203,6 @@ func (n *groupNode) Next() bool {
}

func (n *groupNode) computeAggregates() {
n.intialized = true
var scratch []byte

// Loop over the rows passing the values into the corresponding aggregation
Expand Down Expand Up @@ -229,6 +239,9 @@ func (n *groupNode) computeAggregates() {
n.buckets[""] = struct{}{}
}

// Since this controls Eval behavior of aggregateFunc, it is not set until init is complete.
n.initialized = true

// Render the results.
n.values.rows = make([]parser.DTuple, 0, len(n.buckets))
for k := range n.buckets {
Expand Down Expand Up @@ -260,6 +273,7 @@ func (n *groupNode) computeAggregates() {

n.values.rows = append(n.values.rows, row)
}

}

func (n *groupNode) Err() error {
Expand Down Expand Up @@ -562,6 +576,13 @@ func (a *aggregateFunc) TypeCheck(args parser.MapArgs) (parser.Datum, error) {
}

func (a *aggregateFunc) Eval(ctx parser.EvalContext) (parser.Datum, error) {
// During init of the group buckets, grouped expressions (i.e. wrapped
// qvalues) are Eval()'ed to determine the bucket for a row, so pass these
// calls through to the underlying `arg` expr Eval until init is done.
if !a.group.initialized {
return a.arg.Eval(ctx)
}

found, ok := a.buckets[a.group.currentBucket]
if !ok {
found = a.create()
Expand Down
34 changes: 32 additions & 2 deletions sql/testdata/aggregate
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,34 @@ SELECT COUNT(*), k FROM kv GROUP BY k
1 7
1 8

query II rowsort
SELECT COUNT(*), k FROM kv GROUP BY 2
----
1 1
1 3
1 5
1 6
1 7
1 8

query error invalid column index: 5 not in range \[1, 2\]
SELECT COUNT(*), k FROM kv GROUP BY 5

query error invalid column index: 0 not in range \[1, 2\]
SELECT COUNT(*), k FROM kv GROUP BY 0

query error invalid column index: -4 not in range \[1, 2\]
SELECT COUNT(*), k FROM kv GROUP BY -4

query error non-integer constant column index
SELECT 1 GROUP BY 'a'

query I rowsort
SELECT 1 GROUP BY 'a';
SELECT 1 FROM kv GROUP BY v
----
1
1
1

query I rowsort
SELECT MIN(1) FROM kv;
Expand All @@ -48,14 +72,20 @@ SELECT COUNT(*), k+v FROM kv GROUP BY k+v
1 9
1 NULL


query IT rowsort
SELECT COUNT(*), UPPER(s) FROM kv GROUP BY UPPER(s)
----
1 NULL
2 B
3 A

query IT rowsort
SELECT COUNT(*), UPPER(s) FROM kv GROUP BY 2
----
1 NULL
2 B
3 A

query IT rowsort
SELECT COUNT(*), UPPER(s) FROM kv GROUP BY s
----
Expand Down

0 comments on commit 7d40505

Please sign in to comment.