Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

builtins: use intermediate decimal context for distributed aggregation #105694

Merged
merged 1 commit into from Jul 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 9 additions & 9 deletions pkg/sql/logictest/testdata/logic_test/window
Expand Up @@ -495,38 +495,38 @@ SELECT k, variance(d) OVER (PARTITION BY v ORDER BY w) FROM kv ORDER BY 1
query IR
SELECT k, stddev(d) OVER (PARTITION BY v ORDER BY w) FROM kv ORDER BY 1
----
1 3.4501207708330056852
1 3.4501207708330056853
3 3.5355339059327376220
5 NULL
6 3.4501207708330056852
6 3.4501207708330056853
7 NULL
8 NULL

query IR
SELECT k, stddev(d) OVER w FROM kv WINDOW w as (PARTITION BY v) ORDER BY variance(d) OVER w, k
----
5 NULL
1 3.4501207708330056852
6 3.4501207708330056852
7 3.4501207708330056852
1 3.4501207708330056853
6 3.4501207708330056853
7 3.4501207708330056853
3 3.5355339059327376220
8 3.5355339059327376220

query IRIR
SELECT * FROM (SELECT k, d, v, stddev(d) OVER (PARTITION BY v) FROM kv) sub ORDER BY variance(d) OVER (PARTITION BY v), k
----
5 -321 NULL NULL
1 1 2 3.4501207708330056852
6 4.4 2 3.4501207708330056852
7 7.9 2 3.4501207708330056852
1 1 2 3.4501207708330056853
6 4.4 2 3.4501207708330056853
7 7.9 2 3.4501207708330056853
3 8 4 3.5355339059327376220
8 3 4 3.5355339059327376220

query IR
SELECT k, max(stddev) OVER (ORDER BY d) FROM (SELECT k, d, stddev(d) OVER (PARTITION BY v) as stddev FROM kv) sub ORDER BY 2, k
----
5 NULL
1 3.4501207708330056852
1 3.4501207708330056853
3 3.5355339059327376220
6 3.5355339059327376220
7 3.5355339059327376220
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/opt/exec/execbuilder/testdata/window
Expand Up @@ -49,9 +49,9 @@ fetched: /kv/kv_pkey/8/v -> /4
fetched: /kv/kv_pkey/8/d -> 3
fetched: /kv/kv_pkey/8 -> <undecoded>
output row: [5 NULL]
output row: [1 3.4501207708330056852]
output row: [6 3.4501207708330056852]
output row: [7 3.4501207708330056852]
output row: [1 3.4501207708330056853]
output row: [6 3.4501207708330056853]
output row: [7 3.4501207708330056853]
output row: [3 3.5355339059327376220]
output row: [8 3.5355339059327376220]

Expand Down
135 changes: 70 additions & 65 deletions pkg/sql/sem/builtins/aggregate_builtins.go
Expand Up @@ -1285,6 +1285,41 @@ const sizeOfSTUnionAggregate = int64(unsafe.Sizeof(stUnionAgg{}))
const sizeOfSTCollectAggregate = int64(unsafe.Sizeof(stCollectAgg{}))
const sizeOfSTExtentAggregate = int64(unsafe.Sizeof(stExtentAgg{}))

// aggregateWithIntermediateResult is a common interface for aggregate functions
// which can return a result without loss of precision. This is useful when an
// aggregate function uses the result of another in its own calculations.
type aggregateWithIntermediateResult interface {
eval.AggregateFunc
intermediateResult() (tree.Datum, error)
}

// roundIntermediateDecimalResult retrieves the intermediate result of the
// given aggregate, and rounds it using the default decimal context. This can
// be used to calculate the final result for an aggregate that implements
// aggregateWithIntermediateResult.
func roundIntermediateDecimalResult(a aggregateWithIntermediateResult) (tree.Datum, error) {
res, err := a.intermediateResult()
if err != nil || res == tree.DNull {
return res, err
}
dd := res.(*tree.DDecimal)
_, err = tree.DecimalCtx.Round(&dd.Decimal, &dd.Decimal)
if err != nil {
return nil, err
}
// Remove trailing zeros. Depending on the order in which the input
// is processed, some number of trailing zeros could be added to the
// output. Remove them so that the results are the same regardless of order.
dd.Reduce(&dd.Decimal)
return dd, nil
}

var _ aggregateWithIntermediateResult = &intSqrDiffAggregate{}
var _ aggregateWithIntermediateResult = &decimalSqrDiffAggregate{}
var _ aggregateWithIntermediateResult = &decimalSumSqrDiffsAggregate{}
var _ aggregateWithIntermediateResult = &decimalVarPopAggregate{}
var _ aggregateWithIntermediateResult = &decimalVarianceAggregate{}

// singleDatumAggregateBase is a utility struct that helps aggregate builtins
// that store a single datum internally track their memory usage related to
// that single datum.
Expand Down Expand Up @@ -3736,27 +3771,7 @@ func (a *decimalSqrDiffAggregate) intermediateResult() (tree.Datum, error) {
}

