diff --git a/docs/changelog/138138.yaml b/docs/changelog/138138.yaml
new file mode 100644
index 0000000000000..59b228755f612
--- /dev/null
+++ b/docs/changelog/138138.yaml
@@ -0,0 +1,5 @@
+pr: 138138
+summary: Fixing sorted indices for GPU built indices
+area: Vector Search
+type: bug
+issues: []
diff --git a/x-pack/plugin/gpu/src/internalClusterTest/java/org/elasticsearch/plugin/gpu/GPUIndexIT.java b/x-pack/plugin/gpu/src/internalClusterTest/java/org/elasticsearch/plugin/gpu/GPUIndexIT.java
index f5d8191dcee5c..7456482985b55 100644
--- a/x-pack/plugin/gpu/src/internalClusterTest/java/org/elasticsearch/plugin/gpu/GPUIndexIT.java
+++ b/x-pack/plugin/gpu/src/internalClusterTest/java/org/elasticsearch/plugin/gpu/GPUIndexIT.java
@@ -15,7 +15,9 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.search.vectors.ExactKnnQueryBuilder;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
+import org.elasticsearch.search.vectors.VectorData;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xpack.gpu.GPUPlugin;
import org.elasticsearch.xpack.gpu.GPUSupport;
@@ -23,8 +25,10 @@
import org.junit.BeforeClass;
import java.util.Collection;
+import java.util.HashSet;
import java.util.List;
import java.util.Locale;
+import java.util.Set;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures;
@@ -58,7 +62,6 @@ public void testBasic() {
assertSearch(indexName, randomFloatVector(dims), totalDocs);
}
- @AwaitsFix(bugUrl = "Fix sorted index")
public void testSortedIndexReturnsSameResultsAsUnsorted() {
String indexName1 = "index_unsorted";
String indexName2 = "index_sorted";
@@ -66,12 +69,12 @@ public void testSortedIndexReturnsSameResultsAsUnsorted() {
createIndex(indexName1, dims, false);
createIndex(indexName2, dims, true);
- final int[] numDocs = new int[] { randomIntBetween(50, 100), randomIntBetween(50, 100) };
+ final int[] numDocs = new int[] { randomIntBetween(300, 999), randomIntBetween(300, 999) };
for (int i = 0; i < numDocs.length; i++) {
BulkRequestBuilder bulkRequest1 = client().prepareBulk();
BulkRequestBuilder bulkRequest2 = client().prepareBulk();
for (int j = 0; j < numDocs[i]; j++) {
- String id = String.valueOf(i * 100 + j);
+ String id = String.valueOf(i * 1000 + j);
String keywordValue = String.valueOf(numDocs[i] - j);
float[] vector = randomFloatVector(dims);
bulkRequest1.add(prepareIndex(indexName1).setId(id).setSource("my_vector", vector, "my_keyword", keywordValue));
@@ -86,8 +89,9 @@ public void testSortedIndexReturnsSameResultsAsUnsorted() {
float[] queryVector = randomFloatVector(dims);
int k = 10;
- int numCandidates = k * 10;
+ int numCandidates = k * 5;
+ // Test 1: Approximate KNN search - expect at least k-3 out of k matches
var searchResponse1 = prepareSearch(indexName1).setSize(k)
.setFetchSource(false)
.addFetchField("my_keyword")
@@ -103,22 +107,40 @@ public void testSortedIndexReturnsSameResultsAsUnsorted() {
try {
SearchHit[] hits1 = searchResponse1.getHits().getHits();
SearchHit[] hits2 = searchResponse2.getHits().getHits();
- Assert.assertEquals(hits1.length, hits2.length);
- for (int i = 0; i < hits1.length; i++) {
- Assert.assertEquals(hits1[i].getId(), hits2[i].getId());
- Assert.assertEquals(hits1[i].field("my_keyword").getValue(), (String) hits2[i].field("my_keyword").getValue());
- Assert.assertEquals(hits1[i].getScore(), hits2[i].getScore(), 0.001f);
- }
+ assertAtLeastNOutOfKMatches(hits1, hits2, k - 3, k);
} finally {
searchResponse1.decRef();
searchResponse2.decRef();
}
+ // Test 2: Exact KNN search (brute-force) - expect perfect k out of k matches
+ var exactSearchResponse1 = prepareSearch(indexName1).setSize(k)
+ .setFetchSource(false)
+ .addFetchField("my_keyword")
+ .setQuery(new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector), "my_vector", null))
+ .get();
+
+ var exactSearchResponse2 = prepareSearch(indexName2).setSize(k)
+ .setFetchSource(false)
+ .addFetchField("my_keyword")
+ .setQuery(new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector), "my_vector", null))
+ .get();
+
+ try {
+ SearchHit[] exactHits1 = exactSearchResponse1.getHits().getHits();
+ SearchHit[] exactHits2 = exactSearchResponse2.getHits().getHits();
+ assertExactMatches(exactHits1, exactHits2, k);
+ } finally {
+ exactSearchResponse1.decRef();
+ exactSearchResponse2.decRef();
+ }
+
// Force merge and search again
assertNoFailures(indicesAdmin().prepareForceMerge(indexName1).get());
assertNoFailures(indicesAdmin().prepareForceMerge(indexName2).get());
ensureGreen();
+ // Test 3: Approximate KNN search - expect at least k-3 out of k matches
var searchResponse3 = prepareSearch(indexName1).setSize(k)
.setFetchSource(false)
.addFetchField("my_keyword")
@@ -134,16 +156,33 @@ public void testSortedIndexReturnsSameResultsAsUnsorted() {
try {
SearchHit[] hits3 = searchResponse3.getHits().getHits();
SearchHit[] hits4 = searchResponse4.getHits().getHits();
- Assert.assertEquals(hits3.length, hits4.length);
- for (int i = 0; i < hits3.length; i++) {
- Assert.assertEquals(hits3[i].getId(), hits4[i].getId());
- Assert.assertEquals(hits3[i].field("my_keyword").getValue(), (String) hits4[i].field("my_keyword").getValue());
- Assert.assertEquals(hits3[i].getScore(), hits4[i].getScore(), 0.01f);
- }
+ assertAtLeastNOutOfKMatches(hits3, hits4, k - 3, k);
} finally {
searchResponse3.decRef();
searchResponse4.decRef();
}
+
+ // Test 4: Exact KNN search after merge - expect perfect k out of k matches
+ var exactSearchResponse3 = prepareSearch(indexName1).setSize(k)
+ .setFetchSource(false)
+ .addFetchField("my_keyword")
+ .setQuery(new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector), "my_vector", null))
+ .get();
+
+ var exactSearchResponse4 = prepareSearch(indexName2).setSize(k)
+ .setFetchSource(false)
+ .addFetchField("my_keyword")
+ .setQuery(new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector), "my_vector", null))
+ .get();
+
+ try {
+ SearchHit[] exactHits3 = exactSearchResponse3.getHits().getHits();
+ SearchHit[] exactHits4 = exactSearchResponse4.getHits().getHits();
+ assertExactMatches(exactHits3, exactHits4, k);
+ } finally {
+ exactSearchResponse3.decRef();
+ exactSearchResponse4.decRef();
+ }
}
public void testSearchWithoutGPU() {
@@ -263,4 +302,56 @@ private static float[] randomFloatVector(int dims) {
}
return vector;
}
+
+ /**
+ * Asserts that at least N out of K hits have matching IDs between two result sets.
+ */
+ private static void assertAtLeastNOutOfKMatches(SearchHit[] hits1, SearchHit[] hits2, int minMatches, int k) {
+ Assert.assertEquals("Both result sets should have k hits", k, hits1.length);
+ Assert.assertEquals("Both result sets should have k hits", k, hits2.length);
+ Set