Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/138138.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 138138
summary: Fixing sorted indices for GPU built indices
area: Vector Search
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.vectors.ExactKnnQueryBuilder;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.search.vectors.VectorData;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xpack.gpu.GPUPlugin;
import org.elasticsearch.xpack.gpu.GPUSupport;
import org.junit.Assert;
import org.junit.BeforeClass;

import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;

import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures;
Expand Down Expand Up @@ -58,20 +62,19 @@ public void testBasic() {
assertSearch(indexName, randomFloatVector(dims), totalDocs);
}

@AwaitsFix(bugUrl = "Fix sorted index")
public void testSortedIndexReturnsSameResultsAsUnsorted() {
String indexName1 = "index_unsorted";
String indexName2 = "index_sorted";
final int dims = randomIntBetween(4, 128);
createIndex(indexName1, dims, false);
createIndex(indexName2, dims, true);

final int[] numDocs = new int[] { randomIntBetween(50, 100), randomIntBetween(50, 100) };
final int[] numDocs = new int[] { randomIntBetween(300, 999), randomIntBetween(300, 999) };
for (int i = 0; i < numDocs.length; i++) {
BulkRequestBuilder bulkRequest1 = client().prepareBulk();
BulkRequestBuilder bulkRequest2 = client().prepareBulk();
for (int j = 0; j < numDocs[i]; j++) {
String id = String.valueOf(i * 100 + j);
String id = String.valueOf(i * 1000 + j);
String keywordValue = String.valueOf(numDocs[i] - j);
float[] vector = randomFloatVector(dims);
bulkRequest1.add(prepareIndex(indexName1).setId(id).setSource("my_vector", vector, "my_keyword", keywordValue));
Expand All @@ -86,8 +89,9 @@ public void testSortedIndexReturnsSameResultsAsUnsorted() {

float[] queryVector = randomFloatVector(dims);
int k = 10;
int numCandidates = k * 10;
int numCandidates = k * 5;

// Test 1: Approximate KNN search - expect at least k-3 out of k matches
var searchResponse1 = prepareSearch(indexName1).setSize(k)
.setFetchSource(false)
.addFetchField("my_keyword")
Expand All @@ -103,22 +107,40 @@ public void testSortedIndexReturnsSameResultsAsUnsorted() {
try {
SearchHit[] hits1 = searchResponse1.getHits().getHits();
SearchHit[] hits2 = searchResponse2.getHits().getHits();
Assert.assertEquals(hits1.length, hits2.length);
for (int i = 0; i < hits1.length; i++) {
Assert.assertEquals(hits1[i].getId(), hits2[i].getId());
Assert.assertEquals(hits1[i].field("my_keyword").getValue(), (String) hits2[i].field("my_keyword").getValue());
Assert.assertEquals(hits1[i].getScore(), hits2[i].getScore(), 0.001f);
}
assertAtLeastNOutOfKMatches(hits1, hits2, k - 3, k);
} finally {
searchResponse1.decRef();
searchResponse2.decRef();
}

// Test 2: Exact KNN search (brute-force) - expect perfect k out of k matches
var exactSearchResponse1 = prepareSearch(indexName1).setSize(k)
.setFetchSource(false)
.addFetchField("my_keyword")
.setQuery(new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector), "my_vector", null))
.get();

var exactSearchResponse2 = prepareSearch(indexName2).setSize(k)
.setFetchSource(false)
.addFetchField("my_keyword")
.setQuery(new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector), "my_vector", null))
.get();

try {
SearchHit[] exactHits1 = exactSearchResponse1.getHits().getHits();
SearchHit[] exactHits2 = exactSearchResponse2.getHits().getHits();
assertExactMatches(exactHits1, exactHits2, k);
} finally {
exactSearchResponse1.decRef();
exactSearchResponse2.decRef();
}

// Force merge and search again
assertNoFailures(indicesAdmin().prepareForceMerge(indexName1).get());
assertNoFailures(indicesAdmin().prepareForceMerge(indexName2).get());
ensureGreen();

// Test 3: Approximate KNN search - expect at least k-3 out of k matches
var searchResponse3 = prepareSearch(indexName1).setSize(k)
.setFetchSource(false)
.addFetchField("my_keyword")
Expand All @@ -134,16 +156,33 @@ public void testSortedIndexReturnsSameResultsAsUnsorted() {
try {
SearchHit[] hits3 = searchResponse3.getHits().getHits();
SearchHit[] hits4 = searchResponse4.getHits().getHits();
Assert.assertEquals(hits3.length, hits4.length);
for (int i = 0; i < hits3.length; i++) {
Assert.assertEquals(hits3[i].getId(), hits4[i].getId());
Assert.assertEquals(hits3[i].field("my_keyword").getValue(), (String) hits4[i].field("my_keyword").getValue());
Assert.assertEquals(hits3[i].getScore(), hits4[i].getScore(), 0.01f);
}
assertAtLeastNOutOfKMatches(hits3, hits4, k - 3, k);
} finally {
searchResponse3.decRef();
searchResponse4.decRef();
}

// Test 4: Exact KNN search after merge - expect perfect k out of k matches
var exactSearchResponse3 = prepareSearch(indexName1).setSize(k)
.setFetchSource(false)
.addFetchField("my_keyword")
.setQuery(new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector), "my_vector", null))
.get();

