Skip to content

Commit

Permalink
Correctly calculate doc count error at the slice level for concurrent…
Browse files Browse the repository at this point in the history
… segment search

Signed-off-by: Jay Deng <jayd0104@gmail.com>
  • Loading branch information
jed326 authored and Jay Deng committed Jan 3, 2024
1 parent 6a01d2f commit 1f31849
Show file tree
Hide file tree
Showing 22 changed files with 471 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ private StringTerms newTerms(Random rand, BytesRef[] dict, boolean withNested) {
0,
buckets,
0,
new TermsAggregator.BucketCountThresholds(1, 0, topNSize, numShards)
);
new TermsAggregator.BucketCountThresholds(1, 0, topNSize, numShards),
false);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ private StringTerms newTerms(boolean withNested) {
100000,
resultBuckets,
0,
new TermsAggregator.BucketCountThresholds(1, 0, buckets, buckets)
);
new TermsAggregator.BucketCountThresholds(1, 0, buckets, buckets),
false);
}

@Benchmark
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ public void testShardMinDocCountSignificantTermsTest() throws Exception {
(filter("inclass", QueryBuilders.termQuery("class", true))).subAggregation(
significantTerms("mySignificantTerms").field("text")
.minDocCount(2)
.shardSize(2)
.shardSize(10)
.shardMinDocCount(2)
.size(2)
.executionHint(randomExecutionHint())
Expand Down Expand Up @@ -198,7 +198,7 @@ public void testShardMinDocCountTermsTest() throws Exception {
.minDocCount(2)
.shardMinDocCount(2)
.size(2)
.shardSize(2)
.shardSize(10)
.executionHint(randomExecutionHint())
.order(BucketOrder.key(true))
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ protected StringTerms buildEmptyTermsAggregation() {
0,
emptyList(),
0,
bucketCountThresholds
bucketCountThresholds,
false
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ public DoubleTerms(
long otherDocCount,
List<Bucket> buckets,
long docCountError,
TermsAggregator.BucketCountThresholds bucketCountThresholds
TermsAggregator.BucketCountThresholds bucketCountThresholds,
boolean hasSliceLevelDocCountError
) {
super(
name,
Expand All @@ -150,7 +151,8 @@ public DoubleTerms(
otherDocCount,
buckets,
docCountError,
bucketCountThresholds
bucketCountThresholds,
hasSliceLevelDocCountError
);
}

Expand Down Expand Up @@ -179,7 +181,8 @@ public DoubleTerms create(List<Bucket> buckets) {
otherDocCount,
buckets,
docCountError,
bucketCountThresholds
bucketCountThresholds,
hasSliceLevelDocCountError
);
}

Expand All @@ -196,7 +199,14 @@ public Bucket createBucket(InternalAggregations aggregations, Bucket prototype)
}

@Override
protected DoubleTerms create(String name, List<Bucket> buckets, BucketOrder reduceOrder, long docCountError, long otherDocCount) {
protected DoubleTerms create(
String name,
List<Bucket> buckets,
BucketOrder reduceOrder,
long docCountError,
long otherDocCount,
boolean hasSliceLevelDocCountError
) {
return new DoubleTerms(
name,
reduceOrder,
Expand All @@ -208,7 +218,8 @@ protected DoubleTerms create(String name, List<Bucket> buckets, BucketOrder redu
otherDocCount,
buckets,
docCountError,
bucketCountThresholds
bucketCountThresholds,
hasSliceLevelDocCountError
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,8 @@ StringTerms buildResult(long owningBucketOrd, long otherDocCount, StringTerms.Bu
otherDocCount,
Arrays.asList(topBuckets),
0,
bucketCountThresholds
bucketCountThresholds,
false
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@ protected InternalMappedTerms(
long otherDocCount,
List<B> buckets,
long docCountError,
TermsAggregator.BucketCountThresholds bucketCountThresholds
TermsAggregator.BucketCountThresholds bucketCountThresholds,
boolean hasSliceLevelDocCountError
) {
super(name, reduceOrder, order, bucketCountThresholds, metadata);
super(name, reduceOrder, order, bucketCountThresholds, metadata, hasSliceLevelDocCountError);
this.format = format;
this.shardSize = shardSize;
this.showTermDocCountError = showTermDocCountError;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ public InternalMultiTerms(
List<Bucket> buckets,
TermsAggregator.BucketCountThresholds bucketCountThresholds
) {
super(name, reduceOrder, order, bucketCountThresholds, metadata);
super(name, reduceOrder, order, bucketCountThresholds, metadata, false);
this.shardSize = shardSize;
this.showTermDocCountError = showTermDocCountError;
this.otherDocCount = otherDocCount;
Expand Down Expand Up @@ -349,7 +349,8 @@ protected InternalMultiTerms create(
List<Bucket> buckets,
BucketOrder reduceOrder,
long docCountError,
long otherDocCount
long otherDocCount,
boolean hasSliceLevelDocCountError
) {
return new InternalMultiTerms(
name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,28 +225,33 @@ public int hashCode() {
protected final int requiredSize;
protected final long minDocCount;
protected final TermsAggregator.BucketCountThresholds bucketCountThresholds;
protected boolean hasSliceLevelDocCountError = false;

/**
* Creates a new {@link InternalTerms}
* @param name The name of the aggregation
* @param reduceOrder The {@link BucketOrder} that should be used to merge shard results.
* @param order The {@link BucketOrder} that should be used to sort the final reduce.
* @param bucketCountThresholds Object containing values for minDocCount, shardMinDocCount, size, shardSize.
* @param metadata The metadata associated with the aggregation.
*
* @param name The name of the aggregation
* @param reduceOrder The {@link org.opensearch.search.aggregations.BucketOrder} that should be used to merge shard results.
* @param order The {@link org.opensearch.search.aggregations.BucketOrder} that should be used to sort the final reduce.
* @param bucketCountThresholds Object containing values for minDocCount, shardMinDocCount, size, shardSize.
* @param metadata The metadata associated with the aggregation.
* @param hasSliceLevelDocCountError
*/
protected InternalTerms(
String name,
BucketOrder reduceOrder,
BucketOrder order,
TermsAggregator.BucketCountThresholds bucketCountThresholds,
Map<String, Object> metadata
Map<String, Object> metadata,
boolean hasSliceLevelDocCountError
) {
super(name, metadata);
this.reduceOrder = reduceOrder;
this.order = order;
this.bucketCountThresholds = bucketCountThresholds;
this.requiredSize = bucketCountThresholds.getRequiredSize();
this.minDocCount = bucketCountThresholds.getMinDocCount();
this.hasSliceLevelDocCountError = hasSliceLevelDocCountError;
}

/**
Expand Down Expand Up @@ -299,16 +304,19 @@ private BucketOrder getReduceOrder(List<InternalAggregation> aggregations) {

private long getDocCountError(InternalTerms<?, ?> terms, ReduceContext reduceContext) {
int size = terms.getBuckets().size();
// doc_count_error is always computed at the coordinator based on the buckets returned by the shards. This should be 0 during the
// shard level reduce as no buckets are being pruned at this stage.
if (reduceContext.isSliceLevel() || size == 0 || size < terms.getShardSize() || isKeyOrder(terms.order)) {
// TODO: I think this can be size <= terms.getShardSize() but need to validate
if (size == 0 || size < terms.getShardSize() || isKeyOrder(terms.order)) {
return 0;
} else if (InternalOrder.isCountDesc(terms.order)) {
if (terms.getDocCountError() > 0) {
// If there is an existing docCountError for this agg then
// use this as the error for this aggregation
return terms.getDocCountError();
} else {
// We need a way to indicate to the coordinator that doc count error was gathered at the slice level, so do that here
if (reduceContext.isSliceLevel()) {
hasSliceLevelDocCountError = true;
}
// otherwise use the doc count of the last term in the
// aggregation
return terms.getBuckets().stream().mapToLong(MultiBucketsAggregation.Bucket::getDocCount).min().getAsLong();
Expand Down Expand Up @@ -500,14 +508,35 @@ For backward compatibility, we disable the merge sort and use ({@link InternalTe
if (sumDocCountError == -1) {
docCountError = -1;
} else {
docCountError = aggregations.size() == 1 ? 0 : sumDocCountError;
// If there is doc count error originating from slice_size that needs to be handled differently:
// If there is slice level doc count error then that needs to be propagated to the top level doc count error even if no
// additional error is introduced by shard_size -- in other words the 1 shard case
// However, if there is only 1 slice, then we can set the doc count error to 0 and disregard any slice level doc count error,
// which is what the shards did before.
if (reduceContext.isFinalReduce() && hasSliceLevelDocCountError) {
docCountError = sumDocCountError;
} else {
if (aggregations.size() == 1) {
docCountError = 0;
hasSliceLevelDocCountError = false;
} else {
docCountError = sumDocCountError;
}
}
}

// Shards must return buckets sorted by key, so we apply the sort here in shard level reduce
if (reduceContext.isSliceLevel()) {
Arrays.sort(list, thisReduceOrder.comparator());
}
return create(name, Arrays.asList(list), reduceContext.isFinalReduce() ? order : thisReduceOrder, docCountError, otherDocCount);
return create(
name,
Arrays.asList(list),
reduceContext.isFinalReduce() ? order : thisReduceOrder,
docCountError,
otherDocCount,
hasSliceLevelDocCountError
);
}

@Override
Expand All @@ -523,7 +552,7 @@ protected B reduceBucket(List<B> buckets, ReduceContext context) {
for (B bucket : buckets) {
docCount += bucket.getDocCount();
if (docCountError != -1) {
if (bucket.showDocCountError() == false || bucket.getDocCountError() == -1) {
if (bucket.showDocCountError() == false) {
docCountError = -1;
} else {
docCountError += bucket.getDocCountError();
Expand All @@ -539,7 +568,14 @@ protected B reduceBucket(List<B> buckets, ReduceContext context) {

protected abstract int getShardSize();

protected abstract A create(String name, List<B> buckets, BucketOrder reduceOrder, long docCountError, long otherDocCount);
protected abstract A create(
String name,
List<B> buckets,
BucketOrder reduceOrder,
long docCountError,
long otherDocCount,
boolean hasSliceLevelDocCountError
);

/**
* Create an array to hold some buckets. Used in collecting the results.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ public LongTerms(
long otherDocCount,
List<Bucket> buckets,
long docCountError,
TermsAggregator.BucketCountThresholds bucketCountThresholds
TermsAggregator.BucketCountThresholds bucketCountThresholds,
boolean hasSliceLevelDocCountError
) {
super(
name,
Expand All @@ -162,7 +163,8 @@ public LongTerms(
otherDocCount,
buckets,
docCountError,
bucketCountThresholds
bucketCountThresholds,
hasSliceLevelDocCountError
);
}

Expand Down Expand Up @@ -191,7 +193,8 @@ public LongTerms create(List<Bucket> buckets) {
otherDocCount,
buckets,
docCountError,
bucketCountThresholds
bucketCountThresholds,
hasSliceLevelDocCountError
);
}

Expand All @@ -208,7 +211,14 @@ public Bucket createBucket(InternalAggregations aggregations, Bucket prototype)
}

@Override
protected LongTerms create(String name, List<Bucket> buckets, BucketOrder reduceOrder, long docCountError, long otherDocCount) {
protected LongTerms create(
String name,
List<Bucket> buckets,
BucketOrder reduceOrder,
long docCountError,
long otherDocCount,
boolean hasSliceLevelDocCountError
) {
return new LongTerms(
name,
reduceOrder,
Expand All @@ -220,7 +230,8 @@ protected LongTerms create(String name, List<Bucket> buckets, BucketOrder reduce
otherDocCount,
buckets,
docCountError,
bucketCountThresholds
bucketCountThresholds,
hasSliceLevelDocCountError
);
}

Expand Down Expand Up @@ -296,7 +307,8 @@ static DoubleTerms convertLongTermsToDouble(LongTerms longTerms, DocValueFormat
longTerms.otherDocCount,
newBuckets,
longTerms.docCountError,
longTerms.bucketCountThresholds
longTerms.bucketCountThresholds,
longTerms.hasSliceLevelDocCountError
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,8 @@ StringTerms buildResult(long owningBucketOrd, long otherDocCount, StringTerms.Bu
otherDocCount,
Arrays.asList(topBuckets),
0,
bucketCountThresholds
bucketCountThresholds,
false
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,8 @@ LongTerms buildResult(long owningBucketOrd, long otherDocCount, LongTerms.Bucket
otherDocCount,
List.of(topBuckets),
0,
bucketCountThresholds
bucketCountThresholds,
false
);
}

Expand All @@ -421,7 +422,8 @@ LongTerms buildEmptyResult() {
0,
emptyList(),
0,
bucketCountThresholds
bucketCountThresholds,
false
);
}
}
Expand Down Expand Up @@ -484,7 +486,8 @@ DoubleTerms buildResult(long owningBucketOrd, long otherDocCount, DoubleTerms.Bu
otherDocCount,
List.of(topBuckets),
0,
bucketCountThresholds
bucketCountThresholds,
false
);
}

Expand All @@ -501,7 +504,8 @@ DoubleTerms buildEmptyResult() {
0,
emptyList(),
0,
bucketCountThresholds
bucketCountThresholds,
false
);
}
}
Expand Down

0 comments on commit 1f31849

Please sign in to comment.