Skip to content

Commit

Permalink
Make BloomFilter concurrent.
Browse files Browse the repository at this point in the history
Original pull request by Val Markovic at #2761

-------------
Created by MOE: https://github.com/google/moe
MOE_MIGRATED_REVID=154881019
  • Loading branch information
lowasser authored and cpovirk committed May 2, 2017
1 parent 8f7bbd4 commit 6092a4a
Show file tree
Hide file tree
Showing 16 changed files with 1,754 additions and 116 deletions.
100 changes: 98 additions & 2 deletions android/guava-tests/test/com/google/common/hash/BloomFilterTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,25 @@
package com.google.common.hash;

import static com.google.common.base.Charsets.UTF_8;
import static com.google.common.hash.BloomFilterStrategies.BitArray;
import static com.google.common.truth.Truth.assertThat;

import com.google.common.base.Stopwatch;
import com.google.common.collect.ImmutableSet;
import com.google.common.hash.BloomFilterStrategies.LockFreeBitArray;
import com.google.common.math.LongMath;
import com.google.common.primitives.Ints;
import com.google.common.testing.EqualsTester;
import com.google.common.testing.NullPointerTester;
import com.google.common.testing.SerializableTester;
import com.google.common.util.concurrent.Uninterruptibles;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.lang.Thread.UncaughtExceptionHandler;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import junit.framework.TestCase;

