Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8214761: Bug in parallel Kahan summation implementation #4674

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -156,7 +156,9 @@ public void combine(DoubleSummaryStatistics other) {
count += other.count;
simpleSum += other.simpleSum;
sumWithCompensation(other.sum);
sumWithCompensation(other.sumCompensation);

//Negating this value because low-order bits are in negated form
sumWithCompensation(-other.sumCompensation);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't that be double tmp = sum - sumCompensation; in getSum() in line 246 too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch will review and make the change.

min = Math.min(min, other.min);
max = Math.max(max, other.max);
}
Expand Down Expand Up @@ -241,7 +243,7 @@ public final long getCount() {
*/
public final double getSum() {
// Better error bounds to add both terms as the final sum
double tmp = sum + sumCompensation;
double tmp = sum - sumCompensation;
if (Double.isNaN(tmp) && Double.isInfinite(simpleSum))
// If the compensated sum is spuriously NaN from
// accumulating one or more same-signed infinite values,
Expand Down
17 changes: 12 additions & 5 deletions src/java.base/share/classes/java/util/stream/Collectors.java
Expand Up @@ -734,7 +734,8 @@ public static<T,A,R,RR> Collector<T,A,RR> collectingAndThen(Collector<T,A,R> dow
a[2] += val;},
(a, b) -> { sumWithCompensation(a, b[0]);
a[2] += b[2];
return sumWithCompensation(a, b[1]); },
//Negating this value because low-order bits are in negated form
return sumWithCompensation(a, -b[1]); },
a -> computeFinalSum(a),
CH_NOID);
}
Expand Down Expand Up @@ -765,8 +766,8 @@ static double[] sumWithCompensation(double[] intermediateSum, double value) {
* correctly-signed infinity stored in the simple sum.
*/
static double computeFinalSum(double[] summands) {
// Better error bounds to add both terms as the final sum
double tmp = summands[0] + summands[1];
// Final sum with better error bounds subtract second summand as it is negated
double tmp = summands[0] - summands[1];
double simpleSum = summands[summands.length - 1];
if (Double.isNaN(tmp) && Double.isInfinite(simpleSum))
return simpleSum;
Expand Down Expand Up @@ -840,13 +841,19 @@ static double computeFinalSum(double[] summands) {
/*
* In the arrays allocated for the collect operation, index 0
* holds the high-order bits of the running sum, index 1 holds
* the low-order bits of the sum computed via compensated
* the negated low-order bits of the sum computed via compensated
* summation, and index 2 holds the number of values seen.
*/
return new CollectorImpl<>(
() -> new double[4],
(a, t) -> { double val = mapper.applyAsDouble(t); sumWithCompensation(a, val); a[2]++; a[3]+= val;},
(a, b) -> { sumWithCompensation(a, b[0]); sumWithCompensation(a, b[1]); a[2] += b[2]; a[3] += b[3]; return a; },
(a, b) -> {
sumWithCompensation(a, b[0]);
//Negating this value because low-order bits are in negated form
sumWithCompensation(a, -b[1]);
a[2] += b[2]; a[3] += b[3];
return a;
},
a -> (a[2] == 0) ? 0.0d : (computeFinalSum(a) / a[2]),
CH_NOID);
}
Expand Down
Expand Up @@ -442,7 +442,7 @@ public final double sum() {
/*
* In the arrays allocated for the collect operation, index 0
* holds the high-order bits of the running sum, index 1 holds
* the low-order bits of the sum computed via compensated
* the negated low-order bits of the sum computed via compensated
* summation, and index 2 holds the simple sum used to compute
* the proper result if the stream contains infinite values of
* the same sign.
Expand All @@ -454,7 +454,8 @@ public final double sum() {
},
(ll, rr) -> {
Collectors.sumWithCompensation(ll, rr[0]);
Collectors.sumWithCompensation(ll, rr[1]);
//Negating this value because low-order bits are in negated form
Collectors.sumWithCompensation(ll, -rr[1]);
ll[2] += rr[2];
});

Expand Down Expand Up @@ -497,7 +498,8 @@ public final OptionalDouble average() {
},
(ll, rr) -> {
Collectors.sumWithCompensation(ll, rr[0]);
Collectors.sumWithCompensation(ll, rr[1]);
//Negating this value because low-order bits are in negated form
Collectors.sumWithCompensation(ll, -rr[1]);
ll[2] += rr[2];
ll[3] += rr[3];
});
Expand Down
141 changes: 141 additions & 0 deletions test/jdk/java/util/DoubleStreamSums/CompensatedSums.java
@@ -0,0 +1,141 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/

/*
* @test
* @bug 8214761
* @run testng CompensatedSums
* @summary
*/

import java.util.Random;
import java.util.function.BiConsumer;
import java.util.function.ObjDoubleConsumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;

import static org.testng.Assert.assertTrue;

import org.testng.Assert;
import org.testng.annotations.Test;

public class CompensatedSums {

@Test
public void testCompensatedSums() {
double naive = 0;
double jdkSequentialStreamError = 0;
double goodSequentialStreamError = 0;
double jdkParallelStreamError = 0;
double goodParallelStreamError = 0;
double badParallelStreamError = 0;

for (int loop = 0; loop < 100; loop++) {
// sequence of random numbers of varying magnitudes, both positive and negative
double[] rand = new Random().doubles(1_000_000)
.map(Math::log)
.map(x -> (Double.doubleToLongBits(x) % 2 == 0) ? x : -x)
.toArray();

// base case: standard Kahan summation
double[] sum = new double[2];
for (int i=0; i < rand.length; i++) {
sumWithCompensation(sum, rand[i]);
}

// All error is the squared difference of the standard Kahan Sum vs JDK Stream sum implementation
// Older less accurate implementations included here as the baseline.

// squared error of naive sum by reduction - should be large
naive += Math.pow(DoubleStream.of(rand).reduce((x, y) -> x+y).getAsDouble() - sum[0], 2);

// squared error of sequential sum - should be 0
jdkSequentialStreamError += Math.pow(DoubleStream.of(rand).sum() - sum[0], 2);

goodSequentialStreamError += Math.pow(computeFinalSum(DoubleStream.of(rand).collect(doubleSupplier,objDoubleConsumer,goodCollectorConsumer)) - sum[0], 2);

// squared error of parallel sum from the JDK
jdkParallelStreamError += Math.pow(DoubleStream.of(rand).parallel().sum() - sum[0], 2);

// squared error of parallel sum
goodParallelStreamError += Math.pow(computeFinalSum(DoubleStream.of(rand).parallel().collect(doubleSupplier,objDoubleConsumer,goodCollectorConsumer)) - sum[0], 2);

// the bad parallel stream
badParallelStreamError += Math.pow(computeFinalSum(DoubleStream.of(rand).parallel().collect(doubleSupplier,objDoubleConsumer,badCollectorConsumer)) - sum[0], 2);


}

Assert.assertEquals(goodSequentialStreamError, 0.0);
Assert.assertEquals(goodSequentialStreamError, jdkSequentialStreamError);

Assert.assertTrue(jdkParallelStreamError <= goodParallelStreamError);
Assert.assertTrue(badParallelStreamError > goodParallelStreamError);

Assert.assertTrue(naive > jdkSequentialStreamError);
Assert.assertTrue(naive > jdkParallelStreamError);

}

// from OpenJDK8 Collectors, unmodified
static double[] sumWithCompensation(double[] intermediateSum, double value) {
double tmp = value - intermediateSum[1];
double sum = intermediateSum[0];
double velvel = sum + tmp; // Little wolf of rounding error
intermediateSum[1] = (velvel - sum) - tmp;
intermediateSum[0] = velvel;
return intermediateSum;
}

// from OpenJDK8 Collectors, unmodified
static double computeFinalSum(double[] summands) {
double tmp = summands[0] + summands[1];
double simpleSum = summands[summands.length - 1];
if (Double.isNaN(tmp) && Double.isInfinite(simpleSum))
return simpleSum;
else
return tmp;
}

//Suppliers and consumers for Double Stream summation collection.
static Supplier<double[]> doubleSupplier = () -> new double[3];
static ObjDoubleConsumer<double[]> objDoubleConsumer = (double[] ll, double d) -> {
sumWithCompensation(ll, d);
ll[2] += d;
};
static BiConsumer<double[], double[]> badCollectorConsumer =
(ll, rr) -> {
sumWithCompensation(ll, rr[0]);
sumWithCompensation(ll, rr[1]);
ll[2] += rr[2];
};

static BiConsumer<double[], double[]> goodCollectorConsumer =
(ll, rr) -> {
sumWithCompensation(ll, rr[0]);
sumWithCompensation(ll, -rr[1]);
ll[2] += rr[2];
};

}
@@ -0,0 +1,72 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/

/*
* @test
* @bug 8214761
* @summary When combining two DoubleSummaryStatistics, the compensation
* has to be subtracted.
*/

import java.util.DoubleSummaryStatistics;

public class NegativeCompensation {
static final double VAL = 1.000000001;
static final int LOG_ITER = 21;

public static void main(String[] args) {
DoubleSummaryStatistics stat0 = new DoubleSummaryStatistics();
DoubleSummaryStatistics stat1 = new DoubleSummaryStatistics();
DoubleSummaryStatistics stat2 = new DoubleSummaryStatistics();

stat1.accept(VAL);
stat1.accept(VAL);
stat2.accept(VAL);
stat2.accept(VAL);
stat2.accept(VAL);

for (int i = 0; i < LOG_ITER; ++i) {
stat1.combine(stat2);
stat2.combine(stat1);
}

System.out.println("count: " + stat2.getCount());
for (long i = 0, iend = stat2.getCount(); i < iend; ++i) {
stat0.accept(VAL);
}

double res = 0;
for(long i = 0, iend = stat2.getCount(); i < iend; ++i) {
res += VAL;
}

double absErrN = Math.abs(res - stat2.getSum());
double absErr = Math.abs(stat0.getSum() - stat2.getSum());
System.out.println("serial sum: " + stat0.getSum());
System.out.println("combined sum: " + stat2.getSum());
System.out.println("abs error: " + absErr);
if (absErr == 0.0) {
throw new RuntimeException("Absolute error is too big: " + absErr);
}
}
}