Skip to content

Commit

Permalink
Fix merging of terms aggregation with compound order (#64469)
Browse files Browse the repository at this point in the history
This change fixes a bug introduced in #61779 that uses a compound order to
compare buckets when merging. The bug is triggered when the compound order
uses a primary sort ordered by key (asc or desc).
This commit ensures that we always extract the primary sort when comparing keys
during merging.
The PR is marked as no-issue since the bug has not been released in any official version.
  • Loading branch information
jimczi committed Nov 5, 2020
1 parent 1819637 commit 977c779
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.search.aggregations.InternalOrder.isKeyAsc;
import static org.elasticsearch.search.aggregations.InternalOrder.isKeyOrder;

public abstract class InternalTerms<A extends InternalTerms<A, B>, B extends InternalTerms.Bucket<B>>
Expand Down Expand Up @@ -257,9 +258,9 @@ private long getDocCountError(InternalTerms<?, ?> terms) {
}

private List<B> reduceMergeSort(List<InternalAggregation> aggregations,
BucketOrder reduceOrder, ReduceContext reduceContext) {
assert isKeyOrder(reduceOrder);
final Comparator<MultiBucketsAggregation.Bucket> cmp = reduceOrder.comparator();
BucketOrder thisReduceOrder, ReduceContext reduceContext) {
assert isKeyOrder(thisReduceOrder);
final Comparator<MultiBucketsAggregation.Bucket> cmp = thisReduceOrder.comparator();
final PriorityQueue<IteratorAndCurrent<B>> pq = new PriorityQueue<IteratorAndCurrent<B>>(aggregations.size()) {
@Override
protected boolean lessThan(IteratorAndCurrent<B> a, IteratorAndCurrent<B> b) {
Expand Down Expand Up @@ -369,15 +370,22 @@ public InternalAggregation reduce(List<InternalAggregation> aggregations, Reduce
bucket.docCountError -= thisAggDocCountError;
}
}

final List<B> reducedBuckets;
/**
* Buckets returned by a partial reduce or a shard response are sorted by key since {@link Version#V_7_10_0}.
* That allows to perform a merge sort when reducing multiple aggregations together.
* For backward compatibility, we disable the merge sort and use ({@link InternalTerms#reduceLegacy} if any of
* the provided aggregations use a different {@link InternalTerms#reduceOrder}.
*/
BucketOrder thisReduceOrder = getReduceOrder(aggregations);
List<B> reducedBuckets = isKeyOrder(thisReduceOrder) ?
reduceMergeSort(aggregations, thisReduceOrder, reduceContext) : reduceLegacy(aggregations, reduceContext);
if (isKeyOrder(thisReduceOrder)) {
// extract the primary sort in case this is a compound order.
thisReduceOrder = InternalOrder.key(isKeyAsc(thisReduceOrder) ? true : false);
reducedBuckets = reduceMergeSort(aggregations, thisReduceOrder, reduceContext);
} else {
reducedBuckets = reduceLegacy(aggregations, reduceContext);
}
final B[] list;
if (reduceContext.isFinalReduce()) {
final int size = Math.min(requiredSize, reducedBuckets.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@

public abstract class InternalTermsTestCase extends InternalMultiBucketAggregationTestCase<InternalTerms<?, ?>> {

private boolean showDocCount;
private long docCountError;
protected boolean showDocCount;
protected long docCountError;

@Before
public void init() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,31 +34,24 @@
import java.util.Set;

public class StringTermsTests extends InternalTermsTestCase {

@Override
protected InternalTerms<?, ?> createTestInstance(String name,
Map<String, Object> metadata,
InternalAggregations aggregations,
boolean showTermDocCountError,
long docCountError) {
BucketOrder order = BucketOrder.count(false);
long minDocCount = 1;
int requiredSize = 3;
int shardSize = requiredSize + 2;
DocValueFormat format = DocValueFormat.RAW;
long otherDocCount = 0;
List<StringTerms.Bucket> buckets = new ArrayList<>();
final int numBuckets = randomNumberOfBuckets();
Set<BytesRef> terms = new HashSet<>();
for (int i = 0; i < numBuckets; ++i) {
BytesRef term = randomValueOtherThanMany(b -> terms.add(b) == false, () -> new BytesRef(randomAlphaOfLength(10)));
int docCount = randomIntBetween(1, 100);
buckets.add(new StringTerms.Bucket(term, docCount, aggregations, showTermDocCountError, docCountError, format));
return createTestInstance(generateRandomDict(), name, metadata, aggregations, showTermDocCountError, docCountError);
}

@Override
protected List<InternalTerms<?, ?>> randomResultsToReduce(String name, int size) {
List<InternalTerms<?, ?>> inputs = new ArrayList<>();
BytesRef[] dict = generateRandomDict();
for (int i = 0; i < size; i++) {
InternalTerms<?, ?> t = randomBoolean() ? createUnmappedInstance(name) : createTestInstance(dict, name);
inputs.add(t);
}
BucketOrder reduceOrder = rarely() ? order : BucketOrder.key(true);
Collections.sort(buckets, reduceOrder.comparator());
return new StringTerms(name, reduceOrder, order, requiredSize, minDocCount,
metadata, format, shardSize, showTermDocCountError, otherDocCount, buckets, docCountError);
return inputs;
}

@Override
Expand All @@ -82,74 +75,116 @@ protected Class<? extends ParsedMultiBucketAggregation> implementationClass() {
long docCountError = stringTerms.getDocCountError();
Map<String, Object> metadata = stringTerms.getMetadata();
switch (between(0, 8)) {
case 0:
name += randomAlphaOfLength(5);
break;
case 1:
requiredSize += between(1, 100);
break;
case 2:
minDocCount += between(1, 100);
break;
case 3:
shardSize += between(1, 100);
break;
case 4:
showTermDocCountError = showTermDocCountError == false;
break;
case 5:
otherDocCount += between(1, 100);
break;
case 6:
docCountError += between(1, 100);
break;
case 7:
buckets = new ArrayList<>(buckets);
buckets.add(new StringTerms.Bucket(new BytesRef(randomAlphaOfLengthBetween(1, 10)), randomNonNegativeLong(),
case 0:
name += randomAlphaOfLength(5);
break;
case 1:
requiredSize += between(1, 100);
break;
case 2:
minDocCount += between(1, 100);
break;
case 3:
shardSize += between(1, 100);
break;
case 4:
showTermDocCountError = showTermDocCountError == false;
break;
case 5:
otherDocCount += between(1, 100);
break;
case 6:
docCountError += between(1, 100);
break;
case 7:
buckets = new ArrayList<>(buckets);
buckets.add(new StringTerms.Bucket(new BytesRef(randomAlphaOfLengthBetween(1, 10)), randomNonNegativeLong(),
InternalAggregations.EMPTY, showTermDocCountError, docCountError, format));
break;
case 8:
if (metadata == null) {
metadata = new HashMap<>(1);
} else {
metadata = new HashMap<>(instance.getMetadata());
}
metadata.put(randomAlphaOfLength(15), randomInt());
break;
default:
throw new AssertionError("Illegal randomisation branch");
break;
case 8:
if (metadata == null) {
metadata = new HashMap<>(1);
} else {
metadata = new HashMap<>(instance.getMetadata());
}
metadata.put(randomAlphaOfLength(15), randomInt());
break;
default:
throw new AssertionError("Illegal randomisation branch");
}
Collections.sort(buckets, stringTerms.reduceOrder.comparator());
return new StringTerms(name, stringTerms.reduceOrder, order, requiredSize, minDocCount, metadata, format, shardSize,
showTermDocCountError, otherDocCount, buckets, docCountError);
showTermDocCountError, otherDocCount, buckets, docCountError);
} else {
String name = instance.getName();
BucketOrder order = instance.order;
int requiredSize = instance.requiredSize;
long minDocCount = instance.minDocCount;
Map<String, Object> metadata = instance.getMetadata();
switch (between(0, 3)) {
case 0:
name += randomAlphaOfLength(5);
break;
case 1:
requiredSize += between(1, 100);
break;
case 2:
minDocCount += between(1, 100);
break;
case 3:
if (metadata == null) {
metadata = new HashMap<>(1);
} else {
metadata = new HashMap<>(instance.getMetadata());
}
metadata.put(randomAlphaOfLength(15), randomInt());
break;
default:
throw new AssertionError("Illegal randomisation branch");
case 0:
name += randomAlphaOfLength(5);
break;
case 1:
requiredSize += between(1, 100);
break;
case 2:
minDocCount += between(1, 100);
break;
case 3:
if (metadata == null) {
metadata = new HashMap<>(1);
} else {
metadata = new HashMap<>(instance.getMetadata());
}
metadata.put(randomAlphaOfLength(15), randomInt());
break;
default:
throw new AssertionError("Illegal randomisation branch");
}
return new UnmappedTerms(name, order, requiredSize, minDocCount, metadata);
}
}

private BytesRef[] generateRandomDict() {
Set<BytesRef> terms = new HashSet<>();
int numTerms = randomIntBetween(2, 100);
for (int i = 0; i < numTerms; i++) {
terms.add(new BytesRef(randomAlphaOfLength(10)));
}
return terms.stream().toArray(BytesRef[]::new);
}

private InternalTerms<?, ?> createTestInstance(BytesRef[] dict, String name) {
return createTestInstance(dict, name, createTestMetadata(), createSubAggregations(), showDocCount, docCountError);
}

private InternalTerms<?, ?> createTestInstance(BytesRef[] dict,
String name,
Map<String, Object> metadata,
InternalAggregations aggregations,
boolean showTermDocCountError,
long docCountError) {
BucketOrder order = BucketOrder.count(false);
long minDocCount = 1;
int requiredSize = 3;
int shardSize = requiredSize + 2;
DocValueFormat format = DocValueFormat.RAW;
long otherDocCount = 0;
List<StringTerms.Bucket> buckets = new ArrayList<>();
final int numBuckets = randomNumberOfBuckets();
Set<BytesRef> terms = new HashSet<>();
for (int i = 0; i < numBuckets; ++i) {
BytesRef term = dict[randomIntBetween(0, dict.length-1)];
if (terms.add(term)) {
int docCount = randomIntBetween(1, 100);
buckets.add(new StringTerms.Bucket(term, docCount, aggregations, showTermDocCountError, docCountError, format));
}
}
BucketOrder reduceOrder = randomBoolean() ?
BucketOrder.compound(BucketOrder.key(true), BucketOrder.count(false)) : BucketOrder.key(true);
Collections.sort(buckets, reduceOrder.comparator());
return new StringTerms(name, reduceOrder, order, requiredSize, minDocCount,
metadata, format, shardSize, showTermDocCountError, otherDocCount, buckets, docCountError);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ public final T createTestInstance() {
return createTestInstance(randomAlphaOfLength(5));
}

private T createTestInstance(String name) {
public final Map<String, Object> createTestMetadata() {
Map<String, Object> metadata = null;
if (randomBoolean()) {
metadata = new HashMap<>();
Expand All @@ -428,7 +428,11 @@ private T createTestInstance(String name) {
metadata.put(randomAlphaOfLength(5), randomAlphaOfLength(5));
}
}
return createTestInstance(name, metadata);
return metadata;
}

private T createTestInstance(String name) {
return createTestInstance(name, createTestMetadata());
}

/** Return an instance on an unmapped field. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ public void setSubAggregationsSupplier(Supplier<InternalAggregations> subAggrega
this.subAggregationsSupplier = subAggregationsSupplier;
}

public final InternalAggregations createSubAggregations() {
return subAggregationsSupplier.get();
}

@Override
public void setUp() throws Exception {
super.setUp();
Expand Down

0 comments on commit 977c779

Please sign in to comment.