var exactSearchResponse4 = prepareSearch(indexName2).setSize(k)
.setFetchSource(false)
.addFetchField("my_keyword")
.setQuery(new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector), "my_vector", null))
.get();

try {
SearchHit[] exactHits3 = exactSearchResponse3.getHits().getHits();
SearchHit[] exactHits4 = exactSearchResponse4.getHits().getHits();
assertExactMatches(exactHits3, exactHits4, k);
} finally {
exactSearchResponse3.decRef();
exactSearchResponse4.decRef();
}
}

public void testSearchWithoutGPU() {
Expand Down Expand Up @@ -263,4 +302,56 @@ private static float[] randomFloatVector(int dims) {
}
return vector;
}

/**
* Asserts that at least N out of K hits have matching IDs between two result sets.
*/
private static void assertAtLeastNOutOfKMatches(SearchHit[] hits1, SearchHit[] hits2, int minMatches, int k) {
Assert.assertEquals("Both result sets should have k hits", k, hits1.length);
Assert.assertEquals("Both result sets should have k hits", k, hits2.length);
Set<String> ids1 = new HashSet<>();
Set<String> ids2 = new HashSet<>();

for (SearchHit hit : hits1) {
ids1.add(hit.getId());
}
for (SearchHit hit : hits2) {
ids2.add(hit.getId());
}

Set<String> intersection = new HashSet<>(ids1);
intersection.retainAll(ids2);
Assert.assertTrue(
String.format(
Locale.ROOT,
"Expected at least %d matching IDs out of %d, but found %d. IDs1: %s, IDs2: %s",
minMatches,
k,
intersection.size(),
ids1,
ids2
),
intersection.size() >= minMatches
);
}