func (a *decimalSqrDiffAggregate) Result() (tree.Datum, error) {
res, err := a.intermediateResult()
if err != nil || res == tree.DNull {
return res, err
}

dd := res.(*tree.DDecimal)
// Sqrdiff calculation is used in variance and var_pop as one of intermediate
// results. We want the intermediate results to be as precise as possible.
// That's why sqrdiff uses IntermediateCtx, but due to operations reordering
// in distributed mode the result might be different (see issue #13689,
// PR #18701). By rounding the end result to the DecimalCtx precision we avoid
// such inconsistencies.
_, err = tree.DecimalCtx.Round(&dd.Decimal, &a.sqrDiff)
if err != nil {
return nil, err
}
// Remove trailing zeros. Depending on the order in which the input
// is processed, some number of trailing zeros could be added to the
// output. Remove them so that the results are the same regardless of order.
dd.Decimal.Reduce(&dd.Decimal)
return dd, nil
return roundIntermediateDecimalResult(a)
}

// Reset implements eval.AggregateFunc interface.
Expand Down Expand Up @@ -3969,21 +3984,7 @@ func (a *decimalSumSqrDiffsAggregate) intermediateResult() (tree.Datum, error) {
}

func (a *decimalSumSqrDiffsAggregate) Result() (tree.Datum, error) {
res, err := a.intermediateResult()
if err != nil || res == tree.DNull {
return res, err
}

dd := res.(*tree.DDecimal)
_, err = tree.DecimalCtx.Round(&dd.Decimal, &dd.Decimal)
if err != nil {
return nil, err
}
// Remove trailing zeros. Depending on the order in which the input
// is processed, some number of trailing zeros could be added to the
// output. Remove them so that the results are the same regardless of order.
dd.Reduce(&dd.Decimal)
return dd, nil
return roundIntermediateDecimalResult(a)
}

// Reset implements eval.AggregateFunc interface.
Expand All @@ -4010,12 +4011,9 @@ type floatSqrDiff interface {
}

type decimalSqrDiff interface {
eval.AggregateFunc
aggregateWithIntermediateResult
Count() *apd.Decimal
Tmp() *apd.Decimal
// intermediateResult returns the current value of the accumulation without
// rounding.
intermediateResult() (tree.Datum, error)
}

