Skip to content

Commit

Permalink
Enabling sort optimizatin back for half_float with custom comparators
Browse files Browse the repository at this point in the history
Signed-off-by: Chaitanya Gohel <gashutos@amazon.com>
  • Loading branch information
gashutos committed Oct 31, 2023
1 parent 63aff16 commit f200786
Show file tree
Hide file tree
Showing 7 changed files with 322 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@
properties:
counter:
type: double

- do:
indices.create:
index: test_3
body:
mappings:
properties:
counter:
type: half_float

- do:
bulk:
refresh: true
Expand All @@ -33,15 +43,21 @@
- index:
_index: test_2
- counter: 184.4
- index:
_index: test_3
- counter: 187.4
- index:
_index: test_3
- counter: 194.4

- do:
search:
index: test_*
rest_total_hits_as_int: true
body:
sort: [{ counter: desc }]
- match: { hits.total: 3 }
- length: { hits.hits: 3 }
- match: { hits.total: 5 }
- length: { hits.hits: 5 }
- match: { hits.hits.0._index: test_2 }
- match: { hits.hits.0._source.counter: 1223372036854775800.23 }
- match: { hits.hits.0.sort.0: 1223372036854775800.23 }
Expand All @@ -55,14 +71,13 @@
rest_total_hits_as_int: true
body:
sort: [{ counter: asc }]
- match: { hits.total: 3 }
- length: { hits.hits: 3 }
- match: { hits.total: 5 }
- length: { hits.hits: 5 }
- match: { hits.hits.0._index: test_2 }
- match: { hits.hits.0._source.counter: 184.4 }
- match: { hits.hits.0.sort.0: 184.4 }
- match: { hits.hits.1._index: test_1 }
- match: { hits.hits.1._source.counter: 223372036854775800 }
- match: { hits.hits.1.sort.0: 223372036854775800 }
- match: { hits.hits.1._index: test_3 }
- match: { hits.hits.1._source.counter: 187.4 }

---
"search across indices with mixed long, double and unsigned_long numeric types":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,71 @@
- length: { hits.hits: 1 }
- match: { hits.hits.0._index: test }
- match: { hits.hits.0._source.population: null }

---
"half float":
- do:
indices.create:
index: test
body:
mappings:
properties:
population:
type: half_float
- do:
bulk:
refresh: true
index: test
body: |
{"index":{}}
{"population": 184.4}
{"index":{}}
{"population": 194.4}
- do:
search:
index: test
rest_total_hits_as_int: true
body:
size: 1
sort: [ { population: asc } ]
- match: { hits.total: 2 }
- length: { hits.hits: 1 }
- match: { hits.hits.0._index: test }
- match: { hits.hits.0._source.population: 184.4 }

# search_after with the sort
- do:
search:
index: test
rest_total_hits_as_int: true
body:
size: 1
sort: [ { population: asc } ]
search_after: [ 184.4 ]
- match: { hits.total: 2 }
- length: { hits.hits: 1 }
- match: { hits.hits.0._index: test }
- match: { hits.hits.0._source.population: 194.4 }

# search_after with the sort with missing
- do:
bulk:
refresh: true
index: test
body: |
{"index":{}}
{"population": null}
- do:
search:
index: test
rest_total_hits_as_int: true
body:
"size": 5
"sort": [ { "population": { "order": "asc", "missing": "_last" } } ]
search_after: [ 200 ] # making it out of min/max so only missing value hit is qualified

