Skip to content

Commit

Permalink
Avoid deep copy and other allocation improvements
Browse files Browse the repository at this point in the history
Signed-off-by: expani <anijainc@amazon.com>
  • Loading branch information
expani committed Jul 29, 2024
1 parent 59302a3 commit fb84412
Showing 1 changed file with 89 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.InternalOrder;
import org.opensearch.search.aggregations.LeafBucketCollector;
import org.opensearch.search.aggregations.bucket.BucketsAggregator;
import org.opensearch.search.aggregations.bucket.DeferableBucketAggregator;
import org.opensearch.search.aggregations.bucket.LocalBucketCountThresholds;
import org.opensearch.search.aggregations.support.AggregationPath;
Expand Down Expand Up @@ -215,19 +216,11 @@ public InternalAggregation buildEmptyAggregation() {

@Override
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx);
MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx, bucketOrds, this, sub);
return new LeafBucketCollector() {
@Override
public void collect(int doc, long owningBucketOrd) throws IOException {
for (BytesRef compositeKey : collector.apply(doc)) {
long bucketOrd = bucketOrds.add(owningBucketOrd, compositeKey);
if (bucketOrd < 0) {
bucketOrd = -1 - bucketOrd;
collectExistingBucket(sub, doc, bucketOrd);
} else {
collectBucket(sub, doc, bucketOrd);
}
}
collector.apply(doc, owningBucketOrd);
}
};
}
Expand Down Expand Up @@ -268,12 +261,10 @@ private void collectZeroDocEntriesIfNeeded(long owningBucketOrd) throws IOExcept
}
// we need to fill-in the blanks
for (LeafReaderContext ctx : context.searcher().getTopReaderContext().leaves()) {
MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx);
MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx, bucketOrds, null, null);
// brute force
for (int docId = 0; docId < ctx.reader().maxDoc(); ++docId) {
for (BytesRef compositeKey : collector.apply(docId)) {
bucketOrds.add(owningBucketOrd, compositeKey);
}
collector.apply(docId, owningBucketOrd);
}
}
}
Expand All @@ -287,7 +278,8 @@ interface MultiTermsValuesSourceCollector {
* Collect a list values of multi_terms on each doc.
* Each terms could have multi_values, so the result is the cartesian product of each term's values.
*/
List<BytesRef> apply(int doc) throws IOException;
void apply(int doc, long owningBucketOrd) throws IOException;

}

@FunctionalInterface
Expand Down Expand Up @@ -361,51 +353,17 @@ public MultiTermsValuesSource(List<InternalValuesSource> valuesSources) {
this.valuesSources = valuesSources;
}

public MultiTermsValuesSourceCollector getValues(LeafReaderContext ctx) throws IOException {
public MultiTermsValuesSourceCollector getValues(
LeafReaderContext ctx,
BytesKeyedBucketOrds bucketOrds,
BucketsAggregator aggregator,
LeafBucketCollector sub
) throws IOException {
List<InternalValuesSourceCollector> collectors = new ArrayList<>();
for (InternalValuesSource valuesSource : valuesSources) {
collectors.add(valuesSource.apply(ctx));
}
return new MultiTermsValuesSourceCollector() {
@Override
public List<BytesRef> apply(int doc) throws IOException {
List<List<TermValue<?>>> collectedValues = new ArrayList<>();
for (InternalValuesSourceCollector collector : collectors) {
collectedValues.add(collector.apply(doc));
}
List<BytesRef> result = new ArrayList<>();
scratch.seek(0);
scratch.writeVInt(collectors.size()); // number of fields per composite key
cartesianProduct(result, scratch, collectedValues, 0);
return result;
}

/**
* Cartesian product using depth first search.
*
* <p>
* Composite keys are encoded to a {@link BytesRef} in a format compatible with {@link StreamOutput::writeGenericValue},
* but reuses the encoding of the shared prefixes from the previous levels to avoid wasteful work.
*/
private void cartesianProduct(
List<BytesRef> compositeKeys,
BytesStreamOutput scratch,
List<List<TermValue<?>>> collectedValues,
int index
) throws IOException {
if (collectedValues.size() == index) {
compositeKeys.add(BytesRef.deepCopyOf(scratch.bytes().toBytesRef()));
return;
}

long position = scratch.position();
for (TermValue<?> value : collectedValues.get(index)) {
value.writeTo(scratch); // encode the value
cartesianProduct(compositeKeys, scratch, collectedValues, index + 1); // dfs
scratch.seek(position); // backtrack
}
}
};
return new MultiValuesSourceCollectorImpl(collectors, scratch, bucketOrds, aggregator, sub);
}

@Override
Expand All @@ -414,6 +372,74 @@ public void close() {
}
}

