Skip to content

Commit

Permalink
builtins: use intermediate decimal context for distributed aggregation
Browse files Browse the repository at this point in the history
Aggregate builtin functions round the final result using the default decimal
precision. In some cases, one aggregate uses the result of another in its
own calculations. This could previously cause slightly different results
between local and distributed execution for aggregates that return decimal
values.

This patch adds `intermediateResult` methods to the aggregates that are
used in the intermediate calculations for other aggregates. This avoids
the rounding of intermediate results, which ensures accurate results for
distributed execution.

Fixes #94827

Release note (bug fix): Fixed a rounding error that could cause distributed
execution for some decimal aggregate functions to return slightly
inaccurate results in rare cases.
  • Loading branch information
DrewKimball committed Jul 6, 2023
1 parent 845b23a commit 75084b0
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 77 deletions.
18 changes: 9 additions & 9 deletions pkg/sql/logictest/testdata/logic_test/window
Expand Up @@ -495,38 +495,38 @@ SELECT k, variance(d) OVER (PARTITION BY v ORDER BY w) FROM kv ORDER BY 1
query IR
SELECT k, stddev(d) OVER (PARTITION BY v ORDER BY w) FROM kv ORDER BY 1
----
1 3.4501207708330056852
1 3.4501207708330056853
3 3.5355339059327376220
5 NULL
6 3.4501207708330056852
6 3.4501207708330056853
7 NULL
8 NULL

query IR
SELECT k, stddev(d) OVER w FROM kv WINDOW w as (PARTITION BY v) ORDER BY variance(d) OVER w, k
----
5 NULL
1 3.4501207708330056852
6 3.4501207708330056852
7 3.4501207708330056852
1 3.4501207708330056853
6 3.4501207708330056853
7 3.4501207708330056853
3 3.5355339059327376220
8 3.5355339059327376220

query IRIR
SELECT * FROM (SELECT k, d, v, stddev(d) OVER (PARTITION BY v) FROM kv) sub ORDER BY variance(d) OVER (PARTITION BY v), k
----
5 -321 NULL NULL
1 1 2 3.4501207708330056852
6 4.4 2 3.4501207708330056852
7 7.9 2 3.4501207708330056852
1 1 2 3.4501207708330056853
6 4.4 2 3.4501207708330056853
7 7.9 2 3.4501207708330056853
3 8 4 3.5355339059327376220
8 3 4 3.5355339059327376220

query IR
SELECT k, max(stddev) OVER (ORDER BY d) FROM (SELECT k, d, stddev(d) OVER (PARTITION BY v) as stddev FROM kv) sub ORDER BY 2, k
----
5 NULL
1 3.4501207708330056852
1 3.4501207708330056853
3 3.5355339059327376220
6 3.5355339059327376220
7 3.5355339059327376220
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/opt/exec/execbuilder/testdata/window
Expand Up @@ -49,9 +49,9 @@ fetched: /kv/kv_pkey/8/v -> /4
fetched: /kv/kv_pkey/8/d -> 3
fetched: /kv/kv_pkey/8 -> <undecoded>
output row: [5 NULL]
output row: [1 3.4501207708330056852]
output row: [6 3.4501207708330056852]
output row: [7 3.4501207708330056852]
output row: [1 3.4501207708330056853]
output row: [6 3.4501207708330056853]
output row: [7 3.4501207708330056853]
output row: [3 3.5355339059327376220]
output row: [8 3.5355339059327376220]

Expand Down
135 changes: 70 additions & 65 deletions pkg/sql/sem/builtins/aggregate_builtins.go
Expand Up @@ -1285,6 +1285,41 @@ const sizeOfSTUnionAggregate = int64(unsafe.Sizeof(stUnionAgg{}))
const sizeOfSTCollectAggregate = int64(unsafe.Sizeof(stCollectAgg{}))
const sizeOfSTExtentAggregate = int64(unsafe.Sizeof(stExtentAgg{}))

