Skip to content

Commit

Permalink
[7.17] Fix cardinality agg in async search (#82108) (#82131)
Browse files Browse the repository at this point in the history
Async search reads aggregation results on many threads. It's quite
possible for it to concurrently serialize an aggregation and render
it's xcontent. But the cardinality agg's results were not thread safe!
They reuse non-thread safe shard constructs used to collect the agg.
It's fine for the collection side not to be thread safe. But not
results.

Anyway! This would sometimes cause the async search index to contain
invalid results which would fail to deserialize. This seems to happen
frequently enough to some folks that it makes cardinality totally
unusable with async search. So far as I can tell you have to create a
race on the iterator to make that happen. This is enough:
```
curl -XDELETE -uelastic:password localhost:9200/test
echo
for i in {1..100}; do
  rm -f /tmp/bulk
  printf "%03d:  " $i
  for j in {1..10000}; do
    echo '{"index": {}}' >> /tmp/bulk
    echo '{"i": '$i', "j": '$j'}' >> /tmp/bulk
  done
  curl -s -XPOST -HContent-Type:application/json -uelastic:password localhost:9200/test/_bulk?pretty --data-binary @/tmp/bulk | grep error
done

while true; do
  id=$(curl -s -XPOST -HContent-Type:application/json -uelastic:password 'localhost:9200/test/_async_search?pretty&request_cache=false&wait_for_completion_timeout=10ms' -d'{
    "size": 0,
    "aggs": {
      "i": {
        "terms": {
          "field": "i"
        },
        "aggs": {
          "j": {
            "cardinality": {
              "field": "j"
            }
          }
        }
      }
    }
  }' | jq -r .id)

  while curl -s -HContent-Type:application/json -uelastic:password localhost:9200/_async_search/$id?pretty | tee out | grep '"is_running" : true'; do
    cat out
  done
  cat out

  sleep 1
  curl --fail-with-body -s -HContent-Type:application/json -uelastic:password localhost:9200/_async_search/$id?pretty || break
done
```

Run that without this PR and it'll break with a message about being
unable to deserialize stuff from the index. It'll give a 400 error too
which is totally bogus. On my laptop it takes less than ten iterations
of the loop.

So this PR fixes it! It removes the non-thread safe stuff from the
cardinality results. It also adds a half dozen extra unit tests that'll
be run for hundreds of objects which should catch similar sorts of
possible errors.
  • Loading branch information
nik9000 committed Dec 30, 2021
1 parent 5735f10 commit e09d863
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,11 @@ protected void addRunLen(long bucketOrd, int register, int runLen) {
}

void upgradeToHll(long bucketOrd) {
hll.ensureCapacity(bucketOrd + 1);
final AbstractLinearCounting.HashesIterator hashes = lc.values(bucketOrd);
// We need to copy values into an arrays as we will override
// the values on the buffer
hll.ensureCapacity(bucketOrd + 1);
// It's safe to reuse lc's readSpare because we're single threaded.
final AbstractLinearCounting.HashesIterator hashes = new LinearCountingIterator(lc, lc.readSpare, bucketOrd);
final IntArray values = lc.bigArrays.newIntArray(hashes.size());
try {
int i = 0;
Expand Down Expand Up @@ -201,15 +202,15 @@ private void merge(long thisBucket, AbstractHyperLogLog.RunLenIterator runLens)

private static class HyperLogLog extends AbstractHyperLogLog implements Releasable {
private final BigArrays bigArrays;
private final HyperLogLogIterator iterator;
private final int precision;
// array for holding the runlens.
private ByteArray runLens;

HyperLogLog(BigArrays bigArrays, long initialBucketCount, int precision) {
super(precision);
this.runLens = bigArrays.newByteArray(initialBucketCount << precision);
this.bigArrays = bigArrays;
this.iterator = new HyperLogLogIterator(this, precision, m);
this.precision = precision;
}

public long maxOrd() {
Expand All @@ -224,8 +225,7 @@ protected void addRunLen(long bucketOrd, int register, int encoded) {

@Override
protected RunLenIterator getRunLens(long bucketOrd) {
iterator.reset(bucketOrd);
return iterator;
return new HyperLogLogIterator(this, bucketOrd);
}

protected void reset(long bucketOrd) {
Expand All @@ -245,25 +245,18 @@ public void close() {
private static class HyperLogLogIterator implements AbstractHyperLogLog.RunLenIterator {

private final HyperLogLog hll;
private final int m, p;
int pos;
long start;
private byte value;

HyperLogLogIterator(HyperLogLog hll, int p, int m) {
HyperLogLogIterator(HyperLogLog hll, long bucket) {
this.hll = hll;
this.m = m;
this.p = p;
}

void reset(long bucket) {
pos = 0;
start = bucket << p;
start = bucket << hll.p;
}

@Override
public boolean next() {
if (pos < m) {
if (pos < hll.m) {
value = hll.runLens.get(start + pos);
pos++;
return true;
Expand All @@ -284,31 +277,31 @@ private static class LinearCounting extends AbstractLinearCounting implements Re
private final BytesRef readSpare;
private final ByteBuffer writeSpare;
private final BigArrays bigArrays;
private final LinearCountingIterator iterator;
// We are actually using HyperLogLog's runLens array but interpreting it as a hash set for linear counting.
private final HyperLogLog hll;
private final int capacity;
// Number of elements stored.
private IntArray sizes;

LinearCounting(BigArrays bigArrays, long initialBucketCount, int p, HyperLogLog hll) {
super(p);
this.bigArrays = bigArrays;
this.hll = hll;
final int capacity = (1 << p) / 4; // because ints take 4 bytes
this.capacity = (1 << p) / 4; // because ints take 4 bytes
threshold = (int) (capacity * MAX_LOAD_FACTOR);
mask = capacity - 1;
sizes = bigArrays.newIntArray(initialBucketCount);
readSpare = new BytesRef();
writeSpare = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN);
iterator = new LinearCountingIterator(this, capacity);
}

@Override
protected int addEncoded(long bucketOrd, int encoded) {
sizes = bigArrays.grow(sizes, bucketOrd + 1);
assert encoded != 0;
for (int i = (encoded & mask);; i = (i + 1) & mask) {
final int v = get(bucketOrd, i);
hll.runLens.get(index(bucketOrd, i), 4, readSpare);
final int v = ByteUtils.readIntLE(readSpare.bytes, readSpare.offset);
if (v == 0) {
// means unused, take it!
set(bucketOrd, i, encoded);
Expand All @@ -332,19 +325,14 @@ protected int size(long bucketOrd) {

@Override
protected HashesIterator values(long bucketOrd) {
iterator.reset(bucketOrd, size(bucketOrd));
return iterator;
// Make a fresh BytesRef for reading scratch work because this method can be called on many threads
return new LinearCountingIterator(this, new BytesRef(), bucketOrd);
}

private long index(long bucketOrd, int index) {
return (bucketOrd << p) + (index << 2);
}

private int get(long bucketOrd, int index) {
hll.runLens.get(index(bucketOrd, index), 4, readSpare);
return ByteUtils.readIntLE(readSpare.bytes, readSpare.offset);
}

private void set(long bucketOrd, int index, int value) {
writeSpare.putInt(0, value);
hll.runLens.set(index(bucketOrd, index), writeSpare.array(), 0, 4);
Expand All @@ -355,8 +343,10 @@ private int recomputedSize(long bucketOrd) {
return 0;
}
int size = 0;
BytesRef spare = new BytesRef();
for (int i = 0; i <= mask; ++i) {
final int v = get(bucketOrd, i);
hll.runLens.get(index(bucketOrd, i), 4, spare);
final int v = ByteUtils.readIntLE(spare.bytes, spare.offset);
if (v != 0) {
++size;
}
Expand All @@ -373,20 +363,18 @@ public void close() {
private static class LinearCountingIterator implements AbstractLinearCounting.HashesIterator {

private final LinearCounting lc;
private final int capacity;
private int pos, size;
private long bucketOrd;
private final BytesRef spare;
private final long bucketOrd;
private final int size;
private int pos;
private int value;

LinearCountingIterator(LinearCounting lc, int capacity) {
LinearCountingIterator(LinearCounting lc, BytesRef spare, long bucketOrd) {
this.lc = lc;
this.capacity = capacity;
}

void reset(long bucketOrd, int size) {
this.spare = spare;
this.bucketOrd = bucketOrd;
this.size = size;
this.pos = size == 0 ? capacity : 0;
this.size = lc.size(bucketOrd);
this.pos = size == 0 ? lc.capacity : 0;
}

@Override
Expand All @@ -396,9 +384,10 @@ public int size() {

@Override
public boolean next() {
if (pos < capacity) {
for (; pos < capacity; ++pos) {
final int k = lc.get(bucketOrd, pos);
if (pos < lc.capacity) {
for (; pos < lc.capacity; ++pos) {
lc.hll.runLens.get(lc.index(bucketOrd, pos), 4, spare);
int k = ByteUtils.readIntLE(spare.bytes, spare.offset);
if (k != 0) {
++pos;
value = k;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ protected void addEncoded(long bucket, int encoded) {
private static class LinearCounting extends AbstractLinearCounting implements Releasable {

private final BigArrays bigArrays;
private final LinearCountingIterator iterator;
// We are actually using HyperLogLog's runLens array but interpreting it as a hash set for linear counting.
// Number of elements stored.
private ObjectArray<IntArray> values;
Expand All @@ -105,7 +104,6 @@ private static class LinearCounting extends AbstractLinearCounting implements Re
}
this.values = values;
this.sizes = sizes;
iterator = new LinearCountingIterator();
}

@Override
Expand Down Expand Up @@ -138,8 +136,7 @@ protected int size(long bucketOrd) {

@Override
protected HashesIterator values(long bucketOrd) {
iterator.reset(values.get(bucketOrd), size(bucketOrd));
return iterator;
return new LinearCountingIterator(values.get(bucketOrd), size(bucketOrd));
}

private int set(long bucketOrd, int value) {
Expand Down Expand Up @@ -176,16 +173,14 @@ public void close() {

private static class LinearCountingIterator implements AbstractLinearCounting.HashesIterator {

IntArray values;
int size, value;
private final IntArray values;
private final int size;
private int value;
private long pos;

LinearCountingIterator() {}

void reset(IntArray values, int size) {
LinearCountingIterator(IntArray values, int size) {
this.values = values;
this.size = size;
this.pos = 0;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,11 @@

public class InternalCardinalityTests extends InternalAggregationTestCase<InternalCardinality> {
private static List<HyperLogLogPlusPlus> algos;
private static int p;

@Override
public void setUp() throws Exception {
super.setUp();
algos = new ArrayList<>();
p = randomIntBetween(AbstractHyperLogLog.MIN_PRECISION, AbstractHyperLogLog.MAX_PRECISION);
}

@After // we force @After to have it run before ESTestCase#after otherwise it fails
Expand All @@ -46,18 +44,33 @@ public void tearDown() throws Exception {

@Override
protected InternalCardinality createTestInstance(String name, Map<String, Object> metadata) {
return createTestInstance(name, metadata, randomIntBetween(AbstractHyperLogLog.MIN_PRECISION, AbstractHyperLogLog.MAX_PRECISION));
}

private InternalCardinality createTestInstance(String name, Map<String, Object> metadata, int precision) {
HyperLogLogPlusPlus hllpp = new HyperLogLogPlusPlus(
p,
precision,
new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()),
1
);
algos.add(hllpp);
for (int i = 0; i < 100; i++) {
hllpp.collect(0, BitMixer.mix64(randomIntBetween(1, 100)));
int values = between(0, 1000);
for (int i = 0; i < values; i++) {
hllpp.collect(0, BitMixer.mix64(randomInt()));
}
return new InternalCardinality(name, hllpp, metadata);
}

@Override
protected List<InternalCardinality> randomResultsToReduce(String name, int size) {
int precision = randomIntBetween(AbstractHyperLogLog.MIN_PRECISION, AbstractHyperLogLog.MAX_PRECISION);
List<InternalCardinality> result = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
result.add(createTestInstance(name, createTestMetadata(), precision));
}
return result;
}

@Override
protected void assertReduced(InternalCardinality reduced, List<InternalCardinality> inputs) {
HyperLogLogPlusPlus[] algos = inputs.stream().map(InternalCardinality::getState).toArray(size -> new HyperLogLogPlusPlus[size]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ static TaskInfo randomTaskInfo() {
Task.Status status = randomBoolean() ? randomRawTaskStatus() : null;
String description = randomBoolean() ? randomAlphaOfLength(5) : null;
long startTime = randomLong();
long runningTimeNanos = randomLong();
long runningTimeNanos = randomNonNegativeLong();
boolean cancellable = randomBoolean();
boolean cancelled = cancellable && randomBoolean();
TaskId parentTaskId = randomBoolean() ? TaskId.EMPTY_TASK_ID : randomTaskId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

package org.elasticsearch.test;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.rest.action.search.RestSearchAction;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
Expand All @@ -18,8 +20,11 @@
import java.io.IOException;
import java.time.Instant;
import java.util.Date;
import java.util.concurrent.ExecutionException;
import java.util.function.Predicate;

import static java.util.Collections.singletonMap;
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.AbstractXContentTestCase.xContentTester;

public abstract class AbstractSerializingTestCase<T extends ToXContent & Writeable> extends AbstractWireSerializingTestCase<T> {
Expand All @@ -40,6 +45,40 @@ public final void testFromXContent() throws IOException {
.test();
}

/**
* Calls {@link ToXContent#toXContent} on many threads and verifies that
* they produce the same result. Async search sometimes does this to
* aggregation responses and, in general, we think it's reasonable for
* everything that can convert itself to json to be able to do so
* concurrently.
*/
public final void testConcurrentToXContent() throws IOException, InterruptedException, ExecutionException {
XContentType xContentType = randomFrom(XContentType.values());
T testInstance = createXContextTestInstance(xContentType);
ToXContent.Params params = new ToXContent.DelegatingMapParams(
singletonMap(RestSearchAction.TYPED_KEYS_PARAM, "true"),
getToXContentParams()
);
boolean humanReadable = randomBoolean();
BytesRef firstTimeBytes = toXContent(testInstance, xContentType, params, humanReadable).toBytesRef();

/*
* 500 rounds seems to consistently reproduce the issue on Nik's
* laptop. Larger numbers are going to be slower but more likely
* to reproduce the issue.
*/
int rounds = scaledRandomIntBetween(300, 5000);
concurrentTest(() -> {
try {
for (int r = 0; r < rounds; r++) {
assertEquals(firstTimeBytes, toXContent(testInstance, xContentType, params, humanReadable).toBytesRef());
}
} catch (IOException e) {
throw new AssertionError(e);
}
});
}

/**
* Parses to a new instance using the provided {@link XContentParser}
*/
Expand Down

0 comments on commit e09d863

Please sign in to comment.