Skip to content

Commit

Permalink
Limit blast redius of SearchContext in aggs (#64068)
Browse files Browse the repository at this point in the history
This takes away access to the `SearchContext` from all subclasses of
`Aggregator`. Now they have access to three things:
* BigArrays
* The top level Query
* The IndexSearcher

These are used by a whole bunch of aggs.

This is a useful change because `SearchContext` is very large and
difficult to mock in tests and difficult to reason about in general.
Limiting what aggs can use when they are being collected helps with
this.

We still pass `SearchContext` to `AggregatorBase`'s ctor so the thing is
still around. But we can remove that access in a follow up.
  • Loading branch information
nik9000 committed Oct 27, 2020
1 parent 719d408 commit 6ef0e5f
Show file tree
Hide file tree
Showing 69 changed files with 402 additions and 483 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.ScoreMode;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.ObjectArray;
import org.elasticsearch.index.fielddata.NumericDoubleValues;
import org.elasticsearch.search.MultiValueMode;
Expand Down Expand Up @@ -69,7 +68,6 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
if (valuesSources == null) {
return LeafBucketCollector.NO_OP_COLLECTOR;
}
final BigArrays bigArrays = context.bigArrays();
final NumericDoubleValues[] values = new NumericDoubleValues[valuesSources.fieldNames().length];
for (int i = 0; i < values.length; ++i) {
values[i] = valuesSources.getField(i, ctx);
Expand All @@ -83,7 +81,7 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
public void collect(int doc, long bucket) throws IOException {
// get fields
if (includeDocument(doc)) {
stats = bigArrays.grow(stats, bucket + 1);
stats = bigArrays().grow(stats, bucket + 1);
RunningStats stat = stats.get(bucket);
// add document fields to correlation stats
if (stat == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public void postCollection() throws IOException {

@Override
protected void beforeBuildingBuckets(long[] ordsToCollect) throws IOException {
IndexReader indexReader = context().searcher().getIndexReader();
IndexReader indexReader = searcher().getIndexReader();
for (LeafReaderContext ctx : indexReader.leaves()) {
Scorer childDocsScorer = outFilter.scorer(ctx);
if (childDocsScorer == null) {
Expand Down Expand Up @@ -196,7 +196,7 @@ protected class DenseCollectionStrategy implements CollectionStrategy {
private final BitArray ordsBits;

public DenseCollectionStrategy(long maxOrd, BigArrays bigArrays) {
ordsBits = new BitArray(maxOrd, context.bigArrays());
ordsBits = new BitArray(maxOrd, bigArrays());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ protected Aggregator createInternal(SearchContext searchContext,
Aggregator parent,
CardinalityUpperBound cardinality,
Map<String, Object> metadata) throws IOException {
return new TestAggregator(name, parent, searchContext);
return new TestAggregator(name, parent);
}
};
}
Expand Down Expand Up @@ -541,12 +541,10 @@ public String getType() {
private static class TestAggregator extends Aggregator {
private final String name;
private final Aggregator parent;
private final SearchContext context;

private TestAggregator(String name, Aggregator parent, SearchContext context) {
private TestAggregator(String name, Aggregator parent) {
this.name = name;
this.parent = parent;
this.context = context;
}


Expand All @@ -555,11 +553,6 @@ public String name() {
return name;
}

@Override
public SearchContext context() {
return context;
}

@Override
public Aggregator parent() {
return parent;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.support.AggregationPath;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.sort.SortOrder;

import java.io.IOException;
Expand Down Expand Up @@ -68,11 +67,6 @@ public interface Parser {
*/
public abstract String name();

/**
* Return the {@link SearchContext} attached with this {@link Aggregator}.
*/
public abstract SearchContext context();

/**
* Return the parent aggregator.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
package org.elasticsearch.search.aggregations;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
Expand All @@ -47,7 +50,7 @@ public abstract class AggregatorBase extends Aggregator {

protected final String name;
protected final Aggregator parent;
protected final SearchContext context;
private final SearchContext context;
private final Map<String, Object> metadata;

protected final Aggregator[] subAggregators;
Expand Down Expand Up @@ -119,7 +122,7 @@ public ScoreMode scoreMode() {
* @param config The config for the values source metric.
*/
public final Function<byte[], Number> pointReaderIfAvailable(ValuesSourceConfig config) {
if (context.query() != null && context.query().getClass() != MatchAllDocsQuery.class) {
if (topLevelQuery() != null && topLevelQuery().getClass() != MatchAllDocsQuery.class) {
return null;
}
if (parent != null) {
Expand Down Expand Up @@ -240,14 +243,6 @@ public Aggregator subAggregator(String aggName) {
return subAggregatorbyName.get(aggName);
}

/**
* @return The current aggregation context.
*/
@Override
public SearchContext context() {
return context;
}

/**
* Called after collection of all document is done.
* <p>
Expand Down Expand Up @@ -292,4 +287,34 @@ protected final InternalAggregations buildEmptySubAggregations() {
public String toString() {
return name;
}

/**
* Utilities for sharing large primitive arrays and tracking their usage.
* Used by all subclasses.
*/
protected final BigArrays bigArrays() {
return context.bigArrays();
}

/**
* The "top level" query that will filter the results sent to this
* {@linkplain Aggregator}. Used by all {@linkplain Aggregator}s that
* perform extra collection phases in addition to the one done in
* {@link #getLeafCollector(LeafReaderContext, LeafBucketCollector)}.
*/
protected final Query topLevelQuery() {
return context.query();
}

/**
* The searcher for the shard this {@linkplain Aggregator} is running
* against. Used by all {@linkplain Aggregator}s that perform extra
* collection phases in addition to the one done in
* {@link #getLeafCollector(LeafReaderContext, LeafBucketCollector)}
* and by to look up extra "background" information about contents of
* the shard itself.
*/
protected final IndexSearcher searcher() {
return context.searcher();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ public Aggregator[] createSubAggregators(SearchContext searchContext, Aggregator
Aggregator[] aggregators = new Aggregator[countAggregators()];
for (int i = 0; i < factories.length; ++i) {
Aggregator factory = factories[i].create(searchContext, parent, cardinality);
Profilers profilers = factory.context().getProfilers();
Profilers profilers = searchContext.getProfilers();
if (profilers != null) {
factory = new ProfilingAggregator(factory, profilers.getAggregationProfiler());
}
Expand All @@ -211,7 +211,7 @@ public Aggregator[] createTopLevelAggregators(SearchContext searchContext) throw
* *exactly* what CardinalityUpperBound.ONE *means*.
*/
Aggregator factory = factories[i].create(searchContext, null, CardinalityUpperBound.ONE);
Profilers profilers = factory.context().getProfilers();
Profilers profilers = searchContext.getProfilers();
if (profilers != null) {
factory = new ProfilingAggregator(factory, profilers.getAggregationProfiler());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
Expand All @@ -36,7 +37,6 @@
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.MultiBucketCollector;
import org.elasticsearch.search.internal.SearchContext;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -64,7 +64,8 @@ static class Entry {

protected List<Entry> entries = new ArrayList<>();
protected BucketCollector collector;
protected final SearchContext searchContext;
private final Query topLevelQuery;
private final IndexSearcher searcher;
protected final boolean isGlobal;
protected LeafReaderContext context;
protected PackedLongValues.Builder docDeltasBuilder;
Expand All @@ -75,11 +76,11 @@ static class Entry {

/**
* Sole constructor.
* @param context The search context
* @param isGlobal Whether this collector visits all documents (global context)
*/
public BestBucketsDeferringCollector(SearchContext context, boolean isGlobal) {
this.searchContext = context;
public BestBucketsDeferringCollector(Query topLevelQuery, IndexSearcher searcher, boolean isGlobal) {
this.topLevelQuery = topLevelQuery;
this.searcher = searcher;
this.isGlobal = isGlobal;
}

Expand Down Expand Up @@ -162,8 +163,8 @@ public void prepareSelectedBuckets(long... selectedBuckets) throws IOException {
boolean needsScores = scoreMode().needsScores();
Weight weight = null;
if (needsScores) {
Query query = isGlobal ? new MatchAllDocsQuery() : searchContext.query();
weight = searchContext.searcher().createWeight(searchContext.searcher().rewrite(query), ScoreMode.COMPLETE, 1f);
Query query = isGlobal ? new MatchAllDocsQuery() : topLevelQuery;
weight = searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE, 1f);
}

for (Entry entry : entries) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.CardinalityUpperBound;
import org.elasticsearch.search.aggregations.BucketCollector;
import org.elasticsearch.search.aggregations.CardinalityUpperBound;
import org.elasticsearch.search.aggregations.MultiBucketCollector;
import org.elasticsearch.search.internal.SearchContext;

Expand Down Expand Up @@ -74,7 +74,7 @@ protected void doPreCollection() throws IOException {
public DeferringBucketCollector getDeferringCollector() {
// Default impl is a collector that selects the best buckets
// but an alternative defer policy may be based on best docs.
return new BestBucketsDeferringCollector(context(), descendsFromGlobalAggregator(parent()));
return new BestBucketsDeferringCollector(topLevelQuery(), searcher(), descendsFromGlobalAggregator(parent()));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.support.AggregationPath.PathElement;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.sort.SortOrder;

import java.io.IOException;
Expand Down Expand Up @@ -84,11 +83,6 @@ public Aggregator parent() {
return in.parent();
}

@Override
public SearchContext context() {
return in.context();
}

@Override
public Aggregator subAggregator(String name) {
return in.subAggregator(name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@

package org.elasticsearch.search.aggregations.bucket;

import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.packed.PackedInts;
import org.apache.lucene.util.packed.PackedLongValues;
import org.elasticsearch.search.internal.SearchContext;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -34,8 +35,8 @@
* rounding interval.
*/
public class MergingBucketsDeferringCollector extends BestBucketsDeferringCollector {
public MergingBucketsDeferringCollector(SearchContext context, boolean isGlobal) {
super(context, isGlobal);
public MergingBucketsDeferringCollector(Query topLevelQuery, IndexSearcher searcher, boolean isGlobal) {
super(topLevelQuery, searcher, isGlobal);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Sort;
Expand Down Expand Up @@ -358,10 +357,10 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t
fieldDoc.doc = -1;
}
BooleanQuery newQuery = new BooleanQuery.Builder()
.add(context.query(), BooleanClause.Occur.MUST)
.add(topLevelQuery(), BooleanClause.Occur.MUST)
.add(new SearchAfterSortedDocQuery(applySortFieldRounding(indexSortPrefix), fieldDoc), BooleanClause.Occur.FILTER)
.build();
Weight weight = context.searcher().createWeight(context.searcher().rewrite(newQuery), ScoreMode.COMPLETE_NO_SCORES, 1f);
Weight weight = searcher().createWeight(searcher().rewrite(newQuery), ScoreMode.COMPLETE_NO_SCORES, 1f);
Scorer scorer = weight.scorer(ctx);
if (scorer != null) {
DocIdSetIterator docIt = scorer.iterator();
Expand All @@ -387,12 +386,12 @@ protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucket
int sortPrefixLen = computeSortPrefixLen(indexSortPrefix);

SortedDocsProducer sortedDocsProducer = sortPrefixLen == 0 ?
sources[0].createSortedDocsProducerOrNull(ctx.reader(), context.query()) : null;
sources[0].createSortedDocsProducerOrNull(ctx.reader(), topLevelQuery()) : null;
if (sortedDocsProducer != null) {
// Visit documents sorted by the leading source of the composite definition and terminates
// when the leading source value is guaranteed to be greater than the lowest composite bucket
// in the queue.
DocIdSet docIdSet = sortedDocsProducer.processLeaf(context.query(), queue, ctx, fillDocIdSet);
DocIdSet docIdSet = sortedDocsProducer.processLeaf(topLevelQuery(), queue, ctx, fillDocIdSet);
if (fillDocIdSet) {
entries.add(new Entry(ctx, docIdSet));
}
Expand Down Expand Up @@ -457,8 +456,7 @@ private void runDeferredCollections() throws IOException {
final boolean needsScores = scoreMode().needsScores();
Weight weight = null;
if (needsScores) {
Query query = context.query();
weight = context.searcher().createWeight(context.searcher().rewrite(query), ScoreMode.COMPLETE, 1f);
weight = searcher().createWeight(searcher().rewrite(topLevelQuery()), ScoreMode.COMPLETE, 1f);
}
deferredCollectors.preCollection();
for (Entry entry : entries) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public abstract class GeoGridAggregator<T extends InternalGeoGrid> extends Bucke
this.valuesSource = valuesSource;
this.requiredSize = requiredSize;
this.shardSize = shardSize;
bucketOrds = LongKeyedBucketOrds.build(context.bigArrays(), cardinality);
bucketOrds = LongKeyedBucketOrds.build(bigArrays(), cardinality);
}

@Override
Expand Down

0 comments on commit 6ef0e5f

Please sign in to comment.