Skip to content

Commit

Permalink
ARROW-10263: [C++][Compute] Improve variance kernel numerical stability
Browse files Browse the repository at this point in the history
Improve variance merging method to address stability issue when merging
short chunks with approximate mean value.

Improve reference variance accuracy by leveraging Kahan summation.

Closes apache#8437 from cyb70289/variance-stability

Authored-by: Yibo Cai <yibo.cai@arm.com>
Signed-off-by: Antoine Pitrou <antoine@python.org>
  • Loading branch information
cyb70289 authored and kszucs committed Oct 19, 2020
1 parent f72575c commit a3a35b2
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 26 deletions.
44 changes: 28 additions & 16 deletions cpp/src/arrow/compute/kernels/aggregate_test.cc
Expand Up @@ -953,14 +953,13 @@ class TestPrimitiveVarStdKernel : public ::testing::Test {
using ScalarType = typename TypeTraits<DoubleType>::ScalarType;

void AssertVarStdIs(const Array& array, const VarianceOptions& options,
double expected_var, double diff = 0) {
AssertVarStdIsInternal(array, options, expected_var, diff);
double expected_var) {
AssertVarStdIsInternal(array, options, expected_var);
}

void AssertVarStdIs(const std::shared_ptr<ChunkedArray>& array,
const VarianceOptions& options, double expected_var,
double diff = 0) {
AssertVarStdIsInternal(array, options, expected_var, diff);
const VarianceOptions& options, double expected_var) {
AssertVarStdIsInternal(array, options, expected_var);
}

void AssertVarStdIs(const std::string& json, const VarianceOptions& options,
Expand Down Expand Up @@ -999,18 +998,14 @@ class TestPrimitiveVarStdKernel : public ::testing::Test {

private:
void AssertVarStdIsInternal(const Datum& array, const VarianceOptions& options,
double expected_var, double diff = 0) {
double expected_var) {
ASSERT_OK_AND_ASSIGN(Datum out_var, Variance(array, options));
ASSERT_OK_AND_ASSIGN(Datum out_std, Stddev(array, options));
auto var = checked_cast<const ScalarType*>(out_var.scalar().get());
auto std = checked_cast<const ScalarType*>(out_std.scalar().get());
ASSERT_TRUE(var->is_valid && std->is_valid);
ASSERT_DOUBLE_EQ(std->value * std->value, var->value);
if (diff == 0) {
ASSERT_DOUBLE_EQ(var->value, expected_var); // < 4ULP
} else {
ASSERT_NEAR(var->value, expected_var, diff);
}
ASSERT_DOUBLE_EQ(var->value, expected_var); // < 4ULP
}

void AssertVarStdIsInvalidInternal(const Datum& array, const VarianceOptions& options) {
Expand Down Expand Up @@ -1070,22 +1065,39 @@ TEST_F(TestVarStdKernelStability, Basics) {
VarianceOptions options{1}; // ddof = 1
this->AssertVarStdIs("[100000004, 100000007, 100000013, 100000016]", options, 30.0);
this->AssertVarStdIs("[1000000004, 1000000007, 1000000013, 1000000016]", options, 30.0);

#ifndef __MINGW32__ // MinGW has precision issues
// This test is to make sure our variance combining method is stable.
// XXX: The reference value from numpy is actually wrong due to floating
// point limits. The correct result should equals variance(90, 0) = 4050.
std::vector<std::string> chunks = {"[40000008000000490]", "[40000008000000400]"};
this->AssertVarStdIs(chunks, options, 3904.0);
#endif
}

// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
void KahanSum(double& sum, double& adjust, double addend) {
double y = addend - adjust;
double t = sum + y;
adjust = (t - sum) - y;
sum = t;
}

// Calculate reference variance with Welford's online algorithm
// Calculate reference variance with Welford's online algorithm + Kahan summation
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
std::pair<double, double> WelfordVar(const Array& array) {
const auto& array_numeric = reinterpret_cast<const DoubleArray&>(array);
const auto values = array_numeric.raw_values();
internal::BitmapReader reader(array.null_bitmap_data(), array.offset(), array.length());
double count = 0, mean = 0, m2 = 0;
double mean_adjust = 0, m2_adjust = 0;
for (int64_t i = 0; i < array.length(); ++i) {
if (reader.IsSet()) {
++count;
double delta = values[i] - mean;
mean += delta / count;
KahanSum(mean, mean_adjust, delta / count);
double delta2 = values[i] - mean;
m2 += delta * delta2;
KahanSum(m2, m2_adjust, delta * delta2);
}
reader.Next();
}
Expand Down Expand Up @@ -1116,8 +1128,8 @@ TEST_F(TestVarStdKernelRandom, Basics) {
double var_population, var_sample;
std::tie(var_population, var_sample) = WelfordVar(*(array->Slice(0, total_size)));

this->AssertVarStdIs(chunked, VarianceOptions{0}, var_population, 0.0001);
this->AssertVarStdIs(chunked, VarianceOptions{1}, var_sample, 0.0001);
this->AssertVarStdIs(chunked, VarianceOptions{0}, var_population);
this->AssertVarStdIs(chunked, VarianceOptions{1}, var_sample);
}

} // namespace compute
Expand Down
21 changes: 11 additions & 10 deletions cpp/src/arrow/compute/kernels/aggregate_var_std.cc
Expand Up @@ -53,32 +53,33 @@ struct VarStdState {
[]() {});

this->count = count;
this->sum = sum;
this->mean = mean;
this->m2 = m2;
}

// Combine `m2` from two chunks
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
// Combine `m2` from two chunks (m2 = n*s2)
// https://www.emathzone.com/tutorials/basic-statistics/combined-variance.html
void MergeFrom(const ThisType& state) {
if (state.count == 0) {
return;
}
if (this->count == 0) {
this->count = state.count;
this->sum = state.sum;
this->mean = state.mean;
this->m2 = state.m2;
return;
}
double delta = this->sum / this->count - state.sum / state.count;
this->m2 += state.m2 +
delta * delta * this->count * state.count / (this->count + state.count);
double mean = (this->mean * this->count + state.mean * state.count) /
(this->count + state.count);
this->m2 += state.m2 + this->count * (this->mean - mean) * (this->mean - mean) +
state.count * (state.mean - mean) * (state.mean - mean);
this->count += state.count;
this->sum += state.sum;
this->mean = mean;
}

int64_t count = 0;
double sum = 0;
double m2 = 0; // sum((X-mean)^2)
double mean = 0;
double m2 = 0; // m2 = count*s2 = sum((X-mean)^2)
};

enum class VarOrStd : bool { Var, Std };
Expand Down

0 comments on commit a3a35b2

Please sign in to comment.