static class MultiValuesSourceCollectorImpl implements MultiTermsValuesSourceCollector {

private final List<InternalValuesSourceCollector> collectors;
private final BytesStreamOutput scratch;
private final BytesKeyedBucketOrds bucketOrds;
private final BucketsAggregator aggregator;
private final LeafBucketCollector sub;

private final boolean collectViaAggregator;

public MultiValuesSourceCollectorImpl(
List<InternalValuesSourceCollector> collectors,
BytesStreamOutput scratch,
BytesKeyedBucketOrds bucketOrds,
BucketsAggregator aggregator,
LeafBucketCollector sub
) {
this.collectors = collectors;
this.scratch = scratch;
this.bucketOrds = bucketOrds;
this.aggregator = aggregator;
this.sub = sub;
this.collectViaAggregator = aggregator != null && sub != null;
}

@Override
public void apply(int doc, long owningBucketOrd) throws IOException {
List<List<TermValue<?>>> collectedValues = new ArrayList<>();
for (InternalValuesSourceCollector collector : collectors) {
collectedValues.add(collector.apply(doc));
}
scratch.seek(0);
scratch.writeVInt(collectors.size()); // number of fields per composite key
cartesianProductRecursive(collectedValues, 0, owningBucketOrd, doc);
}

/**
* Cartesian product using depth first search.
*/
private void cartesianProductRecursive(List<List<TermValue<?>>> collectedValues, int index, long owningBucketOrd, int doc)
throws IOException {
if (collectedValues.size() == index) {
// Avoid performing a deep copy of the composite key
long bucketOrd = bucketOrds.add(owningBucketOrd, scratch.bytes().toBytesRef());
if (collectViaAggregator) {
if (bucketOrd < 0) {
bucketOrd = -1 - bucketOrd;
aggregator.collectExistingBucket(sub, doc, bucketOrd);
} else {
aggregator.collectBucket(sub, doc, bucketOrd);
}
}
return;
}

long position = scratch.position();
List<TermValue<?>> values = collectedValues.get(index);
int numIterations = values.size();
for (int i = 0; i < numIterations; i++) {
TermValue<?> value = values.get(i);
value.writeTo(scratch); // encode the value
cartesianProductRecursive(collectedValues, index + 1, owningBucketOrd, doc); // dfs
scratch.seek(position); // backtrack
}
}

}

/**
* Factory for construct {@link InternalValuesSource}.
*
Expand Down Expand Up @@ -441,9 +467,13 @@ static InternalValuesSource bytesValuesSource(ValuesSource valuesSource, Include
if (i > 0 && bytes.equals(previous)) {
continue;
}
BytesRef copy = BytesRef.deepCopyOf(bytes);
termValues.add(TermValue.of(copy));
previous = copy;
if (valuesCount > 1) {
BytesRef copy = BytesRef.deepCopyOf(bytes);
termValues.add(TermValue.of(copy));
previous = copy;
} else {
termValues.add(TermValue.of(bytes));
}
}
return termValues;
};
Expand Down

0 comments on commit fb84412

Please sign in to comment.