Expand All @@ -39,12 +45,22 @@
* @author Dimitris Andreou
*/
public class BloomFilterTest extends TestCase {
private static final int NUM_PUTS = 100_000;
private static final ThreadLocal<Random> random = new ThreadLocal<Random>(){
@Override
protected Random initialValue() {
return new Random();
}
};

private static final int GOLDEN_PRESENT_KEY = random.get().nextInt();

@AndroidIncompatible // OutOfMemoryError
public void testLargeBloomFilterDoesntOverflow() {
long numBits = Integer.MAX_VALUE;
numBits++;

BitArray bitArray = new BitArray(numBits);
LockFreeBitArray bitArray = new LockFreeBitArray(numBits);
assertTrue(
"BitArray.bitSize() must return a positive number, but was " + bitArray.bitSize(),
bitArray.bitSize() > 0);
Expand Down Expand Up @@ -498,4 +514,84 @@ public void testBloomFilterStrategies() {
assertEquals(BloomFilterStrategies.MURMUR128_MITZ_32, BloomFilterStrategies.values()[0]);
assertEquals(BloomFilterStrategies.MURMUR128_MITZ_64, BloomFilterStrategies.values()[1]);
}

public void testNoRaceConditions() throws Exception {
final BloomFilter<Integer> bloomFilter =
BloomFilter.create(Funnels.integerFunnel(), 15_000_000, 0.01);

// This check has to be BEFORE the loop because the random insertions can
// flip GOLDEN_PRESENT_KEY to true even if it wasn't explicitly inserted
// (false positive).
assertThat(bloomFilter.mightContain(GOLDEN_PRESENT_KEY)).isFalse();
for (int i = 0; i < NUM_PUTS; i++) {
bloomFilter.put(getNonGoldenRandomKey());
}
bloomFilter.put(GOLDEN_PRESENT_KEY);

int numThreads = 12;
final double safetyFalsePositiveRate = 0.1;
final Stopwatch stopwatch = Stopwatch.createStarted();

Runnable task =
new Runnable() {
@Override
public void run() {
do {
// We can't have a GOLDEN_NOT_PRESENT_KEY because false positives are
// possible! It's false negatives that can't happen.
assertThat(bloomFilter.mightContain(GOLDEN_PRESENT_KEY)).isTrue();

int key = getNonGoldenRandomKey();
// We can't check that the key is mightContain() == false before the
// put() because the key could have already been generated *or* the
// bloom filter might say true even when it's not there (false
// positive).
bloomFilter.put(key);
// False negative should *never* happen.
assertThat(bloomFilter.mightContain(key)).isTrue();

// If this check ever fails, that means we need to either bump the
// number of expected insertions or don't run the test for so long.
// Don't forget, the bloom filter slowly saturates over time and the
// expected false positive probability goes up!
assertThat(bloomFilter.expectedFpp()).isLessThan(safetyFalsePositiveRate);
} while (stopwatch.elapsed(TimeUnit.SECONDS) < 1);
}
};

List<Throwable> exceptions = runThreadsAndReturnExceptions(numThreads, task);

assertThat(exceptions).isEmpty();
}

private static List<Throwable> runThreadsAndReturnExceptions(int numThreads, Runnable task) {
List<Thread> threads = new ArrayList<>(numThreads);
final List<Throwable> exceptions = new ArrayList<>(numThreads);
for (int i = 0; i < numThreads; i++) {
Thread thread = new Thread(task);
thread.setUncaughtExceptionHandler(
new UncaughtExceptionHandler() {
@Override
public void uncaughtException(Thread t, Throwable e) {
exceptions.add(e);
}
});
threads.add(thread);
}
for (Thread t : threads) {
t.start();
}
for (Thread t : threads) {
Uninterruptibles.joinUninterruptibly(t);
}
return exceptions;
}

private static int getNonGoldenRandomKey() {
int key;
do {
key = random.get().nextInt();
} while (key == GOLDEN_PRESENT_KEY);
return key;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.google.common.hash;

import com.google.common.hash.BloomFilterStrategies.BitArray;
import com.google.common.hash.BloomFilterStrategies.LockFreeBitArray;
import com.google.common.testing.AbstractPackageSanityTests;

/**
Expand All @@ -27,7 +27,7 @@

public class PackageSanityTests extends AbstractPackageSanityTests {
public PackageSanityTests() {
setDefault(BitArray.class, new BitArray(1));
setDefault(LockFreeBitArray.class, new LockFreeBitArray(1));
setDefault(HashCode.class, HashCode.fromInt(1));
setDefault(String.class, "MD5");
setDefault(int.class, 32);
Expand Down
34 changes: 18 additions & 16 deletions android/guava/src/com/google/common/hash/BloomFilter.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Objects;
import com.google.common.base.Predicate;
import com.google.common.hash.BloomFilterStrategies.BitArray;
import com.google.common.hash.BloomFilterStrategies.LockFreeBitArray;
import com.google.common.math.DoubleMath;
import com.google.common.primitives.SignedBytes;
import com.google.common.primitives.UnsignedBytes;
Expand Down Expand Up @@ -54,10 +54,13 @@
* of the code may not be readable by older versions of the code (e.g., a serialized Bloom filter
* generated today may <i>not</i> be readable by a binary that was compiled 6 months ago).
*
* <p>As of Guava 22.0, this class is thread-safe and lock-free. It internally uses atomics and
* compare-and-swap to ensure correctness when multiple threads are used to access it.
*
* @param <T> the type of instances that the {@code BloomFilter} accepts
* @author Dimitris Andreou
* @author Kevin Bourrillion
* @since 11.0
* @since 11.0 (thread-safe since 22.0)
*/
@Beta
public final class BloomFilter<T> implements Predicate<T>, Serializable {
Expand All @@ -73,14 +76,15 @@ interface Strategy extends java.io.Serializable {
*
* <p>Returns whether any bits changed as a result of this operation.
*/
<T> boolean put(T object, Funnel<? super T> funnel, int numHashFunctions, BitArray bits);
<T> boolean put(
T object, Funnel<? super T> funnel, int numHashFunctions, LockFreeBitArray bits);

/**
* Queries {@code numHashFunctions} bits of the given bit array, by hashing a user element;
* returns {@code true} if and only if all selected bits are set.
*/
<T> boolean mightContain(
T object, Funnel<? super T> funnel, int numHashFunctions, BitArray bits);
T object, Funnel<? super T> funnel, int numHashFunctions, LockFreeBitArray bits);

/**
* Identifier used to encode this strategy, when marshalled as part of a BloomFilter. Only
Expand All @@ -93,7 +97,7 @@ <T> boolean mightContain(
}

/** The bit set of the BloomFilter (not necessarily power of 2!) */
private final BitArray bits;
private final LockFreeBitArray bits;

/** Number of hashes per element */
private final int numHashFunctions;
Expand All @@ -106,11 +110,9 @@ <T> boolean mightContain(
*/
private final Strategy strategy;

/**
* Creates a BloomFilter.
*/
/** Creates a BloomFilter. */
private BloomFilter(
BitArray bits, int numHashFunctions, Funnel<? super T> funnel, Strategy strategy) {
LockFreeBitArray bits, int numHashFunctions, Funnel<? super T> funnel, Strategy strategy) {
checkArgument(numHashFunctions > 0, "numHashFunctions (%s) must be > 0", numHashFunctions);
checkArgument(
numHashFunctions <= 255, "numHashFunctions (%s) must be <= 255", numHashFunctions);
Expand Down Expand Up @@ -361,7 +363,7 @@ static <T> BloomFilter<T> create(
long numBits = optimalNumOfBits(expectedInsertions, fpp);
int numHashFunctions = optimalNumOfHashFunctions(expectedInsertions, numBits);
try {
return new BloomFilter<T>(new BitArray(numBits), numHashFunctions, funnel, strategy);
return new BloomFilter<T>(new LockFreeBitArray(numBits), numHashFunctions, funnel, strategy);
} catch (IllegalArgumentException e) {
throw new IllegalArgumentException("Could not create BloomFilter of " + numBits + " bits", e);
}
Expand Down Expand Up @@ -469,14 +471,14 @@ private static class SerialForm<T> implements Serializable {
final Strategy strategy;

SerialForm(BloomFilter<T> bf) {
this.data = bf.bits.data;
this.data = LockFreeBitArray.toPlainArray(bf.bits.data);
this.numHashFunctions = bf.numHashFunctions;
this.funnel = bf.funnel;
this.strategy = bf.strategy;
}

Object readResolve() {
return new BloomFilter<T>(new BitArray(data), numHashFunctions, funnel, strategy);
return new BloomFilter<T>(new LockFreeBitArray(data), numHashFunctions, funnel, strategy);
}

private static final long serialVersionUID = 1;
Expand All @@ -498,9 +500,9 @@ public void writeTo(OutputStream out) throws IOException {
DataOutputStream dout = new DataOutputStream(out);
dout.writeByte(SignedBytes.checkedCast(strategy.ordinal()));
dout.writeByte(UnsignedBytes.checkedCast(numHashFunctions)); // note: checked at the c'tor
dout.writeInt(bits.data.length);
for (long value : bits.data) {
dout.writeLong(value);
dout.writeInt(bits.data.length());
for (int i = 0; i < bits.data.length(); i++) {
dout.writeLong(bits.data.get(i));
}
}

Expand Down Expand Up @@ -536,7 +538,7 @@ public static <T> BloomFilter<T> readFrom(InputStream in, Funnel<? super T> funn
for (int i = 0; i < data.length; i++) {
data[i] = din.readLong();
}
return new BloomFilter<T>(new BitArray(data), numHashFunctions, funnel, strategy);
return new BloomFilter<T>(new LockFreeBitArray(data), numHashFunctions, funnel, strategy);
} catch (RuntimeException e) {
String message =
"Unable to deserialize BloomFilter from InputStream."
Expand Down
Loading

0 comments on commit 6092a4a

Please sign in to comment.