- match: { hits.total: 3 }
- length: { hits.hits: 1 }
- match: { hits.hits.0._index: test }
- match: { hits.hits.0._source.population: null }
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,9 @@ public void testSimpleSorts() throws Exception {
.startObject("float_value")
.field("type", "float")
.endObject()
.startObject("half_float_value")
.field("type", "half_float")
.endObject()
.startObject("double_value")
.field("type", "double")
.endObject()
Expand All @@ -628,6 +631,7 @@ public void testSimpleSorts() throws Exception {
.field("long_value", i)
.field("unsigned_long_value", UNSIGNED_LONG_BASE.add(BigInteger.valueOf(10000 * i)))
.field("float_value", 0.1 * i)
.field("half_float_value", 0.1 * i)
.field("double_value", 0.1 * i)
.endObject()
);
Expand Down Expand Up @@ -794,6 +798,28 @@ public void testSimpleSorts() throws Exception {

assertThat(searchResponse.toString(), not(containsString("error")));

// HALF_FLOAT
size = 1 + random.nextInt(10);
searchResponse = client().prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("half_float_value", SortOrder.ASC).get();

assertHitCount(searchResponse, 10L);
assertThat(searchResponse.getHits().getHits().length, equalTo(size));
for (int i = 0; i < size; i++) {
assertThat(searchResponse.getHits().getAt(i).getId(), equalTo(Integer.toString(i)));
}

assertThat(searchResponse.toString(), not(containsString("error")));
size = 1 + random.nextInt(10);
searchResponse = client().prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("half_float_value", SortOrder.DESC).get();

assertHitCount(searchResponse, 10);
assertThat(searchResponse.getHits().getHits().length, equalTo(size));
for (int i = 0; i < size; i++) {
assertThat(searchResponse.getHits().getAt(i).getId(), equalTo(Integer.toString(9 - i)));
}

assertThat(searchResponse.toString(), not(containsString("error")));

// DOUBLE
size = 1 + random.nextInt(10);
searchResponse = client().prepareSearch().setQuery(matchAllQuery()).setSize(size).addSort("double_value", SortOrder.ASC).get();
Expand Down Expand Up @@ -1330,6 +1356,9 @@ public void testSortMVField() throws Exception {
.startObject("float_values")
.field("type", "float")
.endObject()
.startObject("half_float_values")
.field("type", "float")
.endObject()
.startObject("double_values")
.field("type", "double")
.endObject()
Expand All @@ -1351,6 +1380,7 @@ public void testSortMVField() throws Exception {
.array("short_values", 1, 5, 10, 8)
.array("byte_values", 1, 5, 10, 8)
.array("float_values", 1f, 5f, 10f, 8f)
.array("half_float_values", 1f, 5f, 10f, 8f)
.array("double_values", 1d, 5d, 10d, 8d)
.array("string_values", "01", "05", "10", "08")
.endObject()
Expand All @@ -1365,6 +1395,7 @@ public void testSortMVField() throws Exception {
.array("short_values", 11, 15, 20, 7)
.array("byte_values", 11, 15, 20, 7)
.array("float_values", 11f, 15f, 20f, 7f)
.array("half_float_values", 11f, 15f, 20f, 7f)
.array("double_values", 11d, 15d, 20d, 7d)
.array("string_values", "11", "15", "20", "07")
.endObject()
Expand All @@ -1379,6 +1410,7 @@ public void testSortMVField() throws Exception {
.array("short_values", 2, 1, 3, -4)
.array("byte_values", 2, 1, 3, -4)
.array("float_values", 2f, 1f, 3f, -4f)
.array("half_float_values", 2f, 1f, 3f, -4f)
.array("double_values", 2d, 1d, 3d, -4d)
.array("string_values", "02", "01", "03", "!4")
.endObject()
Expand Down Expand Up @@ -1585,6 +1617,34 @@ public void testSortMVField() throws Exception {
assertThat(searchResponse.getHits().getAt(2).getId(), equalTo(Integer.toString(3)));
assertThat(((Number) searchResponse.getHits().getAt(2).getSortValues()[0]).floatValue(), equalTo(3f));

searchResponse = client().prepareSearch().setQuery(matchAllQuery()).setSize(10).addSort("half_float_values", SortOrder.ASC).get();

assertThat(searchResponse.getHits().getTotalHits().value, equalTo(3L));
assertThat(searchResponse.getHits().getHits().length, equalTo(3));

assertThat(searchResponse.getHits().getAt(0).getId(), equalTo(Integer.toString(3)));
assertThat(((Number) searchResponse.getHits().getAt(0).getSortValues()[0]).floatValue(), equalTo(-4f));

assertThat(searchResponse.getHits().getAt(1).getId(), equalTo(Integer.toString(1)));
assertThat(((Number) searchResponse.getHits().getAt(1).getSortValues()[0]).floatValue(), equalTo(1f));

assertThat(searchResponse.getHits().getAt(2).getId(), equalTo(Integer.toString(2)));
assertThat(((Number) searchResponse.getHits().getAt(2).getSortValues()[0]).floatValue(), equalTo(7f));

searchResponse = client().prepareSearch().setQuery(matchAllQuery()).setSize(10).addSort("half_float_values", SortOrder.DESC).get();

assertThat(searchResponse.getHits().getTotalHits().value, equalTo(3L));
assertThat(searchResponse.getHits().getHits().length, equalTo(3));

assertThat(searchResponse.getHits().getAt(0).getId(), equalTo(Integer.toString(2)));
assertThat(((Number) searchResponse.getHits().getAt(0).getSortValues()[0]).floatValue(), equalTo(20f));

assertThat(searchResponse.getHits().getAt(1).getId(), equalTo(Integer.toString(1)));
assertThat(((Number) searchResponse.getHits().getAt(1).getSortValues()[0]).floatValue(), equalTo(10f));

assertThat(searchResponse.getHits().getAt(2).getId(), equalTo(Integer.toString(3)));
assertThat(((Number) searchResponse.getHits().getAt(2).getSortValues()[0]).floatValue(), equalTo(3f));

searchResponse = client().prepareSearch().setQuery(matchAllQuery()).setSize(10).addSort("double_values", SortOrder.ASC).get();

assertThat(searchResponse.getHits().getTotalHits().value, equalTo(3L));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.opensearch.index.fielddata.IndexFieldData.XFieldComparatorSource.Nested;
import org.opensearch.index.fielddata.fieldcomparator.DoubleValuesComparatorSource;
import org.opensearch.index.fielddata.fieldcomparator.FloatValuesComparatorSource;
import org.opensearch.index.fielddata.fieldcomparator.HalfFloatValuesComparatorSource;
import org.opensearch.index.fielddata.fieldcomparator.IntValuesComparatorSource;
import org.opensearch.index.fielddata.fieldcomparator.LongValuesComparatorSource;
import org.opensearch.index.fielddata.fieldcomparator.UnsignedLongValuesComparatorSource;
Expand Down Expand Up @@ -220,6 +221,8 @@ private XFieldComparatorSource comparatorSource(
final XFieldComparatorSource source;
switch (targetNumericType) {
case HALF_FLOAT:
source = new HalfFloatValuesComparatorSource(this, missingValue, sortMode, nested);
break;
case FLOAT:
source = new FloatValuesComparatorSource(this, missingValue, sortMode, nested);
break;
Expand All @@ -242,7 +245,7 @@ private XFieldComparatorSource comparatorSource(
assert !targetNumericType.isFloatingPoint();
source = new IntValuesComparatorSource(this, missingValue, sortMode, nested);
}
if (targetNumericType != getNumericType() || getNumericType() == NumericType.HALF_FLOAT) {
if (targetNumericType != getNumericType()) {
source.disableSkipping(); // disable skipping logic for cast of sort field
}
return source;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
*/
public class FloatValuesComparatorSource extends IndexFieldData.XFieldComparatorSource {

private final IndexNumericFieldData indexFieldData;
protected final IndexNumericFieldData indexFieldData;

public FloatValuesComparatorSource(
IndexNumericFieldData indexFieldData,
Expand All @@ -78,7 +78,7 @@ public SortField.Type reducedType() {
return SortField.Type.FLOAT;
}

private NumericDoubleValues getNumericDocValues(LeafReaderContext context, float missingValue) throws IOException {
protected NumericDoubleValues getNumericDocValues(LeafReaderContext context, float missingValue) throws IOException {
final SortedNumericDoubleValues values = indexFieldData.load(context).getDoubleValues();
if (nested == null) {
return FieldData.replaceMissing(sortMode.select(values), missingValue);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.index.fielddata.fieldcomparator;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.LeafFieldComparator;
import org.opensearch.index.fielddata.IndexNumericFieldData;
import org.opensearch.index.search.comparators.HalfFloatComparator;
import org.opensearch.search.MultiValueMode;

import java.io.IOException;

/**
* Comparator source for half_float values.
*
* @opensearch.internal
*/
public class HalfFloatValuesComparatorSource extends FloatValuesComparatorSource {
public HalfFloatValuesComparatorSource(
IndexNumericFieldData indexFieldData,
Object missingValue,
MultiValueMode sortMode,
Nested nested
) {
super(indexFieldData, missingValue, sortMode, nested);
}

@Override
public FieldComparator<?> newComparator(String fieldname, int numHits, boolean enableSkipping, boolean reversed) {
assert indexFieldData == null || fieldname.equals(indexFieldData.getFieldName());

final float fMissingValue = (Float) missingObject(missingValue, reversed);
// NOTE: it's important to pass null as a missing value in the constructor so that
// the comparator doesn't check docsWithField since we replace missing values in select()
return new HalfFloatComparator(numHits, fieldname, null, reversed, enableSkipping && this.enableSkipping) {
@Override
public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException {
return new HalfFloatLeafComparator(context) {
@Override
protected NumericDocValues getNumericDocValues(LeafReaderContext context, String field) throws IOException {
return HalfFloatValuesComparatorSource.this.getNumericDocValues(context, fMissingValue).getRawFloatValues();
}
};
}
};
}
}

0 comments on commit f200786

Please sign in to comment.