From 3a87953e6252e88c14dd249461d56f3ac4723813 Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Wed, 19 Nov 2025 17:33:53 +0100 Subject: [PATCH] [DiskBBQ] save same vector operations when performing centroid filtering --- .../next/ESNextDiskBBQVectorsReader.java | 106 +++++++++++------- .../search/vectors/IVFKnnSearchStrategy.java | 2 +- .../next/ESNextDiskBBQVectorsFormatTests.java | 21 +++- 3 files changed, 85 insertions(+), 44 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java index 8e4fdab12930b..b30d560002679 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java @@ -272,11 +272,14 @@ private static CentroidIterator getCentroidIteratorNoParent( FixedBitSet acceptCentroids ) throws IOException { final NeighborQueue neighborQueue = new NeighborQueue(numCentroids, true); + final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Integer.BYTES; score( neighborQueue, numCentroids, 0, scorer, + centroids, + centroidQuantizeSize, quantizeQuery, queryParams, globalCentroidDp, @@ -315,26 +318,41 @@ private static CentroidIterator getCentroidIteratorWithParents( FixedBitSet acceptCentroids ) throws IOException { // build the three queues we are going to use + final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Integer.BYTES; final NeighborQueue parentsQueue = new NeighborQueue(numParents, true); final int maxChildrenSize = centroids.readVInt(); final NeighborQueue currentParentQueue = new NeighborQueue(maxChildrenSize, true); final int bufferSize = (int) Math.min(Math.max(centroidRatio * numCentroids, 1), numCentroids); - final NeighborQueue neighborQueue = new NeighborQueue(bufferSize, true); - // score the parents + final int numCentroidsFiltered = acceptCentroids == null ? numCentroids : acceptCentroids.cardinality(); final float[] scores = new float[ES92Int7VectorsScorer.BULK_SIZE]; - score( - parentsQueue, - numParents, - 0, - scorer, - quantizeQuery, - queryParams, - globalCentroidDp, - fieldInfo.getVectorSimilarityFunction(), - scores, - null - ); - final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Integer.BYTES; + final NeighborQueue neighborQueue; + if (acceptCentroids != null && numCentroidsFiltered <= bufferSize) { + // we are collecting every non-filter centroid, therefore we do not need to score the + // parents. We give each of them the same score. + neighborQueue = new NeighborQueue(numCentroidsFiltered, true); + for (int i = 0; i < numParents; i++) { + parentsQueue.add(i, 0.5f); + } + centroids.skipBytes(centroidQuantizeSize * numParents); + } else { + neighborQueue = new NeighborQueue(bufferSize, true); + // score the parents + score( + parentsQueue, + numParents, + 0, + scorer, + centroids, + centroidQuantizeSize, + quantizeQuery, + queryParams, + globalCentroidDp, + fieldInfo.getVectorSimilarityFunction(), + scores, + null + ); + } + final long offset = centroids.getFilePointer(); final long childrenOffset = offset + (long) Long.BYTES * numParents; // populate the children's queue by reading parents one by one @@ -429,6 +447,8 @@ private static void populateOneChildrenGroup( numChildren, childrenOrdinal, scorer, + centroids, + centroidQuantizeSize, quantizeQuery, queryParams, globalCentroidDp, @@ -443,6 +463,8 @@ private static void score( int size, int scoresOffset, ES92Int7VectorsScorer scorer, + IndexInput centroids, + long centroidQuantizeSize, byte[] quantizeQuery, OptimizedScalarQuantizer.QuantizationResult queryCorrections, float centroidDp, @@ -450,41 +472,47 @@ private static void score( float[] scores, FixedBitSet acceptCentroids ) throws IOException { - // TODO: if accept centroids is not null, we can save some vector ops here int limit = size - ES92Int7VectorsScorer.BULK_SIZE + 1; int i = 0; for (; i < limit; i += ES92Int7VectorsScorer.BULK_SIZE) { - scorer.scoreBulk( - quantizeQuery, - queryCorrections.lowerInterval(), - queryCorrections.upperInterval(), - queryCorrections.quantizedComponentSum(), - queryCorrections.additionalCorrection(), - similarityFunction, - centroidDp, - scores - ); - for (int j = 0; j < ES92Int7VectorsScorer.BULK_SIZE; j++) { - int centroidOrd = scoresOffset + i + j; - if (acceptCentroids == null || acceptCentroids.get(centroidOrd)) { - neighborQueue.add(centroidOrd, scores[j]); + if (acceptCentroids == null + || acceptCentroids.cardinality(scoresOffset + i, scoresOffset + i + ES92Int7VectorsScorer.BULK_SIZE) > 0) { + scorer.scoreBulk( + quantizeQuery, + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + similarityFunction, + centroidDp, + scores + ); + for (int j = 0; j < ES92Int7VectorsScorer.BULK_SIZE; j++) { + int centroidOrd = scoresOffset + i + j; + if (acceptCentroids == null || acceptCentroids.get(centroidOrd)) { + neighborQueue.add(centroidOrd, scores[j]); + } } + } else { + centroids.skipBytes(ES92Int7VectorsScorer.BULK_SIZE * centroidQuantizeSize); } } for (; i < size; i++) { - float score = scorer.score( - quantizeQuery, - queryCorrections.lowerInterval(), - queryCorrections.upperInterval(), - queryCorrections.quantizedComponentSum(), - queryCorrections.additionalCorrection(), - similarityFunction, - centroidDp - ); int centroidOrd = scoresOffset + i; if (acceptCentroids == null || acceptCentroids.get(centroidOrd)) { + float score = scorer.score( + quantizeQuery, + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + similarityFunction, + centroidDp + ); neighborQueue.add(centroidOrd, score); + } else { + centroids.skipBytes(centroidQuantizeSize); } } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/IVFKnnSearchStrategy.java b/server/src/main/java/org/elasticsearch/search/vectors/IVFKnnSearchStrategy.java index 52311298c7bf5..868c4f8227652 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/IVFKnnSearchStrategy.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/IVFKnnSearchStrategy.java @@ -19,7 +19,7 @@ public class IVFKnnSearchStrategy extends KnnSearchStrategy { private final SetOnce collector = new SetOnce<>(); private final LongAccumulator accumulator; - IVFKnnSearchStrategy(float visitRatio, LongAccumulator accumulator) { + public IVFKnnSearchStrategy(float visitRatio, LongAccumulator accumulator) { this.visitRatio = visitRatio; this.accumulator = accumulator; } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormatTests.java index 04e1416c9c3c4..59c4dc05bdd77 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormatTests.java @@ -29,12 +29,15 @@ import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.AcceptDocs; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopKnnCollector; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.search.vectors.IVFKnnSearchStrategy; import org.junit.Before; import java.io.IOException; @@ -353,17 +356,27 @@ private void doRestrictiveFilter(boolean dense) throws IOException { LeafReader leafReader = getOnlyLeafReader(reader); float[] vector = randomVector(dimensions); // we might collect the same document twice because of soar assignments - TopDocs topDocs = leafReader.searchNearestVectors( + KnnCollector collector; + if (random().nextBoolean()) { + collector = new TopKnnCollector(random().nextInt(2 * matchingDocs, 3 * matchingDocs), Integer.MAX_VALUE); + } else { + collector = new TopKnnCollector( + random().nextInt(2 * matchingDocs, 3 * matchingDocs), + Integer.MAX_VALUE, + new IVFKnnSearchStrategy(0.25f, null) + ); + } + leafReader.searchNearestVectors( "f", vector, - random().nextInt(2 * matchingDocs, 3 * matchingDocs), + collector, AcceptDocs.fromIteratorSupplier( () -> leafReader.postings(new Term("k", new BytesRef("A"))), leafReader.getLiveDocs(), leafReader.maxDoc() - ), - Integer.MAX_VALUE + ) ); + TopDocs topDocs = collector.topDocs(); Set uniqueDocIds = new HashSet<>(); for (int i = 0; i < topDocs.scoreDocs.length; i++) { uniqueDocIds.add(topDocs.scoreDocs[i].doc);