type floatVarianceAggregate struct {
Expand Down Expand Up @@ -4092,8 +4090,7 @@ func (a *floatVarianceAggregate) Result() (tree.Datum, error) {
return tree.NewDFloat(tree.DFloat(float64(*sqrDiff.(*tree.DFloat)) / (float64(a.agg.Count()) - 1))), nil
}

// Result calculates the variance from the member square difference aggregator.
func (a *decimalVarianceAggregate) Result() (tree.Datum, error) {
func (a *decimalVarianceAggregate) intermediateResult() (tree.Datum, error) {
if a.agg.Count().Cmp(decimalTwo) < 0 {
return tree.DNull, nil
}
Expand All @@ -4105,7 +4102,7 @@ func (a *decimalVarianceAggregate) Result() (tree.Datum, error) {
return nil, err
}
dd := &tree.DDecimal{}
if _, err = tree.DecimalCtx.Quo(&dd.Decimal, &sqrDiff.(*tree.DDecimal).Decimal, a.agg.Tmp()); err != nil {
if _, err = tree.IntermediateCtx.Quo(&dd.Decimal, &sqrDiff.(*tree.DDecimal).Decimal, a.agg.Tmp()); err != nil {
return nil, err
}
// Remove trailing zeros. Depending on the order in which the input is
Expand All @@ -4116,6 +4113,11 @@ func (a *decimalVarianceAggregate) Result() (tree.Datum, error) {
return dd, nil
}

// Result calculates the variance from the member square difference aggregator.
func (a *decimalVarianceAggregate) Result() (tree.Datum, error) {
return roundIntermediateDecimalResult(a)
}

// Reset implements eval.AggregateFunc interface.
func (a *floatVarianceAggregate) Reset(ctx context.Context) {
a.agg.Reset(ctx)
Expand Down Expand Up @@ -4208,8 +4210,7 @@ func (a *floatVarPopAggregate) Result() (tree.Datum, error) {
return tree.NewDFloat(tree.DFloat(float64(*sqrDiff.(*tree.DFloat)) / (float64(a.agg.Count())))), nil
}

// Result calculates the population variance from the member square difference aggregator.
func (a *decimalVarPopAggregate) Result() (tree.Datum, error) {
func (a *decimalVarPopAggregate) intermediateResult() (tree.Datum, error) {
if a.agg.Count().Cmp(decimalOne) < 0 {
return tree.DNull, nil
}
Expand All @@ -4218,17 +4219,21 @@ func (a *decimalVarPopAggregate) Result() (tree.Datum, error) {
return nil, err
}
dd := &tree.DDecimal{}
if _, err = tree.DecimalCtx.Quo(&dd.Decimal, &sqrDiff.(*tree.DDecimal).Decimal, a.agg.Count()); err != nil {
if _, err = tree.IntermediateCtx.Quo(&dd.Decimal, &sqrDiff.(*tree.DDecimal).Decimal, a.agg.Count()); err != nil {
return nil, err
}
// Remove trailing zeros. Depending on the order in which the input is
// processed, some number of trailing zeros could be added to the
// output. Remove them so that the results are the same regardless of
// order.
dd.Decimal.Reduce(&dd.Decimal)
// Remove trailing zeros. Depending on the order in which the input
// is processed, some number of trailing zeros could be added to the
// output. Remove them so that the results are the same regardless of order.
dd.Reduce(&dd.Decimal)
return dd, nil
}

// Result calculates the population variance from the member square difference aggregator.
func (a *decimalVarPopAggregate) Result() (tree.Datum, error) {
return roundIntermediateDecimalResult(a)
}

// Reset implements eval.AggregateFunc interface.
func (a *floatVarPopAggregate) Reset(ctx context.Context) {
a.agg.Reset(ctx)
Expand Down Expand Up @@ -4264,7 +4269,7 @@ type floatStdDevAggregate struct {
}

type decimalStdDevAggregate struct {
agg eval.AggregateFunc
agg aggregateWithIntermediateResult
}

// Both StdDev and FinalStdDev aggregators have the same codepath for
Expand All @@ -4276,7 +4281,8 @@ type decimalStdDevAggregate struct {
func newIntStdDevAggregate(
params []*types.T, evalCtx *eval.Context, arguments tree.Datums,
) eval.AggregateFunc {
return &decimalStdDevAggregate{agg: newIntVarianceAggregate(params, evalCtx, arguments)}
agg := newIntVarianceAggregate(params, evalCtx, arguments)
return &decimalStdDevAggregate{agg: agg.(aggregateWithIntermediateResult)}
}

func newFloatStdDevAggregate(
Expand All @@ -4288,7 +4294,8 @@ func newFloatStdDevAggregate(
func newDecimalStdDevAggregate(
params []*types.T, evalCtx *eval.Context, arguments tree.Datums,
) eval.AggregateFunc {
return &decimalStdDevAggregate{agg: newDecimalVarianceAggregate(params, evalCtx, arguments)}
agg := newDecimalVarianceAggregate(params, evalCtx, arguments)
return &decimalStdDevAggregate{agg: agg.(aggregateWithIntermediateResult)}
}

func newFloatFinalStdDevAggregate(
Expand All @@ -4300,13 +4307,15 @@ func newFloatFinalStdDevAggregate(
func newDecimalFinalStdDevAggregate(
params []*types.T, evalCtx *eval.Context, arguments tree.Datums,
) eval.AggregateFunc {
return &decimalStdDevAggregate{agg: newDecimalFinalVarianceAggregate(params, evalCtx, arguments)}
agg := newDecimalFinalVarianceAggregate(params, evalCtx, arguments)
return &decimalStdDevAggregate{agg: agg.(aggregateWithIntermediateResult)}
}

func newIntStdDevPopAggregate(
params []*types.T, evalCtx *eval.Context, arguments tree.Datums,
) eval.AggregateFunc {
return &decimalStdDevAggregate{agg: newIntVarPopAggregate(params, evalCtx, arguments)}
agg := newIntVarPopAggregate(params, evalCtx, arguments)
return &decimalStdDevAggregate{agg: agg.(aggregateWithIntermediateResult)}
}

func newFloatStdDevPopAggregate(
Expand All @@ -4318,13 +4327,15 @@ func newFloatStdDevPopAggregate(
func newDecimalStdDevPopAggregate(
params []*types.T, evalCtx *eval.Context, arguments tree.Datums,
) eval.AggregateFunc {
return &decimalStdDevAggregate{agg: newDecimalVarPopAggregate(params, evalCtx, arguments)}
agg := newDecimalVarPopAggregate(params, evalCtx, arguments)
return &decimalStdDevAggregate{agg: agg.(aggregateWithIntermediateResult)}
}

func newDecimalFinalStdDevPopAggregate(
params []*types.T, evalCtx *eval.Context, arguments tree.Datums,
) eval.AggregateFunc {
return &decimalStdDevAggregate{agg: newDecimalFinalVarPopAggregate(params, evalCtx, arguments)}
agg := newDecimalFinalVarPopAggregate(params, evalCtx, arguments)
return &decimalStdDevAggregate{agg: agg.(aggregateWithIntermediateResult)}
}

func newFloatFinalStdDevPopAggregate(
Expand Down Expand Up @@ -4369,13 +4380,7 @@ func (a *floatStdDevAggregate) Result() (tree.Datum, error) {

// Result computes the square root of the variance aggregator.
func (a *decimalStdDevAggregate) Result() (tree.Datum, error) {
// TODO(richardwu): both decimalVarianceAggregate and
// finalDecimalVarianceAggregate return a decimal result with
// default tree.DecimalCtx precision. We want to be able to specify that the
// varianceAggregate use tree.IntermediateCtx (with the extra precision)
// since it is returning an intermediate value for stdDevAggregate (of
// which we take the Sqrt).
variance, err := a.agg.Result()
variance, err := a.agg.intermediateResult()
if err != nil {
return nil, err
}
Expand Down