Skip to content

Commit

Permalink
fix: nested top metrics sort on keyword field (#85058)
Browse files Browse the repository at this point in the history
Using a double as a return value works only if the field we are
sorting on is a number. If the field is not a value we can convert
to a double, like a non-numeric keyword, converting it to a number
returns `NaN`. Without this patch, sorting takes place on the bucket
key, if the order field points to a non-numeric value. The additional
bucket key comparator is implicitly added as a tie breaker to avoid
non-deterministic sorting of buckets.

With this change we support sorting using any subclass of SortValue.
This means the bucket key will be used just in case of equal values
on the order field.

Issue: #78506
  • Loading branch information
salvatore-campagna committed Mar 21, 2022
1 parent dc8ec42 commit 08141cf
Show file tree
Hide file tree
Showing 13 changed files with 411 additions and 25 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/85058.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 85058
summary: "Fix: nested top metrics sort on keyword field"
area: Aggregations
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
import org.elasticsearch.search.aggregations.support.AggregationPath;
import org.elasticsearch.search.aggregations.support.SamplingContext;
import org.elasticsearch.search.sort.SortValue;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
Expand Down Expand Up @@ -231,15 +232,15 @@ public String toString() {
/**
* Get value to use when sorting by this aggregation.
*/
public double sortValue(String key) {
public SortValue sortValue(String key) {
// subclasses will override this with a real implementation if they can be sorted
throw new IllegalArgumentException("Can't sort a [" + getType() + "] aggregation [" + getName() + "]");
}

/**
* Get value to use when sorting by a descendant of this aggregation.
*/
public double sortValue(AggregationPath.PathElement head, Iterator<AggregationPath.PathElement> tail) {
public SortValue sortValue(AggregationPath.PathElement head, Iterator<AggregationPath.PathElement> tail) {
// subclasses will override this with a real implementation if you can sort on a descendant
throw new IllegalArgumentException("Can't sort by a descendant of a [" + getType() + "] aggregation [" + head + "]");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.search.aggregations.pipeline.SiblingPipelineAggregator;
import org.elasticsearch.search.aggregations.support.AggregationPath;
import org.elasticsearch.search.aggregations.support.SamplingContext;
import org.elasticsearch.search.sort.SortValue;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -80,7 +81,7 @@ private List<InternalAggregation> getInternalAggregations() {
/**
* Get value to use when sorting by a descendant of the aggregation containing this.
*/
public double sortValue(AggregationPath.PathElement head, Iterator<AggregationPath.PathElement> tail) {
public SortValue sortValue(AggregationPath.PathElement head, Iterator<AggregationPath.PathElement> tail) {
InternalAggregation aggregation = get(head.name());
if (aggregation == null) {
throw new IllegalArgumentException("Cannot find aggregation named [" + head.name() + "]");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.common.util.Comparators;
import org.elasticsearch.core.RestApiVersion;
import org.elasticsearch.search.aggregations.Aggregator.BucketComparator;
import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation.Bucket;
import org.elasticsearch.search.aggregations.support.AggregationPath;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.search.sort.SortValue;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
Expand Down Expand Up @@ -74,9 +74,10 @@ public <T extends Bucket> Comparator<T> partiallyBuiltBucketComparator(ToLongFun
@Override
public Comparator<Bucket> comparator() {
return (lhs, rhs) -> {
double l = path.resolveValue(((InternalAggregations) lhs.getAggregations()));
double r = path.resolveValue(((InternalAggregations) rhs.getAggregations()));
return Comparators.compareDiscardNaN(l, r, order == SortOrder.ASC);
final SortValue l = path.resolveValue(((InternalAggregations) lhs.getAggregations()));
final SortValue r = path.resolveValue(((InternalAggregations) rhs.getAggregations()));
int compareResult = l.compareTo(r);
return order == SortOrder.ASC ? compareResult : -compareResult;
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.PipelineTree;
import org.elasticsearch.search.aggregations.support.AggregationPath;
import org.elasticsearch.search.sort.SortValue;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
Expand Down Expand Up @@ -154,7 +155,7 @@ public XContentBuilder doXContentBody(XContentBuilder builder, Params params) th
}

@Override
public final double sortValue(String key) {
public final SortValue sortValue(String key) {
if (key != null && false == key.equals("doc_count")) {
throw new IllegalArgumentException(
"Unknown value key ["
Expand All @@ -164,11 +165,11 @@ public final double sortValue(String key) {
+ "]. Either use [doc_count] as key or drop the key all together."
);
}
return docCount;
return SortValue.from(docCount);
}

@Override
public final double sortValue(AggregationPath.PathElement head, Iterator<AggregationPath.PathElement> tail) {
public final SortValue sortValue(AggregationPath.PathElement head, Iterator<AggregationPath.PathElement> tail) {
return aggregations.sortValue(head, tail);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.support.AggregationPath;
import org.elasticsearch.search.sort.SortValue;

import java.io.IOException;
import java.util.Iterator;
Expand All @@ -30,7 +31,7 @@ protected InternalMultiValueAggregation(StreamInput in) throws IOException {
}

@Override
public final double sortValue(AggregationPath.PathElement head, Iterator<AggregationPath.PathElement> tail) {
public final SortValue sortValue(AggregationPath.PathElement head, Iterator<AggregationPath.PathElement> tail) {
throw new IllegalArgumentException("Metrics aggregations cannot have sub-aggregations (at [>" + head + "]");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.support.AggregationPath;
import org.elasticsearch.search.sort.SortValue;

import java.io.IOException;
import java.util.Iterator;
Expand Down Expand Up @@ -62,7 +63,7 @@ public Object getProperty(List<String> path) {
}

@Override
public final double sortValue(String key) {
public final SortValue sortValue(String key) {
if (key != null && false == key.equals("value")) {
throw new IllegalArgumentException(
"Unknown value key ["
Expand All @@ -72,7 +73,7 @@ public final double sortValue(String key) {
+ "]. Either use [value] as key or drop the key all together"
);
}
return value();
return SortValue.from(value());
}
}

Expand Down Expand Up @@ -115,11 +116,11 @@ public Object getProperty(List<String> path) {
}

@Override
public final double sortValue(String key) {
public final SortValue sortValue(String key) {
if (key == null) {
throw new IllegalArgumentException("Missing value key in [" + key + "] which refers to a multi-value metric aggregation");
}
return value(key);
return SortValue.from(value(key));
}
}

Expand All @@ -146,7 +147,7 @@ protected InternalNumericMetricsAggregation(StreamInput in, boolean readFormat)
}

@Override
public final double sortValue(AggregationPath.PathElement head, Iterator<AggregationPath.PathElement> tail) {
public final SortValue sortValue(AggregationPath.PathElement head, Iterator<AggregationPath.PathElement> tail) {
throw new IllegalArgumentException("Metrics aggregations cannot have sub-aggregations (at [>" + head + "]");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregator;
import org.elasticsearch.search.profile.aggregation.ProfilingAggregator;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.search.sort.SortValue;

import java.util.ArrayList;
import java.util.Iterator;
Expand Down Expand Up @@ -163,7 +164,7 @@ public List<String> getPathElementsAsStringList() {
/**
* Looks up the value of this path against a set of aggregation results.
*/
public double resolveValue(InternalAggregations aggregations) {
public SortValue resolveValue(InternalAggregations aggregations) {
try {
Iterator<PathElement> path = pathElements.iterator();
assert path.hasNext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
* A {@link Comparable}, {@link DocValueFormat} aware wrapper around a sort value.
*/
public abstract class SortValue implements NamedWriteable, Comparable<SortValue> {
private static final SortValue EMPTY_SORT_VALUE = new EmptySortValue();

/**
* Get a {@linkplain SortValue} for a double.
*/
Expand All @@ -47,14 +49,22 @@ public static SortValue from(BytesRef bytes) {
return new BytesSortValue(bytes);
}

/**
* Get a {@linkplain SortValue} for data which cannot be sorted.
*/
public static SortValue empty() {
return EMPTY_SORT_VALUE;
}

/**
* Get the list of {@linkplain NamedWriteable}s that this class needs.
*/
public static List<NamedWriteableRegistry.Entry> namedWriteables() {
return Arrays.asList(
new NamedWriteableRegistry.Entry(SortValue.class, DoubleSortValue.NAME, DoubleSortValue::new),
new NamedWriteableRegistry.Entry(SortValue.class, LongSortValue.NAME, LongSortValue::new),
new NamedWriteableRegistry.Entry(SortValue.class, BytesSortValue.NAME, BytesSortValue::new)
new NamedWriteableRegistry.Entry(SortValue.class, BytesSortValue.NAME, BytesSortValue::new),
new NamedWriteableRegistry.Entry(SortValue.class, EmptySortValue.NAME, EmptySortValue::new)
);
}

Expand Down Expand Up @@ -338,4 +348,62 @@ public Number numberValue() {
return Double.NaN;
}
}

private static class EmptySortValue extends SortValue {

public static final String NAME = "empty";
private static final String EMPTY_STRING = "";

private EmptySortValue() {}

private EmptySortValue(StreamInput ignoredIn) {}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public void writeTo(StreamOutput out) throws IOException {}

@Override
public Object getKey() {
return EMPTY_STRING;
}

@Override
public String format(DocValueFormat format) {
return EMPTY_STRING;
}

@Override
protected XContentBuilder rawToXContent(XContentBuilder builder) throws IOException {
return builder;
}

@Override
protected int compareToSameType(SortValue obj) {
return 0;
}

@Override
public boolean equals(Object obj) {
return obj != null && false != getClass().equals(obj.getClass());
}

@Override
public int hashCode() {
return 0;
}

@Override
public String toString() {
return EMPTY_STRING;
}

@Override
public Number numberValue() {
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() {

@Override
protected SortValue createTestInstance() {
return switch (between(0, 2)) {
return switch (between(0, 3)) {
case 0 -> SortValue.from(randomDouble());
case 1 -> SortValue.from(randomLong());
case 2 -> SortValue.from(new BytesRef(randomAlphaOfLength(5)));
case 3 -> SortValue.empty();
default -> throw new AssertionError();
};
}
Expand All @@ -72,6 +73,10 @@ public void testFormatLong() {
assertThat(SortValue.from(1).format(STRICT_DATE_TIME), equalTo("1970-01-01T00:00:00.001Z"));
}

public void testFormatEmpty() {
assertThat(SortValue.empty().format(DocValueFormat.RAW), equalTo(""));
}

public void testToXContentDouble() {
assertThat(toXContent(SortValue.from(1.0), DocValueFormat.RAW), equalTo("{\"test\":1.0}"));
// The date formatter coerces the double into a long to format it
Expand All @@ -91,6 +96,10 @@ public void testToXContentBytes() {
);
}

public void testToXContentEmpty() {
assertThat(toXContent(SortValue.empty(), DocValueFormat.RAW), equalTo("{\"test\"}"));
}

public void testCompareDifferentTypes() {
assertThat(SortValue.from(1.0), lessThan(SortValue.from(1)));
assertThat(SortValue.from(Double.MAX_VALUE), lessThan(SortValue.from(Long.MIN_VALUE)));
Expand All @@ -102,6 +111,20 @@ public void testCompareDifferentTypes() {
assertThat(SortValue.from(1.0), greaterThan(SortValue.from(new BytesRef("cat"))));
}

/**
* When comparing different types ordering takes place according to the writable name.
* This is the reason why "long" is greater than "empty" and "double" is less than "empty".
* See {@link org.elasticsearch.search.sort.SortValue#compareTo}.
*/
public void testCompareToEmpty() {
assertThat(SortValue.from(1.0), lessThan(SortValue.empty()));
assertThat(SortValue.from(Double.MAX_VALUE), lessThan(SortValue.empty()));
assertThat(SortValue.from(Double.NaN), lessThan(SortValue.empty()));
assertThat(SortValue.from(1), greaterThan(SortValue.empty()));
assertThat(SortValue.from(Long.MIN_VALUE), greaterThan(SortValue.empty()));
assertThat(SortValue.from(new BytesRef("cat")), lessThan(SortValue.empty()));
}

public void testCompareDoubles() {
double r = randomDouble();
assertThat(SortValue.from(r), equalTo(SortValue.from(r)));
Expand All @@ -116,6 +139,10 @@ public void testCompareLongs() {
assertThat(SortValue.from(r), greaterThan(SortValue.from(r - 1)));
}

public void testCompareEmpty() {
assertThat(SortValue.empty(), equalTo(SortValue.empty()));
}

public void testBytes() {
String r = randomAlphaOfLength(5);
assertThat(SortValue.from(new BytesRef(r)), equalTo(SortValue.from(new BytesRef(r))));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,23 +171,21 @@ public boolean equals(Object obj) {
}

@Override
public final double sortValue(String key) {
public final SortValue sortValue(String key) {
int index = metricNames.indexOf(key);
if (index < 0) {
throw new IllegalArgumentException("unknown metric [" + key + "]");
}
if (topMetrics.isEmpty()) {
return Double.NaN;
return SortValue.empty();
}

MetricValue value = topMetrics.get(0).metricValues.get(index);
if (value == null) {
return Double.NaN;
return SortValue.empty();
}

// TODO it'd probably be nicer to have "compareTo" instead of assuming a double.
// non-numeric fields always return NaN
return value.numberValue().doubleValue();
return value.getValue();
}

@Override
Expand Down

0 comments on commit 08141cf

Please sign in to comment.