/**
* Asserts that two result sets have exactly the same document IDs in the same order with the same scores.
* Used for exact (brute-force) KNN search which should be deterministic.
* Expects k out of k matches.
*/
private static void assertExactMatches(SearchHit[] hits1, SearchHit[] hits2, int k) {
Assert.assertEquals("Both result sets should have k hits", k, hits1.length);
Assert.assertEquals("Both result sets should have k hits", k, hits2.length);

for (int i = 0; i < k; i++) {
Assert.assertEquals(String.format(Locale.ROOT, "Document ID mismatch at position %d", i), hits1[i].getId(), hits2[i].getId());
Assert.assertEquals(
String.format(Locale.ROOT, "Score mismatch for document ID %s at position %d", hits1[i].getId(), i),
hits1[i].getScore(),
hits2[i].getScore(),
0.0001f
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
Expand Down Expand Up @@ -165,7 +166,6 @@ public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException
* </p>
*/
@Override
// TODO: fix sorted index case
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
var started = System.nanoTime();
flatVectorWriter.flush(maxDoc, sortMap);
Expand All @@ -184,7 +184,11 @@ private void flushFieldsWithoutMemoryMappedFile(Sorter.DocMap sortMap) throws IO
var started = System.nanoTime();
var fieldInfo = field.fieldInfo;

var numVectors = field.flatFieldVectorsWriter.getVectors().size();
var originalVectors = field.flatFieldVectorsWriter.getVectors();
final List<float[]> vectorsInSortedOrder = sortMap == null
? originalVectors
: getVectorsInSortedOrder(field, sortMap, originalVectors);
int numVectors = vectorsInSortedOrder.size();
CagraIndexParams cagraIndexParams = createCagraIndexParams(
fieldInfo.getVectorSimilarityFunction(),
numVectors,
Expand All @@ -194,7 +198,7 @@ private void flushFieldsWithoutMemoryMappedFile(Sorter.DocMap sortMap) throws IO
if (numVectors < MIN_NUM_VECTORS_FOR_GPU_BUILD) {
logger.debug("Skip building carga index; vectors length {} < {} (min for GPU)", numVectors, MIN_NUM_VECTORS_FOR_GPU_BUILD);
// Will not be indexed on the GPU
flushFieldWithMockGraph(fieldInfo, numVectors, sortMap);
generateMockGraphAndWriteMeta(fieldInfo, numVectors);
} else {
try (
var resourcesHolder = new ResourcesHolder(
Expand All @@ -208,11 +212,11 @@ private void flushFieldsWithoutMemoryMappedFile(Sorter.DocMap sortMap) throws IO
fieldInfo.getVectorDimension(),
CuVSMatrix.DataType.FLOAT
);
for (var vector : field.flatFieldVectorsWriter.getVectors()) {
for (var vector : vectorsInSortedOrder) {
builder.addVector(vector);
}
try (var dataset = builder.build()) {
flushFieldWithGpuGraph(resourcesHolder, fieldInfo, dataset, sortMap, cagraIndexParams);
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams);
}
}
}
Expand All @@ -221,28 +225,17 @@ private void flushFieldsWithoutMemoryMappedFile(Sorter.DocMap sortMap) throws IO
}
}

private void flushFieldWithMockGraph(FieldInfo fieldInfo, int numVectors, Sorter.DocMap sortMap) throws IOException {
if (sortMap == null) {
generateMockGraphAndWriteMeta(fieldInfo, numVectors);
} else {
// TODO: use sortMap
generateMockGraphAndWriteMeta(fieldInfo, numVectors);
}
}

private void flushFieldWithGpuGraph(
ResourcesHolder resourcesHolder,
FieldInfo fieldInfo,
CuVSMatrix dataset,
Sorter.DocMap sortMap,
CagraIndexParams cagraIndexParams
) throws IOException {
if (sortMap == null) {
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams);
} else {
// TODO: use sortMap
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams);
private List<float[]> getVectorsInSortedOrder(FieldWriter field, Sorter.DocMap sortMap, List<float[]> originalVectors)
throws IOException {
DocsWithFieldSet docsWithField = field.getDocsWithFieldSet();
int[] ordMap = new int[docsWithField.cardinality()];
DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
KnnVectorsWriter.mapOldOrdToNewOrd(docsWithField, sortMap, null, ordMap, newDocsWithField);
List<float[]> vectorsInSortedOrder = new ArrayList<>(ordMap.length);
for (int oldOrd : ordMap) {
vectorsInSortedOrder.add(originalVectors.get(oldOrd));
}
return vectorsInSortedOrder;
}

@Override
Expand Down Expand Up @@ -514,8 +507,10 @@ public NodesIterator getNodesOnLevel(int level) {

// TODO check with deleted documents
@Override
// fix sorted index case
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
// Note: Merged raw vectors are already in sorted order. The flatVectorWriter and MergedVectorValues utilities
// apply mergeState.docMaps internally, so vectors are returned in the final sorted document order.
// Unlike flush(), we don't need to explicitly handle sorting here.
try (var scorerSupplier = flatVectorWriter.mergeOneFieldToIndex(fieldInfo, mergeState)) {
var started = System.nanoTime();
int numVectors = scorerSupplier.totalVectorCount();
Expand Down Expand Up @@ -846,5 +841,9 @@ public float[] copyValue(float[] vectorValue) {
public long ramBytesUsed() {
return SHALLOW_SIZE + flatFieldVectorsWriter.ramBytesUsed();
}

public DocsWithFieldSet getDocsWithFieldSet() {
return flatFieldVectorsWriter.getDocsWithFieldSet();
}
}
}