diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc index 39b3f8827fb18..6d97e79a23f90 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc @@ -953,14 +953,13 @@ class TestPrimitiveVarStdKernel : public ::testing::Test { using ScalarType = typename TypeTraits::ScalarType; void AssertVarStdIs(const Array& array, const VarianceOptions& options, - double expected_var, double diff = 0) { - AssertVarStdIsInternal(array, options, expected_var, diff); + double expected_var) { + AssertVarStdIsInternal(array, options, expected_var); } void AssertVarStdIs(const std::shared_ptr& array, - const VarianceOptions& options, double expected_var, - double diff = 0) { - AssertVarStdIsInternal(array, options, expected_var, diff); + const VarianceOptions& options, double expected_var) { + AssertVarStdIsInternal(array, options, expected_var); } void AssertVarStdIs(const std::string& json, const VarianceOptions& options, @@ -999,18 +998,14 @@ class TestPrimitiveVarStdKernel : public ::testing::Test { private: void AssertVarStdIsInternal(const Datum& array, const VarianceOptions& options, - double expected_var, double diff = 0) { + double expected_var) { ASSERT_OK_AND_ASSIGN(Datum out_var, Variance(array, options)); ASSERT_OK_AND_ASSIGN(Datum out_std, Stddev(array, options)); auto var = checked_cast(out_var.scalar().get()); auto std = checked_cast(out_std.scalar().get()); ASSERT_TRUE(var->is_valid && std->is_valid); ASSERT_DOUBLE_EQ(std->value * std->value, var->value); - if (diff == 0) { - ASSERT_DOUBLE_EQ(var->value, expected_var); // < 4ULP - } else { - ASSERT_NEAR(var->value, expected_var, diff); - } + ASSERT_DOUBLE_EQ(var->value, expected_var); // < 4ULP } void AssertVarStdIsInvalidInternal(const Datum& array, const VarianceOptions& options) { @@ -1070,22 +1065,39 @@ TEST_F(TestVarStdKernelStability, Basics) { VarianceOptions options{1}; // ddof = 1 this->AssertVarStdIs("[100000004, 100000007, 100000013, 100000016]", options, 30.0); this->AssertVarStdIs("[1000000004, 1000000007, 1000000013, 1000000016]", options, 30.0); + +#ifndef __MINGW32__ // MinGW has precision issues + // This test is to make sure our variance combining method is stable. + // XXX: The reference value from numpy is actually wrong due to floating + // point limits. The correct result should equals variance(90, 0) = 4050. + std::vector chunks = {"[40000008000000490]", "[40000008000000400]"}; + this->AssertVarStdIs(chunks, options, 3904.0); +#endif +} + +// https://en.wikipedia.org/wiki/Kahan_summation_algorithm +void KahanSum(double& sum, double& adjust, double addend) { + double y = addend - adjust; + double t = sum + y; + adjust = (t - sum) - y; + sum = t; } -// Calculate reference variance with Welford's online algorithm +// Calculate reference variance with Welford's online algorithm + Kahan summation // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm std::pair WelfordVar(const Array& array) { const auto& array_numeric = reinterpret_cast(array); const auto values = array_numeric.raw_values(); internal::BitmapReader reader(array.null_bitmap_data(), array.offset(), array.length()); double count = 0, mean = 0, m2 = 0; + double mean_adjust = 0, m2_adjust = 0; for (int64_t i = 0; i < array.length(); ++i) { if (reader.IsSet()) { ++count; double delta = values[i] - mean; - mean += delta / count; + KahanSum(mean, mean_adjust, delta / count); double delta2 = values[i] - mean; - m2 += delta * delta2; + KahanSum(m2, m2_adjust, delta * delta2); } reader.Next(); } @@ -1116,8 +1128,8 @@ TEST_F(TestVarStdKernelRandom, Basics) { double var_population, var_sample; std::tie(var_population, var_sample) = WelfordVar(*(array->Slice(0, total_size))); - this->AssertVarStdIs(chunked, VarianceOptions{0}, var_population, 0.0001); - this->AssertVarStdIs(chunked, VarianceOptions{1}, var_sample, 0.0001); + this->AssertVarStdIs(chunked, VarianceOptions{0}, var_population); + this->AssertVarStdIs(chunked, VarianceOptions{1}, var_sample); } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc index e2b98bb38fc75..327372ad4868b 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc @@ -53,32 +53,33 @@ struct VarStdState { []() {}); this->count = count; - this->sum = sum; + this->mean = mean; this->m2 = m2; } - // Combine `m2` from two chunks - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + // Combine `m2` from two chunks (m2 = n*s2) + // https://www.emathzone.com/tutorials/basic-statistics/combined-variance.html void MergeFrom(const ThisType& state) { if (state.count == 0) { return; } if (this->count == 0) { this->count = state.count; - this->sum = state.sum; + this->mean = state.mean; this->m2 = state.m2; return; } - double delta = this->sum / this->count - state.sum / state.count; - this->m2 += state.m2 + - delta * delta * this->count * state.count / (this->count + state.count); + double mean = (this->mean * this->count + state.mean * state.count) / + (this->count + state.count); + this->m2 += state.m2 + this->count * (this->mean - mean) * (this->mean - mean) + + state.count * (state.mean - mean) * (state.mean - mean); this->count += state.count; - this->sum += state.sum; + this->mean = mean; } int64_t count = 0; - double sum = 0; - double m2 = 0; // sum((X-mean)^2) + double mean = 0; + double m2 = 0; // m2 = count*s2 = sum((X-mean)^2) }; enum class VarOrStd : bool { Var, Std };