From 75084b09bfb6e7f539601513026411d49f8edb63 Mon Sep 17 00:00:00 2001 From: Drew Kimball Date: Wed, 28 Jun 2023 00:54:47 -0600 Subject: [PATCH] builtins: use intermediate decimal context for distributed aggregation Aggregate builtin functions round the final result using the default decimal precision. In some cases, one aggregate uses the result of another in its own calculations. This could previously cause slightly different results between local and distributed execution for aggregates that return decimal values. This patch adds `intermediateResult` methods to the aggregates that are used in the intermediate calculations for other aggregates. This avoids the rounding of intermediate results, which ensures accurate results for distributed execution. Fixes #94827 Release note (bug fix): Fixed a rounding error that could cause distributed execution for some decimal aggregate functions to return slightly inaccurate results in rare cases. --- pkg/sql/logictest/testdata/logic_test/window | 18 +-- pkg/sql/opt/exec/execbuilder/testdata/window | 6 +- pkg/sql/sem/builtins/aggregate_builtins.go | 135 ++++++++++--------- 3 files changed, 82 insertions(+), 77 deletions(-) diff --git a/pkg/sql/logictest/testdata/logic_test/window b/pkg/sql/logictest/testdata/logic_test/window index 213777a0784a..035a490879a3 100644 --- a/pkg/sql/logictest/testdata/logic_test/window +++ b/pkg/sql/logictest/testdata/logic_test/window @@ -495,10 +495,10 @@ SELECT k, variance(d) OVER (PARTITION BY v ORDER BY w) FROM kv ORDER BY 1 query IR SELECT k, stddev(d) OVER (PARTITION BY v ORDER BY w) FROM kv ORDER BY 1 ---- -1 3.4501207708330056852 +1 3.4501207708330056853 3 3.5355339059327376220 5 NULL -6 3.4501207708330056852 +6 3.4501207708330056853 7 NULL 8 NULL @@ -506,9 +506,9 @@ query IR SELECT k, stddev(d) OVER w FROM kv WINDOW w as (PARTITION BY v) ORDER BY variance(d) OVER w, k ---- 5 NULL -1 3.4501207708330056852 -6 3.4501207708330056852 -7 3.4501207708330056852 +1 3.4501207708330056853 +6 3.4501207708330056853 +7 3.4501207708330056853 3 3.5355339059327376220 8 3.5355339059327376220 @@ -516,9 +516,9 @@ query IRIR SELECT * FROM (SELECT k, d, v, stddev(d) OVER (PARTITION BY v) FROM kv) sub ORDER BY variance(d) OVER (PARTITION BY v), k ---- 5 -321 NULL NULL -1 1 2 3.4501207708330056852 -6 4.4 2 3.4501207708330056852 -7 7.9 2 3.4501207708330056852 +1 1 2 3.4501207708330056853 +6 4.4 2 3.4501207708330056853 +7 7.9 2 3.4501207708330056853 3 8 4 3.5355339059327376220 8 3 4 3.5355339059327376220 @@ -526,7 +526,7 @@ query IR SELECT k, max(stddev) OVER (ORDER BY d) FROM (SELECT k, d, stddev(d) OVER (PARTITION BY v) as stddev FROM kv) sub ORDER BY 2, k ---- 5 NULL -1 3.4501207708330056852 +1 3.4501207708330056853 3 3.5355339059327376220 6 3.5355339059327376220 7 3.5355339059327376220 diff --git a/pkg/sql/opt/exec/execbuilder/testdata/window b/pkg/sql/opt/exec/execbuilder/testdata/window index dfa1a015609b..445aba534799 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/window +++ b/pkg/sql/opt/exec/execbuilder/testdata/window @@ -49,9 +49,9 @@ fetched: /kv/kv_pkey/8/v -> /4 fetched: /kv/kv_pkey/8/d -> 3 fetched: /kv/kv_pkey/8 -> output row: [5 NULL] -output row: [1 3.4501207708330056852] -output row: [6 3.4501207708330056852] -output row: [7 3.4501207708330056852] +output row: [1 3.4501207708330056853] +output row: [6 3.4501207708330056853] +output row: [7 3.4501207708330056853] output row: [3 3.5355339059327376220] output row: [8 3.5355339059327376220] diff --git a/pkg/sql/sem/builtins/aggregate_builtins.go b/pkg/sql/sem/builtins/aggregate_builtins.go index 53b0c86621aa..7d481addf329 100644 --- a/pkg/sql/sem/builtins/aggregate_builtins.go +++ b/pkg/sql/sem/builtins/aggregate_builtins.go @@ -1285,6 +1285,41 @@ const sizeOfSTUnionAggregate = int64(unsafe.Sizeof(stUnionAgg{})) const sizeOfSTCollectAggregate = int64(unsafe.Sizeof(stCollectAgg{})) const sizeOfSTExtentAggregate = int64(unsafe.Sizeof(stExtentAgg{})) +// aggregateWithIntermediateResult is a common interface for aggregate functions +// which can return a result without loss of precision. This is useful when an +// aggregate function uses the result of another in its own calculations. +type aggregateWithIntermediateResult interface { + eval.AggregateFunc + intermediateResult() (tree.Datum, error) +} + +// roundIntermediateDecimalResult retrieves the intermediate result of the +// given aggregate, and rounds it using the default decimal context. This can +// be used to calculate the final result for an aggregate that implements +// aggregateWithIntermediateResult. +func roundIntermediateDecimalResult(a aggregateWithIntermediateResult) (tree.Datum, error) { + res, err := a.intermediateResult() + if err != nil || res == tree.DNull { + return res, err + } + dd := res.(*tree.DDecimal) + _, err = tree.DecimalCtx.Round(&dd.Decimal, &dd.Decimal) + if err != nil { + return nil, err + } + // Remove trailing zeros. Depending on the order in which the input + // is processed, some number of trailing zeros could be added to the + // output. Remove them so that the results are the same regardless of order. + dd.Reduce(&dd.Decimal) + return dd, nil +} + +var _ aggregateWithIntermediateResult = &intSqrDiffAggregate{} +var _ aggregateWithIntermediateResult = &decimalSqrDiffAggregate{} +var _ aggregateWithIntermediateResult = &decimalSumSqrDiffsAggregate{} +var _ aggregateWithIntermediateResult = &decimalVarPopAggregate{} +var _ aggregateWithIntermediateResult = &decimalVarianceAggregate{} + // singleDatumAggregateBase is a utility struct that helps aggregate builtins // that store a single datum internally track their memory usage related to // that single datum. @@ -3736,27 +3771,7 @@ func (a *decimalSqrDiffAggregate) intermediateResult() (tree.Datum, error) { } func (a *decimalSqrDiffAggregate) Result() (tree.Datum, error) { - res, err := a.intermediateResult() - if err != nil || res == tree.DNull { - return res, err - } - - dd := res.(*tree.DDecimal) - // Sqrdiff calculation is used in variance and var_pop as one of intermediate - // results. We want the intermediate results to be as precise as possible. - // That's why sqrdiff uses IntermediateCtx, but due to operations reordering - // in distributed mode the result might be different (see issue #13689, - // PR #18701). By rounding the end result to the DecimalCtx precision we avoid - // such inconsistencies. - _, err = tree.DecimalCtx.Round(&dd.Decimal, &a.sqrDiff) - if err != nil { - return nil, err - } - // Remove trailing zeros. Depending on the order in which the input - // is processed, some number of trailing zeros could be added to the - // output. Remove them so that the results are the same regardless of order. - dd.Decimal.Reduce(&dd.Decimal) - return dd, nil + return roundIntermediateDecimalResult(a) } // Reset implements eval.AggregateFunc interface. @@ -3969,21 +3984,7 @@ func (a *decimalSumSqrDiffsAggregate) intermediateResult() (tree.Datum, error) { } func (a *decimalSumSqrDiffsAggregate) Result() (tree.Datum, error) { - res, err := a.intermediateResult() - if err != nil || res == tree.DNull { - return res, err - } - - dd := res.(*tree.DDecimal) - _, err = tree.DecimalCtx.Round(&dd.Decimal, &dd.Decimal) - if err != nil { - return nil, err - } - // Remove trailing zeros. Depending on the order in which the input - // is processed, some number of trailing zeros could be added to the - // output. Remove them so that the results are the same regardless of order. - dd.Reduce(&dd.Decimal) - return dd, nil + return roundIntermediateDecimalResult(a) } // Reset implements eval.AggregateFunc interface. @@ -4010,12 +4011,9 @@ type floatSqrDiff interface { } type decimalSqrDiff interface { - eval.AggregateFunc + aggregateWithIntermediateResult Count() *apd.Decimal Tmp() *apd.Decimal - // intermediateResult returns the current value of the accumulation without - // rounding. - intermediateResult() (tree.Datum, error) } type floatVarianceAggregate struct { @@ -4092,8 +4090,7 @@ func (a *floatVarianceAggregate) Result() (tree.Datum, error) { return tree.NewDFloat(tree.DFloat(float64(*sqrDiff.(*tree.DFloat)) / (float64(a.agg.Count()) - 1))), nil } -// Result calculates the variance from the member square difference aggregator. -func (a *decimalVarianceAggregate) Result() (tree.Datum, error) { +func (a *decimalVarianceAggregate) intermediateResult() (tree.Datum, error) { if a.agg.Count().Cmp(decimalTwo) < 0 { return tree.DNull, nil } @@ -4105,7 +4102,7 @@ func (a *decimalVarianceAggregate) Result() (tree.Datum, error) { return nil, err } dd := &tree.DDecimal{} - if _, err = tree.DecimalCtx.Quo(&dd.Decimal, &sqrDiff.(*tree.DDecimal).Decimal, a.agg.Tmp()); err != nil { + if _, err = tree.IntermediateCtx.Quo(&dd.Decimal, &sqrDiff.(*tree.DDecimal).Decimal, a.agg.Tmp()); err != nil { return nil, err } // Remove trailing zeros. Depending on the order in which the input is @@ -4116,6 +4113,11 @@ func (a *decimalVarianceAggregate) Result() (tree.Datum, error) { return dd, nil } +// Result calculates the variance from the member square difference aggregator. +func (a *decimalVarianceAggregate) Result() (tree.Datum, error) { + return roundIntermediateDecimalResult(a) +} + // Reset implements eval.AggregateFunc interface. func (a *floatVarianceAggregate) Reset(ctx context.Context) { a.agg.Reset(ctx) @@ -4208,8 +4210,7 @@ func (a *floatVarPopAggregate) Result() (tree.Datum, error) { return tree.NewDFloat(tree.DFloat(float64(*sqrDiff.(*tree.DFloat)) / (float64(a.agg.Count())))), nil } -// Result calculates the population variance from the member square difference aggregator. -func (a *decimalVarPopAggregate) Result() (tree.Datum, error) { +func (a *decimalVarPopAggregate) intermediateResult() (tree.Datum, error) { if a.agg.Count().Cmp(decimalOne) < 0 { return tree.DNull, nil } @@ -4218,17 +4219,21 @@ func (a *decimalVarPopAggregate) Result() (tree.Datum, error) { return nil, err } dd := &tree.DDecimal{} - if _, err = tree.DecimalCtx.Quo(&dd.Decimal, &sqrDiff.(*tree.DDecimal).Decimal, a.agg.Count()); err != nil { + if _, err = tree.IntermediateCtx.Quo(&dd.Decimal, &sqrDiff.(*tree.DDecimal).Decimal, a.agg.Count()); err != nil { return nil, err } - // Remove trailing zeros. Depending on the order in which the input is - // processed, some number of trailing zeros could be added to the - // output. Remove them so that the results are the same regardless of - // order. - dd.Decimal.Reduce(&dd.Decimal) + // Remove trailing zeros. Depending on the order in which the input + // is processed, some number of trailing zeros could be added to the + // output. Remove them so that the results are the same regardless of order. + dd.Reduce(&dd.Decimal) return dd, nil } +// Result calculates the population variance from the member square difference aggregator. +func (a *decimalVarPopAggregate) Result() (tree.Datum, error) { + return roundIntermediateDecimalResult(a) +} + // Reset implements eval.AggregateFunc interface. func (a *floatVarPopAggregate) Reset(ctx context.Context) { a.agg.Reset(ctx) @@ -4264,7 +4269,7 @@ type floatStdDevAggregate struct { } type decimalStdDevAggregate struct { - agg eval.AggregateFunc + agg aggregateWithIntermediateResult } // Both StdDev and FinalStdDev aggregators have the same codepath for @@ -4276,7 +4281,8 @@ type decimalStdDevAggregate struct { func newIntStdDevAggregate( params []*types.T, evalCtx *eval.Context, arguments tree.Datums, ) eval.AggregateFunc { - return &decimalStdDevAggregate{agg: newIntVarianceAggregate(params, evalCtx, arguments)} + agg := newIntVarianceAggregate(params, evalCtx, arguments) + return &decimalStdDevAggregate{agg: agg.(aggregateWithIntermediateResult)} } func newFloatStdDevAggregate( @@ -4288,7 +4294,8 @@ func newFloatStdDevAggregate( func newDecimalStdDevAggregate( params []*types.T, evalCtx *eval.Context, arguments tree.Datums, ) eval.AggregateFunc { - return &decimalStdDevAggregate{agg: newDecimalVarianceAggregate(params, evalCtx, arguments)} + agg := newDecimalVarianceAggregate(params, evalCtx, arguments) + return &decimalStdDevAggregate{agg: agg.(aggregateWithIntermediateResult)} } func newFloatFinalStdDevAggregate( @@ -4300,13 +4307,15 @@ func newFloatFinalStdDevAggregate( func newDecimalFinalStdDevAggregate( params []*types.T, evalCtx *eval.Context, arguments tree.Datums, ) eval.AggregateFunc { - return &decimalStdDevAggregate{agg: newDecimalFinalVarianceAggregate(params, evalCtx, arguments)} + agg := newDecimalFinalVarianceAggregate(params, evalCtx, arguments) + return &decimalStdDevAggregate{agg: agg.(aggregateWithIntermediateResult)} } func newIntStdDevPopAggregate( params []*types.T, evalCtx *eval.Context, arguments tree.Datums, ) eval.AggregateFunc { - return &decimalStdDevAggregate{agg: newIntVarPopAggregate(params, evalCtx, arguments)} + agg := newIntVarPopAggregate(params, evalCtx, arguments) + return &decimalStdDevAggregate{agg: agg.(aggregateWithIntermediateResult)} } func newFloatStdDevPopAggregate( @@ -4318,13 +4327,15 @@ func newFloatStdDevPopAggregate( func newDecimalStdDevPopAggregate( params []*types.T, evalCtx *eval.Context, arguments tree.Datums, ) eval.AggregateFunc { - return &decimalStdDevAggregate{agg: newDecimalVarPopAggregate(params, evalCtx, arguments)} + agg := newDecimalVarPopAggregate(params, evalCtx, arguments) + return &decimalStdDevAggregate{agg: agg.(aggregateWithIntermediateResult)} } func newDecimalFinalStdDevPopAggregate( params []*types.T, evalCtx *eval.Context, arguments tree.Datums, ) eval.AggregateFunc { - return &decimalStdDevAggregate{agg: newDecimalFinalVarPopAggregate(params, evalCtx, arguments)} + agg := newDecimalFinalVarPopAggregate(params, evalCtx, arguments) + return &decimalStdDevAggregate{agg: agg.(aggregateWithIntermediateResult)} } func newFloatFinalStdDevPopAggregate( @@ -4369,13 +4380,7 @@ func (a *floatStdDevAggregate) Result() (tree.Datum, error) { // Result computes the square root of the variance aggregator. func (a *decimalStdDevAggregate) Result() (tree.Datum, error) { - // TODO(richardwu): both decimalVarianceAggregate and - // finalDecimalVarianceAggregate return a decimal result with - // default tree.DecimalCtx precision. We want to be able to specify that the - // varianceAggregate use tree.IntermediateCtx (with the extra precision) - // since it is returning an intermediate value for stdDevAggregate (of - // which we take the Sqrt). - variance, err := a.agg.Result() + variance, err := a.agg.intermediateResult() if err != nil { return nil, err }