Skip to content

Commit

Permalink
[ML] scaling metric aggregation counts and sums when aggregations are…
Browse files Browse the repository at this point in the history
… sampled (#83263)

This is a follow up to: #81228

This commits allows aggregations to be aware that they are being reduced within a sampling context. Currently the only sampling context is provided by the `random_sampler` aggregation. 

This commit also enables the following metric aggregations to be sampled and scales the values appropriately before serializing back to the user:

 - percentiles 
 - avg
 - extended_stats
 - geo_bounds
 - geo_centroid
 - max
 - median_absolute_deviation
 - min
 - scripted_metric
 - stats
 - sum
 - top_hits
 - value_count
 - weighted_avg
 - rate
 - string_stats

No multi-bucket aggregation support is added in this commit, that will be in a later commit.
  • Loading branch information
benwtrent committed Feb 9, 2022
1 parent 07fa7a5 commit bf9879f
Show file tree
Hide file tree
Showing 65 changed files with 743 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.search.aggregations.AggregationReduceContext;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.support.SamplingContext;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
Expand Down Expand Up @@ -67,6 +68,9 @@ public String getWriteableName() {
/** get the number of documents */
@Override
public long getDocCount() {
if (results != null) {
return results.getDocCount();
}
if (stats == null) {
return 0;
}
Expand Down Expand Up @@ -241,6 +245,17 @@ public InternalAggregation reduce(List<InternalAggregation> aggregations, Aggreg
return new InternalMatrixStats(name, runningStats.docCount, runningStats, null, getMetadata());
}

@Override
public InternalAggregation finalizeSampling(SamplingContext samplingContext) {
return new InternalMatrixStats(
name,
samplingContext.inverseScale(getDocCount()),
stats,
new MatrixStatsResults(stats, samplingContext),
getMetadata()
);
}

@Override
protected boolean mustReduceOnSingleInternalAgg() {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ protected AggregationBuilder shallowCopy(AggregatorFactories.Builder factoriesBu
return new MatrixStatsAggregationBuilder(this, factoriesBuilder, metadata);
}

@Override
public boolean supportsSampling() {
return true;
}

/**
* Read from a stream.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.search.aggregations.support.SamplingContext;

import java.io.IOException;
import java.util.Collections;
Expand Down Expand Up @@ -41,6 +42,18 @@ class MatrixStatsResults implements Writeable {
this.compute();
}

/** creates and computes the result from the provided stats, scaling as necessary given the sampling context */
MatrixStatsResults(RunningStats stats, SamplingContext samplingContext) {
this.results = stats.clone();
this.correlation = new HashMap<>();
this.compute();
// Note: it is important to scale counts AFTER compute as scaling before could introduce bias
this.results.docCount = samplingContext.inverseScale(this.results.docCount);
for (String field : this.results.counts.keySet()) {
this.results.counts.computeIfPresent(field, (k, v) -> samplingContext.inverseScale(v));
}
}

/** creates a results object from the given stream */
@SuppressWarnings("unchecked")
protected MatrixStatsResults(StreamInput in) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
import org.elasticsearch.search.aggregations.support.AggregationContext;
import org.elasticsearch.search.aggregations.support.SamplingContext;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xcontent.XContentParser;
Expand Down Expand Up @@ -185,6 +186,19 @@ public static final class CommonFields extends ParseField.CommonFields {
public static final ParseField VALUE_TYPE = new ParseField("value_type");
}

/**
* Does this aggregation support running with in a sampling context.
*
* By default, it's false for all aggregations.
*
* If the sub-classed builder supports sampling, be sure of the following that the resulting internal aggregation objects
* override the {@link InternalAggregation#finalizeSampling(SamplingContext)} and scales any values that require scaling.
* @return does this aggregation builder support sampling
*/
public boolean supportsSampling() {
return false;
}

@Override
public String toString() {
return Strings.toString(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
import org.elasticsearch.search.aggregations.support.AggregationPath;
import org.elasticsearch.search.aggregations.support.SamplingContext;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
Expand Down Expand Up @@ -121,7 +122,16 @@ public InternalAggregation reducePipelines(
public abstract InternalAggregation reduce(List<InternalAggregation> aggregations, AggregationReduceContext reduceContext);

/**
* Signal the framework if the {@linkplain InternalAggregation#reduce(List, ReduceContext)} phase needs to be called
* Called by the parent sampling context. Should only ever be called once as some aggregations scale their internal values
* @param samplingContext the current sampling context
* @return new aggregation with the sampling context applied, could be the same aggregation instance if nothing needs to be done
*/
public InternalAggregation finalizeSampling(SamplingContext samplingContext) {
throw new UnsupportedOperationException(getWriteableName() + " aggregation [" + getName() + "] does not support sampling");
}

/**
* Signal the framework if the {@linkplain InternalAggregation#reduce(List, AggregationReduceContext)} phase needs to be called
* when there is only one {@linkplain InternalAggregation}.
*/
protected abstract boolean mustReduceOnSingleInternalAgg();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,38 @@

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.search.aggregations.AggregationReduceContext;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.aggregations.bucket.InternalSingleBucketAggregation;
import org.elasticsearch.search.aggregations.bucket.sampler.Sampler;
import org.elasticsearch.search.aggregations.support.SamplingContext;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class InternalRandomSampler extends InternalSingleBucketAggregation implements Sampler {
public static final String NAME = "mapped_random_sampler";
public static final String PARSER_NAME = "random_sampler";

private final int seed;
private final double probability;

InternalRandomSampler(String name, long docCount, int seed, InternalAggregations subAggregations, Map<String, Object> metadata) {
InternalRandomSampler(
String name,
long docCount,
int seed,
double probability,
InternalAggregations subAggregations,
Map<String, Object> metadata
) {
super(name, docCount, subAggregations, metadata);
this.seed = seed;
this.probability = probability;
}

/**
Expand All @@ -35,12 +50,14 @@ public class InternalRandomSampler extends InternalSingleBucketAggregation imple
public InternalRandomSampler(StreamInput in) throws IOException {
super(in);
this.seed = in.readInt();
this.probability = in.readDouble();
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
super.doWriteTo(out);
out.writeInt(seed);
out.writeDouble(probability);
}

@Override
Expand All @@ -55,12 +72,36 @@ public String getType() {

@Override
protected InternalSingleBucketAggregation newAggregation(String name, long docCount, InternalAggregations subAggregations) {
return new InternalRandomSampler(name, docCount, seed, subAggregations, metadata);
return new InternalRandomSampler(name, docCount, seed, probability, subAggregations, metadata);
}

@Override
public InternalAggregation reduce(List<InternalAggregation> aggregations, AggregationReduceContext reduceContext) {
long docCount = 0L;
List<InternalAggregations> subAggregationsList = new ArrayList<>(aggregations.size());
for (InternalAggregation aggregation : aggregations) {
docCount += ((InternalSingleBucketAggregation) aggregation).getDocCount();
subAggregationsList.add(((InternalSingleBucketAggregation) aggregation).getAggregations());
}
InternalAggregations aggs = InternalAggregations.reduce(subAggregationsList, reduceContext);
if (reduceContext.isFinalReduce() && aggs != null) {
SamplingContext context = buildContext();
aggs = InternalAggregations.from(
aggs.asList().stream().map(agg -> ((InternalAggregation) agg).finalizeSampling(context)).collect(Collectors.toList())
);
}

return newAggregation(getName(), docCount, aggs);
}

public SamplingContext buildContext() {
return new SamplingContext(probability, seed);
}

@Override
public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
builder.field(RandomSamplerAggregationBuilder.SEED.getPreferredName(), seed);
builder.field(RandomSamplerAggregationBuilder.PROBABILITY.getPreferredName(), probability);
builder.field(CommonFields.DOC_COUNT.getPreferredName(), getDocCount());
getAggregations().toXContentInternal(builder, params);
return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.AggregatorFactory;
import org.elasticsearch.search.aggregations.bucket.nested.NestedAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.sampler.DiversifiedAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.sampler.SamplerAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.CardinalityAggregationBuilder;
import org.elasticsearch.search.aggregations.support.AggregationContext;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
Expand Down Expand Up @@ -78,6 +74,10 @@ public RandomSamplerAggregationBuilder(StreamInput in) throws IOException {
this.seed = in.readInt();
}

public double getProbability() {
return p;
}

protected RandomSamplerAggregationBuilder(
RandomSamplerAggregationBuilder clone,
AggregatorFactories.Builder factoriesBuilder,
Expand Down Expand Up @@ -118,10 +118,7 @@ protected AggregatorFactory doBuild(
}
recursivelyCheckSubAggs(subfactoriesBuilder.getAggregatorFactories(), builder -> {
// TODO add a method or interface to aggregation builder that defaults to false
if (builder instanceof CardinalityAggregationBuilder
|| builder instanceof NestedAggregationBuilder
|| builder instanceof SamplerAggregationBuilder
|| builder instanceof DiversifiedAggregationBuilder) {
if (builder.supportsSampling() == false) {
throw new IllegalArgumentException(
"[random_sampler] aggregation ["
+ getName()
Expand All @@ -136,6 +133,10 @@ protected AggregatorFactory doBuild(
return new RandomSamplerAggregatorFactory(name, seed, p, context, parent, subfactoriesBuilder, metadata);
}

public int getSeed() {
return seed;
}

@Override
protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@
public class RandomSamplerAggregator extends BucketsAggregator implements SingleBucketAggregator {

private final int seed;
private final double probability;
private final CheckedSupplier<Weight, IOException> weightSupplier;

RandomSamplerAggregator(
String name,
int seed,
double probability,
CheckedSupplier<Weight, IOException> weightSupplier,
AggregatorFactories factories,
AggregationContext context,
Expand All @@ -43,6 +45,7 @@ public class RandomSamplerAggregator extends BucketsAggregator implements Single
) throws IOException {
super(name, factories, context, parent, cardinalityUpperBound, metadata);
this.seed = seed;
this.probability = probability;
if (this.subAggregators().length == 0) {
throw new IllegalArgumentException(
RandomSamplerAggregationBuilder.NAME + " aggregation [" + name + "] must have sub aggregations configured"
Expand All @@ -59,6 +62,7 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I
name,
bucketDocCount(owningBucketOrd),
seed,
probability,
subAggregationResults,
metadata()
)
Expand All @@ -67,7 +71,7 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I

@Override
public InternalAggregation buildEmptyAggregation() {
return new InternalRandomSampler(name, 0, seed, buildEmptySubAggregations(), metadata());
return new InternalRandomSampler(name, 0, seed, probability, buildEmptySubAggregations(), metadata());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public class RandomSamplerAggregatorFactory extends AggregatorFactory {
@Override
public Aggregator createInternal(Aggregator parent, CardinalityUpperBound cardinality, Map<String, Object> metadata)
throws IOException {
return new RandomSamplerAggregator(name, seed, this::getWeight, factories, context, parent, cardinality, metadata);
return new RandomSamplerAggregator(name, seed, probability, this::getWeight, factories, context, parent, cardinality, metadata);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.AggregationReduceContext;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.support.SamplingContext;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
Expand Down Expand Up @@ -133,6 +134,11 @@ public AbstractInternalHDRPercentiles reduce(List<InternalAggregation> aggregati
return createReduced(getName(), keys, merged, keyed, getMetadata());
}

@Override
public InternalAggregation finalizeSampling(SamplingContext samplingContext) {
return this;
}

protected abstract AbstractInternalHDRPercentiles createReduced(
String name,
double[] keys,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.AggregationReduceContext;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.support.SamplingContext;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
Expand Down Expand Up @@ -116,6 +117,11 @@ public AbstractInternalTDigestPercentiles reduce(List<InternalAggregation> aggre
return createReduced(getName(), keys, merged, keyed, getMetadata());
}

@Override
public InternalAggregation finalizeSampling(SamplingContext samplingContext) {
return this;
}

protected abstract AbstractInternalTDigestPercentiles createReduced(
String name,
double[] keys,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ public static <T extends AbstractPercentilesAggregationBuilder<T>> ConstructingO
}
}

@Override
public boolean supportsSampling() {
return true;
}

@Override
protected void innerWriteTo(StreamOutput out) throws IOException {
out.writeDoubleArray(values);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ public AvgAggregationBuilder(StreamInput in) throws IOException {
super(in);
}

@Override
public boolean supportsSampling() {
return true;
}

@Override
protected AggregationBuilder shallowCopy(AggregatorFactories.Builder factoriesBuilder, Map<String, Object> metadata) {
return new AvgAggregationBuilder(this, factoriesBuilder, metadata);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ public ExtendedStatsAggregationBuilder(StreamInput in) throws IOException {
sigma = in.readDouble();
}

@Override
public boolean supportsSampling() {
return true;
}

@Override
public Set<String> metricNames() {
return InternalExtendedStats.METRIC_NAMES;
Expand Down

0 comments on commit bf9879f

Please sign in to comment.