Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -272,11 +272,14 @@ private static CentroidIterator getCentroidIteratorNoParent(
FixedBitSet acceptCentroids
) throws IOException {
final NeighborQueue neighborQueue = new NeighborQueue(numCentroids, true);
final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Integer.BYTES;
score(
neighborQueue,
numCentroids,
0,
scorer,
centroids,
centroidQuantizeSize,
quantizeQuery,
queryParams,
globalCentroidDp,
Expand Down Expand Up @@ -315,26 +318,41 @@ private static CentroidIterator getCentroidIteratorWithParents(
FixedBitSet acceptCentroids
) throws IOException {
// build the three queues we are going to use
final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Integer.BYTES;
final NeighborQueue parentsQueue = new NeighborQueue(numParents, true);
final int maxChildrenSize = centroids.readVInt();
final NeighborQueue currentParentQueue = new NeighborQueue(maxChildrenSize, true);
final int bufferSize = (int) Math.min(Math.max(centroidRatio * numCentroids, 1), numCentroids);
final NeighborQueue neighborQueue = new NeighborQueue(bufferSize, true);
// score the parents
final int numCentroidsFiltered = acceptCentroids == null ? numCentroids : acceptCentroids.cardinality();
final float[] scores = new float[ES92Int7VectorsScorer.BULK_SIZE];
score(
parentsQueue,
numParents,
0,
scorer,
quantizeQuery,
queryParams,
globalCentroidDp,
fieldInfo.getVectorSimilarityFunction(),
scores,
null
);
final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Integer.BYTES;
final NeighborQueue neighborQueue;
if (acceptCentroids != null && numCentroidsFiltered <= bufferSize) {
// we are collecting every non-filter centroid, therefore we do not need to score the
// parents. We give each of them the same score.
neighborQueue = new NeighborQueue(numCentroidsFiltered, true);
for (int i = 0; i < numParents; i++) {
parentsQueue.add(i, 0.5f);
}
centroids.skipBytes(centroidQuantizeSize * numParents);
} else {
neighborQueue = new NeighborQueue(bufferSize, true);
// score the parents
score(
parentsQueue,
numParents,
0,
scorer,
centroids,
centroidQuantizeSize,
quantizeQuery,
queryParams,
globalCentroidDp,
fieldInfo.getVectorSimilarityFunction(),
scores,
null
);
}

final long offset = centroids.getFilePointer();
final long childrenOffset = offset + (long) Long.BYTES * numParents;
// populate the children's queue by reading parents one by one
Expand Down Expand Up @@ -429,6 +447,8 @@ private static void populateOneChildrenGroup(
numChildren,
childrenOrdinal,
scorer,
centroids,
centroidQuantizeSize,
quantizeQuery,
queryParams,
globalCentroidDp,
Expand All @@ -443,48 +463,56 @@ private static void score(
int size,
int scoresOffset,
ES92Int7VectorsScorer scorer,
IndexInput centroids,
long centroidQuantizeSize,
byte[] quantizeQuery,
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
float centroidDp,
VectorSimilarityFunction similarityFunction,
float[] scores,
FixedBitSet acceptCentroids
) throws IOException {
// TODO: if accept centroids is not null, we can save some vector ops here
int limit = size - ES92Int7VectorsScorer.BULK_SIZE + 1;
int i = 0;
for (; i < limit; i += ES92Int7VectorsScorer.BULK_SIZE) {
scorer.scoreBulk(
quantizeQuery,
queryCorrections.lowerInterval(),
queryCorrections.upperInterval(),
queryCorrections.quantizedComponentSum(),
queryCorrections.additionalCorrection(),
similarityFunction,
centroidDp,
scores
);
for (int j = 0; j < ES92Int7VectorsScorer.BULK_SIZE; j++) {
int centroidOrd = scoresOffset + i + j;
if (acceptCentroids == null || acceptCentroids.get(centroidOrd)) {
neighborQueue.add(centroidOrd, scores[j]);
if (acceptCentroids == null
|| acceptCentroids.cardinality(scoresOffset + i, scoresOffset + i + ES92Int7VectorsScorer.BULK_SIZE) > 0) {
scorer.scoreBulk(
quantizeQuery,
queryCorrections.lowerInterval(),
queryCorrections.upperInterval(),
queryCorrections.quantizedComponentSum(),
queryCorrections.additionalCorrection(),
similarityFunction,
centroidDp,
scores
);
for (int j = 0; j < ES92Int7VectorsScorer.BULK_SIZE; j++) {
int centroidOrd = scoresOffset + i + j;
if (acceptCentroids == null || acceptCentroids.get(centroidOrd)) {
neighborQueue.add(centroidOrd, scores[j]);
}
}
} else {
centroids.skipBytes(ES92Int7VectorsScorer.BULK_SIZE * centroidQuantizeSize);
}
}

for (; i < size; i++) {
float score = scorer.score(
quantizeQuery,
queryCorrections.lowerInterval(),
queryCorrections.upperInterval(),
queryCorrections.quantizedComponentSum(),
queryCorrections.additionalCorrection(),
similarityFunction,
centroidDp
);
int centroidOrd = scoresOffset + i;
if (acceptCentroids == null || acceptCentroids.get(centroidOrd)) {
float score = scorer.score(
quantizeQuery,
queryCorrections.lowerInterval(),
queryCorrections.upperInterval(),
queryCorrections.quantizedComponentSum(),
queryCorrections.additionalCorrection(),
similarityFunction,
centroidDp
);
neighborQueue.add(centroidOrd, score);
} else {
centroids.skipBytes(centroidQuantizeSize);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public class IVFKnnSearchStrategy extends KnnSearchStrategy {
private final SetOnce<AbstractMaxScoreKnnCollector> collector = new SetOnce<>();
private final LongAccumulator accumulator;

IVFKnnSearchStrategy(float visitRatio, LongAccumulator accumulator) {
public IVFKnnSearchStrategy(float visitRatio, LongAccumulator accumulator) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need this on test to increase the visit ratio

this.visitRatio = visitRatio;
this.accumulator = accumulator;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.logging.LogConfigurator;
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;
import org.junit.Before;

import java.io.IOException;
Expand Down Expand Up @@ -353,17 +356,27 @@ private void doRestrictiveFilter(boolean dense) throws IOException {
LeafReader leafReader = getOnlyLeafReader(reader);
float[] vector = randomVector(dimensions);
// we might collect the same document twice because of soar assignments
TopDocs topDocs = leafReader.searchNearestVectors(
KnnCollector collector;
if (random().nextBoolean()) {
collector = new TopKnnCollector(random().nextInt(2 * matchingDocs, 3 * matchingDocs), Integer.MAX_VALUE);
} else {
collector = new TopKnnCollector(
random().nextInt(2 * matchingDocs, 3 * matchingDocs),
Integer.MAX_VALUE,
new IVFKnnSearchStrategy(0.25f, null)
);
}
leafReader.searchNearestVectors(
"f",
vector,
random().nextInt(2 * matchingDocs, 3 * matchingDocs),
collector,
AcceptDocs.fromIteratorSupplier(
() -> leafReader.postings(new Term("k", new BytesRef("A"))),
leafReader.getLiveDocs(),
leafReader.maxDoc()
),
Integer.MAX_VALUE
)
);
TopDocs topDocs = collector.topDocs();
Set<Integer> uniqueDocIds = new HashSet<>();
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
uniqueDocIds.add(topDocs.scoreDocs[i].doc);
Expand Down