diff --git a/pkg/sql/distsqlrun/aggregator.go b/pkg/sql/distsqlrun/aggregator.go index c92c663b3ded..cf0d74c474a3 100644 --- a/pkg/sql/distsqlrun/aggregator.go +++ b/pkg/sql/distsqlrun/aggregator.go @@ -57,7 +57,7 @@ func GetAggregateInfo( datumTypes[i] = inputTypes[i].ToDatumType() } - _, builtins := builtins.GetBuiltinProperties(strings.ToLower(fn.String())) + props, builtins := builtins.GetBuiltinProperties(strings.ToLower(fn.String())) for _, b := range builtins { types := b.Types.Types() if len(types) != len(inputTypes) { @@ -66,6 +66,9 @@ func GetAggregateInfo( match := true for i, t := range types { if !datumTypes[i].Equivalent(t) { + if props.NullableArgs && datumTypes[i].IsAmbiguous() { + continue + } match = false break } diff --git a/pkg/sql/distsqlrun/windower.go b/pkg/sql/distsqlrun/windower.go index 1aab9b17e7a4..f9979733fcb8 100644 --- a/pkg/sql/distsqlrun/windower.go +++ b/pkg/sql/distsqlrun/windower.go @@ -67,7 +67,7 @@ func GetWindowFunctionInfo( "function is neither an aggregate nor a window function", ) } - _, builtins := builtins.GetBuiltinProperties(strings.ToLower(funcStr)) + props, builtins := builtins.GetBuiltinProperties(strings.ToLower(funcStr)) for _, b := range builtins { types := b.Types.Types() if len(types) != len(inputTypes) { @@ -76,6 +76,9 @@ func GetWindowFunctionInfo( match := true for i, t := range types { if !datumTypes[i].Equivalent(t) { + if props.NullableArgs && datumTypes[i].IsAmbiguous() { + continue + } match = false break } diff --git a/pkg/sql/logictest/testdata/logic_test/aggregate b/pkg/sql/logictest/testdata/logic_test/aggregate index 774605f0494f..15cd635e120d 100644 --- a/pkg/sql/logictest/testdata/logic_test/aggregate +++ b/pkg/sql/logictest/testdata/logic_test/aggregate @@ -1195,6 +1195,22 @@ ORDER BY company_id; ---- company_id string_agg +query IT colnames +SELECT company_id, string_agg(employee, NULL) +FROM string_agg_test +GROUP BY company_id +ORDER BY company_id; +---- +company_id string_agg + +query IT colnames +SELECT company_id, string_agg(employee::BYTES, NULL) +FROM string_agg_test +GROUP BY company_id +ORDER BY company_id; +---- +company_id string_agg + statement OK INSERT INTO string_agg_test VALUES (1, 1, 'A'), @@ -1261,6 +1277,86 @@ company_id string_agg 3 CCC 4 DDDD +query IT colnames +SELECT company_id, string_agg(employee, NULL) +FROM string_agg_test +GROUP BY company_id +ORDER BY company_id; +---- +company_id string_agg +1 A +2 BB +3 CCC +4 DDDD + +query IT colnames +SELECT company_id, string_agg(employee::BYTES, NULL) +FROM string_agg_test +GROUP BY company_id +ORDER BY company_id; +---- +company_id string_agg +1 A +2 BB +3 CCC +4 DDDD + +query IT colnames +SELECT company_id, string_agg(NULL::STRING, ',') +FROM string_agg_test +GROUP BY company_id +ORDER BY company_id; +---- +company_id string_agg +1 NULL +2 NULL +3 NULL +4 NULL + +query IT colnames +SELECT company_id, string_agg(NULL::BYTES, b',') +FROM string_agg_test +GROUP BY company_id +ORDER BY company_id; +---- +company_id string_agg +1 NULL +2 NULL +3 NULL +4 NULL + +query IT colnames +SELECT company_id, string_agg(NULL::STRING, NULL) +FROM string_agg_test +GROUP BY company_id +ORDER BY company_id; +---- +company_id string_agg +1 NULL +2 NULL +3 NULL +4 NULL + +query IT colnames +SELECT company_id, string_agg(NULL::BYTES, NULL) +FROM string_agg_test +GROUP BY company_id +ORDER BY company_id; +---- +company_id string_agg +1 NULL +2 NULL +3 NULL +4 NULL + +query error pq: ambiguous call: string_agg\(unknown, unknown\) +SELECT company_id, string_agg(NULL, NULL) +FROM string_agg_test +GROUP BY company_id +ORDER BY company_id; + +# Now test the window function version of string_agg. + query IT colnames SELECT company_id, string_agg(employee, ',') OVER (PARTITION BY company_id ORDER BY id) @@ -1333,6 +1429,156 @@ company_id string_agg 4 DDD 4 DDDD +query IT colnames +SELECT company_id, string_agg(employee, NULL) +OVER (PARTITION BY company_id ORDER BY id) +FROM string_agg_test +ORDER BY company_id, id; +---- +company_id string_agg +1 A +2 B +2 BB +3 C +3 CC +3 CCC +4 D +4 DD +4 DDD +4 DDDD + +query IT colnames +SELECT company_id, string_agg(employee::BYTES, NULL) +OVER (PARTITION BY company_id ORDER BY id) +FROM string_agg_test +ORDER BY company_id, id; +---- +company_id string_agg +1 A +2 B +2 BB +3 C +3 CC +3 CCC +4 D +4 DD +4 DDD +4 DDDD + +query IT colnames +SELECT company_id, string_agg(NULL::STRING, employee) +OVER (PARTITION BY company_id ORDER BY id) +FROM string_agg_test +ORDER BY company_id, id; +---- +company_id string_agg +1 NULL +2 NULL +2 NULL +3 NULL +3 NULL +3 NULL +4 NULL +4 NULL +4 NULL +4 NULL + +query IT colnames +SELECT company_id, string_agg(NULL::BYTES, employee::BYTES) +OVER (PARTITION BY company_id ORDER BY id) +FROM string_agg_test +ORDER BY company_id, id; +---- +company_id string_agg +1 NULL +2 NULL +2 NULL +3 NULL +3 NULL +3 NULL +4 NULL +4 NULL +4 NULL +4 NULL + +query IT colnames +SELECT company_id, string_agg(NULL::STRING, NULL) +OVER (PARTITION BY company_id ORDER BY id) +FROM string_agg_test +ORDER BY company_id, id; +---- +company_id string_agg +1 NULL +2 NULL +2 NULL +3 NULL +3 NULL +3 NULL +4 NULL +4 NULL +4 NULL +4 NULL + +query IT colnames +SELECT company_id, string_agg(NULL::BYTES, NULL) +OVER (PARTITION BY company_id ORDER BY id) +FROM string_agg_test +ORDER BY company_id, id; +---- +company_id string_agg +1 NULL +2 NULL +2 NULL +3 NULL +3 NULL +3 NULL +4 NULL +4 NULL +4 NULL +4 NULL + +query IT colnames +SELECT company_id, string_agg(NULL, NULL::STRING) +OVER (PARTITION BY company_id ORDER BY id) +FROM string_agg_test +ORDER BY company_id, id; +---- +company_id string_agg +1 NULL +2 NULL +2 NULL +3 NULL +3 NULL +3 NULL +4 NULL +4 NULL +4 NULL +4 NULL + +query IT colnames +SELECT company_id, string_agg(NULL, NULL::BYTES) +OVER (PARTITION BY company_id ORDER BY id) +FROM string_agg_test +ORDER BY company_id, id; +---- +company_id string_agg +1 NULL +2 NULL +2 NULL +3 NULL +3 NULL +3 NULL +4 NULL +4 NULL +4 NULL +4 NULL + +query error pq: ambiguous call: string_agg\(unknown, unknown\) +SELECT company_id, string_agg(NULL, NULL) +OVER (PARTITION BY company_id ORDER BY id) +FROM string_agg_test +ORDER BY company_id, id; + query IT colnames SELECT company_id, string_agg(employee, lower(employee)) OVER (PARTITION BY company_id) @@ -1443,6 +1689,32 @@ ORDER BY e.company_id; company_id string_agg 1 D, C, B, A +query IT colnames +SELECT e.company_id, string_agg(e.employee, NULL) +FROM ( + SELECT employee, company_id + FROM string_agg_test + ORDER BY employee DESC + ) AS e +GROUP BY e.company_id +ORDER BY e.company_id; +---- +company_id string_agg +1 DCBA + +query IT colnames +SELECT e.company_id, string_agg(e.employee, NULL) +FROM ( + SELECT employee::BYTES, company_id + FROM string_agg_test + ORDER BY employee DESC + ) AS e +GROUP BY e.company_id +ORDER BY e.company_id; +---- +company_id string_agg +1 DCBA + statement OK DROP TABLE string_agg_test diff --git a/pkg/sql/sem/builtins/aggregate_builtins.go b/pkg/sql/sem/builtins/aggregate_builtins.go index 64284d709e7a..b0e1c1bc221b 100644 --- a/pkg/sql/sem/builtins/aggregate_builtins.go +++ b/pkg/sql/sem/builtins/aggregate_builtins.go @@ -178,7 +178,7 @@ var aggregates = map[string]builtinDefinition{ "Identifies the minimum selected value.") }), - "string_agg": makeBuiltin(aggProps(), + "string_agg": makeBuiltin(aggPropsNullableArgs(), makeAggOverload([]types.T{types.String, types.String}, types.String, newStringConcatAggregate, "Concatenates all selected values using the provided delimiter."), makeAggOverload([]types.T{types.Bytes, types.Bytes}, types.Bytes, newBytesConcatAggregate, @@ -588,13 +588,11 @@ func (a *avgAggregate) Size() int64 { } type concatAggregate struct { - forBytes bool - sawNonNull bool - delimiter string // used for non window functions - delimiterSize uintptr // used for non window functions - first bool - result bytes.Buffer - acc mon.BoundAccount + forBytes bool + sawNonNull bool + delimiter string // used for non window functions + result bytes.Buffer + acc mon.BoundAccount } func newBytesConcatAggregate( @@ -602,12 +600,10 @@ func newBytesConcatAggregate( ) tree.AggregateFunc { concatAgg := &concatAggregate{ forBytes: true, - first: true, acc: evalCtx.Mon.MakeBoundAccount(), } - if len(arguments) == 1 { + if len(arguments) == 1 && arguments[0] != tree.DNull { concatAgg.delimiter = string(tree.MustBeDBytes(arguments[0])) - concatAgg.delimiterSize = arguments[0].Size() } else if len(arguments) > 1 { panic(fmt.Sprintf("too many arguments passed in, expected < 2, got %d", len(arguments))) } @@ -618,12 +614,10 @@ func newStringConcatAggregate( _ []types.T, evalCtx *tree.EvalContext, arguments tree.Datums, ) tree.AggregateFunc { concatAgg := &concatAggregate{ - first: true, - acc: evalCtx.Mon.MakeBoundAccount(), + acc: evalCtx.Mon.MakeBoundAccount(), } - if len(arguments) == 1 { + if len(arguments) == 1 && arguments[0] != tree.DNull { concatAgg.delimiter = string(tree.MustBeDString(arguments[0])) - concatAgg.delimiterSize = arguments[0].Size() } else if len(arguments) > 1 { panic(fmt.Sprintf("too many arguments passed in, expected < 2, got %d", len(arguments))) } @@ -634,26 +628,23 @@ func (a *concatAggregate) Add(ctx context.Context, datum tree.Datum, others ...t if datum == tree.DNull { return nil } - delimiter := a.delimiter - delimiterSize := a.delimiterSize - // If this is called as part of a window function, the delimiter is passed in - // via the first element in others. - if len(others) == 1 && others[0] != tree.DNull { - if a.forBytes { - delimiter = string(tree.MustBeDBytes(others[0])) - } else { - delimiter = string(tree.MustBeDString(others[0])) + if !a.sawNonNull { + a.sawNonNull = true + } else { + delimiter := a.delimiter + // If this is called as part of a window function, the delimiter is passed in + // via the first element in others. + if len(others) == 1 && others[0] != tree.DNull { + if a.forBytes { + delimiter = string(tree.MustBeDBytes(others[0])) + } else { + delimiter = string(tree.MustBeDString(others[0])) + } + } else if len(others) > 1 { + panic(fmt.Sprintf("too many other datums passed in, expected < 2, got %d", len(others))) } - delimiterSize = others[0].Size() - } else if len(others) > 1 { - panic(fmt.Sprintf("too many other datums passed in, expected < 2, got %d", len(others))) - } - a.sawNonNull = true - if delimiterSize > 0 { - if a.first { - a.first = false - } else { - if err := a.acc.Grow(ctx, int64(delimiterSize)); err != nil { + if len(delimiter) > 0 { + if err := a.acc.Grow(ctx, int64(len(delimiter))); err != nil { return err } a.result.WriteString(delimiter) @@ -665,7 +656,7 @@ func (a *concatAggregate) Add(ctx context.Context, datum tree.Datum, others ...t } else { arg = string(tree.MustBeDString(datum)) } - if err := a.acc.Grow(ctx, int64(datum.Size())); err != nil { + if err := a.acc.Grow(ctx, int64(len(arg))); err != nil { return err } a.result.WriteString(arg)