// aggregateWithIntermediateResult is a common interface for aggregate functions
// which can return a result without loss of precision. This is useful when an
// aggregate function uses the result of another in its own calculations.
type aggregateWithIntermediateResult interface {
eval.AggregateFunc
intermediateResult() (tree.Datum, error)
}

// roundIntermediateDecimalResult retrieves the intermediate result of the
// given aggregate, and rounds it using the default decimal context. This can
// be used to calculate the final result for an aggregate that implements
// aggregateWithIntermediateResult.
func roundIntermediateDecimalResult(a aggregateWithIntermediateResult) (tree.Datum, error) {
res, err := a.intermediateResult()
if err != nil || res == tree.DNull {
return res, err
}
dd := res.(*tree.DDecimal)
_, err = tree.DecimalCtx.Round(&dd.Decimal, &dd.Decimal)
if err != nil {
return nil, err
}
// Remove trailing zeros. Depending on the order in which the input
// is processed, some number of trailing zeros could be added to the
// output. Remove them so that the results are the same regardless of order.
dd.Reduce(&dd.Decimal)
return dd, nil
}

var _ aggregateWithIntermediateResult = &intSqrDiffAggregate{}
var _ aggregateWithIntermediateResult = &decimalSqrDiffAggregate{}
var _ aggregateWithIntermediateResult = &decimalSumSqrDiffsAggregate{}
var _ aggregateWithIntermediateResult = &decimalVarPopAggregate{}
var _ aggregateWithIntermediateResult = &decimalVarianceAggregate{}

// singleDatumAggregateBase is a utility struct that helps aggregate builtins
// that store a single datum internally track their memory usage related to
// that single datum.
Expand Down Expand Up @@ -3736,27 +3771,7 @@ func (a *decimalSqrDiffAggregate) intermediateResult() (tree.Datum, error) {
}

func (a *decimalSqrDiffAggregate) Result() (tree.Datum, error) {
res, err := a.intermediateResult()
if err != nil || res == tree.DNull {
return res, err
}

dd := res.(*tree.DDecimal)
// Sqrdiff calculation is used in variance and var_pop as one of intermediate
// results. We want the intermediate results to be as precise as possible.
// That's why sqrdiff uses IntermediateCtx, but due to operations reordering
// in distributed mode the result might be different (see issue #13689,
// PR #18701). By rounding the end result to the DecimalCtx precision we avoid
// such inconsistencies.
_, err = tree.DecimalCtx.Round(&dd.Decimal, &a.sqrDiff)
if err != nil {
return nil, err
}
// Remove trailing zeros. Depending on the order in which the input
// is processed, some number of trailing zeros could be added to the
// output. Remove them so that the results are the same regardless of order.
dd.Decimal.Reduce(&dd.Decimal)
return dd, nil
return roundIntermediateDecimalResult(a)
}

// Reset implements eval.AggregateFunc interface.
Expand Down Expand Up @@ -3969,21 +3984,7 @@ func (a *decimalSumSqrDiffsAggregate) intermediateResult() (tree.Datum, error) {
}

func (a *decimalSumSqrDiffsAggregate) Result() (tree.Datum, error) {
res, err := a.intermediateResult()
if err != nil || res == tree.DNull {
return res, err
}

dd := res.(*tree.DDecimal)
_, err = tree.DecimalCtx.Round(&dd.Decimal, &dd.Decimal)
if err != nil {
return nil, err
}
// Remove trailing zeros. Depending on the order in which the input
// is processed, some number of trailing zeros could be added to the
// output. Remove them so that the results are the same regardless of order.
dd.Reduce(&dd.Decimal)
return dd, nil
return roundIntermediateDecimalResult(a)
}

// Reset implements eval.AggregateFunc interface.
Expand All @@ -4010,12 +4011,9 @@ type floatSqrDiff interface {
}

type decimalSqrDiff interface {
eval.AggregateFunc
aggregateWithIntermediateResult
Count() *apd.Decimal
Tmp() *apd.Decimal
// intermediateResult returns the current value of the accumulation without
// rounding.
intermediateResult() (tree.Datum, error)
}

