Skip to content

Commit

Permalink
Clean ctors for IncludeExclude (#84592)
Browse files Browse the repository at this point in the history
In order to fix an error where large regexes in `include` or
`exclude` fields of the `terms` agg crash the node (#82923) I'd like to
centralize construction of the `RegExp` so we can test it for
large-ness in one spot. The trouble is, there are half a dozen ctors for
`IncludeExclude` and some take `String` and some take `RegExp` and some
take a sets of `String` and some take sets of `BytesRef`. It's all very
convenient for client code, but confusing to deal with. This removes all
but two of the ctors for `IncludeExclude` and mostly standardizes on one
that has:
```
String includeRe, String excludeRe, Set<BytesRef> includePrecise, Set<BytesRef> excludePecise
```

Now I can fix #82923 in a fairly simple follow up.
  • Loading branch information
nik9000 committed Mar 3, 2022
1 parent 0ca2556 commit b17a7e1
Show file tree
Hide file tree
Showing 16 changed files with 178 additions and 198 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ private void testMinDocCountOnTerms(String field, Script script, BucketOrder ord
.executionHint(randomExecutionHint())
.order(order)
.size(size)
.includeExclude(include == null ? null : new IncludeExclude(include, null))
.includeExclude(include == null ? null : new IncludeExclude(include, null, null, null))
.shardSize(cardinality + randomInt(10))
.minDocCount(minDocCount)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ public void testNoBuckets() {
SearchResponse response = client().prepareSearch("idx")
.addAggregation(
terms(termsName).field("tag")
.includeExclude(new IncludeExclude(null, "tag.*"))
.includeExclude(new IncludeExclude(null, "tag.*", null, null))
.subAggregation(sum("sum").field(SINGLE_VALUED_FIELD_NAME))
)
.addAggregation(BucketMetricsPipelineAgg("pipeline_agg", termsName + ">sum"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public void testWrongPercents() throws Exception {
SearchResponse response = client().prepareSearch("idx")
.addAggregation(
terms(termsName).field("tag")
.includeExclude(new IncludeExclude(null, "tag.*"))
.includeExclude(new IncludeExclude(null, "tag.*", null, null))
.subAggregation(sum("sum").field(SINGLE_VALUED_FIELD_NAME))
)
.addAggregation(percentilesBucket("percentiles_bucket", termsName + ">sum").setPercents(PERCENTS))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentFragment;
Expand Down Expand Up @@ -69,15 +70,20 @@ public static IncludeExclude merge(IncludeExclude include, IncludeExclude exclud
throw new IllegalArgumentException("Cannot specify any excludes when using a partition-based include");
}

return new IncludeExclude(include.include, exclude.exclude, include.includeValues, exclude.excludeValues);
return new IncludeExclude(
include.include == null ? null : include.include.getOriginalString(),
exclude.exclude == null ? null : exclude.exclude.getOriginalString(),
include.includeValues,
exclude.excludeValues
);
}

public static IncludeExclude parseInclude(XContentParser parser) throws IOException {
XContentParser.Token token = parser.currentToken();
if (token == XContentParser.Token.VALUE_STRING) {
return new IncludeExclude(parser.text(), null);
return new IncludeExclude(parser.text(), null, null, null);
} else if (token == XContentParser.Token.START_ARRAY) {
return new IncludeExclude(new TreeSet<>(parseArrayToSet(parser)), null);
return new IncludeExclude(null, null, new TreeSet<>(parseArrayToSet(parser)), null);
} else if (token == XContentParser.Token.START_OBJECT) {
String currentFieldName = null;
Integer partition = null, numPartitions = null;
Expand Down Expand Up @@ -111,9 +117,9 @@ public static IncludeExclude parseInclude(XContentParser parser) throws IOExcept
public static IncludeExclude parseExclude(XContentParser parser) throws IOException {
XContentParser.Token token = parser.currentToken();
if (token == XContentParser.Token.VALUE_STRING) {
return new IncludeExclude(null, parser.text());
return new IncludeExclude(null, parser.text(), null, null);
} else if (token == XContentParser.Token.START_ARRAY) {
return new IncludeExclude(null, new TreeSet<>(parseArrayToSet(parser)));
return new IncludeExclude(null, null, null, new TreeSet<>(parseArrayToSet(parser)));
} else {
throw new IllegalArgumentException("Unrecognized token for an exclude [" + token + "]");
}
Expand Down Expand Up @@ -311,11 +317,12 @@ public LongBitSet acceptedGlobalOrdinals(SortedSetDocValues globalOrdinals) thro
* @param include The regular expression pattern for the terms to be included
* @param exclude The regular expression pattern for the terms to be excluded
*/
public IncludeExclude(RegExp include, RegExp exclude) {
this(include, exclude, null, null);
}

public IncludeExclude(RegExp include, RegExp exclude, SortedSet<BytesRef> includeValues, SortedSet<BytesRef> excludeValues) {
public IncludeExclude(
@Nullable String include,
@Nullable String exclude,
@Nullable SortedSet<BytesRef> includeValues,
@Nullable SortedSet<BytesRef> excludeValues
) {
if (include == null && exclude == null && includeValues == null && excludeValues == null) {
throw new IllegalArgumentException();
}
Expand All @@ -325,47 +332,14 @@ public IncludeExclude(RegExp include, RegExp exclude, SortedSet<BytesRef> includ
if (exclude != null && excludeValues != null) {
throw new IllegalArgumentException();
}
this.include = include;
this.exclude = exclude;
this.include = include == null ? null : new RegExp(include);
this.exclude = exclude == null ? null : new RegExp(exclude);
this.includeValues = includeValues;
this.excludeValues = excludeValues;
this.incZeroBasedPartition = 0;
this.incNumPartitions = 0;
}

public IncludeExclude(String include, String exclude, String[] includeValues, String[] excludeValues) {
this(
include == null ? null : new RegExp(include),
exclude == null ? null : new RegExp(exclude),
convertToBytesRefSet(includeValues),
convertToBytesRefSet(excludeValues)
);
}

public IncludeExclude(String include, String exclude) {
this(include == null ? null : new RegExp(include), exclude == null ? null : new RegExp(exclude));
}

/**
* @param includeValues The terms to be included
* @param excludeValues The terms to be excluded
*/
public IncludeExclude(SortedSet<BytesRef> includeValues, SortedSet<BytesRef> excludeValues) {
this(null, null, includeValues, excludeValues);
}

public IncludeExclude(String[] includeValues, String[] excludeValues) {
this(convertToBytesRefSet(includeValues), convertToBytesRefSet(excludeValues));
}

public IncludeExclude(double[] includeValues, double[] excludeValues) {
this(convertToBytesRefSet(includeValues), convertToBytesRefSet(excludeValues));
}

public IncludeExclude(long[] includeValues, long[] excludeValues) {
this(convertToBytesRefSet(includeValues), convertToBytesRefSet(excludeValues));
}

public IncludeExclude(int partition, int numPartitions) {
if (partition < 0 || partition >= numPartitions) {
throw new IllegalArgumentException("Partition must be >=0 and < numPartition which is " + numPartitions);
Expand Down Expand Up @@ -452,39 +426,6 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(incZeroBasedPartition);
}

private static SortedSet<BytesRef> convertToBytesRefSet(String[] values) {
SortedSet<BytesRef> returnSet = null;
if (values != null) {
returnSet = new TreeSet<>();
for (String value : values) {
returnSet.add(new BytesRef(value));
}
}
return returnSet;
}

private static SortedSet<BytesRef> convertToBytesRefSet(double[] values) {
SortedSet<BytesRef> returnSet = null;
if (values != null) {
returnSet = new TreeSet<>();
for (double value : values) {
returnSet.add(new BytesRef(String.valueOf(value)));
}
}
return returnSet;
}

private static SortedSet<BytesRef> convertToBytesRefSet(long[] values) {
SortedSet<BytesRef> returnSet = null;
if (values != null) {
returnSet = new TreeSet<>();
for (long value : values) {
returnSet.add(new BytesRef(String.valueOf(value)));
}
}
return returnSet;
}

/**
* Terms adapter around doc values.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
package org.elasticsearch.search.aggregations.bucket;

import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.automaton.RegExp;
import org.elasticsearch.search.aggregations.BaseAggregationTestCase;
import org.elasticsearch.search.aggregations.bucket.terms.IncludeExclude;
import org.elasticsearch.search.aggregations.bucket.terms.RareTermsAggregationBuilder;
Expand All @@ -34,17 +33,17 @@ protected RareTermsAggregationBuilder createTestAggregatorBuilder() {
if (randomBoolean()) {
IncludeExclude incExc = null;
switch (randomInt(6)) {
case 0 -> incExc = new IncludeExclude(new RegExp("foobar"), null);
case 1 -> incExc = new IncludeExclude(null, new RegExp("foobaz"));
case 2 -> incExc = new IncludeExclude(new RegExp("foobar"), new RegExp("foobaz"));
case 0 -> incExc = new IncludeExclude("foobar", null, null, null);
case 1 -> incExc = new IncludeExclude(null, "foobaz", null, null);
case 2 -> incExc = new IncludeExclude("foobar", "foobaz", null, null);
case 3 -> {
SortedSet<BytesRef> includeValues = new TreeSet<>();
int numIncs = randomIntBetween(1, 20);
for (int i = 0; i < numIncs; i++) {
includeValues.add(new BytesRef(randomAlphaOfLengthBetween(1, 30)));
}
SortedSet<BytesRef> excludeValues = null;
incExc = new IncludeExclude(includeValues, excludeValues);
incExc = new IncludeExclude(null, null, includeValues, excludeValues);
}
case 4 -> {
SortedSet<BytesRef> includeValues2 = null;
Expand All @@ -53,7 +52,7 @@ protected RareTermsAggregationBuilder createTestAggregatorBuilder() {
for (int i = 0; i < numExcs2; i++) {
excludeValues2.add(new BytesRef(randomAlphaOfLengthBetween(1, 30)));
}
incExc = new IncludeExclude(includeValues2, excludeValues2);
incExc = new IncludeExclude(null, null, includeValues2, excludeValues2);
}
case 5 -> {
SortedSet<BytesRef> includeValues3 = new TreeSet<>();
Expand All @@ -66,7 +65,7 @@ protected RareTermsAggregationBuilder createTestAggregatorBuilder() {
for (int i = 0; i < numExcs3; i++) {
excludeValues3.add(new BytesRef(randomAlphaOfLengthBetween(1, 30)));
}
incExc = new IncludeExclude(includeValues3, excludeValues3);
incExc = new IncludeExclude(null, null, includeValues3, excludeValues3);
}
case 6 -> {
final int numPartitions = randomIntBetween(1, 100);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
package org.elasticsearch.search.aggregations.bucket;

import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.automaton.RegExp;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.BaseAggregationTestCase;
import org.elasticsearch.search.aggregations.bucket.terms.IncludeExclude;
Expand Down Expand Up @@ -122,17 +121,17 @@ static SignificanceHeuristic getSignificanceHeuristic() {
static IncludeExclude getIncludeExclude() {
IncludeExclude incExc = null;
switch (randomInt(5)) {
case 0 -> incExc = new IncludeExclude(new RegExp("foobar"), null);
case 1 -> incExc = new IncludeExclude(null, new RegExp("foobaz"));
case 2 -> incExc = new IncludeExclude(new RegExp("foobar"), new RegExp("foobaz"));
case 0 -> incExc = new IncludeExclude("foobar", null, null, null);
case 1 -> incExc = new IncludeExclude(null, "foobaz", null, null);
case 2 -> incExc = new IncludeExclude("foobar", "foobaz", null, null);
case 3 -> {
SortedSet<BytesRef> includeValues = new TreeSet<>();
int numIncs = randomIntBetween(1, 20);
for (int i = 0; i < numIncs; i++) {
includeValues.add(new BytesRef(randomAlphaOfLengthBetween(1, 30)));
}
SortedSet<BytesRef> excludeValues = null;
incExc = new IncludeExclude(includeValues, excludeValues);
incExc = new IncludeExclude(null, null, includeValues, excludeValues);
}
case 4 -> {
SortedSet<BytesRef> includeValues2 = null;
Expand All @@ -141,7 +140,7 @@ static IncludeExclude getIncludeExclude() {
for (int i = 0; i < numExcs2; i++) {
excludeValues2.add(new BytesRef(randomAlphaOfLengthBetween(1, 30)));
}
incExc = new IncludeExclude(includeValues2, excludeValues2);
incExc = new IncludeExclude(null, null, includeValues2, excludeValues2);
}
case 5 -> {
SortedSet<BytesRef> includeValues3 = new TreeSet<>();
Expand All @@ -154,7 +153,7 @@ static IncludeExclude getIncludeExclude() {
for (int i = 0; i < numExcs3; i++) {
excludeValues3.add(new BytesRef(randomAlphaOfLengthBetween(1, 30)));
}
incExc = new IncludeExclude(includeValues3, excludeValues3);
incExc = new IncludeExclude(null, null, includeValues3, excludeValues3);
}
default -> fail();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
package org.elasticsearch.search.aggregations.bucket;

import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.automaton.RegExp;
import org.elasticsearch.search.aggregations.Aggregator.SubAggCollectionMode;
import org.elasticsearch.search.aggregations.BaseAggregationTestCase;
import org.elasticsearch.search.aggregations.BucketOrder;
Expand Down Expand Up @@ -91,14 +90,14 @@ protected TermsAggregationBuilder createTestAggregatorBuilder() {
factory.format("###.##");
}
if (randomBoolean()) {
RegExp includeRegexp = null, excludeRegexp = null;
String includeRegexp = null, excludeRegexp = null;
SortedSet<BytesRef> includeValues = null, excludeValues = null;
boolean hasIncludeOrExclude = false;

if (randomBoolean()) {
hasIncludeOrExclude = true;
if (randomBoolean()) {
includeRegexp = new RegExp(randomAlphaOfLengthBetween(5, 10));
includeRegexp = randomAlphaOfLengthBetween(5, 10);
} else {
includeValues = new TreeSet<>();
int numIncs = randomIntBetween(1, 20);
Expand All @@ -111,7 +110,7 @@ protected TermsAggregationBuilder createTestAggregatorBuilder() {
if (randomBoolean()) {
hasIncludeOrExclude = true;
if (randomBoolean()) {
excludeRegexp = new RegExp(randomAlphaOfLengthBetween(5, 10));
excludeRegexp = randomAlphaOfLengthBetween(5, 10);
} else {
excludeValues = new TreeSet<>();
int numIncs = randomIntBetween(1, 20);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.automaton.RegExp;
import org.elasticsearch.common.Numbers;
import org.elasticsearch.index.mapper.BinaryFieldMapper;
import org.elasticsearch.index.mapper.MappedFieldType;
Expand Down Expand Up @@ -73,7 +72,7 @@ public void testMatchAllDocs() throws IOException {
}

public void testBadIncludeExclude() throws IOException {
IncludeExclude includeExclude = new IncludeExclude(new RegExp("foo"), null);
IncludeExclude includeExclude = new IncludeExclude("foo", null, null, null);

// Make sure the include/exclude fails regardless of how the user tries to type hint the agg
AggregationExecutionException e = expectThrows(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.apache.lucene.search.Query;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.util.automaton.RegExp;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.search.aggregations.AggregationExecutionException;
Expand Down Expand Up @@ -92,7 +91,7 @@ public void testMatchAllDocs() throws IOException {
}

public void testBadIncludeExclude() throws IOException {
IncludeExclude includeExclude = new IncludeExclude(new RegExp("foo"), null);
IncludeExclude includeExclude = new IncludeExclude("foo", null, null, null);

// Numerics don't support any regex include/exclude, so should fail no matter what we do

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import java.util.function.Consumer;

import static java.util.stream.Collectors.toList;
Expand Down Expand Up @@ -155,27 +157,14 @@ public void testIncludeExclude() throws IOException {
dataset,
aggregation -> aggregation.field(LONG_FIELD)
.maxDocCount(2) // bump to 2 since we're only including "2"
.includeExclude(new IncludeExclude(new long[] { 2 }, new long[] {})),
.includeExclude(new IncludeExclude(null, null, new TreeSet<>(Set.of(new BytesRef("2"))), new TreeSet<>())),
agg -> {
assertEquals(1, agg.getBuckets().size());
LongRareTerms.Bucket bucket = (LongRareTerms.Bucket) agg.getBuckets().get(0);
assertThat(bucket.getKey(), equalTo(2L));
assertThat(bucket.getDocCount(), equalTo(2L));
}
);
testSearchCase(
query,
dataset,
aggregation -> aggregation.field(KEYWORD_FIELD)
.maxDocCount(2) // bump to 2 since we're only including "2"
.includeExclude(new IncludeExclude(new String[] { "2" }, new String[] {})),
agg -> {
assertEquals(1, agg.getBuckets().size());
StringRareTerms.Bucket bucket = (StringRareTerms.Bucket) agg.getBuckets().get(0);
assertThat(bucket.getKeyAsString(), equalTo("2"));
assertThat(bucket.getDocCount(), equalTo(2L));
}
);
}

public void testEmbeddedMaxAgg() throws IOException {
Expand Down

0 comments on commit b17a7e1

Please sign in to comment.