type floatVarianceAggregate struct {
Expand Down Expand Up @@ -4092,8 +4090,7 @@ func (a *floatVarianceAggregate) Result() (tree.Datum, error) {
return tree.NewDFloat(tree.DFloat(float64(*sqrDiff.(*tree.DFloat)) / (float64(a.agg.Count()) - 1))), nil
}

// Result calculates the variance from the member square difference aggregator.
func (a *decimalVarianceAggregate) Result() (tree.Datum, error) {
func (a *decimalVarianceAggregate) intermediateResult() (tree.Datum, error) {
if a.agg.Count().Cmp(decimalTwo) < 0 {
return tree.DNull, nil
}
Expand All @@ -4105,7 +4102,7 @@ func (a *decimalVarianceAggregate) Result() (tree.Datum, error) {
return nil, err
}
dd := &tree.DDecimal{}
if _, err = tree.DecimalCtx.Quo(&dd.Decimal, &sqrDiff.(*tree.DDecimal).Decimal, a.agg.Tmp()); err != nil {
if _, err = tree.IntermediateCtx.Quo(&dd.Decimal, &sqrDiff.(*tree.DDecimal).Decimal, a.agg.Tmp()); err != nil {
return nil, err
}
// Remove trailing zeros. Depending on the order in which the input is
Expand All @@ -4116,6 +4113,11 @@ func (a *decimalVarianceAggregate) Result() (tree.Datum, error) {
return dd, nil
}

// Result calculates the variance from the member square difference aggregator.
func (a *decimalVarianceAggregate) Result() (tree.Datum, error) {
return roundIntermediateDecimalResult(a)
}

// Reset implements eval.AggregateFunc interface.
func (a *floatVarianceAggregate) Reset(ctx context.Context) {
a.agg.Reset(ctx)
Expand Down Expand Up @@ -4208,8 +4210,7 @@ func (a *floatVarPopAggregate) Result() (tree.Datum, error) {
return tree.NewDFloat(tree.DFloat(float64(*sqrDiff.(*tree.DFloat)) / (float64(a.agg.Count())))), nil
}

// Result calculates the population variance from the member square difference aggregator.
func (a *decimalVarPopAggregate) Result() (tree.Datum, error) {
func (a *decimalVarPopAggregate) intermediateResult() (tree.Datum, error) {
if a.agg.Count().Cmp(decimalOne) < 0 {
return tree.DNull, nil
}
Expand All @@ -4218,17 +4219,21 @@ func (a *decimalVarPopAggregate) Result() (tree.Datum, error) {
return nil, err
}
dd := &tree.DDecimal{}
if _, err = tree.DecimalCtx.Quo(&dd.Decimal, &sqrDiff.(*tree.DDecimal).Decimal, a.agg.Count()); err != nil {
if _, err = tree.IntermediateCtx.Quo(&dd.Decimal, &sqrDiff.(*tree.DDecimal).Decimal, a.agg.Count()); err != nil {
return nil, err
}
// Remove trailing zeros. Depending on the order in which the input is
// processed, some number of trailing zeros could be added to the
// output. Remove them so that the results are the same regardless of
// order.
dd.Decimal.Reduce(&dd.Decimal)
// Remove trailing zeros. Depending on the order in which the input
// is processed, some number of trailing zeros could be added to the
// output. Remove them so that the results are the same regardless of order.
dd.Reduce(&dd.Decimal)
return dd, nil
}

// Result calculates the population variance from the member square difference aggregator.
func (a *decimalVarPopAggregate) Result() (tree.Datum, error) {
return roundIntermediateDecimalResult(a)
}

// Reset implements eval.AggregateFunc interface.
func (a *floatVarPopAggregate) Reset(ctx context.Context) {
a.agg.Reset(ctx)
Expand Down Expand Up @@ -4264,7 +4269,7 @@ type floatStdDevAggregate struct {
}

type decimalStdDevAggregate struct {
agg eval.AggregateFunc
agg aggregateWithIntermediateResult
}

// Both StdDev and FinalStdDev aggregators have the same codepath for
Expand All @@ -4276,7 +4281,8 @@ type decimalStdDevAggregate struct {
func newIntStdDevAggregate(
params []*types.T, evalCtx *eval.Context, arguments tree.Datums,
) eval.AggregateFunc {
return &decimalStdDevAggregate{agg: newIntVarianceAggregate(params, evalCtx, arguments)}
agg := newIntVarianceAggregate(params, evalCtx, arguments)
return &decimalStdDevAggregate{agg: agg.(aggregateWithIntermediateResult)}
}

func newFloatStdDevAggregate(
Expand All @@ -4288,7 +4294,8 @@ func newFloatStdDevAggregate(
func newDecimalStdDevAggregate(
params []*types.T, evalCtx *eval.Context, arguments tree.Datums,
) eval.AggregateFunc {
return &decimalStdDevAggregate{agg: newDecimalVarianceAggregate(params, evalCtx, arguments)}
agg := newDecimalVarianceAggregate(params, evalCtx, arguments)
return &decimalStdDevAggregate{agg: agg.(aggregateWithIntermediateResult)}
}

func newFloatFinalStdDevAggregate(
Expand All @@ -4300,13 +4307,15 @@ func newFloatFinalStdDevAggregate(
func newDecimalFinalStdDevAggregate(
params []*types.T, evalCtx *eval.Context, arguments tree.Datums,
) eval.AggregateFunc {
return &decimalStdDevAggregate{agg: newDecimalFinalVarianceAggregate(params, evalCtx, arguments)}
agg := newDecimalFinalVarianceAggregate(params, evalCtx, arguments)
return &decimalStdDevAggregate{agg: agg.(aggregateWithIntermediateResult)}
}

func newIntStdDevPopAggregate(
params []*types.T, evalCtx *eval.Context, arguments tree.Datums,
) eval.AggregateFunc {
return &decimalStdDevAggregate{agg: newIntVarPopAggregate(params, evalCtx, arguments)}
agg := newIntVarPopAggregate(params, evalCtx, arguments)
return &decimalStdDevAggregate{agg: agg.(aggregateWithIntermediateResult)}
}

func newFloatStdDevPopAggregate(
Expand All @@ -4318,13 +4327,15 @@ func newFloatStdDevPopAggregate(
func newDecimalStdDevPopAggregate(
params []*types.T, evalCtx *eval.Context, arguments tree.Datums,
) eval.AggregateFunc {
return &decimalStdDevAggregate{agg: newDecimalVarPopAggregate(params, evalCtx, arguments)}
agg := newDecimalVarPopAggregate(params, evalCtx, arguments)
return &decimalStdDevAggregate{agg: agg.(aggregateWithIntermediateResult)}
}

func newDecimalFinalStdDevPopAggregate(
params []*types.T, evalCtx *eval.Context, arguments tree.Datums,
) eval.AggregateFunc {
return &decimalStdDevAggregate{agg: newDecimalFinalVarPopAggregate(params, evalCtx, arguments)}
agg := newDecimalFinalVarPopAggregate(params, evalCtx, arguments)
return &decimalStdDevAggregate{agg: agg.(aggregateWithIntermediateResult)}
}

func newFloatFinalStdDevPopAggregate(
Expand Down Expand Up @@ -4369,13 +4380,7 @@ func (a *floatStdDevAggregate) Result() (tree.Datum, error) {

// Result computes the square root of the variance aggregator.
func (a *decimalStdDevAggregate) Result() (tree.Datum, error) {
// TODO(richardwu): both decimalVarianceAggregate and
// finalDecimalVarianceAggregate return a decimal result with
// default tree.DecimalCtx precision. We want to be able to specify that the
// varianceAggregate use tree.IntermediateCtx (with the extra precision)
// since it is returning an intermediate value for stdDevAggregate (of
// which we take the Sqrt).
variance, err := a.agg.Result()
variance, err := a.agg.intermediateResult()
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 75084b0

Please sign in to comment.