From df067164db3e3c52ebc101e732d495ec9fac8dea Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 12 Nov 2024 14:28:21 +0100 Subject: [PATCH 01/56] Use a FunctionScoreQuery to replace scores using a VectorSimilarity based DoubleValueSource --- .../org/elasticsearch/TransportVersions.java | 1 + .../vectors/DenseVectorFieldMapper.java | 88 +++++++++++++++-- .../VectorSimilarityByteValueSource.java | 95 +++++++++++++++++++ .../VectorSimilarityFloatValueSource.java | 95 +++++++++++++++++++ .../search/vectors/KnnVectorQueryBuilder.java | 72 ++++++++++++-- .../vectors/DenseVectorFieldMapperTests.java | 40 ++++++-- .../vectors/DenseVectorFieldTypeTests.java | 30 +++--- 7 files changed, 388 insertions(+), 33 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java create mode 100644 server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 3134eb4966115..dad1d43dc5f4c 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -190,6 +190,7 @@ static TransportVersion def(int id) { public static final TransportVersion LOGSDB_TELEMETRY_STATS = def(8_785_00_0); public static final TransportVersion KQL_QUERY_ADDED = def(8_786_00_0); public static final TransportVersion ROLE_MONITOR_STATS = def(8_787_00_0); + public static final TransportVersion KNN_QUERY_RESCORE_OVERSAMPLE = def(8_788_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index dea9368a9377e..fd2b3c19316ac 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -30,6 +30,7 @@ import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.queries.function.FunctionScoreQuery; import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; @@ -121,6 +122,8 @@ public static boolean isNotUnitVector(float magnitude) { public static short MAX_DIMS_COUNT = 4096; // maximum allowed number of dimensions public static int MAX_DIMS_COUNT_BIT = 4096 * Byte.SIZE; // maximum allowed number of dimensions + public static final int OVERSAMPLE_LIMIT = 10_000; // Max oversample allowed for k and num_candidates + public static short MIN_DIMS_FOR_DYNAMIC_FLOAT_MAPPING = 128; // minimum number of dims for floats to be dynamically mapped to vector public static final int MAGNITUDE_BYTES = 4; @@ -2000,6 +2003,7 @@ public Query createKnnQuery( VectorData queryVector, Integer k, int numCands, + Float rescoreOversample, Query filter, Float similarityThreshold, BitSetProducer parentFilter @@ -2010,9 +2014,33 @@ public Query createKnnQuery( ); } return switch (getElementType()) { - case BYTE -> createKnnByteQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter); - case FLOAT -> createKnnFloatQuery(queryVector.asFloatVector(), k, numCands, filter, similarityThreshold, parentFilter); - case BIT -> createKnnBitQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter); + case BYTE -> createKnnByteQuery( + queryVector.asByteVector(), + k, + numCands, + filter, + rescoreOversample, + similarityThreshold, + parentFilter + ); + case FLOAT -> createKnnFloatQuery( + queryVector.asFloatVector(), + k, + numCands, + rescoreOversample, + filter, + similarityThreshold, + parentFilter + ); + case BIT -> createKnnBitQuery( + queryVector.asByteVector(), + k, + numCands, + rescoreOversample, + filter, + similarityThreshold, + parentFilter + ); }; } @@ -2020,11 +2048,16 @@ private Query createKnnBitQuery( byte[] queryVector, Integer k, int numCands, + Float rescoreOversample, Query filter, Float similarityThreshold, BitSetProducer parentFilter ) { elementType.checkDimensions(dims, queryVector.length); + if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { + float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); + elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude); + } Query knnQuery = parentFilter != null ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter) : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter); @@ -2035,6 +2068,17 @@ private Query createKnnBitQuery( similarity.score(similarityThreshold, elementType, dims) ); } + if (rescoreOversample != null) { + knnQuery = new FunctionScoreQuery( + knnQuery, + new VectorSimilarityByteValueSource( + name(), + queryVector, + similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.BYTE) + ) + ); + + } return knnQuery; } @@ -2043,6 +2087,7 @@ private Query createKnnByteQuery( Integer k, int numCands, Query filter, + Float rescoreOversample, Float similarityThreshold, BitSetProducer parentFilter ) { @@ -2052,9 +2097,12 @@ private Query createKnnByteQuery( float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude); } + int adjustedK = rescoreOversample == null ? k : Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample)); + int adjustedNumCands = Math.max(adjustedK, numCands); + Query knnQuery = parentFilter != null - ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter) - : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter); + ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter) + : new ESKnnByteVectorQuery(name(), queryVector, adjustedK, adjustedNumCands, filter); if (similarityThreshold != null) { knnQuery = new VectorSimilarityQuery( knnQuery, @@ -2062,6 +2110,17 @@ private Query createKnnByteQuery( similarity.score(similarityThreshold, elementType, dims) ); } + if (rescoreOversample != null) { + knnQuery = new FunctionScoreQuery( + knnQuery, + new VectorSimilarityByteValueSource( + name(), + queryVector, + similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.BYTE) + ) + ); + + } return knnQuery; } @@ -2069,6 +2128,7 @@ private Query createKnnFloatQuery( float[] queryVector, Integer k, int numCands, + Float rescoreOversample, Query filter, Float similarityThreshold, BitSetProducer parentFilter @@ -2088,9 +2148,12 @@ && isNotUnitVector(squaredMagnitude)) { } } } + + int adjustedK = rescoreOversample == null ? k : Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample)); + int adjustedNumCands = Math.max(adjustedK, numCands); Query knnQuery = parentFilter != null - ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter) - : new ESKnnFloatVectorQuery(name(), queryVector, k, numCands, filter); + ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter) + : new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, adjustedNumCands, filter); if (similarityThreshold != null) { knnQuery = new VectorSimilarityQuery( knnQuery, @@ -2098,6 +2161,17 @@ && isNotUnitVector(squaredMagnitude)) { similarity.score(similarityThreshold, elementType, dims) ); } + if (rescoreOversample != null) { + knnQuery = new FunctionScoreQuery( + knnQuery, + new VectorSimilarityFloatValueSource( + name(), + queryVector, + similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT) + ) + ); + + } return knnQuery; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java new file mode 100644 index 0000000000000..661ae2ac2fcd4 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java @@ -0,0 +1,95 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.mapper.vectors; + +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.DoubleValues; +import org.apache.lucene.search.DoubleValuesSource; +import org.apache.lucene.search.IndexSearcher; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; + +public class VectorSimilarityByteValueSource extends DoubleValuesSource { + + private final String field; + private final byte[] target; + private final VectorSimilarityFunction vectorSimilarityFunction; + + public VectorSimilarityByteValueSource(String field, byte[] target, VectorSimilarityFunction vectorSimilarityFunction) { + this.field = field; + this.target = target; + this.vectorSimilarityFunction = vectorSimilarityFunction; + } + + @Override + public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException { + final LeafReader reader = ctx.reader(); + + ByteVectorValues vectorValues = reader.getByteVectorValues(field); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + + return new DoubleValues() { + private int docId = -1; + + @Override + public double doubleValue() throws IOException { + return vectorSimilarityFunction.compare(target, vectorValues.vectorValue(docId)); + } + + @Override + public boolean advanceExact(int doc) throws IOException { + docId = doc; + return iterator.advance(docId) != DocIdSetIterator.NO_MORE_DOCS; + } + }; + } + + @Override + public boolean needsScores() { + return false; + } + + @Override + public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException { + return this; + } + + @Override + public int hashCode() { + return Objects.hash(field, Arrays.hashCode(target), vectorSimilarityFunction); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + VectorSimilarityByteValueSource that = (VectorSimilarityByteValueSource) o; + return Objects.equals(field, that.field) + && Objects.deepEquals(target, that.target) + && vectorSimilarityFunction == that.vectorSimilarityFunction; + } + + @Override + public String toString() { + return "VectorSimilarityByteValueSource(" + field + ", " + Arrays.toString(target) + ", " + vectorSimilarityFunction + ")"; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java new file mode 100644 index 0000000000000..13d42272f4744 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java @@ -0,0 +1,95 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.mapper.vectors; + +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.DoubleValues; +import org.apache.lucene.search.DoubleValuesSource; +import org.apache.lucene.search.IndexSearcher; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; + +public class VectorSimilarityFloatValueSource extends DoubleValuesSource { + + private final String field; + private final float[] target; + private final VectorSimilarityFunction vectorSimilarityFunction; + + public VectorSimilarityFloatValueSource(String field, float[] target, VectorSimilarityFunction vectorSimilarityFunction) { + this.field = field; + this.target = target; + this.vectorSimilarityFunction = vectorSimilarityFunction; + } + + @Override + public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException { + final LeafReader reader = ctx.reader(); + + FloatVectorValues vectorValues = reader.getFloatVectorValues(field); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + + return new DoubleValues() { + private int docId = -1; + + @Override + public double doubleValue() throws IOException { + return vectorSimilarityFunction.compare(target, vectorValues.vectorValue(docId)); + } + + @Override + public boolean advanceExact(int doc) throws IOException { + docId = doc; + return iterator.advance(docId) != DocIdSetIterator.NO_MORE_DOCS; + } + }; + } + + @Override + public boolean needsScores() { + return false; + } + + @Override + public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException { + return this; + } + + @Override + public int hashCode() { + return Objects.hash(field, Arrays.hashCode(target), vectorSimilarityFunction); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + VectorSimilarityFloatValueSource that = (VectorSimilarityFloatValueSource) o; + return Objects.equals(field, that.field) + && Objects.deepEquals(target, that.target) + && vectorSimilarityFunction == that.vectorSimilarityFunction; + } + + @Override + public String toString() { + return "VectorSimilarityFloatValueSource(" + field + ", " + Arrays.toString(target) + ", " + vectorSimilarityFunction + ")"; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index deb7e6bd035b8..6c41374818611 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -45,6 +45,7 @@ import java.util.Objects; import java.util.function.Supplier; +import static org.elasticsearch.TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE; import static org.elasticsearch.common.Strings.format; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; @@ -66,10 +67,10 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder PARSER = new ConstructingObjectParser<>( "knn", args -> new KnnVectorQueryBuilder( @@ -79,7 +80,8 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder p.namedObject(QueryVectorBuilder.class, n, c), QUERY_VECTOR_BUILDER_FIELD ); + PARSER.declareFloat(optionalConstructorArg(), RESCORE_VECTOR_OVERSAMPLE); PARSER.declareFieldArray( KnnVectorQueryBuilder::addFilterQueries, (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p), @@ -115,14 +118,15 @@ public static KnnVectorQueryBuilder fromXContent(XContentParser parser) { private final String fieldName; private final VectorData queryVector; private final Integer k; - private Integer numCands; + private final Integer numCands; private final List filterQueries = new ArrayList<>(); private final Float vectorSimilarity; private final QueryVectorBuilder queryVectorBuilder; private final Supplier queryVectorSupplier; + private final Float rescoreOversample; public KnnVectorQueryBuilder(String fieldName, float[] queryVector, Integer k, Integer numCands, Float vectorSimilarity) { - this(fieldName, VectorData.fromFloats(queryVector), null, null, k, numCands, vectorSimilarity); + this(fieldName, VectorData.fromFloats(queryVector), null, null, k, numCands, vectorSimilarity, null); } public KnnVectorQueryBuilder( @@ -132,7 +136,7 @@ public KnnVectorQueryBuilder( Integer numCands, Float vectorSimilarity ) { - this(fieldName, null, queryVectorBuilder, null, k, numCands, vectorSimilarity); + this(fieldName, null, queryVectorBuilder, null, k, numCands, vectorSimilarity, null); } public KnnVectorQueryBuilder(String fieldName, byte[] queryVector, Integer k, Integer numCands, Float vectorSimilarity) { @@ -151,6 +155,19 @@ private KnnVectorQueryBuilder( Integer k, Integer numCands, Float vectorSimilarity + ) { + this(fieldName, queryVector, null, null, k, numCands, vectorSimilarity, null); + } + + private KnnVectorQueryBuilder( + String fieldName, + VectorData queryVector, + QueryVectorBuilder queryVectorBuilder, + Supplier queryVectorSupplier, + Integer k, + Integer numCands, + Float vectorSimilarity, + Float rescoreOversample ) { if (k != null && k < 1) { throw new IllegalArgumentException("[" + K_FIELD.getPreferredName() + "] must be greater than 0"); @@ -187,6 +204,7 @@ private KnnVectorQueryBuilder( this.vectorSimilarity = vectorSimilarity; this.queryVectorBuilder = queryVectorBuilder; this.queryVectorSupplier = queryVectorSupplier; + this.rescoreOversample = rescoreOversample; } public KnnVectorQueryBuilder(StreamInput in) throws IOException { @@ -227,6 +245,12 @@ public KnnVectorQueryBuilder(StreamInput in) throws IOException { } else { this.queryVectorBuilder = null; } + if (in.getTransportVersion().onOrAfter(KNN_QUERY_RESCORE_OVERSAMPLE)) { + this.rescoreOversample = in.readOptionalFloat(); + } else { + this.rescoreOversample = null; + } + this.queryVectorSupplier = null; } @@ -252,6 +276,10 @@ public Integer numCands() { return numCands; } + public Float rescoreOversample() { + return rescoreOversample; + } + public List filterQueries() { return filterQueries; } @@ -327,6 +355,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0)) { out.writeOptionalNamedWriteable(queryVectorBuilder); } + if (out.getTransportVersion().onOrAfter(KNN_QUERY_RESCORE_OVERSAMPLE)) { + out.writeOptionalFloat(rescoreOversample); + } } @Override @@ -425,6 +456,11 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { return this; } + @Override + protected QueryBuilder doIndexMetadataRewrite(QueryRewriteContext context) throws IOException { + return super.doIndexMetadataRewrite(context); + } + @Override protected Query doToQuery(SearchExecutionContext context) throws IOException { MappedFieldType fieldType = context.getFieldType(fieldName); @@ -491,14 +527,31 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { // Now join the filterQuery & parentFilter to provide the matching blocks of children filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet); } - return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, filterQuery, vectorSimilarity, parentBitSet); + return vectorFieldType.createKnnQuery( + queryVector, + k, + adjustedNumCands, + rescoreOversample, + filterQuery, + vectorSimilarity, + parentBitSet + ); } - return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, filterQuery, vectorSimilarity, null); + return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, rescoreOversample, filterQuery, vectorSimilarity, null); } @Override protected int doHashCode() { - return Objects.hash(fieldName, Objects.hashCode(queryVector), k, numCands, filterQueries, vectorSimilarity, queryVectorBuilder); + return Objects.hash( + fieldName, + Objects.hashCode(queryVector), + k, + numCands, + filterQueries, + vectorSimilarity, + queryVectorBuilder, + rescoreOversample + ); } @Override @@ -509,7 +562,8 @@ protected boolean doEquals(KnnVectorQueryBuilder other) { && Objects.equals(numCands, other.numCands) && Objects.equals(filterQueries, other.filterQueries) && Objects.equals(vectorSimilarity, other.vectorSimilarity) - && Objects.equals(queryVectorBuilder, other.queryVectorBuilder); + && Objects.equals(queryVectorBuilder, other.queryVectorBuilder) + && Objects.equals(rescoreOversample, other.rescoreOversample); } @Override diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index de084cd4582e2..1d949c657c2f2 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -1674,7 +1674,7 @@ public void testByteVectorQueryBoundaries() throws IOException { Exception e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 128, 0, 0 }), 3, 3, null, null, null, null) ); assertThat( e.getMessage(), @@ -1683,7 +1683,15 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0f, -129.0f }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery( + VectorData.fromFloats(new float[] { 0.0f, 0f, -129.0f }), + 3, + 3, + null, + null, + null, + null + ) ); assertThat( e.getMessage(), @@ -1692,7 +1700,7 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.5f, 0.0f }), 3, 3, null, null, null, null) ); assertThat( e.getMessage(), @@ -1701,7 +1709,7 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { 0, 0.0f, -0.25f }), 3, 3, null, null, null, null) ); assertThat( e.getMessage(), @@ -1710,7 +1718,15 @@ public void testByteVectorQueryBoundaries() throws IOException { e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery( + VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), + 3, + 3, + null, + null, + null, + null + ) ); assertThat(e.getMessage(), containsString("element_type [byte] vectors do not support NaN values but found [NaN] at dim [0];")); @@ -1722,6 +1738,7 @@ public void testByteVectorQueryBoundaries() throws IOException { 3, null, null, + null, null ) ); @@ -1738,6 +1755,7 @@ public void testByteVectorQueryBoundaries() throws IOException { 3, null, null, + null, null ) ); @@ -1765,7 +1783,15 @@ public void testFloatVectorQueryBoundaries() throws IOException { Exception e = expectThrows( IllegalArgumentException.class, - () -> denseVectorFieldType.createKnnQuery(VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), 3, 3, null, null, null) + () -> denseVectorFieldType.createKnnQuery( + VectorData.fromFloats(new float[] { Float.NaN, 0f, 0.0f }), + 3, + 3, + null, + null, + null, + null + ) ); assertThat(e.getMessage(), containsString("element_type [float] vectors do not support NaN values but found [NaN] at dim [0];")); @@ -1777,6 +1803,7 @@ public void testFloatVectorQueryBoundaries() throws IOException { 3, null, null, + null, null ) ); @@ -1793,6 +1820,7 @@ public void testFloatVectorQueryBoundaries() throws IOException { 3, null, null, + null, null ) ); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 9e819f38eae6e..e6ccd05c8f6e5 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -170,7 +170,7 @@ public void testCreateNestedKnnQuery() { for (int i = 0; i < dims; i++) { queryVector[i] = randomFloat(); } - Query query = field.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, producer); + Query query = field.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class)); } { @@ -191,11 +191,11 @@ public void testCreateNestedKnnQuery() { floatQueryVector[i] = queryVector[i]; } VectorData vectorData = new VectorData(null, queryVector); - Query query = field.createKnnQuery(vectorData, 10, 10, null, null, producer); + Query query = field.createKnnQuery(vectorData, 10, 10, null, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); vectorData = new VectorData(floatQueryVector, null); - query = field.createKnnQuery(vectorData, 10, 10, null, null, producer); + query = field.createKnnQuery(vectorData, 10, 10, null, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); } } @@ -256,7 +256,15 @@ public void testFloatCreateKnnQuery() { ); IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> unindexedField.createKnnQuery(VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f, 0.0f }), 10, 10, null, null, null) + () -> unindexedField.createKnnQuery( + VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f, 0.0f }), + 10, + 10, + null, + null, + null, + null + ) ); assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]")); @@ -276,7 +284,7 @@ public void testFloatCreateKnnQuery() { } e = expectThrows( IllegalArgumentException.class, - () -> dotProductField.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null) + () -> dotProductField.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors.")); @@ -292,7 +300,7 @@ public void testFloatCreateKnnQuery() { ); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[BBQ_MIN_DIMS]), 10, 10, null, null, null) + () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[BBQ_MIN_DIMS]), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); } @@ -313,7 +321,7 @@ public void testCreateKnnQueryMaxDims() { for (int i = 0; i < 4096; i++) { queryVector[i] = randomFloat(); } - Query query = fieldWith4096dims.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null); + Query query = fieldWith4096dims.createKnnQuery(VectorData.fromFloats(queryVector), 10, 10, null, null, null, null); assertThat(query, instanceOf(KnnFloatVectorQuery.class)); } @@ -333,7 +341,7 @@ public void testCreateKnnQueryMaxDims() { queryVector[i] = randomByte(); } VectorData vectorData = new VectorData(null, queryVector); - Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, 10, null, null, null); + Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, 10, null, null, null, null); assertThat(query, instanceOf(KnnByteVectorQuery.class)); } } @@ -351,7 +359,7 @@ public void testByteCreateKnnQuery() { ); IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> unindexedField.createKnnQuery(VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f }), 10, 10, null, null, null) + () -> unindexedField.createKnnQuery(VectorData.fromFloats(new float[] { 0.3f, 0.1f, 1.0f }), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]")); @@ -367,13 +375,13 @@ public void testByteCreateKnnQuery() { ); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.0f, 0.0f }), 10, 10, null, null, null) + () -> cosineField.createKnnQuery(VectorData.fromFloats(new float[] { 0.0f, 0.0f, 0.0f }), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); e = expectThrows( IllegalArgumentException.class, - () -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, 10, null, null, null) + () -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, 10, null, null, null, null) ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); } From be7644487a4e4a1786ed6952b7601c61d1fcf7a5 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 13 Nov 2024 17:36:28 +0100 Subject: [PATCH 02/56] Change API to use "rescore": {"oversample": 1.0} --- .../retriever/RetrieverTelemetryIT.java | 2 +- .../vectors/DenseVectorFieldMapper.java | 12 +- .../search/vectors/KnnSearchBuilder.java | 2 +- .../vectors/KnnSearchRequestParser.java | 2 +- .../search/vectors/KnnVectorQueryBuilder.java | 110 ++++++++++++------ .../search/vectors/RescoreVectorBuilder.java | 87 ++++++++++++++ .../mapper/SemanticTextFieldMapper.java | 2 +- 7 files changed, 171 insertions(+), 46 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java index 537ace30e88f0..73debb6c9b985 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java @@ -98,7 +98,7 @@ public void testTelemetryForRetrievers() throws IOException { { performSearch( new SearchSourceBuilder().retriever( - new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null)) + new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null, null)) ) ); } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index fd2b3c19316ac..69dbdce5832a3 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2097,8 +2097,10 @@ private Query createKnnByteQuery( float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude); } - int adjustedK = rescoreOversample == null ? k : Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample)); - int adjustedNumCands = Math.max(adjustedK, numCands); + Integer adjustedK = k == null || rescoreOversample == null + ? null + : Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample)); + int adjustedNumCands = Math.max(adjustedK == null ? 0 : adjustedK, numCands); Query knnQuery = parentFilter != null ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter) @@ -2149,8 +2151,10 @@ && isNotUnitVector(squaredMagnitude)) { } } - int adjustedK = rescoreOversample == null ? k : Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample)); - int adjustedNumCands = Math.max(adjustedK, numCands); + Integer adjustedK = k == null || rescoreOversample == null + ? k + : Integer.valueOf(Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample))); + int adjustedNumCands = adjustedK == null ? numCands : Math.max(adjustedK, numCands); Query knnQuery = parentFilter != null ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter) : new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, adjustedNumCands, filter); diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java index 41673a0e7edb0..90b89d9ff1a13 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java @@ -407,7 +407,7 @@ public KnnVectorQueryBuilder toQueryBuilder() { if (queryVectorBuilder != null) { throw new IllegalArgumentException("missing rewrite"); } - return new KnnVectorQueryBuilder(field, queryVector, null, numCands, similarity).boost(boost) + return new KnnVectorQueryBuilder(field, queryVector, null, numCands, null, similarity).boost(boost) .queryName(queryName) .addFilterQueries(filterQueries); } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java index a28448336ab3f..81b00f1329591 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java @@ -256,7 +256,7 @@ public KnnVectorQueryBuilder toQueryBuilder() { if (numCands > NUM_CANDS_LIMIT) { throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]"); } - return new KnnVectorQueryBuilder(field, queryVector, null, numCands, null); + return new KnnVectorQueryBuilder(field, queryVector, null, numCands, null, null); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index 6c41374818611..b250f6484bda6 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -67,9 +67,9 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder PARSER = new ConstructingObjectParser<>( "knn", @@ -80,8 +80,8 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder p.namedObject(QueryVectorBuilder.class, n, c), QUERY_VECTOR_BUILDER_FIELD ); - PARSER.declareFloat(optionalConstructorArg(), RESCORE_VECTOR_OVERSAMPLE); + PARSER.declareField( + optionalConstructorArg(), + (p, c) -> RescoreVectorBuilder.fromXContent(p), + RESCORE_FIELD, + ObjectParser.ValueType.OBJECT_OR_NULL + ); PARSER.declareFieldArray( KnnVectorQueryBuilder::addFilterQueries, (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p), @@ -123,10 +128,17 @@ public static KnnVectorQueryBuilder fromXContent(XContentParser parser) { private final Float vectorSimilarity; private final QueryVectorBuilder queryVectorBuilder; private final Supplier queryVectorSupplier; - private final Float rescoreOversample; + private final RescoreVectorBuilder rescoreVectorBuilder; - public KnnVectorQueryBuilder(String fieldName, float[] queryVector, Integer k, Integer numCands, Float vectorSimilarity) { - this(fieldName, VectorData.fromFloats(queryVector), null, null, k, numCands, vectorSimilarity, null); + public KnnVectorQueryBuilder( + String fieldName, + float[] queryVector, + Integer k, + Integer numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float vectorSimilarity + ) { + this(fieldName, VectorData.fromFloats(queryVector), null, null, k, numCands, rescoreVectorBuilder, vectorSimilarity); } public KnnVectorQueryBuilder( @@ -136,27 +148,29 @@ public KnnVectorQueryBuilder( Integer numCands, Float vectorSimilarity ) { - this(fieldName, null, queryVectorBuilder, null, k, numCands, vectorSimilarity, null); + this(fieldName, null, queryVectorBuilder, null, k, numCands, null, vectorSimilarity); } - public KnnVectorQueryBuilder(String fieldName, byte[] queryVector, Integer k, Integer numCands, Float vectorSimilarity) { - this(fieldName, VectorData.fromBytes(queryVector), null, null, k, numCands, vectorSimilarity); - } - - public KnnVectorQueryBuilder(String fieldName, VectorData queryVector, Integer k, Integer numCands, Float vectorSimilarity) { - this(fieldName, queryVector, null, null, k, numCands, vectorSimilarity); + public KnnVectorQueryBuilder( + String fieldName, + byte[] queryVector, + Integer k, + Integer numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float vectorSimilarity + ) { + this(fieldName, VectorData.fromBytes(queryVector), null, null, k, numCands, rescoreVectorBuilder, vectorSimilarity); } - private KnnVectorQueryBuilder( + public KnnVectorQueryBuilder( String fieldName, VectorData queryVector, - QueryVectorBuilder queryVectorBuilder, - Supplier queryVectorSupplier, Integer k, Integer numCands, + RescoreVectorBuilder rescoreVectorBuilder, Float vectorSimilarity ) { - this(fieldName, queryVector, null, null, k, numCands, vectorSimilarity, null); + this(fieldName, queryVector, null, null, k, numCands, rescoreVectorBuilder, vectorSimilarity); } private KnnVectorQueryBuilder( @@ -166,8 +180,8 @@ private KnnVectorQueryBuilder( Supplier queryVectorSupplier, Integer k, Integer numCands, - Float vectorSimilarity, - Float rescoreOversample + RescoreVectorBuilder rescoreVectorBuilder, + Float vectorSimilarity ) { if (k != null && k < 1) { throw new IllegalArgumentException("[" + K_FIELD.getPreferredName() + "] must be greater than 0"); @@ -204,7 +218,7 @@ private KnnVectorQueryBuilder( this.vectorSimilarity = vectorSimilarity; this.queryVectorBuilder = queryVectorBuilder; this.queryVectorSupplier = queryVectorSupplier; - this.rescoreOversample = rescoreOversample; + this.rescoreVectorBuilder = rescoreVectorBuilder; } public KnnVectorQueryBuilder(StreamInput in) throws IOException { @@ -246,9 +260,9 @@ public KnnVectorQueryBuilder(StreamInput in) throws IOException { this.queryVectorBuilder = null; } if (in.getTransportVersion().onOrAfter(KNN_QUERY_RESCORE_OVERSAMPLE)) { - this.rescoreOversample = in.readOptionalFloat(); + this.rescoreVectorBuilder = in.readOptional(RescoreVectorBuilder::new); } else { - this.rescoreOversample = null; + this.rescoreVectorBuilder = null; } this.queryVectorSupplier = null; @@ -276,10 +290,6 @@ public Integer numCands() { return numCands; } - public Float rescoreOversample() { - return rescoreOversample; - } - public List filterQueries() { return filterQueries; } @@ -289,6 +299,10 @@ public QueryVectorBuilder queryVectorBuilder() { return queryVectorBuilder; } + public RescoreVectorBuilder rescoreVectorBuilder() { + return rescoreVectorBuilder; + } + public KnnVectorQueryBuilder addFilterQuery(QueryBuilder filterQuery) { Objects.requireNonNull(filterQuery); this.filterQueries.add(filterQuery); @@ -356,7 +370,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeOptionalNamedWriteable(queryVectorBuilder); } if (out.getTransportVersion().onOrAfter(KNN_QUERY_RESCORE_OVERSAMPLE)) { - out.writeOptionalFloat(rescoreOversample); + out.writeOptionalWriteable(rescoreVectorBuilder); } } @@ -391,6 +405,11 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep } builder.endArray(); } + if (rescoreVectorBuilder != null) { + builder.startObject(RESCORE_FIELD.getPreferredName()); + rescoreVectorBuilder.toXContent(builder, params); + builder.endObject(); + } boostAndQueryNameToXContent(builder); builder.endObject(); } @@ -406,7 +425,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { if (queryVectorSupplier.get() == null) { return this; } - return new KnnVectorQueryBuilder(fieldName, queryVectorSupplier.get(), k, numCands, vectorSimilarity).boost(boost) + return new KnnVectorQueryBuilder(fieldName, queryVectorSupplier.get(), k, numCands, rescoreVectorBuilder, vectorSimilarity) + .boost(boost) .queryName(queryName) .addFilterQueries(filterQueries); } @@ -428,9 +448,16 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { } ll.onResponse(null); }))); - return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, toSet::get, k, numCands, vectorSimilarity).boost( - boost - ).queryName(queryName).addFilterQueries(filterQueries); + return new KnnVectorQueryBuilder( + fieldName, + queryVector, + queryVectorBuilder, + toSet::get, + k, + numCands, + rescoreVectorBuilder, + vectorSimilarity + ).boost(boost).queryName(queryName).addFilterQueries(filterQueries); } if (ctx.convertToInnerHitsRewriteContext() != null) { return new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity).boost(boost).queryName(queryName); @@ -448,10 +475,16 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { rewrittenQueries.add(rewrittenQuery); } if (changed) { - return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, queryVectorSupplier, k, numCands, vectorSimilarity) - .boost(boost) - .queryName(queryName) - .addFilterQueries(rewrittenQueries); + return new KnnVectorQueryBuilder( + fieldName, + queryVector, + queryVectorBuilder, + queryVectorSupplier, + k, + numCands, + rescoreVectorBuilder, + vectorSimilarity + ).boost(boost).queryName(queryName).addFilterQueries(rewrittenQueries); } return this; } @@ -495,6 +528,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) fieldType; String parentPath = context.nestedLookup().getNestedParent(fieldName); + Float rescoreOversample = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.oversample(); if (parentPath != null) { final BitSetProducer parentBitSet; @@ -550,7 +584,7 @@ protected int doHashCode() { filterQueries, vectorSimilarity, queryVectorBuilder, - rescoreOversample + rescoreVectorBuilder ); } @@ -563,7 +597,7 @@ protected boolean doEquals(KnnVectorQueryBuilder other) { && Objects.equals(filterQueries, other.filterQueries) && Objects.equals(vectorSimilarity, other.vectorSimilarity) && Objects.equals(queryVectorBuilder, other.queryVectorBuilder) - && Objects.equals(rescoreOversample, other.rescoreOversample); + && Objects.equals(rescoreVectorBuilder, other.rescoreVectorBuilder); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java new file mode 100644 index 0000000000000..26c151b5d2f72 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java @@ -0,0 +1,87 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class RescoreVectorBuilder implements Writeable, ToXContentObject { + + public static final ParseField OVERSAMPLE_FIELD = new ParseField("oversample"); + public static final int MIN_OVERSAMPLE = 1; + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "rescore", + args -> new RescoreVectorBuilder((Float) args[0]) + ); + + static { + PARSER.declareFloat(ConstructingObjectParser.optionalConstructorArg(), OVERSAMPLE_FIELD); + } + + // Oversample is required as of now as it is the only field in the rescore vector + // that may change in the future, so we treat it as optional + private final Float oversample; + + public RescoreVectorBuilder(Float oversample) { + Objects.requireNonNull(oversample, "[" + OVERSAMPLE_FIELD.getPreferredName() + "] must be set"); + if (oversample <= MIN_OVERSAMPLE) { + throw new IllegalArgumentException("[" + OVERSAMPLE_FIELD.getPreferredName() + "] must be > " + MIN_OVERSAMPLE); + } + this.oversample = oversample; + } + + public RescoreVectorBuilder(StreamInput in) throws IOException { + this.oversample = in.readOptionalFloat(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalFloat(oversample); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (oversample != null) { + builder.field(OVERSAMPLE_FIELD.getPreferredName(), oversample); + } + + return builder; + } + + public static RescoreVectorBuilder fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RescoreVectorBuilder that = (RescoreVectorBuilder) o; + return Objects.equals(oversample, that.oversample); + } + + @Override + public int hashCode() { + return Objects.hashCode(oversample); + } + + public Float oversample() { + return oversample; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index f0cb612c9082f..c04ca72b0e643 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -556,7 +556,7 @@ public QueryBuilder semanticQuery(InferenceResults inferenceResults, float boost ); } - yield new KnnVectorQueryBuilder(inferenceResultsFieldName, inference, null, null, null); + yield new KnnVectorQueryBuilder(inferenceResultsFieldName, inference, null, null, null,null); } default -> throw new IllegalStateException( "Field [" From bd920c5ed4390de88ac43a96f1caea52ae18a5bc Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 13 Nov 2024 17:36:55 +0100 Subject: [PATCH 03/56] Add tests --- .../search/KnnSearchSingleNodeTests.java | 4 +- .../index/query/NestedQueryBuilderTests.java | 1 + ...AbstractKnnVectorQueryBuilderTestCase.java | 109 +++++++++++++----- .../KnnByteVectorQueryBuilderTests.java | 10 +- .../KnnFloatVectorQueryBuilderTests.java | 10 +- .../search/vectors/KnnSearchBuilderTests.java | 2 +- 6 files changed, 101 insertions(+), 35 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java b/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java index 042890001c2ea..a52e3bc910bc2 100644 --- a/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java @@ -417,7 +417,9 @@ public void testKnnSearchAction() throws IOException { // how the action works (it builds a kNN query under the hood) float[] queryVector = randomVector(); assertResponse( - client().prepareSearch("index1", "index2").setQuery(new KnnVectorQueryBuilder("vector", queryVector, null, 5, null)).setSize(2), + client().prepareSearch("index1", "index2") + .setQuery(new KnnVectorQueryBuilder("vector", queryVector, null, 5, null, null)) + .setSize(2), response -> { // The total hits is num_cands * num_shards, since the query gathers num_cands hits from each shard assertHitCount(response, 5 * 2); diff --git a/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java index 6076665e26824..7f4f95cdd2416 100644 --- a/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java @@ -270,6 +270,7 @@ public void testKnnRewriteForInnerHits() throws IOException { new float[] { 1.0f, 2.0f, 3.0f }, null, 1, + null, null ); NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder( diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index f93bdd14f0645..d603ad7e39b1f 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -42,6 +42,7 @@ import java.util.ArrayList; import java.util.List; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -56,7 +57,13 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa abstract DenseVectorFieldMapper.ElementType elementType(); - abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, Integer k, int numCands, Float similarity); + abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder( + String fieldName, + Integer k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float similarity + ); @Override protected void initializeAdditionalMappings(MapperService mapperService) throws IOException { @@ -88,7 +95,13 @@ protected KnnVectorQueryBuilder doCreateTestQueryBuilder() { String fieldName = randomBoolean() ? VECTOR_FIELD : VECTOR_ALIAS_FIELD; Integer k = randomBoolean() ? null : randomIntBetween(1, 100); int numCands = randomIntBetween(k == null ? DEFAULT_SIZE : k + 20, 1000); - KnnVectorQueryBuilder queryBuilder = createKnnVectorQueryBuilder(fieldName, k, numCands, randomFloat()); + KnnVectorQueryBuilder queryBuilder = createKnnVectorQueryBuilder( + fieldName, + k, + numCands, + randomRescoreVectorBuilder(), + randomFloat() + ); if (randomBoolean()) { List filters = new ArrayList<>(); @@ -99,11 +112,24 @@ protected KnnVectorQueryBuilder doCreateTestQueryBuilder() { } queryBuilder.addFilterQueries(filters); } + return queryBuilder; } + protected RescoreVectorBuilder randomRescoreVectorBuilder() { + if (randomBoolean()) { + return null; + } + + return new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); + } + @Override protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query query, SearchExecutionContext context) throws IOException { + if (queryBuilder.rescoreVectorBuilder() != null) { + assertTrue(query instanceof org.apache.lucene.queries.function.FunctionScoreQuery); + query = ((org.apache.lucene.queries.function.FunctionScoreQuery) query).getWrappedQuery(); + } if (queryBuilder.getVectorSimilarity() != null) { assertTrue(query instanceof VectorSimilarityQuery); Query knnQuery = ((VectorSimilarityQuery) query).getInnerKnnQuery(); @@ -126,21 +152,17 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que BooleanQuery booleanQuery = builder.build(); Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery; // The field should always be resolved to the concrete field + Integer k = queryBuilder.k(); + Integer numCands = queryBuilder.numCands(); + if (queryBuilder.rescoreVectorBuilder() != null) { + Float rescoreOversample = queryBuilder.rescoreVectorBuilder().oversample(); + k = k == null ? null : Integer.valueOf(Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample))); + numCands = numCands == null ? null : Math.max(k == null ? 0 : k, numCands); + } + Query knnVectorQueryBuilt = switch (elementType()) { - case BYTE, BIT -> new ESKnnByteVectorQuery( - VECTOR_FIELD, - queryBuilder.queryVector().asByteVector(), - queryBuilder.k(), - queryBuilder.numCands(), - filterQuery - ); - case FLOAT -> new ESKnnFloatVectorQuery( - VECTOR_FIELD, - queryBuilder.queryVector().asFloatVector(), - queryBuilder.k(), - queryBuilder.numCands(), - filterQuery - ); + case BYTE, BIT -> new ESKnnByteVectorQuery(VECTOR_FIELD, queryBuilder.queryVector().asByteVector(), k, numCands, filterQuery); + case FLOAT -> new ESKnnFloatVectorQuery(VECTOR_FIELD, queryBuilder.queryVector().asFloatVector(), k, numCands, filterQuery); }; if (query instanceof VectorSimilarityQuery vectorSimilarityQuery) { query = vectorSimilarityQuery.getInnerKnnQuery(); @@ -150,7 +172,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que public void testWrongDimension() { SearchExecutionContext context = createSearchExecutionContext(); - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 5, 10, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 5, 10, null, null); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context)); assertThat( e.getMessage(), @@ -160,7 +182,7 @@ public void testWrongDimension() { public void testNonexistentField() { SearchExecutionContext context = createSearchExecutionContext(); - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null, null); context.setAllowUnmappedFields(false); QueryShardException e = expectThrows(QueryShardException.class, () -> query.doToQuery(context)); assertThat(e.getMessage(), containsString("No field mapping can be found for the field with name [nonexistent]")); @@ -168,7 +190,7 @@ public void testNonexistentField() { public void testNonexistentFieldReturnEmpty() throws IOException { SearchExecutionContext context = createSearchExecutionContext(); - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null, null); Query queryNone = query.doToQuery(context); assertThat(queryNone, instanceOf(MatchNoDocsQuery.class)); } @@ -180,6 +202,7 @@ public void testWrongFieldType() { new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, + null, null ); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context)); @@ -191,14 +214,14 @@ public void testNumCandsLessThanK() { int numCands = 3; IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 1.0f, 1.0f }, k, numCands, null) + () -> new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 1.0f, 1.0f }, k, numCands, null, null) ); assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]")); } @Override public void testValidOutput() { - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, null, 10, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, null, 10, null, null); String expected = """ { "knn" : { @@ -213,7 +236,7 @@ public void testValidOutput() { }"""; assertEquals(expected, query.toString()); - KnnVectorQueryBuilder query2 = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, 5, 10, null); + KnnVectorQueryBuilder query2 = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, 5, 10, null, null); String expected2 = """ { "knn" : { @@ -240,6 +263,7 @@ public void testMustRewrite() throws IOException { new float[] { 1.0f, 2.0f, 3.0f }, VECTOR_DIMENSION, null, + null, null ); query.addFilterQuery(termQuery); @@ -254,9 +278,14 @@ public void testMustRewrite() throws IOException { public void testBWCVersionSerializationFilters() throws IOException { KnnVectorQueryBuilder query = createTestQueryBuilder(); VectorData vectorData = VectorData.fromFloats(query.queryVector().asFloatVector()); - KnnVectorQueryBuilder queryNoFilters = new KnnVectorQueryBuilder(query.getFieldName(), vectorData, null, query.numCands(), null) - .queryName(query.queryName()) - .boost(query.boost()); + KnnVectorQueryBuilder queryNoFilters = new KnnVectorQueryBuilder( + query.getFieldName(), + vectorData, + null, + query.numCands(), + null, + null + ).queryName(query.queryName()).boost(query.boost()); TransportVersion beforeFilterVersion = TransportVersionUtils.randomVersionBetween( random(), TransportVersions.V_8_0_0, @@ -268,10 +297,14 @@ public void testBWCVersionSerializationFilters() throws IOException { public void testBWCVersionSerializationSimilarity() throws IOException { KnnVectorQueryBuilder query = createTestQueryBuilder(); VectorData vectorData = VectorData.fromFloats(query.queryVector().asFloatVector()); - KnnVectorQueryBuilder queryNoSimilarity = new KnnVectorQueryBuilder(query.getFieldName(), vectorData, null, query.numCands(), null) - .queryName(query.queryName()) - .boost(query.boost()) - .addFilterQueries(query.filterQueries()); + KnnVectorQueryBuilder queryNoSimilarity = new KnnVectorQueryBuilder( + query.getFieldName(), + vectorData, + null, + query.numCands(), + null, + null + ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries()); assertBWCSerialization(query, queryNoSimilarity, TransportVersions.V_8_7_0); } @@ -289,11 +322,29 @@ public void testBWCVersionSerializationQuery() throws IOException { vectorData, null, query.numCands(), + null, similarity ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries()); assertBWCSerialization(query, queryOlderVersion, differentQueryVersion); } + public void testBWCVersionSerializationRescoreVector() throws IOException { + KnnVectorQueryBuilder query = createTestQueryBuilder(); + KnnVectorQueryBuilder queryNoRescoreVector = new KnnVectorQueryBuilder( + query.getFieldName(), + query.queryVector(), + query.k(), + query.numCands(), + null, + query.getVectorSimilarity() + ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries()); + assertBWCSerialization( + query, + queryNoRescoreVector, + TransportVersionUtils.randomVersionBetween(random(), TransportVersions.V_8_8_0, TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE) + ); + } + private void assertBWCSerialization(QueryBuilder newQuery, QueryBuilder bwcQuery, TransportVersion version) throws IOException { assertSerialization(bwcQuery, version); try (BytesStreamOutput output = new BytesStreamOutput()) { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java index 0fc2304e904a4..980e506c0ca35 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java @@ -18,11 +18,17 @@ DenseVectorFieldMapper.ElementType elementType() { } @Override - protected KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, Integer k, int numCands, Float similarity) { + protected KnnVectorQueryBuilder createKnnVectorQueryBuilder( + String fieldName, + Integer k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float similarity + ) { byte[] vector = new byte[VECTOR_DIMENSION]; for (int i = 0; i < vector.length; i++) { vector[i] = randomByte(); } - return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, similarity); + return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, rescoreVectorBuilder, similarity); } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java index ba2245ced3305..75b1f395c57e7 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java @@ -18,11 +18,17 @@ DenseVectorFieldMapper.ElementType elementType() { } @Override - KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, Integer k, int numCands, Float similarity) { + KnnVectorQueryBuilder createKnnVectorQueryBuilder( + String fieldName, + Integer k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float similarity + ) { float[] vector = new float[VECTOR_DIMENSION]; for (int i = 0; i < vector.length; i++) { vector[i] = randomFloat(); } - return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, similarity); + return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, rescoreVectorBuilder, similarity); } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java index 2184e8af54aed..3753e8ce874cb 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java @@ -167,7 +167,7 @@ public void testToQueryBuilder() { builder.addFilterQuery(filter); } - QueryBuilder expected = new KnnVectorQueryBuilder(field, vector, null, numCands, similarity).addFilterQueries(filterQueries) + QueryBuilder expected = new KnnVectorQueryBuilder(field, vector, null, numCands, null, similarity).addFilterQueries(filterQueries) .boost(boost); assertEquals(expected, builder.toQueryBuilder()); } From 91204a1bfa0ad3fee9f440ad9a010f0d4c1e5fc8 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 13 Nov 2024 17:47:01 +0100 Subject: [PATCH 04/56] Fix inference module --- .../xpack/inference/mapper/SemanticTextFieldMapper.java | 2 +- .../TextSimilarityRankRetrieverTelemetryTests.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index c04ca72b0e643..8aa53430110b2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -556,7 +556,7 @@ public QueryBuilder semanticQuery(InferenceResults inferenceResults, float boost ); } - yield new KnnVectorQueryBuilder(inferenceResultsFieldName, inference, null, null, null,null); + yield new KnnVectorQueryBuilder(inferenceResultsFieldName, inference, null, null, null, null); } default -> throw new IllegalStateException( "Field [" diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java index 916703446995d..9ec88e959b6c7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java @@ -116,7 +116,7 @@ public void testTelemetryForRRFRetriever() throws IOException { { performSearch( new SearchSourceBuilder().retriever( - new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null)) + new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null, null)) ) ); } From b44ec483280c6a0223ba09037a190c6bd7fff1ad Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 14 Nov 2024 20:28:34 +0100 Subject: [PATCH 05/56] Fix knn query usage in other modules --- .../elasticsearch/percolator/PercolatorQuerySearchIT.java | 2 +- .../xpack/rank/rrf/RRFRetrieverTelemetryIT.java | 2 +- .../integration/DocumentLevelSecurityTests.java | 2 +- .../elasticsearch/integration/FieldLevelSecurityTests.java | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/percolator/src/internalClusterTest/java/org/elasticsearch/percolator/PercolatorQuerySearchIT.java b/modules/percolator/src/internalClusterTest/java/org/elasticsearch/percolator/PercolatorQuerySearchIT.java index 05f456b7f2229..8a7f1405f8f4e 100644 --- a/modules/percolator/src/internalClusterTest/java/org/elasticsearch/percolator/PercolatorQuerySearchIT.java +++ b/modules/percolator/src/internalClusterTest/java/org/elasticsearch/percolator/PercolatorQuerySearchIT.java @@ -1359,7 +1359,7 @@ public void testKnnQueryNotSupportedInPercolator() throws IOException { """); indicesAdmin().prepareCreate("index1").setMapping(mappings).get(); ensureGreen(); - QueryBuilder knnVectorQueryBuilder = new KnnVectorQueryBuilder("my_vector", new float[] { 1, 1, 1, 1, 1 }, 10, 10, null); + QueryBuilder knnVectorQueryBuilder = new KnnVectorQueryBuilder("my_vector", new float[] { 1, 1, 1, 1, 1 }, 10, 10, null, null); IndexRequestBuilder indexRequestBuilder = prepareIndex("index1").setId("knn_query1") .setSource(jsonBuilder().startObject().field("my_query", knnVectorQueryBuilder).endObject()); diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java index 4eaea9a596361..f93c13682739f 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java @@ -117,7 +117,7 @@ public void testTelemetryForRRFRetriever() throws IOException { { performSearch( new SearchSourceBuilder().retriever( - new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null)) + new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null, null)) ) ); } diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java index 87ca7d279c709..12b75c787d6e9 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java @@ -884,7 +884,7 @@ public void testKnnSearch() throws Exception { // Since there's no kNN search action at the transport layer, we just emulate // how the action works (it builds a kNN query under the hood) float[] queryVector = new float[] { 0.0f, 0.0f, 0.0f }; - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 50, 50, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 50, 50, null, null); if (randomBoolean()) { query.addFilterQuery(new WildcardQueryBuilder("other", "value*")); diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/FieldLevelSecurityTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/FieldLevelSecurityTests.java index 66c8c0a5b1b52..6c7ba15b773ba 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/FieldLevelSecurityTests.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/FieldLevelSecurityTests.java @@ -441,7 +441,7 @@ public void testKnnSearch() throws IOException { // Since there's no kNN search action at the transport layer, we just emulate // how the action works (it builds a kNN query under the hood) float[] queryVector = new float[] { 0.0f, 0.0f, 0.0f }; - KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 10, 10, null); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 10, 10, null, null); // user1 has access to vector field, so the query should match with the document: assertResponse( @@ -475,7 +475,7 @@ public void testKnnSearch() throws IOException { } ); // user1 can access field1, so the filtered query should match with the document: - KnnVectorQueryBuilder filterQuery1 = new KnnVectorQueryBuilder("vector", queryVector, 10, 10, null).addFilterQuery( + KnnVectorQueryBuilder filterQuery1 = new KnnVectorQueryBuilder("vector", queryVector, 10, 10, null, null).addFilterQuery( QueryBuilders.matchQuery("field1", "value1") ); assertHitCount( @@ -486,7 +486,7 @@ public void testKnnSearch() throws IOException { ); // user1 cannot access field2, so the filtered query should not match with the document: - KnnVectorQueryBuilder filterQuery2 = new KnnVectorQueryBuilder("vector", queryVector, 10, 10, null).addFilterQuery( + KnnVectorQueryBuilder filterQuery2 = new KnnVectorQueryBuilder("vector", queryVector, 10, 10, null, null).addFilterQuery( QueryBuilders.matchQuery("field2", "value2") ); assertHitCount( From 2a9e30059447e7c272150162e9abc65a0cf6dcbf Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 14 Nov 2024 23:54:30 +0100 Subject: [PATCH 06/56] Add rescore vector builder to KnnSearchBuilder --- .../search/nested/VectorNestedIT.java | 2 +- .../search/profile/dfs/DfsProfilerIT.java | 2 + .../retriever/RetrieverTelemetryIT.java | 8 +- .../search/vectors/KnnSearchBuilder.java | 97 +++++++++++++-- .../action/search/DfsQueryPhaseTests.java | 4 +- .../search/KnnSearchSingleNodeTests.java | 20 ++-- .../action/search/SearchRequestTests.java | 48 ++++++-- .../search/TransportSearchActionTests.java | 4 +- .../action/search/RestSearchActionTests.java | 2 +- .../builder/SearchSourceBuilderTests.java | 2 +- .../search/vectors/KnnSearchBuilderTests.java | 113 ++++++++++++++---- .../search/RandomSearchRequestGenerator.java | 9 +- .../AbstractQueryVectorBuilderTestCase.java | 3 + .../xpack/rank/rrf/RRFRankBuilder.java | 3 +- 14 files changed, 249 insertions(+), 68 deletions(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/nested/VectorNestedIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/nested/VectorNestedIT.java index d1021715ceffc..aaab14941d4bb 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/nested/VectorNestedIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/nested/VectorNestedIT.java @@ -69,7 +69,7 @@ public void testSimpleNested() throws Exception { assertResponse( prepareSearch("test").setKnnSearch( - List.of(new KnnSearchBuilder("nested.vector", new float[] { 1, 1, 1 }, 1, 1, null).innerHit(new InnerHitBuilder())) + List.of(new KnnSearchBuilder("nested.vector", new float[] { 1, 1, 1 }, 1, 1, null, null).innerHit(new InnerHitBuilder())) ).setAllowPartialSearchResults(false), response -> assertThat(response.getHits().getHits().length, greaterThan(0)) ); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java index 876edc282c903..95d69a6ebaa86 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java @@ -19,6 +19,7 @@ import org.elasticsearch.search.profile.query.CollectorResult; import org.elasticsearch.search.profile.query.QueryProfileShardResult; import org.elasticsearch.search.vectors.KnnSearchBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.xcontent.XContentFactory; @@ -71,6 +72,7 @@ public void testProfileDfs() throws Exception { new float[] { randomFloat(), randomFloat(), randomFloat() }, randomIntBetween(5, 10), 50, + randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)), randomBoolean() ? null : randomFloat() ); if (randomBoolean()) { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java index 73debb6c9b985..40849bea5512e 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java @@ -84,7 +84,9 @@ public void testTelemetryForRetrievers() throws IOException { // search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers` { - performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null))); + performSearch( + new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null)) + ); } // search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under @@ -112,7 +114,9 @@ public void testTelemetryForRetrievers() throws IOException { // search#5 - t // his will record 1 entry for "knn" in `sections` { - performSearch(new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null)))); + performSearch( + new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null, null))) + ); } // search#6 - this will record 1 entry for "query" in `sections`, and 1 for "match_all" under `queries` diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java index 90b89d9ff1a13..81646f4a84d10 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java @@ -56,6 +56,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea public static final ParseField NAME_FIELD = AbstractQueryBuilder.NAME_FIELD; public static final ParseField BOOST_FIELD = AbstractQueryBuilder.BOOST_FIELD; public static final ParseField INNER_HITS_FIELD = new ParseField("inner_hits"); + public static final ParseField RESCORE_FIELD = new ParseField("rescore"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("knn", args -> { @@ -65,7 +66,8 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea .queryVectorBuilder((QueryVectorBuilder) args[4]) .k((Integer) args[2]) .numCandidates((Integer) args[3]) - .similarity((Float) args[5]); + .similarity((Float) args[5]) + .rescoreVectorBuilder((RescoreVectorBuilder) args[6]); }); static { @@ -78,13 +80,18 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea ); PARSER.declareInt(optionalConstructorArg(), K_FIELD); PARSER.declareInt(optionalConstructorArg(), NUM_CANDS_FIELD); - PARSER.declareNamedObject( optionalConstructorArg(), (p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c), QUERY_VECTOR_BUILDER_FIELD ); PARSER.declareFloat(optionalConstructorArg(), VECTOR_SIMILARITY); + PARSER.declareField( + optionalConstructorArg(), + (p, c) -> RescoreVectorBuilder.fromXContent(p), + RESCORE_FIELD, + ObjectParser.ValueType.OBJECT_OR_NULL + ); PARSER.declareFieldArray( KnnSearchBuilder.Builder::addFilterQueries, (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p), @@ -116,6 +123,7 @@ public static KnnSearchBuilder.Builder fromXContent(XContentParser parser) throw String queryName; float boost = DEFAULT_BOOST; InnerHitBuilder innerHitBuilder; + final RescoreVectorBuilder rescoreVectorBuilder; /** * Defines a kNN search. @@ -124,14 +132,23 @@ public static KnnSearchBuilder.Builder fromXContent(XContentParser parser) throw * @param queryVector the query vector * @param k the final number of nearest neighbors to return as top hits * @param numCands the number of nearest neighbor candidates to consider per shard + * @param rescoreVectorBuilder rescore vector information */ - public KnnSearchBuilder(String field, float[] queryVector, int k, int numCands, Float similarity) { + public KnnSearchBuilder( + String field, + float[] queryVector, + int k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float similarity + ) { this( field, Objects.requireNonNull(VectorData.fromFloats(queryVector), format("[%s] cannot be null", QUERY_VECTOR_FIELD)), null, k, numCands, + rescoreVectorBuilder, similarity ); } @@ -144,8 +161,15 @@ public KnnSearchBuilder(String field, float[] queryVector, int k, int numCands, * @param k the final number of nearest neighbors to return as top hits * @param numCands the number of nearest neighbor candidates to consider per shard */ - public KnnSearchBuilder(String field, VectorData queryVector, int k, int numCands, Float similarity) { - this(field, queryVector, null, k, numCands, similarity); + public KnnSearchBuilder( + String field, + VectorData queryVector, + int k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float similarity + ) { + this(field, queryVector, null, k, numCands, rescoreVectorBuilder, similarity); } /** @@ -156,13 +180,21 @@ public KnnSearchBuilder(String field, VectorData queryVector, int k, int numCand * @param k the final number of nearest neighbors to return as top hits * @param numCands the number of nearest neighbor candidates to consider per shard */ - public KnnSearchBuilder(String field, QueryVectorBuilder queryVectorBuilder, int k, int numCands, Float similarity) { + public KnnSearchBuilder( + String field, + QueryVectorBuilder queryVectorBuilder, + int k, + int numCands, + RescoreVectorBuilder rescoreVectorBuilder, + Float similarity + ) { this( field, null, Objects.requireNonNull(queryVectorBuilder, format("[%s] cannot be null", QUERY_VECTOR_BUILDER_FIELD.getPreferredName())), k, numCands, + rescoreVectorBuilder, similarity ); } @@ -173,9 +205,22 @@ public KnnSearchBuilder( QueryVectorBuilder queryVectorBuilder, int k, int numCands, + RescoreVectorBuilder rescoreVectorBuilder, Float similarity ) { - this(field, queryVectorBuilder, queryVector, new ArrayList<>(), k, numCands, similarity, null, null, DEFAULT_BOOST); + this( + field, + queryVectorBuilder, + queryVector, + new ArrayList<>(), + k, + numCands, + rescoreVectorBuilder, + similarity, + null, + null, + DEFAULT_BOOST + ); } private KnnSearchBuilder( @@ -183,6 +228,7 @@ private KnnSearchBuilder( Supplier querySupplier, Integer k, Integer numCands, + RescoreVectorBuilder rescoreVectorBuilder, List filterQueries, Float similarity ) { @@ -194,6 +240,7 @@ private KnnSearchBuilder( this.filterQueries = filterQueries; this.querySupplier = querySupplier; this.similarity = similarity; + this.rescoreVectorBuilder = rescoreVectorBuilder; } private KnnSearchBuilder( @@ -203,6 +250,7 @@ private KnnSearchBuilder( List filterQueries, int k, int numCandidates, + RescoreVectorBuilder rescoreVectorBuilder, Float similarity, InnerHitBuilder innerHitBuilder, String queryName, @@ -242,6 +290,7 @@ private KnnSearchBuilder( this.queryVectorBuilder = queryVectorBuilder; this.k = k; this.numCands = numCandidates; + this.rescoreVectorBuilder = rescoreVectorBuilder; this.innerHitBuilder = innerHitBuilder; this.similarity = similarity; this.queryName = queryName; @@ -280,6 +329,11 @@ public KnnSearchBuilder(StreamInput in) throws IOException { if (in.getTransportVersion().onOrAfter(V_8_11_X)) { this.innerHitBuilder = in.readOptionalWriteable(InnerHitBuilder::new); } + if (in.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE)) { + this.rescoreVectorBuilder = in.readOptional(RescoreVectorBuilder::new); + } else { + this.rescoreVectorBuilder = null; + } } public int k() { @@ -290,6 +344,10 @@ public int getNumCands() { return numCands; } + public RescoreVectorBuilder getRescoreVectorBuilder() { + return rescoreVectorBuilder; + } + public QueryVectorBuilder getQueryVectorBuilder() { return queryVectorBuilder; } @@ -358,7 +416,7 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException { if (querySupplier.get() == null) { return this; } - return new KnnSearchBuilder(field, querySupplier.get(), k, numCands, similarity).boost(boost) + return new KnnSearchBuilder(field, querySupplier.get(), k, numCands, rescoreVectorBuilder, similarity).boost(boost) .queryName(queryName) .addFilterQueries(filterQueries) .innerHit(innerHitBuilder); @@ -381,7 +439,7 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException { } ll.onResponse(null); }))); - return new KnnSearchBuilder(field, toSet::get, k, numCands, filterQueries, similarity).boost(boost) + return new KnnSearchBuilder(field, toSet::get, k, numCands, rescoreVectorBuilder, filterQueries, similarity).boost(boost) .queryName(queryName) .innerHit(innerHitBuilder); } @@ -395,7 +453,7 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException { rewrittenQueries.add(rewrittenQuery); } if (changed) { - return new KnnSearchBuilder(field, queryVector, k, numCands, similarity).boost(boost) + return new KnnSearchBuilder(field, queryVector, k, numCands, rescoreVectorBuilder, similarity).boost(boost) .queryName(queryName) .addFilterQueries(rewrittenQueries) .innerHit(innerHitBuilder); @@ -407,7 +465,7 @@ public KnnVectorQueryBuilder toQueryBuilder() { if (queryVectorBuilder != null) { throw new IllegalArgumentException("missing rewrite"); } - return new KnnVectorQueryBuilder(field, queryVector, null, numCands, null, similarity).boost(boost) + return new KnnVectorQueryBuilder(field, queryVector, null, numCands, rescoreVectorBuilder, similarity).boost(boost) .queryName(queryName) .addFilterQueries(filterQueries); } @@ -423,6 +481,7 @@ public boolean equals(Object o) { KnnSearchBuilder that = (KnnSearchBuilder) o; return k == that.k && numCands == that.numCands + && Objects.equals(rescoreVectorBuilder, that.rescoreVectorBuilder) && Objects.equals(field, that.field) && Objects.equals(queryVector, that.queryVector) && Objects.equals(queryVectorBuilder, that.queryVectorBuilder) @@ -442,6 +501,7 @@ public int hashCode() { numCands, querySupplier, queryVectorBuilder, + rescoreVectorBuilder, similarity, Objects.hashCode(queryVector), Objects.hashCode(filterQueries), @@ -486,6 +546,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (queryName != null) { builder.field(NAME_FIELD.getPreferredName(), queryName); } + if (rescoreVectorBuilder != null) { + builder.startObject(RESCORE_FIELD.getPreferredName()); + rescoreVectorBuilder.toXContent(builder, params); + builder.endObject(); + } return builder; } @@ -526,6 +591,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(V_8_11_X)) { out.writeOptionalWriteable(innerHitBuilder); } + if (out.getTransportVersion().onOrAfter(TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE)) { + out.writeOptionalWriteable(rescoreVectorBuilder); + } } public static class Builder { @@ -540,6 +608,7 @@ public static class Builder { private String queryName; private float boost = DEFAULT_BOOST; private InnerHitBuilder innerHitBuilder; + private RescoreVectorBuilder rescoreVectorBuilder; public Builder addFilterQueries(List filterQueries) { Objects.requireNonNull(filterQueries); @@ -592,6 +661,11 @@ public Builder similarity(Float similarity) { return this; } + public Builder rescoreVectorBuilder(RescoreVectorBuilder rescoreVectorBuilder) { + this.rescoreVectorBuilder = rescoreVectorBuilder; + return this; + } + public KnnSearchBuilder build(int size) { int requestSize = size < 0 ? DEFAULT_SIZE : size; int adjustedK = k == null ? requestSize : k; @@ -605,6 +679,7 @@ public KnnSearchBuilder build(int size) { filterQueries, adjustedK, adjustedNumCandidates, + rescoreVectorBuilder, similarity, innerHitBuilder, queryName, diff --git a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java index 64362daf7f75c..193855a4c835f 100644 --- a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java @@ -344,8 +344,8 @@ public void testRewriteShardSearchRequestWithRank() { SearchSourceBuilder ssb = new SearchSourceBuilder().query(bm25) .knnSearch( List.of( - new KnnSearchBuilder("vector", new float[] { 0.0f }, 10, 100, null), - new KnnSearchBuilder("vector2", new float[] { 0.0f }, 10, 100, null) + new KnnSearchBuilder("vector", new float[] { 0.0f }, 10, 100, null, null), + new KnnSearchBuilder("vector2", new float[] { 0.0f }, 10, 100, null, null) ) ) .rankBuilder(new TestRankBuilder(100)); diff --git a/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java b/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java index a52e3bc910bc2..353188af8be3c 100644 --- a/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java @@ -63,7 +63,7 @@ public void testKnnSearchRemovedVector() throws IOException { client().prepareUpdate("index", "0").setDoc("vector", (Object) null).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).get(); float[] queryVector = randomVector(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 20, 50, null).boost(5.0f); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 20, 50, null, null).boost(5.0f); assertResponse( client().prepareSearch("index") .setKnnSearch(List.of(knnSearch)) @@ -107,7 +107,7 @@ public void testKnnWithQuery() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).boost(5.0f).queryName("knn"); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null, null).boost(5.0f).queryName("knn"); assertResponse( client().prepareSearch("index") .setKnnSearch(List.of(knnSearch)) @@ -156,7 +156,7 @@ public void testKnnFilter() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).addFilterQuery( + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null, null).addFilterQuery( QueryBuilders.termsQuery("field", "second") ); assertResponse(client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).addFetchField("*").setSize(10), response -> { @@ -199,7 +199,7 @@ public void testKnnFilterWithRewrite() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).addFilterQuery( + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null, null).addFilterQuery( QueryBuilders.termsLookupQuery("field", new TermsLookup("index", "lookup-doc", "other-field")) ); assertResponse(client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).setSize(10), response -> { @@ -246,8 +246,8 @@ public void testMultiKnnClauses() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(20f, 21f); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).boost(5.0f); - KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, null).boost(10.0f); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null, null).boost(5.0f); + KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, null, null).boost(10.0f); assertResponse( client().prepareSearch("index") .setKnnSearch(List.of(knnSearch, knnSearch2)) @@ -308,8 +308,8 @@ public void testMultiKnnClausesSameDoc() throws IOException { float[] queryVector = randomVector(); // Having the same query vector and same docs should mean our KNN scores are linearly combined if the same doc is matched - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null); - KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null, null); + KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, null, null); assertResponse( client().prepareSearch("index") .setKnnSearch(List.of(knnSearch)) @@ -381,7 +381,7 @@ public void testKnnFilteredAlias() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 10, 50, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 10, 50, null, null); final int expectedHitCount = expectedHits; assertResponse(client().prepareSearch("test-alias").setKnnSearch(List.of(knnSearch)).setSize(10), response -> { assertHitCount(response, expectedHitCount); @@ -452,7 +452,7 @@ public void testKnnVectorsWith4096Dims() throws IOException { indicesAdmin().prepareRefresh("index").get(); float[] queryVector = randomVector(4096); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 50, null).boost(5.0f); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 50, null, null).boost(5.0f); assertResponse(client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).addFetchField("*").setSize(10), response -> { assertHitCount(response, 3); assertEquals(3, response.getHits().getHits().length); diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java index 526961d74bf52..2afb5235ddf82 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java @@ -39,6 +39,7 @@ import org.elasticsearch.search.suggest.SuggestBuilder; import org.elasticsearch.search.suggest.term.TermSuggestionBuilder; import org.elasticsearch.search.vectors.KnnSearchBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.TransportVersionUtils; @@ -135,8 +136,22 @@ public void testSerializationMultiKNN() throws Exception { searchRequest.source() .knnSearch( List.of( - new KnnSearchBuilder(randomAlphaOfLength(10), new float[] { 1, 2 }, 5, 10, randomBoolean() ? null : randomFloat()), - new KnnSearchBuilder(randomAlphaOfLength(10), new float[] { 4, 12, 41 }, 3, 5, randomBoolean() ? null : randomFloat()) + new KnnSearchBuilder( + randomAlphaOfLength(10), + new float[] { 1, 2 }, + 5, + 10, + randomRescoreVectorBuilder(), + randomBoolean() ? null : randomFloat() + ), + new KnnSearchBuilder( + randomAlphaOfLength(10), + new float[] { 4, 12, 41 }, + 3, + 5, + randomRescoreVectorBuilder(), + randomBoolean() ? null : randomFloat() + ) ) ); expectThrows( @@ -151,7 +166,16 @@ public void testSerializationMultiKNN() throws Exception { searchRequest.source() .knnSearch( - List.of(new KnnSearchBuilder(randomAlphaOfLength(10), new float[] { 1, 2 }, 5, 10, randomBoolean() ? null : randomFloat())) + List.of( + new KnnSearchBuilder( + randomAlphaOfLength(10), + new float[] { 1, 2 }, + 5, + 10, + randomRescoreVectorBuilder(), + randomBoolean() ? null : randomFloat() + ) + ) ); // Shouldn't throw because its just one KNN request copyWriteable( @@ -162,6 +186,10 @@ public void testSerializationMultiKNN() throws Exception { ); } + private static RescoreVectorBuilder randomRescoreVectorBuilder() { + return randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); + } + public void testRandomVersionSerialization() throws IOException { SearchRequest searchRequest = createSearchRequest(); TransportVersion version = TransportVersionUtils.randomVersion(random()); @@ -474,7 +502,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) .size(0) ); ActionRequestValidationException validationErrors = searchRequest.validate(); @@ -486,7 +514,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(1)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) .size(2) ); ActionRequestValidationException validationErrors = searchRequest.validate(); @@ -513,7 +541,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) ).scroll(new TimeValue(1000)); ActionRequestValidationException validationErrors = searchRequest.validate(); assertNotNull(validationErrors); @@ -524,7 +552,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(9)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) ); ActionRequestValidationException validationErrors = searchRequest.validate(); assertNotNull(validationErrors); @@ -538,7 +566,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(3)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) .size(3) .from(4) ); @@ -549,7 +577,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) .addRescorer(new QueryRescorerBuilder(QueryBuilders.termQuery("rescore", "another term"))) ); ActionRequestValidationException validationErrors = searchRequest.validate(); @@ -561,7 +589,7 @@ public QueryBuilder topDocsQuery() { SearchRequest searchRequest = new SearchRequest().source( new SearchSourceBuilder().rankBuilder(new TestRankBuilder(100)) .query(QueryBuilders.termQuery("field", "term")) - .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null))) + .knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 0f }, 10, 100, null, null))) .suggest(new SuggestBuilder().setGlobalText("test").addSuggestion("suggestion", new TermSuggestionBuilder("term"))) ); ActionRequestValidationException validationErrors = searchRequest.validate(); diff --git a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java index a9de118c6b859..ed3d26141fe04 100644 --- a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java @@ -1367,7 +1367,7 @@ public void testShouldMinimizeRoundtrips() throws Exception { { SearchRequest searchRequest = new SearchRequest(); SearchSourceBuilder source = new SearchSourceBuilder(); - source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50, null))); + source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50, null, null))); searchRequest.source(source); searchRequest.setCcsMinimizeRoundtrips(true); @@ -1382,7 +1382,7 @@ public void testAdjustSearchType() { // If the search includes kNN, we should always use DFS_QUERY_THEN_FETCH SearchRequest searchRequest = new SearchRequest(); SearchSourceBuilder source = new SearchSourceBuilder(); - source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50, null))); + source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50, null, null))); searchRequest.source(source); TransportSearchAction.adjustSearchType(searchRequest, randomBoolean()); diff --git a/server/src/test/java/org/elasticsearch/rest/action/search/RestSearchActionTests.java b/server/src/test/java/org/elasticsearch/rest/action/search/RestSearchActionTests.java index 4822b1c64cf41..aaa706acf525a 100644 --- a/server/src/test/java/org/elasticsearch/rest/action/search/RestSearchActionTests.java +++ b/server/src/test/java/org/elasticsearch/rest/action/search/RestSearchActionTests.java @@ -83,7 +83,7 @@ public void testValidateSearchRequest() { .build(); SearchRequest searchRequest = new SearchRequest(); - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", new float[] { 1, 1, 1 }, 10, 100, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", new float[] { 1, 1, 1 }, 10, 100, null, null); searchRequest.source(new SearchSourceBuilder().knnSearch(List.of(knnSearch))); Exception ex = expectThrows( diff --git a/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java b/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java index 380b5189b3efc..070773a2a7d42 100644 --- a/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java @@ -825,7 +825,7 @@ public void testSearchSectionsUsageCollection() throws IOException { searchSourceBuilder.fetchField("field"); // these are not correct runtime mappings but they are counted compared to empty object searchSourceBuilder.runtimeMappings(Collections.singletonMap("field", "keyword")); - searchSourceBuilder.knnSearch(List.of(new KnnSearchBuilder("field", new float[] {}, 2, 5, null))); + searchSourceBuilder.knnSearch(List.of(new KnnSearchBuilder("field", new float[] {}, 2, 5, null, null))); searchSourceBuilder.pointInTimeBuilder(new PointInTimeBuilder(new BytesArray("pitid"))); searchSourceBuilder.docValueField("field"); searchSourceBuilder.storedField("field"); diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java index 3753e8ce874cb..ec2ddebcfc2ca 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java @@ -52,8 +52,18 @@ public static KnnSearchBuilder randomTestInstance() { float[] vector = randomVector(dim); int k = randomIntBetween(1, 100); int numCands = randomIntBetween(k + 20, 1000); - - KnnSearchBuilder builder = new KnnSearchBuilder(field, vector, k, numCands, randomBoolean() ? null : randomFloat()); + RescoreVectorBuilder rescoreVectorBuilder = randomBoolean() + ? null + : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); + + KnnSearchBuilder builder = new KnnSearchBuilder( + field, + vector, + k, + numCands, + rescoreVectorBuilder, + randomBoolean() ? null : randomFloat() + ); if (randomBoolean()) { builder.boost(randomFloat()); } @@ -100,46 +110,90 @@ protected KnnSearchBuilder createTestInstance() { @Override protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) { - switch (random().nextInt(7)) { + switch (random().nextInt(8)) { case 0: String newField = randomValueOtherThan(instance.field, () -> randomAlphaOfLength(5)); - return new KnnSearchBuilder(newField, instance.queryVector, instance.k, instance.numCands, instance.similarity).boost( - instance.boost - ); + return new KnnSearchBuilder( + newField, + instance.queryVector, + instance.k, + instance.numCands, + instance.rescoreVectorBuilder, + instance.similarity + ).boost(instance.boost); case 1: float[] newVector = randomValueOtherThan(instance.queryVector.asFloatVector(), () -> randomVector(5)); - return new KnnSearchBuilder(instance.field, newVector, instance.k, instance.numCands, instance.similarity).boost( - instance.boost - ); + return new KnnSearchBuilder( + instance.field, + newVector, + instance.k, + instance.numCands, + instance.rescoreVectorBuilder, + instance.similarity + ).boost(instance.boost); case 2: // given how the test instance is created, we have a 20-value gap between `k` and `numCands` so we SHOULD be safe Integer newK = randomValueOtherThan(instance.k, () -> instance.k + ESTestCase.randomInt(10)); - return new KnnSearchBuilder(instance.field, instance.queryVector, newK, instance.numCands, instance.similarity).boost( - instance.boost - ); + return new KnnSearchBuilder( + instance.field, + instance.queryVector, + newK, + instance.numCands, + instance.rescoreVectorBuilder, + instance.similarity + ).boost(instance.boost); case 3: Integer newNumCands = randomValueOtherThan(instance.numCands, () -> instance.numCands + ESTestCase.randomInt(100)); - return new KnnSearchBuilder(instance.field, instance.queryVector, instance.k, newNumCands, instance.similarity).boost( - instance.boost - ); + return new KnnSearchBuilder( + instance.field, + instance.queryVector, + instance.k, + newNumCands, + instance.rescoreVectorBuilder, + instance.similarity + ).boost(instance.boost); case 4: - return new KnnSearchBuilder(instance.field, instance.queryVector, instance.k, instance.numCands, instance.similarity) - .addFilterQueries(instance.filterQueries) + return new KnnSearchBuilder( + instance.field, + instance.queryVector, + instance.k, + instance.numCands, + instance.rescoreVectorBuilder, + instance.similarity + ).addFilterQueries(instance.filterQueries) .addFilterQuery(QueryBuilders.termQuery("new_field", "new-value")) .boost(instance.boost); case 5: float newBoost = randomValueOtherThan(instance.boost, ESTestCase::randomFloat); - return new KnnSearchBuilder(instance.field, instance.queryVector, instance.k, instance.numCands, instance.similarity) - .addFilterQueries(instance.filterQueries) - .boost(newBoost); + return new KnnSearchBuilder( + instance.field, + instance.queryVector, + instance.k, + instance.numCands, + instance.rescoreVectorBuilder, + instance.similarity + ).addFilterQueries(instance.filterQueries).boost(newBoost); case 6: return new KnnSearchBuilder( instance.field, instance.queryVector, instance.k, instance.numCands, + instance.rescoreVectorBuilder, randomValueOtherThan(instance.similarity, ESTestCase::randomFloat) ).addFilterQueries(instance.filterQueries).boost(instance.boost); + case 7: + return new KnnSearchBuilder( + instance.field, + instance.queryVector, + instance.k, + instance.numCands, + randomValueOtherThan( + instance.rescoreVectorBuilder, + () -> new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)) + ), + instance.similarity + ).addFilterQueries(instance.filterQueries).boost(instance.boost); default: throw new IllegalStateException(); } @@ -151,7 +205,10 @@ public void testToQueryBuilder() { int k = randomIntBetween(1, 100); int numCands = randomIntBetween(k, 1000); Float similarity = randomBoolean() ? null : randomFloat(); - KnnSearchBuilder builder = new KnnSearchBuilder(field, vector, k, numCands, similarity); + RescoreVectorBuilder rescoreVectorBuilder = randomBoolean() + ? null + : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); + KnnSearchBuilder builder = new KnnSearchBuilder(field, vector, k, numCands, rescoreVectorBuilder, similarity); float boost = AbstractQueryBuilder.DEFAULT_BOOST; if (randomBoolean()) { @@ -167,15 +224,16 @@ public void testToQueryBuilder() { builder.addFilterQuery(filter); } - QueryBuilder expected = new KnnVectorQueryBuilder(field, vector, null, numCands, null, similarity).addFilterQueries(filterQueries) - .boost(boost); + QueryBuilder expected = new KnnVectorQueryBuilder(field, vector, null, numCands, rescoreVectorBuilder, similarity).addFilterQueries( + filterQueries + ).boost(boost); assertEquals(expected, builder.toQueryBuilder()); } public void testNumCandsLessThanK() { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new KnnSearchBuilder("field", randomVector(3), 50, 10, null) + () -> new KnnSearchBuilder("field", randomVector(3), 50, 10, null, null) ); assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]")); } @@ -183,7 +241,7 @@ public void testNumCandsLessThanK() { public void testNumCandsExceedsLimit() { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new KnnSearchBuilder("field", randomVector(3), 100, 10002, null) + () -> new KnnSearchBuilder("field", randomVector(3), 100, 10002, null, null) ); assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [10000]")); } @@ -191,18 +249,20 @@ public void testNumCandsExceedsLimit() { public void testInvalidK() { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new KnnSearchBuilder("field", randomVector(3), 0, 100, null) + () -> new KnnSearchBuilder("field", randomVector(3), 0, 100, null, null) ); assertThat(e.getMessage(), containsString("[k] must be greater than 0")); } public void testRewrite() throws Exception { float[] expectedArray = randomVector(randomIntBetween(10, 1024)); + RescoreVectorBuilder expectedRescore = new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); KnnSearchBuilder searchBuilder = new KnnSearchBuilder( "field", new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(expectedArray), 5, 10, + expectedRescore, 1f ); searchBuilder.boost(randomFloat()); @@ -220,6 +280,7 @@ public void testRewrite() throws Exception { assertThat(rewritten.filterQueries, hasSize(1)); assertThat(rewritten.similarity, equalTo(1f)); assertThat(((RewriteableQuery) rewritten.filterQueries.get(0)).rewrites, equalTo(1)); + assertThat(rewritten.rescoreVectorBuilder, equalTo(expectedRescore)); } public static float[] randomVector(int dim) { diff --git a/test/framework/src/main/java/org/elasticsearch/search/RandomSearchRequestGenerator.java b/test/framework/src/main/java/org/elasticsearch/search/RandomSearchRequestGenerator.java index 363d34ca3ff86..6e8cf735983aa 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/RandomSearchRequestGenerator.java +++ b/test/framework/src/main/java/org/elasticsearch/search/RandomSearchRequestGenerator.java @@ -36,6 +36,7 @@ import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.search.suggest.SuggestBuilder; import org.elasticsearch.search.vectors.KnnSearchBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.test.AbstractQueryTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; @@ -61,6 +62,7 @@ import static org.elasticsearch.test.ESTestCase.randomByte; import static org.elasticsearch.test.ESTestCase.randomDouble; import static org.elasticsearch.test.ESTestCase.randomFloat; +import static org.elasticsearch.test.ESTestCase.randomFloatBetween; import static org.elasticsearch.test.ESTestCase.randomFrom; import static org.elasticsearch.test.ESTestCase.randomInt; import static org.elasticsearch.test.ESTestCase.randomIntBetween; @@ -264,7 +266,12 @@ public static SearchSourceBuilder randomSearchSourceBuilder( } int k = randomIntBetween(1, 100); int numCands = randomIntBetween(k, 1000); - knnSearchBuilders.add(new KnnSearchBuilder(field, vector, k, numCands, randomBoolean() ? null : randomFloat())); + RescoreVectorBuilder rescoreVectorBuilder = randomBoolean() + ? null + : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); + knnSearchBuilders.add( + new KnnSearchBuilder(field, vector, k, numCands, rescoreVectorBuilder, randomBoolean() ? null : randomFloat()) + ); } builder.knnSearch(knnSearchBuilders); } diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java index e00dc9f693ff3..1ca6ef0b43a38 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java @@ -23,6 +23,7 @@ import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.vectors.KnnSearchBuilder; import org.elasticsearch.search.vectors.QueryVectorBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.test.client.NoOpClient; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -97,6 +98,7 @@ public final void testKnnSearchBuilderWireSerialization() throws IOException { createTestInstance(), 5, 10, + randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)), randomBoolean() ? null : randomFloat() ); searchBuilder.queryName(randomAlphaOfLengthBetween(5, 10)); @@ -120,6 +122,7 @@ public final void testKnnSearchRewrite() throws Exception { queryVectorBuilder, 5, 10, + randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)), randomBoolean() ? null : randomFloat() ); KnnSearchBuilder serialized = copyWriteable( diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java index df65aac5b79b8..14e97ae7251db 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java @@ -215,8 +215,9 @@ public RetrieverBuilder toRetriever(SearchSourceBuilder source, Predicate Date: Thu, 14 Nov 2024 23:58:49 +0100 Subject: [PATCH 07/56] Add vector rescore builder to kNN retriever --- .../search/retriever/KnnRetrieverBuilder.java | 25 ++++++++++++++++--- .../KnnRetrieverBuilderParsingTests.java | 15 +++++++++-- .../RankDocsRetrieverBuilderTests.java | 3 ++- ...SimilarityRankRetrieverTelemetryTests.java | 2 +- .../xpack/rank/rrf/RRFRetrieverBuilderIT.java | 18 ++++++------- .../rrf/RRFRetrieverBuilderNestedDocsIT.java | 2 +- .../rank/rrf/RRFRetrieverTelemetryIT.java | 4 +-- 7 files changed, 50 insertions(+), 19 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java index 8be9a78dae154..188c3c5848162 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java @@ -20,8 +20,10 @@ import org.elasticsearch.search.vectors.ExactKnnQueryBuilder; import org.elasticsearch.search.vectors.KnnSearchBuilder; import org.elasticsearch.search.vectors.QueryVectorBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; @@ -52,6 +54,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector"); public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder"); public static final ParseField VECTOR_SIMILARITY = new ParseField("similarity"); + public static final ParseField RESCORE_FIELD = new ParseField("rescore"); @SuppressWarnings("unchecked") public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( @@ -73,7 +76,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { (QueryVectorBuilder) args[2], (int) args[3], (int) args[4], - (Float) args[5] + (RescoreVectorBuilder) args[6], (Float) args[5] ); } ); @@ -89,6 +92,12 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { PARSER.declareInt(constructorArg(), K_FIELD); PARSER.declareInt(constructorArg(), NUM_CANDS_FIELD); PARSER.declareFloat(optionalConstructorArg(), VECTOR_SIMILARITY); + PARSER.declareField( + optionalConstructorArg(), + (p, c) -> RescoreVectorBuilder.fromXContent(p), + RESCORE_FIELD, + ObjectParser.ValueType.OBJECT_OR_NULL + ); RetrieverBuilder.declareBaseParserFields(NAME, PARSER); } @@ -104,6 +113,7 @@ public static KnnRetrieverBuilder fromXContent(XContentParser parser, RetrieverP private final QueryVectorBuilder queryVectorBuilder; private final int k; private final int numCands; + private final RescoreVectorBuilder rescoreVectorBuilder; private final Float similarity; public KnnRetrieverBuilder( @@ -112,6 +122,7 @@ public KnnRetrieverBuilder( QueryVectorBuilder queryVectorBuilder, int k, int numCands, + RescoreVectorBuilder rescoreVectorBuilder, Float similarity ) { if (queryVector == null && queryVectorBuilder == null) { @@ -137,6 +148,7 @@ public KnnRetrieverBuilder( this.k = k; this.numCands = numCands; this.similarity = similarity; + this.rescoreVectorBuilder = rescoreVectorBuilder; } private KnnRetrieverBuilder(KnnRetrieverBuilder clone, Supplier queryVector, QueryVectorBuilder queryVectorBuilder) { @@ -148,6 +160,7 @@ private KnnRetrieverBuilder(KnnRetrieverBuilder clone, Supplier queryVe this.similarity = clone.similarity; this.retrieverName = clone.retrieverName; this.preFilterQueryBuilders = clone.preFilterQueryBuilders; + this.rescoreVectorBuilder = clone.rescoreVectorBuilder; } @Override @@ -229,6 +242,7 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder null, k, numCands, + rescoreVectorBuilder, similarity ); if (preFilterQueryBuilders != null) { @@ -261,6 +275,10 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept if (similarity != null) { builder.field(VECTOR_SIMILARITY.getPreferredName(), similarity); } + + if (rescoreVectorBuilder != null) { + builder.field(RESCORE_FIELD.getPreferredName(), rescoreVectorBuilder); + } } @Override @@ -272,12 +290,13 @@ public boolean doEquals(Object o) { && ((queryVector == null && that.queryVector == null) || (queryVector != null && that.queryVector != null && Arrays.equals(queryVector.get(), that.queryVector.get()))) && Objects.equals(queryVectorBuilder, that.queryVectorBuilder) - && Objects.equals(similarity, that.similarity); + && Objects.equals(similarity, that.similarity) + && Objects.equals(rescoreVectorBuilder, that.rescoreVectorBuilder); } @Override public int doHashCode() { - int result = Objects.hash(field, queryVectorBuilder, k, numCands, similarity); + int result = Objects.hash(field, queryVectorBuilder, k, numCands, rescoreVectorBuilder, similarity); result = 31 * result + Arrays.hashCode(queryVector != null ? queryVector.get() : null); return result; } diff --git a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java index 7923cb5f0d918..7929c2f03dbdb 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.usage.SearchUsage; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -51,8 +52,18 @@ public static KnnRetrieverBuilder createRandomKnnRetrieverBuilder() { int k = randomIntBetween(1, 100); int numCands = randomIntBetween(k + 20, 1000); Float similarity = randomBoolean() ? null : randomFloat(); - - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(field, vector, null, k, numCands, similarity); + RescoreVectorBuilder rescoreVectorBuilder = randomBoolean() + ? null + : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); + + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( + field, + vector, + null, + k, + numCands, + rescoreVectorBuilder, similarity + ); List preFilterQueryBuilders = new ArrayList<>(); diff --git a/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java b/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java index af6782c45dce8..c971a48504287 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.test.ESTestCase; import java.io.IOException; @@ -69,7 +70,7 @@ private List innerRetrievers(QueryRewriteContext queryRewriteC null, randomInt(10), randomIntBetween(10, 100), - randomFloat() + randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)), randomFloat() ); if (randomBoolean()) { knnRetrieverBuilder.preFilterQueryBuilders = preFilters(queryRewriteContext); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java index 9ec88e959b6c7..e817c7cad2935 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java @@ -102,7 +102,7 @@ public void testTelemetryForRRFRetriever() throws IOException { // search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers` { - performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null))); + performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, (RescoreVectorBuilder) args[6], null))); } // search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java index 37e1807d138aa..04f005810717c 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java @@ -183,7 +183,7 @@ public void testRRFPagination() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -233,7 +233,7 @@ public void testRRFWithAggs() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -288,7 +288,7 @@ public void testRRFWithCollapse() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -345,7 +345,7 @@ public void testRRFRetrieverWithCollapseAndAggs() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -411,7 +411,7 @@ public void testMultipleRRFRetrievers() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -430,7 +430,7 @@ public void testMultipleRRFRetrievers() { ), // this one bring just doc 7 which should be ranked first eventually new CompoundRetrieverBuilder.RetrieverSource( - new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, null), + new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, (RescoreVectorBuilder) args[6], null), null ) ), @@ -477,7 +477,7 @@ public void testRRFExplainWithNamedRetrievers() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -536,7 +536,7 @@ public void testRRFExplainWithAnotherNestedRRF() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null); RRFRetrieverBuilder nestedRRF = new RRFRetrieverBuilder( Arrays.asList( @@ -773,7 +773,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws throw new IllegalStateException("Should not be called"); } }; - var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null); + var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, (RescoreVectorBuilder) args[6], null); var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null)); var rrf = new RRFRetrieverBuilder( List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)), diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java index b1358f11bf633..4ce243288994c 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java @@ -149,7 +149,7 @@ public void testRRFRetrieverWithNestedQuery() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 6 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 6.0f }, null, 1, 100, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 6.0f }, null, 1, 100, (RescoreVectorBuilder) args[6], null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java index f93c13682739f..f217a82ced635 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java @@ -103,7 +103,7 @@ public void testTelemetryForRRFRetriever() throws IOException { // search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers` { - performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null))); + performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, (RescoreVectorBuilder) args[6], null))); } // search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under @@ -136,7 +136,7 @@ public void testTelemetryForRRFRetriever() throws IOException { new RRFRetrieverBuilder( Arrays.asList( new CompoundRetrieverBuilder.RetrieverSource( - new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null), + new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, (RescoreVectorBuilder) args[6], null), null ), new CompoundRetrieverBuilder.RetrieverSource( From 955da1f8295fa742f5f85ce6de2192f7b3da39eb Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 15 Nov 2024 00:06:22 +0100 Subject: [PATCH 08/56] Fix refactoring, spotless --- .../search/retriever/KnnRetrieverBuilder.java | 3 ++- .../KnnRetrieverBuilderParsingTests.java | 3 ++- .../RankDocsRetrieverBuilderTests.java | 3 ++- ...SimilarityRankRetrieverTelemetryTests.java | 10 +++++-- .../xpack/rank/rrf/RRFRetrieverBuilderIT.java | 26 ++++++++++++------- .../rrf/RRFRetrieverBuilderNestedDocsIT.java | 2 +- .../rank/rrf/RRFRetrieverTelemetryIT.java | 10 ++++--- .../xpack/rank/rrf/RRFRankBuilder.java | 2 +- 8 files changed, 40 insertions(+), 19 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java index 188c3c5848162..97c87d755ca25 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java @@ -76,7 +76,8 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { (QueryVectorBuilder) args[2], (int) args[3], (int) args[4], - (RescoreVectorBuilder) args[6], (Float) args[5] + (RescoreVectorBuilder) args[6], + (Float) args[5] ); } ); diff --git a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java index 7929c2f03dbdb..0213a83385739 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java @@ -62,7 +62,8 @@ public static KnnRetrieverBuilder createRandomKnnRetrieverBuilder() { null, k, numCands, - rescoreVectorBuilder, similarity + rescoreVectorBuilder, + similarity ); List preFilterQueryBuilders = new ArrayList<>(); diff --git a/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java b/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java index c971a48504287..d3fd98d4c9ef5 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java @@ -70,7 +70,8 @@ private List innerRetrievers(QueryRewriteContext queryRewriteC null, randomInt(10), randomIntBetween(10, 100), - randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)), randomFloat() + randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)), + randomFloat() ); if (randomBoolean()) { knnRetrieverBuilder.preFilterQueryBuilders = preFilters(queryRewriteContext); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java index e817c7cad2935..9b48abab25889 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java @@ -102,7 +102,11 @@ public void testTelemetryForRRFRetriever() throws IOException { // search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers` { - performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, (RescoreVectorBuilder) args[6], null))); + performSearch( + new SearchSourceBuilder().retriever( + new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null) + ) + ); } // search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under @@ -146,7 +150,9 @@ public void testTelemetryForRRFRetriever() throws IOException { // search#6 - this will record 1 entry for "knn" in `sections` { - performSearch(new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null)))); + performSearch( + new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null, null))) + ); } // search#7 - this will record 1 entry for "query" in `sections`, and 1 for "match_all" under `queries` diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java index 04f005810717c..ced8d56e4b7cb 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java @@ -183,7 +183,15 @@ public void testRRFPagination() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( + VECTOR_FIELD, + new float[] { 2.0f }, + null, + 10, + 100, + null, + null + ); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -233,7 +241,7 @@ public void testRRFWithAggs() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -288,7 +296,7 @@ public void testRRFWithCollapse() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -345,7 +353,7 @@ public void testRRFRetrieverWithCollapseAndAggs() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -411,7 +419,7 @@ public void testMultipleRRFRetrievers() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -430,7 +438,7 @@ public void testMultipleRRFRetrievers() { ), // this one bring just doc 7 which should be ranked first eventually new CompoundRetrieverBuilder.RetrieverSource( - new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, (RescoreVectorBuilder) args[6], null), + new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, null, null), null ) ), @@ -477,7 +485,7 @@ public void testRRFExplainWithNamedRetrievers() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -536,7 +544,7 @@ public void testRRFExplainWithAnotherNestedRRF() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); RRFRetrieverBuilder nestedRRF = new RRFRetrieverBuilder( Arrays.asList( @@ -773,7 +781,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws throw new IllegalStateException("Should not be called"); } }; - var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, (RescoreVectorBuilder) args[6], null); + var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null, null); var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null)); var rrf = new RRFRetrieverBuilder( List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)), diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java index 4ce243288994c..a00b940bbed62 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java @@ -149,7 +149,7 @@ public void testRRFRetrieverWithNestedQuery() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 6 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 6.0f }, null, 1, 100, (RescoreVectorBuilder) args[6], null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 6.0f }, null, 1, 100, null, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java index f217a82ced635..9bc1cd80ea381 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java @@ -103,7 +103,9 @@ public void testTelemetryForRRFRetriever() throws IOException { // search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers` { - performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, (RescoreVectorBuilder) args[6], null))); + performSearch( + new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null)) + ); } // search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under @@ -136,7 +138,7 @@ public void testTelemetryForRRFRetriever() throws IOException { new RRFRetrieverBuilder( Arrays.asList( new CompoundRetrieverBuilder.RetrieverSource( - new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, (RescoreVectorBuilder) args[6], null), + new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null), null ), new CompoundRetrieverBuilder.RetrieverSource( @@ -153,7 +155,9 @@ public void testTelemetryForRRFRetriever() throws IOException { // search#6 - this will record 1 entry for "knn" in `sections` { - performSearch(new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null)))); + performSearch( + new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null, null))) + ); } // search#7 - this will record 1 entry for "query" in `sections`, and 1 for "match_all" under `queries` diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java index 14e97ae7251db..b5bca57478684 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java @@ -217,7 +217,7 @@ public RetrieverBuilder toRetriever(SearchSourceBuilder source, Predicate Date: Mon, 18 Nov 2024 16:53:37 +0100 Subject: [PATCH 09/56] Check oversampling is not used for quantized types --- .../vectors/DenseVectorFieldMapper.java | 34 ++++++++++++++----- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 69dbdce5832a3..6792202462297 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -1219,7 +1219,7 @@ public final int hashCode() { } private enum VectorIndexType { - HNSW("hnsw") { + HNSW("hnsw", false) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object mNode = indexOptionsMap.remove("m"); @@ -1246,7 +1246,7 @@ public boolean supportsDimension(int dims) { return true; } }, - INT8_HNSW("int8_hnsw") { + INT8_HNSW("int8_hnsw", true) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object mNode = indexOptionsMap.remove("m"); @@ -1278,7 +1278,7 @@ public boolean supportsDimension(int dims) { return true; } }, - INT4_HNSW("int4_hnsw") { + INT4_HNSW("int4_hnsw", true) { public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object mNode = indexOptionsMap.remove("m"); Object efConstructionNode = indexOptionsMap.remove("ef_construction"); @@ -1309,7 +1309,7 @@ public boolean supportsDimension(int dims) { return dims % 2 == 0; } }, - FLAT("flat") { + FLAT("flat", false) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); @@ -1326,7 +1326,7 @@ public boolean supportsDimension(int dims) { return true; } }, - INT8_FLAT("int8_flat") { + INT8_FLAT("int8_flat", true) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); @@ -1348,7 +1348,7 @@ public boolean supportsDimension(int dims) { return true; } }, - INT4_FLAT("int4_flat") { + INT4_FLAT("int4_flat", true) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object confidenceIntervalNode = indexOptionsMap.remove("confidence_interval"); @@ -1370,7 +1370,7 @@ public boolean supportsDimension(int dims) { return dims % 2 == 0; } }, - BBQ_HNSW("bbq_hnsw") { + BBQ_HNSW("bbq_hnsw", true) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { Object mNode = indexOptionsMap.remove("m"); @@ -1397,7 +1397,7 @@ public boolean supportsDimension(int dims) { return dims >= BBQ_MIN_DIMS; } }, - BBQ_FLAT("bbq_flat") { + BBQ_FLAT("bbq_flat", true) { @Override public IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap) { MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap); @@ -1420,9 +1420,11 @@ static Optional fromString(String type) { } private final String name; + private final boolean quantized; - VectorIndexType(String name) { + VectorIndexType(String name, boolean quantized) { this.name = name; + this.quantized = quantized; } abstract IndexOptions parseIndexOptions(String fieldName, Map indexOptionsMap); @@ -1431,6 +1433,10 @@ static Optional fromString(String type) { public abstract boolean supportsDimension(int dims); + public boolean isQuantized() { + return quantized; + } + @Override public String toString() { return name; @@ -2013,6 +2019,16 @@ public Query createKnnQuery( "to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]" ); } + if (rescoreOversample != null && indexOptions.type.isQuantized() == false) { + throw new IllegalArgumentException( + "cannot use rescore oversample on field [" + + name() + + "], that uses non-quantized type [" + + indexOptions.type + + "]. " + + "Only quantized index option types support rescore oversample." + ); + } return switch (getElementType()) { case BYTE -> createKnnByteQuery( queryVector.asByteVector(), From bc1e5c692158ff6dbcb80da390cc374c6ac8eaec Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 20 Nov 2024 18:04:41 +0100 Subject: [PATCH 10/56] Minor refactoring to reuse KnnScoreDocQuery --- .../search/vectors/KnnScoreDocQuery.java | 38 ++++++++++++++----- .../vectors/KnnScoreDocQueryBuilder.java | 24 +----------- 2 files changed, 30 insertions(+), 32 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java index bb83b8528c6c8..db7484dddc226 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java @@ -9,6 +9,7 @@ package org.elasticsearch.search.vectors; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; @@ -37,7 +38,13 @@ public class KnnScoreDocQuery extends Query { private final int[] docs; private final float[] scores; + + // the indexes in docs and scores corresponding to the first matching document in each segment. + // If a segment has no matching documents, it should be assigned the index of the next segment that does. + // There should be a final entry that is always docs.length-1. private final int[] segmentStarts; + // an object identifying the reader context that was used to build this query + private final Object contextIdentity; /** @@ -45,18 +52,31 @@ public class KnnScoreDocQuery extends Query { * * @param docs the global doc IDs of documents that match, in ascending order * @param scores the scores of the matching documents - * @param segmentStarts the indexes in docs and scores corresponding to the first matching - * document in each segment. If a segment has no matching documents, it should be assigned - * the index of the next segment that does. There should be a final entry that is always - * docs.length-1. - * @param contextIdentity an object identifying the reader context that was used to build this - * query + * @param reader IndexReader */ - KnnScoreDocQuery(int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) { + KnnScoreDocQuery(int[] docs, float[] scores, IndexReader reader) { this.docs = docs; this.scores = scores; - this.segmentStarts = segmentStarts; - this.contextIdentity = contextIdentity; + this.segmentStarts = findSegmentStarts(reader, docs); + this.contextIdentity = reader.getContext().id(); + } + + private static int[] findSegmentStarts(IndexReader reader, int[] docs) { + int[] starts = new int[reader.leaves().size() + 1]; + starts[starts.length - 1] = docs.length; + if (starts.length == 2) { + return starts; + } + int resultIndex = 0; + for (int i = 1; i < starts.length - 1; i++) { + int upper = reader.leaves().get(i).docBase; + resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper); + if (resultIndex < 0) { + resultIndex = -1 - resultIndex; + } + starts[i] = resultIndex; + } + return starts; } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java index f52addefc8b1c..10bee9ec66c2c 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java @@ -9,7 +9,6 @@ package org.elasticsearch.search.vectors; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.TransportVersion; @@ -25,7 +24,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; -import java.util.Arrays; import java.util.Objects; /** @@ -153,9 +151,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { scores[i] = scoreDocs[i].score; } - IndexReader reader = context.getIndexReader(); - int[] segmentStarts = findSegmentStarts(reader, docs); - return new KnnScoreDocQuery(docs, scores, segmentStarts, reader.getContext().id()); + return new KnnScoreDocQuery(docs, scores, context.getIndexReader()); } @Override @@ -169,24 +165,6 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws return super.doRewrite(queryRewriteContext); } - private static int[] findSegmentStarts(IndexReader reader, int[] docs) { - int[] starts = new int[reader.leaves().size() + 1]; - starts[starts.length - 1] = docs.length; - if (starts.length == 2) { - return starts; - } - int resultIndex = 0; - for (int i = 1; i < starts.length - 1; i++) { - int upper = reader.leaves().get(i).docBase; - resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper); - if (resultIndex < 0) { - resultIndex = -1 - resultIndex; - } - starts[i] = resultIndex; - } - return starts; - } - @Override protected boolean doEquals(KnnScoreDocQueryBuilder other) { if (scoreDocs.length != other.scoreDocs.length) { From a7936da7e137fec3ce20f55aec9f11cff6583b87 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 20 Nov 2024 18:05:41 +0100 Subject: [PATCH 11/56] Use KnnRescoreVectorQuery to perform rescoring and limiting the number of results from each shard --- .../vectors/DenseVectorFieldMapper.java | 63 +++----- .../search/vectors/KnnRescoreVectorQuery.java | 151 ++++++++++++++++++ 2 files changed, 172 insertions(+), 42 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/search/vectors/KnnRescoreVectorQuery.java diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 6792202462297..4d7366dd3dda9 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -30,7 +30,6 @@ import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.queries.function.FunctionScoreQuery; import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; @@ -71,6 +70,7 @@ import org.elasticsearch.search.vectors.ESDiversifyingChildrenFloatKnnVectorQuery; import org.elasticsearch.search.vectors.ESKnnByteVectorQuery; import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery; +import org.elasticsearch.search.vectors.KnnRescoreVectorQuery; import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.search.vectors.VectorSimilarityQuery; import org.elasticsearch.xcontent.ToXContent; @@ -2019,16 +2019,6 @@ public Query createKnnQuery( "to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]" ); } - if (rescoreOversample != null && indexOptions.type.isQuantized() == false) { - throw new IllegalArgumentException( - "cannot use rescore oversample on field [" - + name() - + "], that uses non-quantized type [" - + indexOptions.type - + "]. " - + "Only quantized index option types support rescore oversample." - ); - } return switch (getElementType()) { case BYTE -> createKnnByteQuery( queryVector.asByteVector(), @@ -2060,6 +2050,10 @@ public Query createKnnQuery( }; } + private boolean needsRescore(Float rescoreOversample) { + return rescoreOversample != null && (indexOptions == null || indexOptions.type == null || indexOptions.type.isQuantized()); + } + private Query createKnnBitQuery( byte[] queryVector, Integer k, @@ -2084,17 +2078,6 @@ private Query createKnnBitQuery( similarity.score(similarityThreshold, elementType, dims) ); } - if (rescoreOversample != null) { - knnQuery = new FunctionScoreQuery( - knnQuery, - new VectorSimilarityByteValueSource( - name(), - queryVector, - similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.BYTE) - ) - ); - - } return knnQuery; } @@ -2113,7 +2096,7 @@ private Query createKnnByteQuery( float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude); } - Integer adjustedK = k == null || rescoreOversample == null + Integer adjustedK = k == null || needsRescore(rescoreOversample) == false ? null : Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample)); int adjustedNumCands = Math.max(adjustedK == null ? 0 : adjustedK, numCands); @@ -2128,16 +2111,14 @@ private Query createKnnByteQuery( similarity.score(similarityThreshold, elementType, dims) ); } - if (rescoreOversample != null) { - knnQuery = new FunctionScoreQuery( - knnQuery, - new VectorSimilarityByteValueSource( - name(), - queryVector, - similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.BYTE) - ) + if (needsRescore(rescoreOversample)) { + knnQuery = new KnnRescoreVectorQuery( + name(), + queryVector, + similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.BYTE), + k, + knnQuery ); - } return knnQuery; } @@ -2167,7 +2148,7 @@ && isNotUnitVector(squaredMagnitude)) { } } - Integer adjustedK = k == null || rescoreOversample == null + Integer adjustedK = k == null || needsRescore(rescoreOversample) == false ? k : Integer.valueOf(Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample))); int adjustedNumCands = adjustedK == null ? numCands : Math.max(adjustedK, numCands); @@ -2181,16 +2162,14 @@ && isNotUnitVector(squaredMagnitude)) { similarity.score(similarityThreshold, elementType, dims) ); } - if (rescoreOversample != null) { - knnQuery = new FunctionScoreQuery( - knnQuery, - new VectorSimilarityFloatValueSource( - name(), - queryVector, - similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT) - ) + if (needsRescore(rescoreOversample)) { + knnQuery = new KnnRescoreVectorQuery( + name(), + queryVector, + similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT), + k, + knnQuery ); - } return knnQuery; } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnRescoreVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnRescoreVectorQuery.java new file mode 100644 index 0000000000000..092217d1be7ab --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnRescoreVectorQuery.java @@ -0,0 +1,151 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.queries.function.FunctionScoreQuery; +import org.apache.lucene.search.DoubleValuesSource; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.elasticsearch.index.mapper.vectors.VectorSimilarityByteValueSource; +import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; +import org.elasticsearch.search.profile.query.QueryProfiler; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; + +/** + * Wraps a kNN vector query to rescore the results using the non-quantized vectors + */ +public class KnnRescoreVectorQuery extends Query implements ProfilingQuery { + private final String fieldName; + private final byte[] byteTarget; + private final float[] floatTarget; + private final VectorSimilarityFunction vectorSimilarityFunction; + private final Integer k; + private final Query vectorQuery; + + private long vectorOpsCount; + + public KnnRescoreVectorQuery( + String fieldName, + byte[] byteTarget, + VectorSimilarityFunction vectorSimilarityFunction, + Integer k, + Query vectorQuery + ) { + this.fieldName = fieldName; + this.byteTarget = byteTarget; + this.floatTarget = null; + this.vectorSimilarityFunction = vectorSimilarityFunction; + this.k = k; + this.vectorQuery = vectorQuery; + } + + public KnnRescoreVectorQuery( + String fieldName, + float[] floatTarget, + VectorSimilarityFunction vectorSimilarityFunction, + Integer k, + Query vectorQuery + ) { + this.fieldName = fieldName; + this.byteTarget = null; + this.floatTarget = floatTarget; + this.vectorSimilarityFunction = vectorSimilarityFunction; + this.k = k; + this.vectorQuery = vectorQuery; + } + + @Override + public Query rewrite(IndexSearcher searcher) throws IOException { + Query rewritten = super.rewrite(searcher); + if (rewritten != this) { + return rewritten; + } + + final DoubleValuesSource valueSource; + if (byteTarget != null) { + valueSource = new VectorSimilarityByteValueSource(fieldName, byteTarget, vectorSimilarityFunction); + } else { + valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); + } + FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(vectorQuery, valueSource); + Query query = searcher.rewrite(functionScoreQuery); + + if (k == null) { + // No need to calculate top k - let the request size limit the results + return query; + } + + TopDocs topDocs = searcher.search(query, k); + ScoreDoc[] scoreDocs = topDocs.scoreDocs; + int[] docIds = new int[scoreDocs.length]; + float[] scores = new float[scoreDocs.length]; + for (int i = 0; i < scoreDocs.length; i++) { + docIds[i] = scoreDocs[i].doc; + scores[i] = scoreDocs[i].score; + } + + vectorOpsCount = scoreDocs.length; + + return new KnnScoreDocQuery(docIds, scores, searcher.getIndexReader()); + } + + @Override + public void profile(QueryProfiler queryProfiler) { + queryProfiler.setVectorOpsCount(vectorOpsCount); + } + + @Override + public void visit(QueryVisitor visitor) { + if (visitor.acceptField(fieldName)) { + visitor.visitLeaf(this); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + KnnRescoreVectorQuery that = (KnnRescoreVectorQuery) o; + return Objects.equals(fieldName, that.fieldName) + && Objects.deepEquals(byteTarget, that.byteTarget) + && Objects.deepEquals(floatTarget, that.floatTarget) + && vectorSimilarityFunction == that.vectorSimilarityFunction + && Objects.equals(k, that.k) + && Objects.equals(vectorQuery, that.vectorQuery); + } + + @Override + public int hashCode() { + return Objects.hash(fieldName, Arrays.hashCode(byteTarget), Arrays.hashCode(floatTarget), vectorSimilarityFunction, k, vectorQuery); + } + + @Override + public String toString(String field) { + final StringBuilder sb = new StringBuilder("KnnRescoreVectorQuery{"); + sb.append("fieldName='").append(fieldName).append('\''); + if (byteTarget != null) { + sb.append(", byteTarget=").append(Arrays.toString(byteTarget)); + } else { + sb.append(", floatTarget=").append(Arrays.toString(floatTarget)); + } + sb.append(", vectorSimilarityFunction=").append(vectorSimilarityFunction); + sb.append(", k=").append(k); + sb.append(", vectorQuery=").append(vectorQuery); + sb.append('}'); + return sb.toString(); + } +} From f5080a6c4590f406ef8a82ffcf7955674435e5ae Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 21 Nov 2024 15:03:05 +0100 Subject: [PATCH 12/56] Small name refactoring, fix adjusting parameters --- .../vectors/DenseVectorFieldMapper.java | 26 ++++++++------ .../search/vectors/ESKnnByteVectorQuery.java | 4 +++ .../search/vectors/ESKnnFloatVectorQuery.java | 4 +++ ...rQuery.java => RescoreKnnVectorQuery.java} | 34 ++++++++++++------- 4 files changed, 44 insertions(+), 24 deletions(-) rename server/src/main/java/org/elasticsearch/search/vectors/{KnnRescoreVectorQuery.java => RescoreKnnVectorQuery.java} (87%) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 4d7366dd3dda9..eefdc48e2595b 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -70,7 +70,7 @@ import org.elasticsearch.search.vectors.ESDiversifyingChildrenFloatKnnVectorQuery; import org.elasticsearch.search.vectors.ESKnnByteVectorQuery; import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery; -import org.elasticsearch.search.vectors.KnnRescoreVectorQuery; +import org.elasticsearch.search.vectors.RescoreKnnVectorQuery; import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.search.vectors.VectorSimilarityQuery; import org.elasticsearch.xcontent.ToXContent; @@ -2096,10 +2096,12 @@ private Query createKnnByteQuery( float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude); } - Integer adjustedK = k == null || needsRescore(rescoreOversample) == false - ? null - : Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample)); - int adjustedNumCands = Math.max(adjustedK == null ? 0 : adjustedK, numCands); + Integer adjustedK = k; + int adjustedNumCands = numCands; + if (needsRescore(rescoreOversample) && adjustedK != null) { + adjustedK = Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample)); + adjustedNumCands = Math.max(adjustedK, numCands); + } Query knnQuery = parentFilter != null ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter) @@ -2112,7 +2114,7 @@ private Query createKnnByteQuery( ); } if (needsRescore(rescoreOversample)) { - knnQuery = new KnnRescoreVectorQuery( + knnQuery = new RescoreKnnVectorQuery( name(), queryVector, similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.BYTE), @@ -2148,10 +2150,12 @@ && isNotUnitVector(squaredMagnitude)) { } } - Integer adjustedK = k == null || needsRescore(rescoreOversample) == false - ? k - : Integer.valueOf(Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample))); - int adjustedNumCands = adjustedK == null ? numCands : Math.max(adjustedK, numCands); + Integer adjustedK = k; + int adjustedNumCands = numCands; + if (needsRescore(rescoreOversample) && adjustedK != null) { + adjustedK = Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample)); + adjustedNumCands = Math.max(adjustedK, numCands); + } Query knnQuery = parentFilter != null ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter) : new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, adjustedNumCands, filter); @@ -2163,7 +2167,7 @@ && isNotUnitVector(squaredMagnitude)) { ); } if (needsRescore(rescoreOversample)) { - knnQuery = new KnnRescoreVectorQuery( + knnQuery = new RescoreKnnVectorQuery( name(), queryVector, similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT), diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java index 9363f67a7350b..b2d2d06179a6e 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java @@ -35,4 +35,8 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { public void profile(QueryProfiler queryProfiler) { queryProfiler.setVectorOpsCount(vectorOpsCount); } + + public Integer kParam() { + return kParam; + } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java index be0437af9131d..b492d1ea46f75 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java @@ -35,4 +35,8 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { public void profile(QueryProfiler queryProfiler) { queryProfiler.setVectorOpsCount(vectorOpsCount); } + + public Integer kParam() { + return kParam; + } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnRescoreVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java similarity index 87% rename from server/src/main/java/org/elasticsearch/search/vectors/KnnRescoreVectorQuery.java rename to server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index 092217d1be7ab..a3fe03732d5af 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnRescoreVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -28,44 +28,44 @@ /** * Wraps a kNN vector query to rescore the results using the non-quantized vectors */ -public class KnnRescoreVectorQuery extends Query implements ProfilingQuery { +public class RescoreKnnVectorQuery extends Query implements ProfilingQuery { private final String fieldName; private final byte[] byteTarget; private final float[] floatTarget; private final VectorSimilarityFunction vectorSimilarityFunction; private final Integer k; - private final Query vectorQuery; + private final Query innerQuery; private long vectorOpsCount; - public KnnRescoreVectorQuery( + public RescoreKnnVectorQuery( String fieldName, byte[] byteTarget, VectorSimilarityFunction vectorSimilarityFunction, Integer k, - Query vectorQuery + Query innerQuery ) { this.fieldName = fieldName; this.byteTarget = byteTarget; this.floatTarget = null; this.vectorSimilarityFunction = vectorSimilarityFunction; this.k = k; - this.vectorQuery = vectorQuery; + this.innerQuery = innerQuery; } - public KnnRescoreVectorQuery( + public RescoreKnnVectorQuery( String fieldName, float[] floatTarget, VectorSimilarityFunction vectorSimilarityFunction, Integer k, - Query vectorQuery + Query innerQuery ) { this.fieldName = fieldName; this.byteTarget = null; this.floatTarget = floatTarget; this.vectorSimilarityFunction = vectorSimilarityFunction; this.k = k; - this.vectorQuery = vectorQuery; + this.innerQuery = innerQuery; } @Override @@ -81,7 +81,7 @@ public Query rewrite(IndexSearcher searcher) throws IOException { } else { valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); } - FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(vectorQuery, valueSource); + FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource); Query query = searcher.rewrite(functionScoreQuery); if (k == null) { @@ -103,6 +103,14 @@ public Query rewrite(IndexSearcher searcher) throws IOException { return new KnnScoreDocQuery(docIds, scores, searcher.getIndexReader()); } + public Query innerQuery() { + return innerQuery; + } + + public Integer k() { + return k; + } + @Override public void profile(QueryProfiler queryProfiler) { queryProfiler.setVectorOpsCount(vectorOpsCount); @@ -119,18 +127,18 @@ public void visit(QueryVisitor visitor) { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - KnnRescoreVectorQuery that = (KnnRescoreVectorQuery) o; + RescoreKnnVectorQuery that = (RescoreKnnVectorQuery) o; return Objects.equals(fieldName, that.fieldName) && Objects.deepEquals(byteTarget, that.byteTarget) && Objects.deepEquals(floatTarget, that.floatTarget) && vectorSimilarityFunction == that.vectorSimilarityFunction && Objects.equals(k, that.k) - && Objects.equals(vectorQuery, that.vectorQuery); + && Objects.equals(innerQuery, that.innerQuery); } @Override public int hashCode() { - return Objects.hash(fieldName, Arrays.hashCode(byteTarget), Arrays.hashCode(floatTarget), vectorSimilarityFunction, k, vectorQuery); + return Objects.hash(fieldName, Arrays.hashCode(byteTarget), Arrays.hashCode(floatTarget), vectorSimilarityFunction, k, innerQuery); } @Override @@ -144,7 +152,7 @@ public String toString(String field) { } sb.append(", vectorSimilarityFunction=").append(vectorSimilarityFunction); sb.append(", k=").append(k); - sb.append(", vectorQuery=").append(vectorQuery); + sb.append(", vectorQuery=").append(innerQuery); sb.append('}'); return sb.toString(); } From 39e16762e205bef7fb25190b3dff7269001fef9d Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 21 Nov 2024 15:03:11 +0100 Subject: [PATCH 13/56] Add testing --- .../vectors/DenseVectorFieldTypeTests.java | 130 ++++++++++++++++-- ...AbstractKnnVectorQueryBuilderTestCase.java | 6 +- 2 files changed, 121 insertions(+), 15 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index e6ccd05c8f6e5..10b15e0c97650 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -20,9 +20,13 @@ import org.elasticsearch.index.mapper.FieldTypeTestCase; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorSimilarity; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.vectors.DenseVectorQuery; +import org.elasticsearch.search.vectors.ESKnnByteVectorQuery; +import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery; +import org.elasticsearch.search.vectors.RescoreKnnVectorQuery; import org.elasticsearch.search.vectors.VectorData; import java.io.IOException; @@ -31,8 +35,12 @@ import java.util.Set; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.FLOAT; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; public class DenseVectorFieldTypeTests extends FieldTypeTestCase { private final boolean indexed; @@ -69,11 +77,27 @@ private DenseVectorFieldMapper.IndexOptions randomIndexOptionsAll() { ); } + private DenseVectorFieldMapper.IndexOptions randomIndexOptionsHnswQuantized() { + return randomFrom( + new DenseVectorFieldMapper.Int8HnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) + ), + new DenseVectorFieldMapper.Int4HnswIndexOptions( + randomIntBetween(1, 100), + randomIntBetween(1, 10_000), + randomFrom((Float) null, 0f, (float) randomDoubleBetween(0.9, 1.0, true)) + ), + new DenseVectorFieldMapper.BBQHnswIndexOptions(randomIntBetween(1, 100), randomIntBetween(1, 10_000)) + ); + } + private DenseVectorFieldType createFloatFieldType() { return new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, BBQ_MIN_DIMS, indexed, VectorSimilarity.COSINE, @@ -86,7 +110,7 @@ private DenseVectorFieldType createByteFieldType() { return new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.BYTE, + BYTE, 5, true, VectorSimilarity.COSINE, @@ -159,7 +183,7 @@ public void testCreateNestedKnnQuery() { DenseVectorFieldType field = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, dims, true, VectorSimilarity.COSINE, @@ -177,7 +201,7 @@ public void testCreateNestedKnnQuery() { DenseVectorFieldType field = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.BYTE, + BYTE, dims, true, VectorSimilarity.COSINE, @@ -209,7 +233,7 @@ public void testExactKnnQuery() { DenseVectorFieldType field = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, dims, true, VectorSimilarity.COSINE, @@ -227,7 +251,7 @@ public void testExactKnnQuery() { DenseVectorFieldType field = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.BYTE, + BYTE, dims, true, VectorSimilarity.COSINE, @@ -247,7 +271,7 @@ public void testFloatCreateKnnQuery() { DenseVectorFieldType unindexedField = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, 4, false, VectorSimilarity.COSINE, @@ -271,7 +295,7 @@ public void testFloatCreateKnnQuery() { DenseVectorFieldType dotProductField = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, BBQ_MIN_DIMS, true, VectorSimilarity.DOT_PRODUCT, @@ -291,7 +315,7 @@ public void testFloatCreateKnnQuery() { DenseVectorFieldType cosineField = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, BBQ_MIN_DIMS, true, VectorSimilarity.COSINE, @@ -310,7 +334,7 @@ public void testCreateKnnQueryMaxDims() { DenseVectorFieldType fieldWith4096dims = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.FLOAT, + FLOAT, 4096, true, VectorSimilarity.COSINE, @@ -329,7 +353,7 @@ public void testCreateKnnQueryMaxDims() { DenseVectorFieldType fieldWith4096dims = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.BYTE, + BYTE, 4096, true, VectorSimilarity.COSINE, @@ -350,7 +374,7 @@ public void testByteCreateKnnQuery() { DenseVectorFieldType unindexedField = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.BYTE, + BYTE, 3, false, VectorSimilarity.COSINE, @@ -366,7 +390,7 @@ public void testByteCreateKnnQuery() { DenseVectorFieldType cosineField = new DenseVectorFieldType( "f", IndexVersion.current(), - DenseVectorFieldMapper.ElementType.BYTE, + BYTE, 3, true, VectorSimilarity.COSINE, @@ -385,4 +409,84 @@ public void testByteCreateKnnQuery() { ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); } + + public void testRescoreOversampleUsedWithoutQuantization() { + ElementType elementType = randomFrom(BYTE, FLOAT); + DenseVectorFieldType nonQuantizedField = new DenseVectorFieldType( + "f", + IndexVersion.current(), + elementType, + 3, + true, + VectorSimilarity.COSINE, + randomIndexOptionsNonQuantized(), + Collections.emptyMap() + ); + + Query knnQuery = nonQuantizedField.createKnnQuery( + new VectorData(null, new byte[]{1, 4, 10}), + 10, + 100, + randomFloatBetween(1.0F, 10.0F, false), + null, + null, + null + ); + + if (elementType == BYTE) { + ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery) knnQuery; + assertThat(esKnnQuery.getK(), is(100)); + assertThat(esKnnQuery.kParam(), is(10)); + } else { + ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) knnQuery; + assertThat(esKnnQuery.getK(), is(100)); + assertThat(esKnnQuery.kParam(), is(10)); + } + } + + public void testRescoreOversampleModifiesKnnParams() { + DenseVectorFieldType fieldType = new DenseVectorFieldType( + "f", + IndexVersion.current(), + randomFrom(BYTE, FLOAT), + 3, + true, + VectorSimilarity.COSINE, + randomIndexOptionsHnswQuantized(), + Collections.emptyMap() + ); + + // Total results is k, internal k is multiplied by oversample + checkRescoreQueryParameters(fieldType, 10, 200, 2.5F, 10, 25, 200); + // If numCands < k, update numCands to k + checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, 10, 25, 25); + // Oversampling limit + checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, 1000, 10000, 10000); + checkRescoreQueryParameters(fieldType, 5000, 7500, 2.5F, 5000, 10000, 10000); + } + + private static void checkRescoreQueryParameters( + DenseVectorFieldType fieldType, + int k, + int candidates, + float oversample, + int expectedResults, + int expectedK, + int expectedCandidates + ) { + Query query = fieldType.createKnnQuery(new VectorData(null, new byte[] { 1, 4, 10}), k, candidates, oversample, null, null, null); + + RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query; + if (fieldType.getElementType() == BYTE) { + ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery) rescoreQuery.innerQuery(); + assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults)); + assertThat("Unexpected k parameter", esKnnQuery.kParam(), equalTo(expectedK)); + assertThat("Unexpected candidates", esKnnQuery.getK(), equalTo(expectedCandidates)); + } else { + ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) rescoreQuery.innerQuery(); + assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults)); + assertThat("Unexpected k parameter", esKnnQuery.kParam(), equalTo(expectedK)); + assertThat("Unexpected candidates", esKnnQuery.getK(), equalTo(expectedCandidates)); + } + } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index d603ad7e39b1f..ee447ff2872e7 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -23,6 +23,8 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.VectorSimilarityByteValueSource; +import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; import org.elasticsearch.index.query.InnerHitsRewriteContext; import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; @@ -127,8 +129,8 @@ protected RescoreVectorBuilder randomRescoreVectorBuilder() { @Override protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query query, SearchExecutionContext context) throws IOException { if (queryBuilder.rescoreVectorBuilder() != null) { - assertTrue(query instanceof org.apache.lucene.queries.function.FunctionScoreQuery); - query = ((org.apache.lucene.queries.function.FunctionScoreQuery) query).getWrappedQuery(); + RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query; + query = rescoreQuery.innerQuery(); } if (queryBuilder.getVectorSimilarity() != null) { assertTrue(query instanceof VectorSimilarityQuery); From 9946e8d3b645925348699967ae1b7fa74a1f3d9a Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 27 Nov 2024 10:15:17 +0100 Subject: [PATCH 14/56] Add tests for RescoreKnnVectorQuery --- .../vectors/RescoreKnnVectorQueryTests.java | 126 ++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java new file mode 100644 index 0000000000000..984d6f194d666 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -0,0 +1,126 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.vectors; + + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.elasticsearch.test.ESTestCase; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.stream.Collectors; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.equalTo; + +public class RescoreKnnVectorQueryTests extends ESTestCase { + + public static final String FIELD_NAME = "float_vector"; + + public void testRescoresTopK() throws Exception { + int numDocs = randomIntBetween(10, 100); + testRescoreDocs(numDocs, randomIntBetween(5, numDocs - 1)); + } + + public void testRescoresNoKParameter() throws Exception { + testRescoreDocs(randomIntBetween(10, 100), null); + } + + private void testRescoreDocs(int numDocs, Integer k) throws Exception { + int numDims = randomIntBetween(5, 100); + + if (k == null) { + k = numDocs; + } + + try (Directory d = newDirectory()) { + try (IndexWriter w = new IndexWriter(d, newIndexWriterConfig())) { + for (int i = 0; i < numDocs; i++) { + Document document = new Document(); + float[] vector = randomVector(numDims); + KnnFloatVectorField vectorField = new KnnFloatVectorField( + FIELD_NAME, vector); + document.add(vectorField); + w.addDocument(document); + } + w.commit(); + w.forceMerge(1); + } + + try (IndexReader reader = DirectoryReader.open(d)) { + float[] queryVector = randomVector(numDims); + + RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( + FIELD_NAME, queryVector, VectorSimilarityFunction.COSINE, k, new MatchAllDocsQuery()); + + IndexSearcher searcher = newSearcher(reader, true, false); + TopDocs docs = searcher.search(rescoreKnnVectorQuery, numDocs); + Map rescoredDocs = Arrays.stream(docs.scoreDocs).collect(Collectors.toMap( + scoreDoc -> scoreDoc.doc, + scoreDoc -> scoreDoc.score) + ); + + assertThat(rescoredDocs.size(), equalTo(k)); + + Collection rescoredScores = new ArrayList<>(rescoredDocs.values()); + PriorityQueue topK = new PriorityQueue<>((o1, o2) -> Float.compare(o2, o1)); + + for (LeafReaderContext leafReaderContext : reader.leaves()) { + FloatVectorValues floatVectorValues = leafReaderContext.reader().getFloatVectorValues(FIELD_NAME); + KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); + while (iterator.nextDoc() != NO_MORE_DOCS) { + float[] vector = floatVectorValues.vectorValue(iterator.index()); + float score = VectorSimilarityFunction.COSINE.compare(queryVector, vector); + topK.add(score); + int docId = iterator.docID(); + if (rescoredDocs.containsKey(docId)) { + assertThat(rescoredDocs.get(docId), equalTo(score)); + rescoredDocs.remove(docId); + } + } + } + + assertThat(rescoredDocs.size(), equalTo(0)); + + // Check top scoring docs are contained in rescored docs + for (int i = 0; i < k; i++) { + Float topScore = topK.poll(); + if (rescoredScores.contains(topScore) == false ) { + fail("Top score " + topScore + " not contained in rescored doc scores " + rescoredScores); + } + } + } + } + } + + private static float[] randomVector(int numDims) { + float[] vector = new float[numDims]; + for (int j = 0; j < numDims; j++) { + vector[j] = randomFloatBetween(0, 1, true); + } + return vector; + } + +} From 4fbbadd0f2d3d41098416b44e62c3808902798bc Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 27 Nov 2024 10:15:50 +0100 Subject: [PATCH 15/56] Spotless --- .../search/vectors/RescoreKnnVectorQuery.java | 2 ++ .../vectors/DenseVectorFieldTypeTests.java | 4 ++-- ...AbstractKnnVectorQueryBuilderTestCase.java | 2 -- .../vectors/RescoreKnnVectorQueryTests.java | 19 ++++++++++--------- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index a3fe03732d5af..15cbbbffc9f27 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -70,6 +70,8 @@ public RescoreKnnVectorQuery( @Override public Query rewrite(IndexSearcher searcher) throws IOException { + assert byteTarget == null ^ floatTarget == null : "Either byteTarget or floatTarget must be set"; + Query rewritten = super.rewrite(searcher); if (rewritten != this) { return rewritten; diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 10b15e0c97650..714ed282d06a3 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -424,7 +424,7 @@ public void testRescoreOversampleUsedWithoutQuantization() { ); Query knnQuery = nonQuantizedField.createKnnQuery( - new VectorData(null, new byte[]{1, 4, 10}), + new VectorData(null, new byte[] { 1, 4, 10 }), 10, 100, randomFloatBetween(1.0F, 10.0F, false), @@ -474,7 +474,7 @@ private static void checkRescoreQueryParameters( int expectedK, int expectedCandidates ) { - Query query = fieldType.createKnnQuery(new VectorData(null, new byte[] { 1, 4, 10}), k, candidates, oversample, null, null, null); + Query query = fieldType.createKnnQuery(new VectorData(null, new byte[] { 1, 4, 10 }), k, candidates, oversample, null, null, null); RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query; if (fieldType.getElementType() == BYTE) { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index ee447ff2872e7..c75058a977364 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -23,8 +23,6 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.index.mapper.vectors.VectorSimilarityByteValueSource; -import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; import org.elasticsearch.index.query.InnerHitsRewriteContext; import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 984d6f194d666..640b7785f737d 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -9,7 +9,6 @@ package org.elasticsearch.search.vectors; - import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.DirectoryReader; @@ -60,8 +59,7 @@ private void testRescoreDocs(int numDocs, Integer k) throws Exception { for (int i = 0; i < numDocs; i++) { Document document = new Document(); float[] vector = randomVector(numDims); - KnnFloatVectorField vectorField = new KnnFloatVectorField( - FIELD_NAME, vector); + KnnFloatVectorField vectorField = new KnnFloatVectorField(FIELD_NAME, vector); document.add(vectorField); w.addDocument(document); } @@ -73,14 +71,17 @@ private void testRescoreDocs(int numDocs, Integer k) throws Exception { float[] queryVector = randomVector(numDims); RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( - FIELD_NAME, queryVector, VectorSimilarityFunction.COSINE, k, new MatchAllDocsQuery()); + FIELD_NAME, + queryVector, + VectorSimilarityFunction.COSINE, + k, + new MatchAllDocsQuery() + ); IndexSearcher searcher = newSearcher(reader, true, false); TopDocs docs = searcher.search(rescoreKnnVectorQuery, numDocs); - Map rescoredDocs = Arrays.stream(docs.scoreDocs).collect(Collectors.toMap( - scoreDoc -> scoreDoc.doc, - scoreDoc -> scoreDoc.score) - ); + Map rescoredDocs = Arrays.stream(docs.scoreDocs) + .collect(Collectors.toMap(scoreDoc -> scoreDoc.doc, scoreDoc -> scoreDoc.score)); assertThat(rescoredDocs.size(), equalTo(k)); @@ -107,7 +108,7 @@ private void testRescoreDocs(int numDocs, Integer k) throws Exception { // Check top scoring docs are contained in rescored docs for (int i = 0; i < k; i++) { Float topScore = topK.poll(); - if (rescoredScores.contains(topScore) == false ) { + if (rescoredScores.contains(topScore) == false) { fail("Top score " + topScore + " not contained in rescored doc scores " + rescoredScores); } } From 0dab8ea7473d85fb0f7a166a712b5bf9a09b6722 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 29 Nov 2024 14:51:14 +0100 Subject: [PATCH 16/56] Add test for knn retriever --- .../search/retriever/KnnRetrieverBuilder.java | 10 ++++++++-- .../retriever/KnnRetrieverBuilderParsingTests.java | 1 + 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java index 97c87d755ca25..da6254201072b 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java @@ -257,7 +257,11 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder searchSourceBuilder.knnSearch(knnSearchBuilders); } - // ---- FOR TESTING XCONTENT PARSING ---- + RescoreVectorBuilder rescoreVectorBuilder() { + return rescoreVectorBuilder; + } + +// ---- FOR TESTING XCONTENT PARSING ---- @Override public void doToXContent(XContentBuilder builder, Params params) throws IOException { @@ -278,7 +282,9 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept } if (rescoreVectorBuilder != null) { - builder.field(RESCORE_FIELD.getPreferredName(), rescoreVectorBuilder); + builder.startObject(RESCORE_FIELD.getPreferredName()); + rescoreVectorBuilder.toXContent(builder, params); + builder.endObject(); } } diff --git a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java index 0213a83385739..da28b0eff441f 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java @@ -105,6 +105,7 @@ public void testRewrite() throws IOException { assertNull(source.query()); assertThat(source.knnSearch().size(), equalTo(1)); assertThat(source.knnSearch().get(0).getFilterQueries().size(), equalTo(knnRetriever.preFilterQueryBuilders.size())); + assertThat(source.knnSearch().get(0).getRescoreVectorBuilder(), equalTo(knnRetriever.rescoreVectorBuilder())); for (int j = 0; j < knnRetriever.preFilterQueryBuilders.size(); j++) { assertThat( source.knnSearch().get(0).getFilterQueries().get(j), From 257b75d1d7a9fe1e989a011fd26babe9b8c667fa Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 29 Nov 2024 15:41:22 +0100 Subject: [PATCH 17/56] Add tests --- .../search/vectors/RescoreVectorBuilder.java | 2 +- .../search/vectors/KnnSearchBuilderTests.java | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java index 26c151b5d2f72..2471b29dc1193 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java @@ -24,7 +24,7 @@ public class RescoreVectorBuilder implements Writeable, ToXContentObject { public static final ParseField OVERSAMPLE_FIELD = new ParseField("oversample"); - public static final int MIN_OVERSAMPLE = 1; + public static final float MIN_OVERSAMPLE = 1.0F; private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "rescore", args -> new RescoreVectorBuilder((Float) args[0]) diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java index ec2ddebcfc2ca..1a54d7c8420da 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java @@ -254,6 +254,14 @@ public void testInvalidK() { assertThat(e.getMessage(), containsString("[k] must be greater than 0")); } + public void testInvalidRescoreVectorBuilder() { + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> new KnnSearchBuilder("field", randomVector(3), 0, 100, new RescoreVectorBuilder(1.0F), null) + ); + assertThat(e.getMessage(), containsString("[oversample] must be > 1.0")); + } + public void testRewrite() throws Exception { float[] expectedArray = randomVector(randomIntBetween(10, 1024)); RescoreVectorBuilder expectedRescore = new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)); From 81384f2175694083b9a472f02a78b9dfe7476b18 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 29 Nov 2024 17:15:01 +0100 Subject: [PATCH 18/56] Parameterize recore knn vector query tests --- .../vectors/RescoreKnnVectorQueryTests.java | 221 ++++++++++++++---- 1 file changed, 181 insertions(+), 40 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 640b7785f737d..36f1ab5cc6d57 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -9,13 +9,18 @@ package org.elasticsearch.search.vectors; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; @@ -24,9 +29,12 @@ import org.apache.lucene.store.Directory; import org.elasticsearch.test.ESTestCase; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.PriorityQueue; import java.util.stream.Collectors; @@ -37,65 +45,56 @@ public class RescoreKnnVectorQueryTests extends ESTestCase { public static final String FIELD_NAME = "float_vector"; + private final int numDocs; + private final VectorProvider vectorProvider; + private final Integer k; - public void testRescoresTopK() throws Exception { - int numDocs = randomIntBetween(10, 100); - testRescoreDocs(numDocs, randomIntBetween(5, numDocs - 1)); - } - - public void testRescoresNoKParameter() throws Exception { - testRescoreDocs(randomIntBetween(10, 100), null); + public RescoreKnnVectorQueryTests(VectorProvider vectorProvider, boolean useK) { + this.vectorProvider = vectorProvider; + this.numDocs = randomIntBetween(10, 100);; + this.k = useK ? randomIntBetween(1, numDocs - 1) : null; } - private void testRescoreDocs(int numDocs, Integer k) throws Exception { + public void testRescoreDocs() throws Exception { int numDims = randomIntBetween(5, 100); + Integer adjustedK = k; if (k == null) { - k = numDocs; + adjustedK = numDocs; } try (Directory d = newDirectory()) { - try (IndexWriter w = new IndexWriter(d, newIndexWriterConfig())) { - for (int i = 0; i < numDocs; i++) { - Document document = new Document(); - float[] vector = randomVector(numDims); - KnnFloatVectorField vectorField = new KnnFloatVectorField(FIELD_NAME, vector); - document.add(vectorField); - w.addDocument(document); - } - w.commit(); - w.forceMerge(1); - } + addRandomDocuments(numDocs, d, numDims, vectorProvider); try (IndexReader reader = DirectoryReader.open(d)) { - float[] queryVector = randomVector(numDims); - RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( - FIELD_NAME, - queryVector, - VectorSimilarityFunction.COSINE, - k, - new MatchAllDocsQuery() - ); + // Use a RescoreKnnVectorQuery with a match all query, to ensure we get scoring of 1 from the inner query + // and thus we're rescoring the top k docs. + VectorData queryVector = vectorProvider.randomVector(numDims); + RescoreKnnVectorQuery rescoreKnnVectorQuery = vectorProvider.createRescoreQuery(queryVector, adjustedK); IndexSearcher searcher = newSearcher(reader, true, false); TopDocs docs = searcher.search(rescoreKnnVectorQuery, numDocs); Map rescoredDocs = Arrays.stream(docs.scoreDocs) .collect(Collectors.toMap(scoreDoc -> scoreDoc.doc, scoreDoc -> scoreDoc.score)); - assertThat(rescoredDocs.size(), equalTo(k)); + assertThat(rescoredDocs.size(), equalTo(adjustedK)); + + Collection rescoredScores = new HashSet<>(rescoredDocs.values()); - Collection rescoredScores = new ArrayList<>(rescoredDocs.values()); + // Collect all docs sequentially, and score them using the similarity function to get the top K scores PriorityQueue topK = new PriorityQueue<>((o1, o2) -> Float.compare(o2, o1)); for (LeafReaderContext leafReaderContext : reader.leaves()) { - FloatVectorValues floatVectorValues = leafReaderContext.reader().getFloatVectorValues(FIELD_NAME); - KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); + KnnVectorValues vectorValues = vectorProvider.vectorValues(leafReaderContext.reader()); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); while (iterator.nextDoc() != NO_MORE_DOCS) { - float[] vector = floatVectorValues.vectorValue(iterator.index()); - float score = VectorSimilarityFunction.COSINE.compare(queryVector, vector); + VectorData vectorData = vectorProvider.dataVectorForDoc(vectorValues, iterator.docID()); + float score = vectorProvider.score(queryVector, vectorData); topK.add(score); int docId = iterator.docID(); + // If the doc has been retrieved from the RescoreKnnVectorQuery, check the score is the same and remove it + // to ensure we found them all if (rescoredDocs.containsKey(docId)) { assertThat(rescoredDocs.get(docId), equalTo(score)); rescoredDocs.remove(docId); @@ -106,7 +105,7 @@ private void testRescoreDocs(int numDocs, Integer k) throws Exception { assertThat(rescoredDocs.size(), equalTo(0)); // Check top scoring docs are contained in rescored docs - for (int i = 0; i < k; i++) { + for (int i = 0; i < adjustedK; i++) { Float topScore = topK.poll(); if (rescoredScores.contains(topScore) == false) { fail("Top score " + topScore + " not contained in rescored doc scores " + rescoredScores); @@ -116,12 +115,154 @@ private void testRescoreDocs(int numDocs, Integer k) throws Exception { } } - private static float[] randomVector(int numDims) { - float[] vector = new float[numDims]; - for (int j = 0; j < numDims; j++) { - vector[j] = randomFloatBetween(0, 1, true); + private interface VectorProvider { + VectorData randomVector(int numDimensions); + + RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k); + + KnnVectorValues vectorValues(LeafReader leafReader) throws IOException; + + void addVectorField(Document document, VectorData vector); + + VectorData dataVectorForDoc(KnnVectorValues vectorValues, int docId) throws IOException; + + float score(VectorData queryVector, VectorData dataVector); + } + + private static class FloatVectorProvider implements VectorProvider { + @Override + public VectorData randomVector(int numDimensions) { + float[] vector = new float[numDimensions]; + for (int j = 0; j < numDimensions; j++) { + vector[j] = randomFloatBetween(0, 1, true); + } + return VectorData.fromFloats(vector); + } + + @Override + public RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k) { + return new RescoreKnnVectorQuery( + FIELD_NAME, + queryVector.floatVector(), + VectorSimilarityFunction.COSINE, + k, + new MatchAllDocsQuery() + ); + } + + @Override + public KnnVectorValues vectorValues(LeafReader leafReader) throws IOException { + return leafReader.getFloatVectorValues(FIELD_NAME); + } + + @Override + public void addVectorField(Document document, VectorData vector) { + KnnFloatVectorField vectorField = new KnnFloatVectorField(FIELD_NAME, vector.floatVector()); + document.add(vectorField); + } + + @Override + public VectorData dataVectorForDoc(KnnVectorValues vectorValues, int docId) throws IOException { + return VectorData.fromFloats(((FloatVectorValues)vectorValues).vectorValue(docId)); + } + + @Override + public float score(VectorData queryVector, VectorData dataVector) { + return VectorSimilarityFunction.COSINE.compare(queryVector.floatVector(), dataVector.floatVector()); } - return vector; } + private static class ByteVectorProvider implements VectorProvider { + @Override + public VectorData randomVector(int numDimensions) { + byte[] vector = new byte[numDimensions]; + for (int j = 0; j < numDimensions; j++) { + vector[j] = randomByte(); + } + return VectorData.fromBytes(vector); + } + + @Override + public RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k) { + return new RescoreKnnVectorQuery( + FIELD_NAME, + queryVector.byteVector(), + VectorSimilarityFunction.COSINE, + k, + new MatchAllDocsQuery() + ); + } + + @Override + public KnnVectorValues vectorValues(LeafReader leafReader) throws IOException { + return leafReader.getByteVectorValues(FIELD_NAME); + } + + @Override + public void addVectorField(Document document, VectorData vector) { + KnnByteVectorField vectorField = new KnnByteVectorField(FIELD_NAME, vector.byteVector()); + document.add(vectorField); + } + + @Override + public VectorData dataVectorForDoc(KnnVectorValues vectorValues, int docId) throws IOException { + return VectorData.fromBytes(((ByteVectorValues)vectorValues).vectorValue(docId)); + } + + @Override + public float score(VectorData queryVector, VectorData dataVector) { + return VectorSimilarityFunction.COSINE.compare(queryVector.byteVector(), dataVector.byteVector()); + } + } + + private static void addRandomDocuments(int numDocs, Directory d, int numDims, VectorProvider vectorProvider) throws IOException { + try (IndexWriter w = new IndexWriter(d, newIndexWriterConfig())) { + for (int i = 0; i < numDocs; i++) { + Document document = new Document(); + VectorData vector = vectorProvider.randomVector(numDims); + vectorProvider.addVectorField(document, vector); + w.addDocument(document); + } + w.commit(); + w.forceMerge(1); + } + } + + @ParametersFactory + public static Iterable parameters() { + + List params = new ArrayList<>(); + params.add(new Object[] {new FloatVectorProvider(), true}); + params.add(new Object[] {new FloatVectorProvider(), false}); + params.add(new Object[] {new ByteVectorProvider(), true}); + params.add(new Object[] {new ByteVectorProvider(), false}); + + return params; + } + +// public void testProfiling() throws Exception { +// int numDocs = randomIntBetween(10, 100); +// int numDims = randomIntBetween(5, 100); +// +// try (Directory d = newDirectory()) { +// addRandomDocuments(numDocs, d, numDims, vectorProvider); +// +// try (IndexReader reader = DirectoryReader.open(d)) { +// float[] queryVector = randomVector(numDims); +// +// RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( +// FIELD_NAME, +// queryVector, +// VectorSimilarityFunction.COSINE, +// randomIntBetween(5, numDocs - 1), +// new MatchAllDocsQuery() +// ); +// +// IndexSearcher searcher = newSearcher(reader, true, false); +// QueryProfiler queryProfiler = new QueryProfiler(); +// rescoreKnnVectorQuery.profile(queryProfiler); +// } +// } +// } + } From 1347d4b8d8600dd7caff7716957e0fea9ccb7ccb Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 29 Nov 2024 17:54:33 +0100 Subject: [PATCH 19/56] Adds profiling, including a small refactoring of the QueryProfiler interface --- .../search/profile/query/QueryProfiler.java | 4 +- ...iversifyingChildrenByteKnnVectorQuery.java | 2 +- ...versifyingChildrenFloatKnnVectorQuery.java | 2 +- .../search/vectors/ESKnnByteVectorQuery.java | 2 +- .../search/vectors/ESKnnFloatVectorQuery.java | 2 +- .../search/vectors/RescoreKnnVectorQuery.java | 10 +- .../search/vectors/VectorSimilarityQuery.java | 10 +- .../vectors/RescoreKnnVectorQueryTests.java | 156 ++++++++++++------ 8 files changed, 123 insertions(+), 65 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/profile/query/QueryProfiler.java b/server/src/main/java/org/elasticsearch/search/profile/query/QueryProfiler.java index 98ddfa95bf156..59196fdfba714 100644 --- a/server/src/main/java/org/elasticsearch/search/profile/query/QueryProfiler.java +++ b/server/src/main/java/org/elasticsearch/search/profile/query/QueryProfiler.java @@ -39,8 +39,8 @@ public QueryProfiler() { super(new InternalQueryProfileTree()); } - public void setVectorOpsCount(long vectorOpsCount) { - this.vectorOpsCount = vectorOpsCount; + public void addVectorOpsCount(long vectorOpsCount) { + this.vectorOpsCount += vectorOpsCount; } public long getVectorOpsCount() { diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java index 9f3d83b4da082..40803586a31ce 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java @@ -40,6 +40,6 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { @Override public void profile(QueryProfiler queryProfiler) { - queryProfiler.setVectorOpsCount(vectorOpsCount); + queryProfiler.addVectorOpsCount(vectorOpsCount); } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java index 3907bdf89bc6f..f69aca5339daf 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java @@ -40,6 +40,6 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { @Override public void profile(QueryProfiler queryProfiler) { - queryProfiler.setVectorOpsCount(vectorOpsCount); + queryProfiler.addVectorOpsCount(vectorOpsCount); } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java index b2d2d06179a6e..1de5471cb96ac 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java @@ -33,7 +33,7 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { @Override public void profile(QueryProfiler queryProfiler) { - queryProfiler.setVectorOpsCount(vectorOpsCount); + queryProfiler.addVectorOpsCount(vectorOpsCount); } public Integer kParam() { diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java index b492d1ea46f75..0d37a46625a9e 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java @@ -33,7 +33,7 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { @Override public void profile(QueryProfiler queryProfiler) { - queryProfiler.setVectorOpsCount(vectorOpsCount); + queryProfiler.addVectorOpsCount(vectorOpsCount); } public Integer kParam() { diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index 15cbbbffc9f27..0680c2e4ac438 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -72,11 +72,6 @@ public RescoreKnnVectorQuery( public Query rewrite(IndexSearcher searcher) throws IOException { assert byteTarget == null ^ floatTarget == null : "Either byteTarget or floatTarget must be set"; - Query rewritten = super.rewrite(searcher); - if (rewritten != this) { - return rewritten; - } - final DoubleValuesSource valueSource; if (byteTarget != null) { valueSource = new VectorSimilarityByteValueSource(fieldName, byteTarget, vectorSimilarityFunction); @@ -115,7 +110,10 @@ public Integer k() { @Override public void profile(QueryProfiler queryProfiler) { - queryProfiler.setVectorOpsCount(vectorOpsCount); + if (innerQuery instanceof ProfilingQuery profilingQuery) { + profilingQuery.profile(queryProfiler); + } + queryProfiler.addVectorOpsCount(vectorOpsCount); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java index 5219778047bcd..74daafe747ca4 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java @@ -21,6 +21,7 @@ import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; import org.elasticsearch.common.lucene.search.function.MinScoreScorer; +import org.elasticsearch.search.profile.query.QueryProfiler; import java.io.IOException; import java.util.Objects; @@ -30,7 +31,7 @@ /** * This query provides a simple post-filter for the provided Query. The query is assumed to be a Knn(Float|Byte)VectorQuery. */ -public class VectorSimilarityQuery extends Query { +public class VectorSimilarityQuery extends Query implements ProfilingQuery { private final float similarity; private final float docScore; private final Query innerKnnQuery; @@ -78,6 +79,13 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo return new MinScoreWeight(innerWeight, docScore, similarity, this, boost); } + @Override + public void profile(QueryProfiler queryProfiler) { + if (innerKnnQuery instanceof ProfilingQuery profilingQuery) { + profilingQuery.profile(queryProfiler); + } + } + @Override public String toString(String field) { return "VectorSimilarityQuery[" diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 36f1ab5cc6d57..8c68fca1b3640 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -25,8 +25,13 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; +import org.elasticsearch.search.profile.query.QueryProfiler; import org.elasticsearch.test.ESTestCase; import java.io.IOException; @@ -41,6 +46,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; public class RescoreKnnVectorQueryTests extends ESTestCase { @@ -51,7 +57,8 @@ public class RescoreKnnVectorQueryTests extends ESTestCase { public RescoreKnnVectorQueryTests(VectorProvider vectorProvider, boolean useK) { this.vectorProvider = vectorProvider; - this.numDocs = randomIntBetween(10, 100);; + this.numDocs = randomIntBetween(10, 100); + ; this.k = useK ? randomIntBetween(1, numDocs - 1) : null; } @@ -71,7 +78,11 @@ public void testRescoreDocs() throws Exception { // Use a RescoreKnnVectorQuery with a match all query, to ensure we get scoring of 1 from the inner query // and thus we're rescoring the top k docs. VectorData queryVector = vectorProvider.randomVector(numDims); - RescoreKnnVectorQuery rescoreKnnVectorQuery = vectorProvider.createRescoreQuery(queryVector, adjustedK); + RescoreKnnVectorQuery rescoreKnnVectorQuery = vectorProvider.createRescoreQuery( + queryVector, + adjustedK, + new MatchAllDocsQuery() + ); IndexSearcher searcher = newSearcher(reader, true, false); TopDocs docs = searcher.search(rescoreKnnVectorQuery, numDocs); @@ -115,10 +126,90 @@ public void testRescoreDocs() throws Exception { } } + public void testProfiling() throws Exception { + int numDims = randomIntBetween(5, 100); + + try (Directory d = newDirectory()) { + addRandomDocuments(numDocs, d, numDims, vectorProvider); + + try (IndexReader reader = DirectoryReader.open(d)) { + VectorData queryVector = vectorProvider.randomVector(numDims); + + checkProfiling(queryVector, reader, new MatchAllDocsQuery()); + checkProfiling(queryVector, reader, new MockProfilingQuery(randomIntBetween(1, 100))); + } + } + } + + private void checkProfiling(VectorData queryVector, IndexReader reader, Query innerQuery) throws IOException { + RescoreKnnVectorQuery rescoreKnnVectorQuery = vectorProvider.createRescoreQuery(queryVector, k, innerQuery); + IndexSearcher searcher = newSearcher(reader, true, false); + searcher.search(rescoreKnnVectorQuery, numDocs); + + QueryProfiler queryProfiler = new QueryProfiler(); + rescoreKnnVectorQuery.profile(queryProfiler); + + long expectedVectorOpsCount = 0; + if (k != null) { + expectedVectorOpsCount += k; + } + if (innerQuery instanceof ProfilingQuery profilingQuery) { + QueryProfiler anotherProfiler = new QueryProfiler(); + profilingQuery.profile(anotherProfiler); + assertThat(anotherProfiler.getVectorOpsCount(), greaterThan(0L)); + expectedVectorOpsCount += anotherProfiler.getVectorOpsCount(); + } + + assertThat(queryProfiler.getVectorOpsCount(), equalTo(expectedVectorOpsCount)); + } + + /** + * A mock query that is used to test profiling + */ + private static class MockProfilingQuery extends Query implements ProfilingQuery { + + private final long vectorOpsCount; + + private MockProfilingQuery(long vectorOpsCount) { + this.vectorOpsCount = vectorOpsCount; + } + + @Override + public String toString(String field) { + return ""; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + return new MatchAllDocsQuery().createWeight(searcher, scoreMode, boost); + } + + @Override + public void visit(QueryVisitor visitor) {} + + @Override + public boolean equals(Object obj) { + return obj instanceof MockProfilingQuery; + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public void profile(QueryProfiler queryProfiler) { + queryProfiler.addVectorOpsCount(vectorOpsCount); + } + } + + /** + * Vector operations depend on the type of vector field used. This interface abstracts the operations needed to perform the tests + */ private interface VectorProvider { VectorData randomVector(int numDimensions); - RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k); + RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k, Query innerQuery); KnnVectorValues vectorValues(LeafReader leafReader) throws IOException; @@ -140,14 +231,8 @@ public VectorData randomVector(int numDimensions) { } @Override - public RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k) { - return new RescoreKnnVectorQuery( - FIELD_NAME, - queryVector.floatVector(), - VectorSimilarityFunction.COSINE, - k, - new MatchAllDocsQuery() - ); + public RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k, Query innerQuery) { + return new RescoreKnnVectorQuery(FIELD_NAME, queryVector.floatVector(), VectorSimilarityFunction.COSINE, k, innerQuery); } @Override @@ -163,7 +248,7 @@ public void addVectorField(Document document, VectorData vector) { @Override public VectorData dataVectorForDoc(KnnVectorValues vectorValues, int docId) throws IOException { - return VectorData.fromFloats(((FloatVectorValues)vectorValues).vectorValue(docId)); + return VectorData.fromFloats(((FloatVectorValues) vectorValues).vectorValue(docId)); } @Override @@ -183,14 +268,8 @@ public VectorData randomVector(int numDimensions) { } @Override - public RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k) { - return new RescoreKnnVectorQuery( - FIELD_NAME, - queryVector.byteVector(), - VectorSimilarityFunction.COSINE, - k, - new MatchAllDocsQuery() - ); + public RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k, Query innerQuery) { + return new RescoreKnnVectorQuery(FIELD_NAME, queryVector.byteVector(), VectorSimilarityFunction.COSINE, k, innerQuery); } @Override @@ -206,7 +285,7 @@ public void addVectorField(Document document, VectorData vector) { @Override public VectorData dataVectorForDoc(KnnVectorValues vectorValues, int docId) throws IOException { - return VectorData.fromBytes(((ByteVectorValues)vectorValues).vectorValue(docId)); + return VectorData.fromBytes(((ByteVectorValues) vectorValues).vectorValue(docId)); } @Override @@ -230,39 +309,12 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims, Ve @ParametersFactory public static Iterable parameters() { - List params = new ArrayList<>(); - params.add(new Object[] {new FloatVectorProvider(), true}); - params.add(new Object[] {new FloatVectorProvider(), false}); - params.add(new Object[] {new ByteVectorProvider(), true}); - params.add(new Object[] {new ByteVectorProvider(), false}); + params.add(new Object[] { new FloatVectorProvider(), true }); + params.add(new Object[] { new FloatVectorProvider(), false }); + params.add(new Object[] { new ByteVectorProvider(), true }); + params.add(new Object[] { new ByteVectorProvider(), false }); return params; } - -// public void testProfiling() throws Exception { -// int numDocs = randomIntBetween(10, 100); -// int numDims = randomIntBetween(5, 100); -// -// try (Directory d = newDirectory()) { -// addRandomDocuments(numDocs, d, numDims, vectorProvider); -// -// try (IndexReader reader = DirectoryReader.open(d)) { -// float[] queryVector = randomVector(numDims); -// -// RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( -// FIELD_NAME, -// queryVector, -// VectorSimilarityFunction.COSINE, -// randomIntBetween(5, numDocs - 1), -// new MatchAllDocsQuery() -// ); -// -// IndexSearcher searcher = newSearcher(reader, true, false); -// QueryProfiler queryProfiler = new QueryProfiler(); -// rescoreKnnVectorQuery.profile(queryProfiler); -// } -// } -// } - } From 916ac830d9b7451ce86be3fed3780cc0859331ed Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 29 Nov 2024 18:08:36 +0100 Subject: [PATCH 20/56] Spotless --- .../elasticsearch/search/retriever/KnnRetrieverBuilder.java | 2 +- .../TextSimilarityRankRetrieverTelemetryTests.java | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java index da6254201072b..63422b60534bd 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java @@ -261,7 +261,7 @@ RescoreVectorBuilder rescoreVectorBuilder() { return rescoreVectorBuilder; } -// ---- FOR TESTING XCONTENT PARSING ---- + // ---- FOR TESTING XCONTENT PARSING ---- @Override public void doToXContent(XContentBuilder builder, Params params) throws IOException { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java index 9b48abab25889..084a7f3de4a53 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java @@ -103,9 +103,7 @@ public void testTelemetryForRRFRetriever() throws IOException { // search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers` { performSearch( - new SearchSourceBuilder().retriever( - new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null) - ) + new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null)) ); } From 229ce2d110e48faeb827273f230c98240cd194e1 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 2 Dec 2024 15:49:42 +0100 Subject: [PATCH 21/56] Add YAML tests --- .../search.vectors/41_knn_search_bbq_hnsw.yml | 80 +++++++++++++++++++ .../41_knn_search_byte_quantized.yml | 79 ++++++++++++++++++ .../41_knn_search_half_byte_quantized.yml | 78 ++++++++++++++++++ .../search.vectors/42_knn_search_bbq_flat.yml | 80 +++++++++++++++++++ .../42_knn_search_int4_flat.yml | 77 ++++++++++++++++++ .../42_knn_search_int8_flat.yml | 77 ++++++++++++++++++ 6 files changed, 471 insertions(+) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index 188c155e4a836..31f934d58b7ec 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -87,6 +87,86 @@ setup: # here we verify that are last hit is always the worst one - match: { hits.hits.2._id: "1" } +--- +"Vector rescoring has similar ordering as knn, same scoring as exact search for kNN section": + - skip: + features: "headers" + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_hnsw + body: + knn: + field: vector + query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._id: knn_id0 } + - set: { hits.hits.1._id: knn_id1 } + - set: { hits.hits.2._id: knn_id2 } + - set: { hits.hits.0._score: knn_score0 } + - set: { hits.hits.1._score: knn_score1 } + - set: { hits.hits.2._score: knn_score2 } + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_hnsw + body: + knn: + field: vector + query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + rescore: + oversample: 1.5 + + # Comparing to knn search, we already have changes in ordering and scoring + - match: { hits.hits.0._id: $knn_id1 } + - match: { hits.hits.1._id: $knn_id0 } + - match: { hits.hits.2._id: $knn_id2 } + + # Get rescoring scores + - match: { hits.total: 3 } + - set: { hits.hits.0._id: rescore_id0 } + - set: { hits.hits.1._id: rescore_id1 } + - set: { hits.hits.2._id: rescore_id2 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "1.0 / (1.0 + Math.pow(l2norm(params.query_vector, 'vector'), 2.0))" + params: + query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0 ] + + # Check same ordering (which will not be true for larger datasets) + # and scoring (which should be for the elements that are present in both) + - match: { hits.total: 3 } + - match: { hits.hits.0._id: $rescore_id0 } + - match: { hits.hits.1._id: $rescore_id1 } + - match: { hits.hits.2._id: $rescore_id2 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + --- "Test bad quantization parameters": - do: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml index b7a5517309949..2bd4f97614390 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml @@ -368,6 +368,85 @@ setup: - match: {hits.hits.2._id: "1"} - gte: {hits.hits.2._score: 0.78} - lte: {hits.hits.2._score: 0.791} + +--- +# Won't be true for larger datasets, but this helps checking kNN vs rescoring vs exact search +"Vector rescoring has same ordering as knn, same scoring as exact search for kNN section": + - skip: + features: "headers" + + # kNN search + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: hnsw_byte_quantized + body: + size: 3 + query: + knn: + k: 3 + num_candidates: 3 + field: vector + query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] + + - match: { hits.total: 3 } + - set: { hits.hits.0._id: knn_id0 } + - set: { hits.hits.1._id: knn_id1 } + - set: { hits.hits.2._id: knn_id2 } + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: hnsw_byte_quantized + body: + size: 3 + query: + knn: + k: 3 + num_candidates: 3 + field: vector + query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] + rescore: + oversample: 1.5 + + # Check same ordering (which will not be true for larger datasets) + - match: { hits.total: 3 } + - match: { hits.hits.0._id: $knn_id0 } + - match: { hits.hits.1._id: $knn_id1 } + - match: { hits.hits.2._id: $knn_id2 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "1.0 / (1.0 + Math.pow(l2norm(params.query_vector, 'vector'), 2.0))" + params: + query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] + + # Check same ordering (which will not be true for larger datasets) + # and scoring (which should be for the elements that are present in both) + - match: { hits.total: 3 } + - match: { hits.hits.0._id: $knn_id0 } + - match: { hits.hits.1._id: $knn_id1 } + - match: { hits.hits.2._id: $knn_id2 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + --- "Test bad quantization parameters": - do: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml index 5f1af2ca5c52f..ecf9d08396f30 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml @@ -549,6 +549,84 @@ setup: - match: { hits.hits.1._id: "2"} - match: { hits.hits.2._id: "3"} --- +"Vector rescoring has same ordering as knn, same scoring as exact search for kNN section": + - skip: + features: "headers" + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: hnsw_byte_quantized + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._id: knn_id0 } + - set: { hits.hits.1._id: knn_id1 } + - set: { hits.hits.2._id: knn_id2 } + + # Rescore + - do: + headers: + Content-Type: application/json + search: + index: hnsw_byte_quantized + rest_total_hits_as_int: true + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8] + k: 3 + num_candidates: 3 + rescore: + oversample: 1.5 + + # Comparing to knn search + - match: { hits.hits.0._id: $knn_id0 } + - match: { hits.hits.1._id: $knn_id1 } + - match: { hits.hits.2._id: $knn_id2 } + + # Get rescoring scores + - match: { hits.total: 3 } + - set: { hits.hits.0._id: rescore_id0 } + - set: { hits.hits.1._id: rescore_id1 } + - set: { hits.hits.2._id: rescore_id2 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "1.0 / (1.0 + Math.pow(l2norm(params.query_vector, 'vector'), 2.0))" + params: + query_vector: [-0.5, 90.0, -10, 14.8] + + # Check same ordering (which will not be true for larger datasets) + # and scoring (which should be for the elements that are present in both) + - match: { hits.total: 3 } + - match: { hits.hits.0._id: $rescore_id0 } + - match: { hits.hits.1._id: $rescore_id1 } + - match: { hits.hits.2._id: $rescore_id2 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- "Test odd dimensions fail indexing": - do: catch: bad_request diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml index ed7a8dd5df65d..960dbeb786c9a 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml @@ -87,6 +87,86 @@ setup: # here we verify that are last hit is always the worst one - match: { hits.hits.2._id: "1" } --- +"Vector rescoring has similar ordering as knn, same scoring as exact search for kNN section": + - skip: + features: "headers" + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_flat + body: + knn: + field: vector + query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._id: knn_id0 } + - set: { hits.hits.1._id: knn_id1 } + - set: { hits.hits.2._id: knn_id2 } + - set: { hits.hits.0._score: knn_score0 } + - set: { hits.hits.1._score: knn_score1 } + - set: { hits.hits.2._score: knn_score2 } + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_flat + body: + knn: + field: vector + query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + rescore: + oversample: 1.5 + + # Comparing to knn search, we already have changes in ordering and scoring + - match: { hits.hits.0._id: $knn_id1 } + - match: { hits.hits.1._id: $knn_id0 } + - match: { hits.hits.2._id: $knn_id2 } + + # Get rescoring scores + - match: { hits.total: 3 } + - set: { hits.hits.0._id: rescore_id0 } + - set: { hits.hits.1._id: rescore_id1 } + - set: { hits.hits.2._id: rescore_id2 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "1.0 / (1.0 + Math.pow(l2norm(params.query_vector, 'vector'), 2.0))" + params: + query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] + + # Check same ordering (which will not be true for larger datasets) + # and scoring (which should be for the elements that are present in both) + - match: { hits.total: 3 } + - match: { hits.hits.0._id: $rescore_id0 } + - match: { hits.hits.1._id: $rescore_id1 } + - match: { hits.hits.2._id: $rescore_id2 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- "Test bad parameters": - do: catch: bad_request diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml index b9a0b16f2bd7a..4661510780a5f 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml @@ -344,3 +344,80 @@ setup: index: dynamic_dim_hnsw_quantized body: vector: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] +--- +"Vector rescoring has same ordering as knn, same scoring as exact search for kNN section": + - skip: + features: "headers" + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: int4_flat + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._id: knn_id0 } + - set: { hits.hits.1._id: knn_id1 } + - set: { hits.hits.2._id: knn_id2 } + + # Rescore + - do: + headers: + Content-Type: application/json + search: + index: int4_flat + rest_total_hits_as_int: true + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8] + k: 3 + num_candidates: 3 + rescore: + oversample: 1.5 + + # Comparing to knn search + - match: { hits.hits.0._id: $knn_id0 } + - match: { hits.hits.1._id: $knn_id1 } + - match: { hits.hits.2._id: $knn_id2 } + + # Get rescoring scores + - match: { hits.total: 3 } + - set: { hits.hits.0._id: rescore_id0 } + - set: { hits.hits.1._id: rescore_id1 } + - set: { hits.hits.2._id: rescore_id2 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "1.0 / (1.0 + Math.pow(l2norm(params.query_vector, 'vector'), 2.0))" + params: + query_vector: [-0.5, 90.0, -10, 14.8] + + # Check same ordering (which will not be true for larger datasets) + # and scoring (which should be for the elements that are present in both) + - match: { hits.total: 3 } + - match: { hits.hits.0._id: $rescore_id0 } + - match: { hits.hits.1._id: $rescore_id1 } + - match: { hits.hits.2._id: $rescore_id2 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml index 139747c5e7ee5..2e02e9063e4da 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml @@ -262,6 +262,83 @@ setup: - gte: {hits.hits.2._score: 0.78} - lte: {hits.hits.2._score: 0.791} --- +"Vector rescoring has same ordering as knn, same scoring as exact search for kNN section": + - skip: + features: "headers" + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: int8_flat + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._id: knn_id0 } + - set: { hits.hits.1._id: knn_id1 } + - set: { hits.hits.2._id: knn_id2 } + + # Rescore + - do: + headers: + Content-Type: application/json + search: + index: int8_flat + rest_total_hits_as_int: true + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + rescore: + oversample: 1.5 + + # Comparing to knn search + - match: { hits.hits.0._id: $knn_id0 } + - match: { hits.hits.1._id: $knn_id1 } + - match: { hits.hits.2._id: $knn_id2 } + + # Get rescoring scores + - match: { hits.total: 3 } + - set: { hits.hits.0._id: rescore_id0 } + - set: { hits.hits.1._id: rescore_id1 } + - set: { hits.hits.2._id: rescore_id2 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "1.0 / (1.0 + Math.pow(l2norm(params.query_vector, 'vector'), 2.0))" + params: + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + + # Check same ordering (which will not be true for larger datasets) + # and scoring (which should be for the elements that are present in both) + - match: { hits.total: 3 } + - match: { hits.hits.0._id: $rescore_id0 } + - match: { hits.hits.1._id: $rescore_id1 } + - match: { hits.hits.2._id: $rescore_id2 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } +--- "Test bad parameters": - do: catch: bad_request From 934eedb4207289241bd68b68349b128e81476e6b Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 2 Dec 2024 17:46:00 +0100 Subject: [PATCH 22/56] Minor documentation / style fixes --- .../mapper/vectors/DenseVectorFieldMapper.java | 18 ++---------------- .../VectorSimilarityByteValueSource.java | 4 ++++ .../VectorSimilarityFloatValueSource.java | 4 ++++ .../search/profile/query/QueryProfiler.java | 8 ++++++++ .../search/vectors/KnnScoreDocQuery.java | 2 +- .../search/vectors/RescoreKnnVectorQuery.java | 3 ++- 6 files changed, 21 insertions(+), 18 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index eefdc48e2595b..e953462c9e490 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -122,10 +122,9 @@ public static boolean isNotUnitVector(float magnitude) { public static short MAX_DIMS_COUNT = 4096; // maximum allowed number of dimensions public static int MAX_DIMS_COUNT_BIT = 4096 * Byte.SIZE; // maximum allowed number of dimensions - public static final int OVERSAMPLE_LIMIT = 10_000; // Max oversample allowed for k and num_candidates - public static short MIN_DIMS_FOR_DYNAMIC_FLOAT_MAPPING = 128; // minimum number of dims for floats to be dynamically mapped to vector public static final int MAGNITUDE_BYTES = 4; + public static final int OVERSAMPLE_LIMIT = 10_000; // Max oversample allowed for k and num_candidates private static DenseVectorFieldMapper toType(FieldMapper in) { return (DenseVectorFieldMapper) in; @@ -2038,15 +2037,7 @@ public Query createKnnQuery( similarityThreshold, parentFilter ); - case BIT -> createKnnBitQuery( - queryVector.asByteVector(), - k, - numCands, - rescoreOversample, - filter, - similarityThreshold, - parentFilter - ); + case BIT -> createKnnBitQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter); }; } @@ -2058,16 +2049,11 @@ private Query createKnnBitQuery( byte[] queryVector, Integer k, int numCands, - Float rescoreOversample, Query filter, Float similarityThreshold, BitSetProducer parentFilter ) { elementType.checkDimensions(dims, queryVector.length); - if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { - float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); - elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude); - } Query knnQuery = parentFilter != null ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter) : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java index 661ae2ac2fcd4..96209592e5240 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java @@ -23,6 +23,10 @@ import java.util.Arrays; import java.util.Objects; +/** + * DoubleValuesSource that is used to calculate scores according to a similarity function for a KnnByteVectorField, using the + * original vector values stored in the index + */ public class VectorSimilarityByteValueSource extends DoubleValuesSource { private final String field; diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java index 13d42272f4744..f3e1f1683c945 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java @@ -23,6 +23,10 @@ import java.util.Arrays; import java.util.Objects; +/** + * DoubleValuesSource that is used to calculate scores according to a similarity function for a KnnFloatVectorField, using the + * original vector values stored in the index + */ public class VectorSimilarityFloatValueSource extends DoubleValuesSource { private final String field; diff --git a/server/src/main/java/org/elasticsearch/search/profile/query/QueryProfiler.java b/server/src/main/java/org/elasticsearch/search/profile/query/QueryProfiler.java index 59196fdfba714..23ce52b6c5b82 100644 --- a/server/src/main/java/org/elasticsearch/search/profile/query/QueryProfiler.java +++ b/server/src/main/java/org/elasticsearch/search/profile/query/QueryProfiler.java @@ -39,10 +39,18 @@ public QueryProfiler() { super(new InternalQueryProfileTree()); } + /** + * Adds a number of vector operations to the current count + * @param vectorOpsCount number of vector ops to add to the profiler + */ public void addVectorOpsCount(long vectorOpsCount) { this.vectorOpsCount += vectorOpsCount; } + /** + * Retrieves the number of vector operations performed by the queries + * @return number of vector operations performed by the queries + */ public long getVectorOpsCount() { return this.vectorOpsCount; } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java index db7484dddc226..3d13f3cd82b9c 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQuery.java @@ -43,8 +43,8 @@ public class KnnScoreDocQuery extends Query { // If a segment has no matching documents, it should be assigned the index of the next segment that does. // There should be a final entry that is always docs.length-1. private final int[] segmentStarts; - // an object identifying the reader context that was used to build this query + // an object identifying the reader context that was used to build this query private final Object contextIdentity; /** diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index 0680c2e4ac438..569d3c858ac98 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -26,7 +26,7 @@ import java.util.Objects; /** - * Wraps a kNN vector query to rescore the results using the non-quantized vectors + * Wraps an internal query to rescore the results using a similarity function over the original, non-quantized vectors of a vector field */ public class RescoreKnnVectorQuery extends Query implements ProfilingQuery { private final String fieldName; @@ -86,6 +86,7 @@ public Query rewrite(IndexSearcher searcher) throws IOException { return query; } + // Retrieve top k documents from the rescored query TopDocs topDocs = searcher.search(query, k); ScoreDoc[] scoreDocs = topDocs.scoreDocs; int[] docIds = new int[scoreDocs.length]; From cca6e3963ef7c226d163dbcb44c8051208d85df9 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 2 Dec 2024 18:23:40 +0100 Subject: [PATCH 23/56] Fix test commpilation --- .../xpack/rank/rrf/RRFRankMultiShardIT.java | 40 +++++++++---------- .../xpack/rank/rrf/RRFRankSingleShardIT.java | 34 ++++++++-------- 2 files changed, 37 insertions(+), 37 deletions(-) diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankMultiShardIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankMultiShardIT.java index b4cf409f5fd72..457c57410d168 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankMultiShardIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankMultiShardIT.java @@ -136,7 +136,7 @@ public void setupSuiteScopeCluster() throws Exception { public void testTotalDocsSmallerThanSize() { float[] queryVector = { 0.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 3, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 3, null, null); assertResponse( prepareSearch("tiny_index").setRankBuilder(new RRFRankBuilder(100, 1)) .setKnnSearch(List.of(knnSearch)) @@ -167,7 +167,7 @@ public void testTotalDocsSmallerThanSize() { public void testBM25AndKnn() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(false) @@ -208,8 +208,8 @@ public void testBM25AndKnn() { public void testMultipleOnlyKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(51, 1)) .setTrackTotalHits(true) @@ -260,8 +260,8 @@ public void testMultipleOnlyKnn() { public void testBM25AndMultipleKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(51, 1)) .setTrackTotalHits(false) @@ -332,7 +332,7 @@ public void testBM25AndMultipleKnn() { public void testBM25AndKnnWithBucketAggregation() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(true) @@ -389,8 +389,8 @@ public void testBM25AndKnnWithBucketAggregation() { public void testMultipleOnlyKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(51, 1)) .setTrackTotalHits(false) @@ -457,8 +457,8 @@ public void testMultipleOnlyKnnWithAggregation() { public void testBM25AndMultipleKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(51, 1)) .setTrackTotalHits(true) @@ -704,7 +704,7 @@ public void testMultiBM25WithAggregation() { public void testMultiBM25AndSingleKnn() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(false) @@ -762,7 +762,7 @@ public void testMultiBM25AndSingleKnn() { public void testMultiBM25AndSingleKnnWithAggregation() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(false) @@ -837,8 +837,8 @@ public void testMultiBM25AndSingleKnnWithAggregation() { public void testMultiBM25AndMultipleKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(false) @@ -899,8 +899,8 @@ public void testMultiBM25AndMultipleKnn() { public void testMultiBM25AndMultipleKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null, null); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(101, 1)) .setTrackTotalHits(false) @@ -979,7 +979,7 @@ public void testBasicRRFExplain() { // the first result should be the one present in both queries (i.e. doc with text0: 10 and vector: [10]) and the other ones // should only match the knn query float[] queryVector = { 9f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null).queryName("my_knn_search"); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null).queryName("my_knn_search"); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(100, 1)) .setKnnSearch(List.of(knnSearch)) @@ -1045,7 +1045,7 @@ public void testRRFExplainUnknownField() { // in this test we try knn with a query on an unknown field that would be rewritten to MatchNoneQuery // so we expect results and explanations only for the first part float[] queryVector = { 9f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null).queryName("my_knn_search"); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null).queryName("my_knn_search"); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(100, 1)) .setKnnSearch(List.of(knnSearch)) @@ -1112,7 +1112,7 @@ public void testRRFExplainOneUnknownFieldSubSearches() { // while the other one would produce a match. // So, we'd have a total of 3 queries, a (rewritten) MatchNoneQuery, a TermQuery, and a kNN query float[] queryVector = { 9f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null).queryName("my_knn_search"); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null).queryName("my_knn_search"); assertResponse( prepareSearch("nrd_index").setRankBuilder(new RRFRankBuilder(100, 1)) .setKnnSearch(List.of(knnSearch)) diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankSingleShardIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankSingleShardIT.java index ed26aa50ffa62..a4e7db3b3e3fe 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankSingleShardIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankSingleShardIT.java @@ -131,7 +131,7 @@ public void setupIndices() throws Exception { public void testTotalDocsSmallerThanSize() { float[] queryVector = { 0.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 3, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 3, 3, null, null); assertResponse( client().prepareSearch("tiny_index") @@ -164,7 +164,7 @@ public void testTotalDocsSmallerThanSize() { public void testBM25AndKnn() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) @@ -206,8 +206,8 @@ public void testBM25AndKnn() { public void testMultipleOnlyKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(51, 1)) @@ -259,8 +259,8 @@ public void testMultipleOnlyKnn() { public void testBM25AndMultipleKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(51, 1)) @@ -332,7 +332,7 @@ public void testBM25AndMultipleKnn() { public void testBM25AndKnnWithBucketAggregation() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) @@ -390,8 +390,8 @@ public void testBM25AndKnnWithBucketAggregation() { public void testMultipleOnlyKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(51, 1)) @@ -459,8 +459,8 @@ public void testMultipleOnlyKnnWithAggregation() { public void testBM25AndMultipleKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 51, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 51, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(51, 1)) @@ -709,7 +709,7 @@ public void testMultiBM25WithAggregation() { public void testMultiBM25AndSingleKnn() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) @@ -768,7 +768,7 @@ public void testMultiBM25AndSingleKnn() { public void testMultiBM25AndSingleKnnWithAggregation() { float[] queryVector = { 500.0f }; - KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null); + KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector_asc", queryVector, 101, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) @@ -844,8 +844,8 @@ public void testMultiBM25AndSingleKnnWithAggregation() { public void testMultiBM25AndMultipleKnn() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) @@ -907,8 +907,8 @@ public void testMultiBM25AndMultipleKnn() { public void testMultiBM25AndMultipleKnnWithAggregation() { float[] queryVectorAsc = { 500.0f }; float[] queryVectorDesc = { 500.0f }; - KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null); - KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null); + KnnSearchBuilder knnSearchAsc = new KnnSearchBuilder("vector_asc", queryVectorAsc, 101, 1001, null, null); + KnnSearchBuilder knnSearchDesc = new KnnSearchBuilder("vector_desc", queryVectorDesc, 101, 1001, null, null); assertResponse( client().prepareSearch("nrd_index") .setRankBuilder(new RRFRankBuilder(101, 1)) From d95db48b9ee36063bc67a96cfd0880d87ab282ed Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 2 Dec 2024 18:23:52 +0100 Subject: [PATCH 24/56] Properly implement advanceExact() --- .../index/mapper/vectors/VectorSimilarityByteValueSource.java | 4 +--- .../mapper/vectors/VectorSimilarityFloatValueSource.java | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java index 96209592e5240..48da96c0c77b0 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java @@ -14,7 +14,6 @@ import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DoubleValues; import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.IndexSearcher; @@ -56,8 +55,7 @@ public double doubleValue() throws IOException { @Override public boolean advanceExact(int doc) throws IOException { - docId = doc; - return iterator.advance(docId) != DocIdSetIterator.NO_MORE_DOCS; + return iterator.advance(doc) == doc; } }; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java index f3e1f1683c945..7b1da0e3b006b 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java @@ -14,7 +14,6 @@ import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DoubleValues; import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.IndexSearcher; @@ -56,8 +55,7 @@ public double doubleValue() throws IOException { @Override public boolean advanceExact(int doc) throws IOException { - docId = doc; - return iterator.advance(docId) != DocIdSetIterator.NO_MORE_DOCS; + return iterator.advance(doc) == doc; } }; } From ee904b0905b91cdb84a2fc587ecf9d09c15d66f6 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 2 Dec 2024 20:28:52 +0100 Subject: [PATCH 25/56] Fix VectorSimilarityFloatValueSource implementation for advanceExact --- .../mapper/vectors/VectorSimilarityByteValueSource.java | 6 ++++-- .../mapper/vectors/VectorSimilarityFloatValueSource.java | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java index 48da96c0c77b0..9ec668edffb79 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java @@ -14,6 +14,7 @@ import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DoubleValues; import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.IndexSearcher; @@ -43,7 +44,7 @@ public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws final LeafReader reader = ctx.reader(); ByteVectorValues vectorValues = reader.getByteVectorValues(field); - KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + final KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); return new DoubleValues() { private int docId = -1; @@ -55,7 +56,8 @@ public double doubleValue() throws IOException { @Override public boolean advanceExact(int doc) throws IOException { - return iterator.advance(doc) == doc; + docId = doc; + return doc >= iterator.docID() && iterator.docID() != DocIdSetIterator.NO_MORE_DOCS && iterator.advance(doc) == doc; } }; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java index 7b1da0e3b006b..bc2d771e03aa7 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java @@ -14,6 +14,7 @@ import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DoubleValues; import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.IndexSearcher; @@ -43,7 +44,7 @@ public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws final LeafReader reader = ctx.reader(); FloatVectorValues vectorValues = reader.getFloatVectorValues(field); - KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + final KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); return new DoubleValues() { private int docId = -1; @@ -55,7 +56,8 @@ public double doubleValue() throws IOException { @Override public boolean advanceExact(int doc) throws IOException { - return iterator.advance(doc) == doc; + docId = doc; + return doc >= iterator.docID() && iterator.docID() != DocIdSetIterator.NO_MORE_DOCS && iterator.advance(doc) == doc; } }; } From 0d77521aad2ed6118da14faf01d2de19cd2bb79f Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 2 Dec 2024 20:29:25 +0100 Subject: [PATCH 26/56] Add capability for BwC tests --- .../test/search.vectors/41_knn_search_bbq_hnsw.yml | 7 +++++++ .../test/search.vectors/41_knn_search_byte_quantized.yml | 8 +++++++- .../search.vectors/41_knn_search_half_byte_quantized.yml | 7 +++++++ .../test/search.vectors/42_knn_search_bbq_flat.yml | 8 ++++++++ .../test/search.vectors/42_knn_search_int4_flat.yml | 7 +++++++ .../test/search.vectors/42_knn_search_int8_flat.yml | 7 +++++++ .../rest/action/search/SearchCapabilities.java | 2 ++ 7 files changed, 45 insertions(+), 1 deletion(-) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index 31f934d58b7ec..fe04ee08e4b10 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -89,6 +89,13 @@ setup: --- "Vector rescoring has similar ordering as knn, same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] - skip: features: "headers" - do: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml index 2bd4f97614390..d63ddfb1eb680 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml @@ -372,9 +372,15 @@ setup: --- # Won't be true for larger datasets, but this helps checking kNN vs rescoring vs exact search "Vector rescoring has same ordering as knn, same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] - skip: features: "headers" - # kNN search - do: headers: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml index ecf9d08396f30..4ec46b5578c02 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml @@ -550,6 +550,13 @@ setup: - match: { hits.hits.2._id: "3"} --- "Vector rescoring has same ordering as knn, same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] - skip: features: "headers" - do: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml index 960dbeb786c9a..60f4a499ca629 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml @@ -88,6 +88,13 @@ setup: - match: { hits.hits.2._id: "1" } --- "Vector rescoring has similar ordering as knn, same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] - skip: features: "headers" - do: @@ -147,6 +154,7 @@ setup: Content-Type: application/json search: rest_total_hits_as_int: true + index: bbq_flat body: query: script_score: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml index 4661510780a5f..cfae5ecace868 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml @@ -346,6 +346,13 @@ setup: vector: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] --- "Vector rescoring has same ordering as knn, same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] - skip: features: "headers" - do: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml index 2e02e9063e4da..b6d0f0343d592 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml @@ -263,6 +263,13 @@ setup: - lte: {hits.hits.2._score: 0.791} --- "Vector rescoring has same ordering as knn, same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] - skip: features: "headers" - do: diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java index 794b30aa5aab2..f9123890dc740 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java @@ -44,6 +44,7 @@ private SearchCapabilities() {} private static final String MULTI_DENSE_VECTOR_SCRIPT_MAX_SIM = "multi_dense_vector_script_max_sim_with_bugfix"; private static final String RANDOM_SAMPLER_WITH_SCORED_SUBAGGS = "random_sampler_with_scored_subaggs"; + private static final String KNN_QUANTIZED_VECTOR_RESCORE = "knn_quantized_vector_rescore"; public static final Set CAPABILITIES; static { @@ -55,6 +56,7 @@ private SearchCapabilities() {} capabilities.add(TRANSFORM_RANK_RRF_TO_RETRIEVER); capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT); capabilities.add(RANDOM_SAMPLER_WITH_SCORED_SUBAGGS); + capabilities.add(KNN_QUANTIZED_VECTOR_RESCORE); if (MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled()) { capabilities.add(MULTI_DENSE_VECTOR_FIELD_MAPPER); capabilities.add(MULTI_DENSE_VECTOR_SCRIPT_ACCESS); From 095a951854e50a60a4e1043934d85db0e057052f Mon Sep 17 00:00:00 2001 From: Carlos Delgado <6339205+carlosdelest@users.noreply.github.com> Date: Tue, 3 Dec 2024 10:15:52 +0100 Subject: [PATCH 27/56] Update docs/changelog/116663.yaml --- docs/changelog/116663.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/116663.yaml diff --git a/docs/changelog/116663.yaml b/docs/changelog/116663.yaml new file mode 100644 index 0000000000000..40bcdea29bc31 --- /dev/null +++ b/docs/changelog/116663.yaml @@ -0,0 +1,5 @@ +pr: 116663 +summary: KNN vector rescoring for quantized vectors +area: Vector Search +type: feature +issues: [] From 732bd7d9e4a2c262fbcdce7b8e74a28cfbb2ee33 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 4 Dec 2024 15:22:22 +0100 Subject: [PATCH 28/56] Correctly implement profiling. Rename ProfilingQuery to QueryProfilerProvider --- .../search.vectors/210_knn_search_profile.yml | 104 ++++++++++++++++++ .../VectorSimilarityByteValueSource.java | 11 +- .../VectorSimilarityFloatValueSource.java | 11 +- .../elasticsearch/search/dfs/DfsPhase.java | 6 +- ...iversifyingChildrenByteKnnVectorQuery.java | 2 +- ...versifyingChildrenFloatKnnVectorQuery.java | 2 +- .../search/vectors/ESKnnByteVectorQuery.java | 2 +- .../search/vectors/ESKnnFloatVectorQuery.java | 2 +- ...gQuery.java => QueryProfilerProvider.java} | 2 +- .../search/vectors/RescoreKnnVectorQuery.java | 54 +++++---- .../search/vectors/VectorSimilarityQuery.java | 6 +- .../vectors/RescoreKnnVectorQueryTests.java | 25 +++-- 12 files changed, 183 insertions(+), 44 deletions(-) create mode 100644 rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml rename server/src/main/java/org/elasticsearch/search/vectors/{ProfilingQuery.java => QueryProfilerProvider.java} (96%) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml new file mode 100644 index 0000000000000..5d9f14c3e353f --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml @@ -0,0 +1,104 @@ +setup: + - requires: + cluster_features: "mapper.vectors.bbq" + reason: 'kNN float to better-binary quantization is required' + - do: + indices.create: + index: bbq_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + name: + type: keyword + vector: + type: dense_vector + dims: 64 + index: true + similarity: l2_norm + index_options: + type: bbq_hnsw + another_vector: + type: dense_vector + dims: 64 + index: true + similarity: l2_norm + index_options: + type: bbq_hnsw + + - do: + index: + index: bbq_hnsw + id: "1" + body: + name: cow.jpg + vector: [300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0] + another_vector: [115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_hnsw + + - do: + index: + index: bbq_hnsw + id: "2" + body: + name: moose.jpg + vector: [100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0] + another_vector: [50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_hnsw + + - do: + index: + index: bbq_hnsw + id: "3" + body: + name: rabbit.jpg + vector: [111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0] + another_vector: [11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_hnsw + + - do: + indices.forcemerge: + index: bbq_hnsw + max_num_segments: 1 + +--- +"Profile rescored knn search": + - do: + search: + index: bbq_hnsw + body: + profile: true + knn: + field: vector + query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + + - match: { profile.shards.0.dfs.knn.0.vector_operations_count: 3 } + + - do: + search: + index: bbq_hnsw + body: + profile: true + knn: + field: vector + query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + "rescore": + "oversample": 2.0 + + # We expect the knn search ops + rescoring num_cnaidates (for rescoring) per shard + - match: { profile.shards.0.dfs.knn.0.vector_operations_count: 6 } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java index 9ec668edffb79..33864fb76a310 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java @@ -18,6 +18,8 @@ import org.apache.lucene.search.DoubleValues; import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.IndexSearcher; +import org.elasticsearch.search.profile.query.QueryProfiler; +import org.elasticsearch.search.vectors.QueryProfilerProvider; import java.io.IOException; import java.util.Arrays; @@ -27,11 +29,12 @@ * DoubleValuesSource that is used to calculate scores according to a similarity function for a KnnByteVectorField, using the * original vector values stored in the index */ -public class VectorSimilarityByteValueSource extends DoubleValuesSource { +public class VectorSimilarityByteValueSource extends DoubleValuesSource implements QueryProfilerProvider { private final String field; private final byte[] target; private final VectorSimilarityFunction vectorSimilarityFunction; + private long vectorOpsCount; public VectorSimilarityByteValueSource(String field, byte[] target, VectorSimilarityFunction vectorSimilarityFunction) { this.field = field; @@ -51,6 +54,7 @@ public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws @Override public double doubleValue() throws IOException { + vectorOpsCount++; return vectorSimilarityFunction.compare(target, vectorValues.vectorValue(docId)); } @@ -72,6 +76,11 @@ public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException { return this; } + @Override + public void profile(QueryProfiler queryProfiler) { + queryProfiler.addVectorOpsCount(vectorOpsCount); + } + @Override public int hashCode() { return Objects.hash(field, Arrays.hashCode(target), vectorSimilarityFunction); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java index bc2d771e03aa7..c8e128f7fe3fe 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java @@ -18,6 +18,8 @@ import org.apache.lucene.search.DoubleValues; import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.IndexSearcher; +import org.elasticsearch.search.profile.query.QueryProfiler; +import org.elasticsearch.search.vectors.QueryProfilerProvider; import java.io.IOException; import java.util.Arrays; @@ -27,11 +29,12 @@ * DoubleValuesSource that is used to calculate scores according to a similarity function for a KnnFloatVectorField, using the * original vector values stored in the index */ -public class VectorSimilarityFloatValueSource extends DoubleValuesSource { +public class VectorSimilarityFloatValueSource extends DoubleValuesSource implements QueryProfilerProvider { private final String field; private final float[] target; private final VectorSimilarityFunction vectorSimilarityFunction; + private long vectorOpsCount; public VectorSimilarityFloatValueSource(String field, float[] target, VectorSimilarityFunction vectorSimilarityFunction) { this.field = field; @@ -51,6 +54,7 @@ public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws @Override public double doubleValue() throws IOException { + vectorOpsCount++; return vectorSimilarityFunction.compare(target, vectorValues.vectorValue(docId)); } @@ -72,6 +76,11 @@ public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException { return this; } + @Override + public void profile(QueryProfiler queryProfiler) { + queryProfiler.addVectorOpsCount(vectorOpsCount); + } + @Override public int hashCode() { return Objects.hash(field, Arrays.hashCode(target), vectorSimilarityFunction); diff --git a/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java b/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java index 76b3f45ffb84a..6a99b51ac679c 100644 --- a/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java +++ b/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java @@ -34,7 +34,7 @@ import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.vectors.KnnSearchBuilder; import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; -import org.elasticsearch.search.vectors.ProfilingQuery; +import org.elasticsearch.search.vectors.QueryProfilerProvider; import org.elasticsearch.tasks.TaskCancelledException; import java.io.IOException; @@ -224,8 +224,8 @@ static DfsKnnResults singleKnnSearch(Query knnQuery, int k, Profilers profilers, ); topDocs = searcher.search(knnQuery, ipcm); - if (knnQuery instanceof ProfilingQuery profilingQuery) { - profilingQuery.profile(knnProfiler); + if (knnQuery instanceof QueryProfilerProvider queryProfilerProvider) { + queryProfilerProvider.profile(knnProfiler); } knnProfiler.setCollectorResult(ipcm.getCollectorTree()); diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java index 40803586a31ce..b7f129f674036 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java @@ -15,7 +15,7 @@ import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; import org.elasticsearch.search.profile.query.QueryProfiler; -public class ESDiversifyingChildrenByteKnnVectorQuery extends DiversifyingChildrenByteKnnVectorQuery implements ProfilingQuery { +public class ESDiversifyingChildrenByteKnnVectorQuery extends DiversifyingChildrenByteKnnVectorQuery implements QueryProfilerProvider { private final Integer kParam; private long vectorOpsCount; diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java index f69aca5339daf..cb323bbe3932a 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java @@ -15,7 +15,7 @@ import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.elasticsearch.search.profile.query.QueryProfiler; -public class ESDiversifyingChildrenFloatKnnVectorQuery extends DiversifyingChildrenFloatKnnVectorQuery implements ProfilingQuery { +public class ESDiversifyingChildrenFloatKnnVectorQuery extends DiversifyingChildrenFloatKnnVectorQuery implements QueryProfilerProvider { private final Integer kParam; private long vectorOpsCount; diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java index 1de5471cb96ac..5c199f42093b1 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java @@ -14,7 +14,7 @@ import org.apache.lucene.search.TopDocs; import org.elasticsearch.search.profile.query.QueryProfiler; -public class ESKnnByteVectorQuery extends KnnByteVectorQuery implements ProfilingQuery { +public class ESKnnByteVectorQuery extends KnnByteVectorQuery implements QueryProfilerProvider { private final Integer kParam; private long vectorOpsCount; diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java index 0d37a46625a9e..b7b9d092ceeac 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java @@ -14,7 +14,7 @@ import org.apache.lucene.search.TopDocs; import org.elasticsearch.search.profile.query.QueryProfiler; -public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements ProfilingQuery { +public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements QueryProfilerProvider { private final Integer kParam; private long vectorOpsCount; diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ProfilingQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/QueryProfilerProvider.java similarity index 96% rename from server/src/main/java/org/elasticsearch/search/vectors/ProfilingQuery.java rename to server/src/main/java/org/elasticsearch/search/vectors/QueryProfilerProvider.java index 4d36d8eae57cc..47b0e1e299968 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ProfilingQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/QueryProfilerProvider.java @@ -18,7 +18,7 @@ * must provide an implementation for profile() to store profiling information in the {@link QueryProfiler}. */ -public interface ProfilingQuery { +public interface QueryProfilerProvider { /** * Store the profiling information in the {@link QueryProfiler} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index 569d3c858ac98..b13f9490d2b3a 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -11,6 +11,7 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.queries.function.FunctionScoreQuery; +import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; @@ -28,7 +29,7 @@ /** * Wraps an internal query to rescore the results using a similarity function over the original, non-quantized vectors of a vector field */ -public class RescoreKnnVectorQuery extends Query implements ProfilingQuery { +public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvider { private final String fieldName; private final byte[] byteTarget; private final float[] floatTarget; @@ -36,7 +37,7 @@ public class RescoreKnnVectorQuery extends Query implements ProfilingQuery { private final Integer k; private final Query innerQuery; - private long vectorOpsCount; + private QueryProfilerProvider vectorProfiling; public RescoreKnnVectorQuery( String fieldName, @@ -45,12 +46,7 @@ public RescoreKnnVectorQuery( Integer k, Query innerQuery ) { - this.fieldName = fieldName; - this.byteTarget = byteTarget; - this.floatTarget = null; - this.vectorSimilarityFunction = vectorSimilarityFunction; - this.k = k; - this.innerQuery = innerQuery; + this(fieldName, byteTarget, null, vectorSimilarityFunction, k, innerQuery, null); } public RescoreKnnVectorQuery( @@ -60,29 +56,47 @@ public RescoreKnnVectorQuery( Integer k, Query innerQuery ) { + this(fieldName, null, floatTarget, vectorSimilarityFunction, k, innerQuery, null); + } + + private RescoreKnnVectorQuery( + String fieldName, + byte[] byteTarget, + float[] floatTarget, + VectorSimilarityFunction vectorSimilarityFunction, + Integer k, + Query innerQuery, + QueryProfilerProvider queryProfilerProvider + ) { + if ((byteTarget == null ^ floatTarget == null) == false) { + throw new IllegalArgumentException("Either byteTarget or floatTarget must be set"); + } + this.fieldName = fieldName; - this.byteTarget = null; + this.byteTarget = byteTarget; this.floatTarget = floatTarget; this.vectorSimilarityFunction = vectorSimilarityFunction; this.k = k; this.innerQuery = innerQuery; + this.vectorProfiling = queryProfilerProvider; } @Override public Query rewrite(IndexSearcher searcher) throws IOException { - assert byteTarget == null ^ floatTarget == null : "Either byteTarget or floatTarget must be set"; - final DoubleValuesSource valueSource; if (byteTarget != null) { valueSource = new VectorSimilarityByteValueSource(fieldName, byteTarget, vectorSimilarityFunction); } else { valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); } + // Vector similarity DoubleValueSource keep track of the compared vectors - we need that in case we don't need + // to calculate top k and return directly the query + vectorProfiling = (QueryProfilerProvider) valueSource; FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource); Query query = searcher.rewrite(functionScoreQuery); if (k == null) { - // No need to calculate top k - let the request size limit the results + // No need to calculate top k - let the request size limit the results. return query; } @@ -96,8 +110,6 @@ public Query rewrite(IndexSearcher searcher) throws IOException { scores[i] = scoreDocs[i].score; } - vectorOpsCount = scoreDocs.length; - return new KnnScoreDocQuery(docIds, scores, searcher.getIndexReader()); } @@ -111,17 +123,19 @@ public Integer k() { @Override public void profile(QueryProfiler queryProfiler) { - if (innerQuery instanceof ProfilingQuery profilingQuery) { - profilingQuery.profile(queryProfiler); + if (innerQuery instanceof QueryProfilerProvider queryProfilerProvider) { + queryProfilerProvider.profile(queryProfiler); } - queryProfiler.addVectorOpsCount(vectorOpsCount); + + if (vectorProfiling == null) { + throw new IllegalStateException("Query should have been rewritten"); + } + vectorProfiling.profile(queryProfiler); } @Override public void visit(QueryVisitor visitor) { - if (visitor.acceptField(fieldName)) { - visitor.visitLeaf(this); - } + innerQuery.visit(visitor.getSubVisitor(BooleanClause.Occur.MUST, this)); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java index 74daafe747ca4..d91994f843541 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java @@ -31,7 +31,7 @@ /** * This query provides a simple post-filter for the provided Query. The query is assumed to be a Knn(Float|Byte)VectorQuery. */ -public class VectorSimilarityQuery extends Query implements ProfilingQuery { +public class VectorSimilarityQuery extends Query implements QueryProfilerProvider { private final float similarity; private final float docScore; private final Query innerKnnQuery; @@ -81,8 +81,8 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo @Override public void profile(QueryProfiler queryProfiler) { - if (innerKnnQuery instanceof ProfilingQuery profilingQuery) { - profilingQuery.profile(queryProfiler); + if (innerKnnQuery instanceof QueryProfilerProvider queryProfilerProvider) { + queryProfilerProvider.profile(queryProfiler); } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 8c68fca1b3640..cbaa8eb81bcf4 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -35,6 +35,7 @@ import org.elasticsearch.test.ESTestCase; import java.io.IOException; +import java.io.UnsupportedEncodingException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -136,7 +137,7 @@ public void testProfiling() throws Exception { VectorData queryVector = vectorProvider.randomVector(numDims); checkProfiling(queryVector, reader, new MatchAllDocsQuery()); - checkProfiling(queryVector, reader, new MockProfilingQuery(randomIntBetween(1, 100))); + checkProfiling(queryVector, reader, new MockQueryProfilerProvider(randomIntBetween(1, 100))); } } } @@ -149,13 +150,10 @@ private void checkProfiling(VectorData queryVector, IndexReader reader, Query in QueryProfiler queryProfiler = new QueryProfiler(); rescoreKnnVectorQuery.profile(queryProfiler); - long expectedVectorOpsCount = 0; - if (k != null) { - expectedVectorOpsCount += k; - } - if (innerQuery instanceof ProfilingQuery profilingQuery) { + long expectedVectorOpsCount = numDocs; + if (innerQuery instanceof QueryProfilerProvider queryProfilerProvider) { QueryProfiler anotherProfiler = new QueryProfiler(); - profilingQuery.profile(anotherProfiler); + queryProfilerProvider.profile(anotherProfiler); assertThat(anotherProfiler.getVectorOpsCount(), greaterThan(0L)); expectedVectorOpsCount += anotherProfiler.getVectorOpsCount(); } @@ -166,11 +164,11 @@ private void checkProfiling(VectorData queryVector, IndexReader reader, Query in /** * A mock query that is used to test profiling */ - private static class MockProfilingQuery extends Query implements ProfilingQuery { + private static class MockQueryProfilerProvider extends Query implements QueryProfilerProvider { private final long vectorOpsCount; - private MockProfilingQuery(long vectorOpsCount) { + private MockQueryProfilerProvider(long vectorOpsCount) { this.vectorOpsCount = vectorOpsCount; } @@ -181,7 +179,12 @@ public String toString(String field) { @Override public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - return new MatchAllDocsQuery().createWeight(searcher, scoreMode, boost); + throw new UnsupportedEncodingException("Should have been rewritten"); + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + return new MatchAllDocsQuery(); } @Override @@ -189,7 +192,7 @@ public void visit(QueryVisitor visitor) {} @Override public boolean equals(Object obj) { - return obj instanceof MockProfilingQuery; + return obj instanceof MockQueryProfilerProvider; } @Override From b5e6309c5bd02d92488d83a4368f2a10f34676fb Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 4 Dec 2024 16:47:05 +0100 Subject: [PATCH 29/56] Fix toString() --- .../elasticsearch/search/vectors/RescoreKnnVectorQuery.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index b13f9490d2b3a..e754ba0c05688 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -161,9 +161,9 @@ public String toString(String field) { final StringBuilder sb = new StringBuilder("KnnRescoreVectorQuery{"); sb.append("fieldName='").append(fieldName).append('\''); if (byteTarget != null) { - sb.append(", byteTarget=").append(Arrays.toString(byteTarget)); + sb.append(", byteTarget=").append(byteTarget[0]).append("..."); } else { - sb.append(", floatTarget=").append(Arrays.toString(floatTarget)); + sb.append(", floatTarget=").append(floatTarget[0]).append("..."); } sb.append(", vectorSimilarityFunction=").append(vectorSimilarityFunction); sb.append(", k=").append(k); From 3120c5cd5a78fb361296d4ab56925e833d319fc1 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 4 Dec 2024 17:49:57 +0100 Subject: [PATCH 30/56] YAML tests do not check doc ordering, just scores --- .../search.vectors/41_knn_search_bbq_hnsw.yml | 38 ++----------------- .../41_knn_search_byte_quantized.yml | 34 ++--------------- .../41_knn_search_half_byte_quantized.yml | 37 ++---------------- .../search.vectors/42_knn_search_bbq_flat.yml | 38 ++----------------- .../42_knn_search_int4_flat.yml | 37 ++---------------- .../42_knn_search_int8_flat.yml | 36 ++---------------- 6 files changed, 21 insertions(+), 199 deletions(-) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index fe04ee08e4b10..92b882892cee4 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -88,7 +88,7 @@ setup: - match: { hits.hits.2._id: "1" } --- -"Vector rescoring has similar ordering as knn, same scoring as exact search for kNN section": +"Vector rescoring has same scoring as exact search for kNN section": - requires: reason: 'Quantized vector rescoring is required' test_runner_features: [capabilities] @@ -98,26 +98,6 @@ setup: capabilities: [knn_quantized_vector_rescore] - skip: features: "headers" - - do: - headers: - Content-Type: application/json - search: - rest_total_hits_as_int: true - index: bbq_hnsw - body: - knn: - field: vector - query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] - k: 3 - num_candidates: 3 - - - match: { hits.total: 3 } - - set: { hits.hits.0._id: knn_id0 } - - set: { hits.hits.1._id: knn_id1 } - - set: { hits.hits.2._id: knn_id2 } - - set: { hits.hits.0._score: knn_score0 } - - set: { hits.hits.1._score: knn_score1 } - - set: { hits.hits.2._score: knn_score2 } # Rescore - do: @@ -135,16 +115,8 @@ setup: rescore: oversample: 1.5 - # Comparing to knn search, we already have changes in ordering and scoring - - match: { hits.hits.0._id: $knn_id1 } - - match: { hits.hits.1._id: $knn_id0 } - - match: { hits.hits.2._id: $knn_id2 } - - # Get rescoring scores + # Get rescoring scores - hit ordering may change depending on how things are distributed - match: { hits.total: 3 } - - set: { hits.hits.0._id: rescore_id0 } - - set: { hits.hits.1._id: rescore_id1 } - - set: { hits.hits.2._id: rescore_id2 } - set: { hits.hits.0._score: rescore_score0 } - set: { hits.hits.1._score: rescore_score1 } - set: { hits.hits.2._score: rescore_score2 } @@ -164,12 +136,8 @@ setup: params: query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0 ] - # Check same ordering (which will not be true for larger datasets) - # and scoring (which should be for the elements that are present in both) + # Compare scores as hit IDs may change depending on how things are distributed - match: { hits.total: 3 } - - match: { hits.hits.0._id: $rescore_id0 } - - match: { hits.hits.1._id: $rescore_id1 } - - match: { hits.hits.2._id: $rescore_id2 } - match: { hits.hits.0._score: $rescore_score0 } - match: { hits.hits.1._score: $rescore_score1 } - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml index d63ddfb1eb680..7ff20e1beb8a5 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml @@ -371,7 +371,7 @@ setup: --- # Won't be true for larger datasets, but this helps checking kNN vs rescoring vs exact search -"Vector rescoring has same ordering as knn, same scoring as exact search for kNN section": +"Vector rescoring has the same scoring as exact search for kNN section": - requires: reason: 'Quantized vector rescoring is required' test_runner_features: [capabilities] @@ -381,26 +381,6 @@ setup: capabilities: [knn_quantized_vector_rescore] - skip: features: "headers" - # kNN search - - do: - headers: - Content-Type: application/json - search: - rest_total_hits_as_int: true - index: hnsw_byte_quantized - body: - size: 3 - query: - knn: - k: 3 - num_candidates: 3 - field: vector - query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] - - - match: { hits.total: 3 } - - set: { hits.hits.0._id: knn_id0 } - - set: { hits.hits.1._id: knn_id1 } - - set: { hits.hits.2._id: knn_id2 } # Rescore - do: @@ -420,15 +400,13 @@ setup: rescore: oversample: 1.5 - # Check same ordering (which will not be true for larger datasets) + # Get rescoring scores - hit ordering may change depending on how things are distributed - match: { hits.total: 3 } - - match: { hits.hits.0._id: $knn_id0 } - - match: { hits.hits.1._id: $knn_id1 } - - match: { hits.hits.2._id: $knn_id2 } - set: { hits.hits.0._score: rescore_score0 } - set: { hits.hits.1._score: rescore_score1 } - set: { hits.hits.2._score: rescore_score2 } + # Exact knn via script score - do: headers: Content-Type: application/json @@ -443,12 +421,8 @@ setup: params: query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] - # Check same ordering (which will not be true for larger datasets) - # and scoring (which should be for the elements that are present in both) + # Compare scores as hit IDs may change depending on how things are distributed - match: { hits.total: 3 } - - match: { hits.hits.0._id: $knn_id0 } - - match: { hits.hits.1._id: $knn_id1 } - - match: { hits.hits.2._id: $knn_id2 } - match: { hits.hits.0._score: $rescore_score0 } - match: { hits.hits.1._score: $rescore_score1 } - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml index 4ec46b5578c02..867e47624873f 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml @@ -549,7 +549,7 @@ setup: - match: { hits.hits.1._id: "2"} - match: { hits.hits.2._id: "3"} --- -"Vector rescoring has same ordering as knn, same scoring as exact search for kNN section": +"Vector rescoring has the same scoring as exact search for kNN section": - requires: reason: 'Quantized vector rescoring is required' test_runner_features: [capabilities] @@ -559,24 +559,6 @@ setup: capabilities: [knn_quantized_vector_rescore] - skip: features: "headers" - - do: - headers: - Content-Type: application/json - search: - rest_total_hits_as_int: true - index: hnsw_byte_quantized - body: - fields: [ "name" ] - knn: - field: vector - query_vector: [-0.5, 90.0, -10, 14.8] - k: 3 - num_candidates: 3 - - - match: { hits.total: 3 } - - set: { hits.hits.0._id: knn_id0 } - - set: { hits.hits.1._id: knn_id1 } - - set: { hits.hits.2._id: knn_id2 } # Rescore - do: @@ -595,20 +577,13 @@ setup: rescore: oversample: 1.5 - # Comparing to knn search - - match: { hits.hits.0._id: $knn_id0 } - - match: { hits.hits.1._id: $knn_id1 } - - match: { hits.hits.2._id: $knn_id2 } - - # Get rescoring scores + # Get rescoring scores - hit ordering may change depending on how things are distributed - match: { hits.total: 3 } - - set: { hits.hits.0._id: rescore_id0 } - - set: { hits.hits.1._id: rescore_id1 } - - set: { hits.hits.2._id: rescore_id2 } - set: { hits.hits.0._score: rescore_score0 } - set: { hits.hits.1._score: rescore_score1 } - set: { hits.hits.2._score: rescore_score2 } + # Exact knn via script score - do: headers: Content-Type: application/json @@ -623,12 +598,8 @@ setup: params: query_vector: [-0.5, 90.0, -10, 14.8] - # Check same ordering (which will not be true for larger datasets) - # and scoring (which should be for the elements that are present in both) + # Compare scores as hit IDs may change depending on how things are distributed - match: { hits.total: 3 } - - match: { hits.hits.0._id: $rescore_id0 } - - match: { hits.hits.1._id: $rescore_id1 } - - match: { hits.hits.2._id: $rescore_id2 } - match: { hits.hits.0._score: $rescore_score0 } - match: { hits.hits.1._score: $rescore_score1 } - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml index 60f4a499ca629..48560747365eb 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml @@ -87,7 +87,7 @@ setup: # here we verify that are last hit is always the worst one - match: { hits.hits.2._id: "1" } --- -"Vector rescoring has similar ordering as knn, same scoring as exact search for kNN section": +"Vector rescoring has same scoring as exact search for kNN section": - requires: reason: 'Quantized vector rescoring is required' test_runner_features: [capabilities] @@ -97,26 +97,6 @@ setup: capabilities: [knn_quantized_vector_rescore] - skip: features: "headers" - - do: - headers: - Content-Type: application/json - search: - rest_total_hits_as_int: true - index: bbq_flat - body: - knn: - field: vector - query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] - k: 3 - num_candidates: 3 - - - match: { hits.total: 3 } - - set: { hits.hits.0._id: knn_id0 } - - set: { hits.hits.1._id: knn_id1 } - - set: { hits.hits.2._id: knn_id2 } - - set: { hits.hits.0._score: knn_score0 } - - set: { hits.hits.1._score: knn_score1 } - - set: { hits.hits.2._score: knn_score2 } # Rescore - do: @@ -134,16 +114,8 @@ setup: rescore: oversample: 1.5 - # Comparing to knn search, we already have changes in ordering and scoring - - match: { hits.hits.0._id: $knn_id1 } - - match: { hits.hits.1._id: $knn_id0 } - - match: { hits.hits.2._id: $knn_id2 } - - # Get rescoring scores + # Get rescoring scores - hit ordering may change depending on how things are distributed - match: { hits.total: 3 } - - set: { hits.hits.0._id: rescore_id0 } - - set: { hits.hits.1._id: rescore_id1 } - - set: { hits.hits.2._id: rescore_id2 } - set: { hits.hits.0._score: rescore_score0 } - set: { hits.hits.1._score: rescore_score1 } - set: { hits.hits.2._score: rescore_score2 } @@ -164,12 +136,8 @@ setup: params: query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] - # Check same ordering (which will not be true for larger datasets) - # and scoring (which should be for the elements that are present in both) + # Compare scores as hit IDs may change depending on how things are distributed - match: { hits.total: 3 } - - match: { hits.hits.0._id: $rescore_id0 } - - match: { hits.hits.1._id: $rescore_id1 } - - match: { hits.hits.2._id: $rescore_id2 } - match: { hits.hits.0._score: $rescore_score0 } - match: { hits.hits.1._score: $rescore_score1 } - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml index cfae5ecace868..3be3cf70cdc69 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml @@ -345,7 +345,7 @@ setup: body: vector: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] --- -"Vector rescoring has same ordering as knn, same scoring as exact search for kNN section": +"Vector rescoring has the same scoring as exact search for kNN section": - requires: reason: 'Quantized vector rescoring is required' test_runner_features: [capabilities] @@ -355,24 +355,6 @@ setup: capabilities: [knn_quantized_vector_rescore] - skip: features: "headers" - - do: - headers: - Content-Type: application/json - search: - rest_total_hits_as_int: true - index: int4_flat - body: - fields: [ "name" ] - knn: - field: vector - query_vector: [-0.5, 90.0, -10, 14.8] - k: 3 - num_candidates: 3 - - - match: { hits.total: 3 } - - set: { hits.hits.0._id: knn_id0 } - - set: { hits.hits.1._id: knn_id1 } - - set: { hits.hits.2._id: knn_id2 } # Rescore - do: @@ -391,20 +373,13 @@ setup: rescore: oversample: 1.5 - # Comparing to knn search - - match: { hits.hits.0._id: $knn_id0 } - - match: { hits.hits.1._id: $knn_id1 } - - match: { hits.hits.2._id: $knn_id2 } - - # Get rescoring scores + # Get rescoring scores - hit ordering may change depending on how things are distributed - match: { hits.total: 3 } - - set: { hits.hits.0._id: rescore_id0 } - - set: { hits.hits.1._id: rescore_id1 } - - set: { hits.hits.2._id: rescore_id2 } - set: { hits.hits.0._score: rescore_score0 } - set: { hits.hits.1._score: rescore_score1 } - set: { hits.hits.2._score: rescore_score2 } + # Exact knn via script score - do: headers: Content-Type: application/json @@ -419,12 +394,8 @@ setup: params: query_vector: [-0.5, 90.0, -10, 14.8] - # Check same ordering (which will not be true for larger datasets) - # and scoring (which should be for the elements that are present in both) + # Get rescoring scores - hit ordering may change depending on how things are distributed - match: { hits.total: 3 } - - match: { hits.hits.0._id: $rescore_id0 } - - match: { hits.hits.1._id: $rescore_id1 } - - match: { hits.hits.2._id: $rescore_id2 } - match: { hits.hits.0._score: $rescore_score0 } - match: { hits.hits.1._score: $rescore_score1 } - match: { hits.hits.2._score: $rescore_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml index b6d0f0343d592..282dd7df1f038 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml @@ -262,7 +262,7 @@ setup: - gte: {hits.hits.2._score: 0.78} - lte: {hits.hits.2._score: 0.791} --- -"Vector rescoring has same ordering as knn, same scoring as exact search for kNN section": +"Vector rescoring has the same scoring as exact search for kNN section": - requires: reason: 'Quantized vector rescoring is required' test_runner_features: [capabilities] @@ -272,24 +272,6 @@ setup: capabilities: [knn_quantized_vector_rescore] - skip: features: "headers" - - do: - headers: - Content-Type: application/json - search: - rest_total_hits_as_int: true - index: int8_flat - body: - fields: [ "name" ] - knn: - field: vector - query_vector: [-0.5, 90.0, -10, 14.8, -156.0] - k: 3 - num_candidates: 3 - - - match: { hits.total: 3 } - - set: { hits.hits.0._id: knn_id0 } - - set: { hits.hits.1._id: knn_id1 } - - set: { hits.hits.2._id: knn_id2 } # Rescore - do: @@ -308,16 +290,8 @@ setup: rescore: oversample: 1.5 - # Comparing to knn search - - match: { hits.hits.0._id: $knn_id0 } - - match: { hits.hits.1._id: $knn_id1 } - - match: { hits.hits.2._id: $knn_id2 } - - # Get rescoring scores + # Get rescoring scores - hit ordering may change depending on how things are distributed - match: { hits.total: 3 } - - set: { hits.hits.0._id: rescore_id0 } - - set: { hits.hits.1._id: rescore_id1 } - - set: { hits.hits.2._id: rescore_id2 } - set: { hits.hits.0._score: rescore_score0 } - set: { hits.hits.1._score: rescore_score1 } - set: { hits.hits.2._score: rescore_score2 } @@ -336,12 +310,8 @@ setup: params: query_vector: [-0.5, 90.0, -10, 14.8, -156.0] - # Check same ordering (which will not be true for larger datasets) - # and scoring (which should be for the elements that are present in both) + # Get rescoring scores - hit ordering may change depending on how things are distributed - match: { hits.total: 3 } - - match: { hits.hits.0._id: $rescore_id0 } - - match: { hits.hits.1._id: $rescore_id1 } - - match: { hits.hits.2._id: $rescore_id2 } - match: { hits.hits.0._score: $rescore_score0 } - match: { hits.hits.1._score: $rescore_score1 } - match: { hits.hits.2._score: $rescore_score2 } From da018b861637d3be2cb3523e7b02ea5bfb849569 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 4 Dec 2024 18:05:06 +0100 Subject: [PATCH 31/56] Add tests for rescoring on non-quantized values --- .../test/search.vectors/40_knn_search.yml | 55 ++++++++++++++++++ .../search.vectors/42_knn_search_flat.yml | 55 ++++++++++++++++++ .../test/search.vectors/45_knn_search_bit.yml | 56 ++++++++++++++++++ .../search.vectors/45_knn_search_bit_flat.yml | 56 ++++++++++++++++++ .../search.vectors/45_knn_search_byte.yml | 57 +++++++++++++++++++ 5 files changed, 279 insertions(+) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml index b3d86a066550e..e3179a0065c2e 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml @@ -541,3 +541,58 @@ setup: num_candidates: 3 - match: { hits.total.value: 0 } +--- +"Vector rescoring has no effect for non-quantized vectors and provides same results as non-rescored knn": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Non-rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + + # Get scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: knn_score0 } + - set: { hits.hits.1._score: knn_score1 } + - set: { hits.hits.2._score: knn_score2 } + + # Rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + rescore: + oversample: 1.5 + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $knn_score0 } + - match: { hits.hits.1._score: $knn_score1 } + - match: { hits.hits.2._score: $knn_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml index 1b439967ba163..9258ce99a31aa 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml @@ -257,6 +257,61 @@ setup: - gte: {hits.hits.2._score: 0.78} - lte: {hits.hits.2._score: 0.791} --- +"Vector rescoring has no effect for non-quantized vectors and provides same results as non-rescored knn": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Non-rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: flat + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + + # Get scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: knn_score0 } + - set: { hits.hits.1._score: knn_score1 } + - set: { hits.hits.2._score: knn_score2 } + + # Rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: flat + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [-0.5, 90.0, -10, 14.8, -156.0] + k: 3 + num_candidates: 3 + rescore: + oversample: 1.5 + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $knn_score0 } + - match: { hits.hits.1._score: $knn_score1 } + - match: { hits.hits.2._score: $knn_score2 } +--- "Test bad parameters": - do: catch: bad_request diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml index 02576ad1b2b01..342b206f42a8b 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml @@ -405,3 +405,59 @@ setup: - match: {hits.hits.0._id: "1"} - match: {hits.hits.0._source.vector1: [2, -1, 1, 4, -3]} - match: {hits.hits.0._source.vector2: [2, -1, 1, 4, -3]} + +--- +"Vector rescoring has no effect for non-quantized vectors and provides same results as non-rescored knn": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Non-rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127.0, -128.0, 0.0, 1.0, -1.0] + k: 3 + num_candidates: 3 + + # Get scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: knn_score0 } + - set: { hits.hits.1._score: knn_score1 } + - set: { hits.hits.2._score: knn_score2 } + + # Rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127.0, -128.0, 0.0, 1.0, -1.0] + k: 3 + num_candidates: 3 + rescore: + oversample: 1.5 + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $knn_score0 } + - match: { hits.hits.1._score: $knn_score1 } + - match: { hits.hits.2._score: $knn_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml index ec7bde4de8435..dd8e544417fd4 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml @@ -221,3 +221,59 @@ setup: similarity: l2_norm index_options: type: int8_hnsw + +--- +"Vector rescoring has no effect for non-quantized vectors and provides same results as non-rescored knn": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Non-rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127, 127, -128, -128, 127] + k: 3 + num_candidates: 3 + + # Get scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: knn_score0 } + - set: { hits.hits.1._score: knn_score1 } + - set: { hits.hits.2._score: knn_score2 } + + # Rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127, 127, -128, -128, 127] + k: 3 + num_candidates: 3 + rescore: + oversample: 1.5 + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $knn_score0 } + - match: { hits.hits.1._score: $knn_score1 } + - match: { hits.hits.2._score: $knn_score2 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml index 0cedfaa873095..9b105813f7ec6 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml @@ -254,3 +254,60 @@ setup: filter: {"term": {"name": "cow.jpg"}} - length: {hits.hits: 0} + +--- +"Vector rescoring has no effect for non-quantized vectors and provides same results as non-rescored knn": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Non-rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127, 127, -128, -128, 127] + k: 3 + num_candidates: 3 + + # Get scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: knn_score0 } + - set: { hits.hits.1._score: knn_score1 } + - set: { hits.hits.2._score: knn_score2 } + + # Rescored knn + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127, 127, -128, -128, 127] + k: 3 + num_candidates: 3 + rescore: + oversample: 1.5 + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $knn_score0 } + - match: { hits.hits.1._score: $knn_score1 } + - match: { hits.hits.2._score: $knn_score2 } + From 1d5426b48c9e7caba29652cf0a7949974e238309 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 4 Dec 2024 18:15:44 +0100 Subject: [PATCH 32/56] having null index type means no quantization, as INT8 is explicit in mappings --- .../index/mapper/vectors/DenseVectorFieldMapper.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index e953462c9e490..1b3abfd6ebe3e 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2042,7 +2042,7 @@ public Query createKnnQuery( } private boolean needsRescore(Float rescoreOversample) { - return rescoreOversample != null && (indexOptions == null || indexOptions.type == null || indexOptions.type.isQuantized()); + return rescoreOversample != null && (indexOptions != null && indexOptions.type != null && indexOptions.type.isQuantized()); } private Query createKnnBitQuery( From 2379596c364312ea42d79f52c333e63dc9272074 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 4 Dec 2024 19:34:41 +0100 Subject: [PATCH 33/56] Bytes can't be quantized - remove all infra for byte vectors in rescoring --- .../vectors/DenseVectorFieldMapper.java | 31 +--- .../VectorSimilarityByteValueSource.java | 108 ------------- .../search/vectors/ESKnnByteVectorQuery.java | 4 - .../search/vectors/RescoreKnnVectorQuery.java | 50 +----- .../vectors/DenseVectorFieldTypeTests.java | 33 +--- .../vectors/RescoreKnnVectorQueryTests.java | 148 ++++-------------- 6 files changed, 49 insertions(+), 325 deletions(-) delete mode 100644 server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 1b3abfd6ebe3e..4b516434ae6c4 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2019,15 +2019,7 @@ public Query createKnnQuery( ); } return switch (getElementType()) { - case BYTE -> createKnnByteQuery( - queryVector.asByteVector(), - k, - numCands, - filter, - rescoreOversample, - similarityThreshold, - parentFilter - ); + case BYTE -> createKnnByteQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter); case FLOAT -> createKnnFloatQuery( queryVector.asFloatVector(), k, @@ -2072,7 +2064,6 @@ private Query createKnnByteQuery( Integer k, int numCands, Query filter, - Float rescoreOversample, Float similarityThreshold, BitSetProducer parentFilter ) { @@ -2082,16 +2073,9 @@ private Query createKnnByteQuery( float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude); } - Integer adjustedK = k; - int adjustedNumCands = numCands; - if (needsRescore(rescoreOversample) && adjustedK != null) { - adjustedK = Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample)); - adjustedNumCands = Math.max(adjustedK, numCands); - } - Query knnQuery = parentFilter != null - ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter) - : new ESKnnByteVectorQuery(name(), queryVector, adjustedK, adjustedNumCands, filter); + ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter) + : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter); if (similarityThreshold != null) { knnQuery = new VectorSimilarityQuery( knnQuery, @@ -2099,15 +2083,6 @@ private Query createKnnByteQuery( similarity.score(similarityThreshold, elementType, dims) ); } - if (needsRescore(rescoreOversample)) { - knnQuery = new RescoreKnnVectorQuery( - name(), - queryVector, - similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.BYTE), - k, - knnQuery - ); - } return knnQuery; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java deleted file mode 100644 index 33864fb76a310..0000000000000 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityByteValueSource.java +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.index.mapper.vectors; - -import org.apache.lucene.index.ByteVectorValues; -import org.apache.lucene.index.KnnVectorValues; -import org.apache.lucene.index.LeafReader; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.DoubleValues; -import org.apache.lucene.search.DoubleValuesSource; -import org.apache.lucene.search.IndexSearcher; -import org.elasticsearch.search.profile.query.QueryProfiler; -import org.elasticsearch.search.vectors.QueryProfilerProvider; - -import java.io.IOException; -import java.util.Arrays; -import java.util.Objects; - -/** - * DoubleValuesSource that is used to calculate scores according to a similarity function for a KnnByteVectorField, using the - * original vector values stored in the index - */ -public class VectorSimilarityByteValueSource extends DoubleValuesSource implements QueryProfilerProvider { - - private final String field; - private final byte[] target; - private final VectorSimilarityFunction vectorSimilarityFunction; - private long vectorOpsCount; - - public VectorSimilarityByteValueSource(String field, byte[] target, VectorSimilarityFunction vectorSimilarityFunction) { - this.field = field; - this.target = target; - this.vectorSimilarityFunction = vectorSimilarityFunction; - } - - @Override - public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException { - final LeafReader reader = ctx.reader(); - - ByteVectorValues vectorValues = reader.getByteVectorValues(field); - final KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); - - return new DoubleValues() { - private int docId = -1; - - @Override - public double doubleValue() throws IOException { - vectorOpsCount++; - return vectorSimilarityFunction.compare(target, vectorValues.vectorValue(docId)); - } - - @Override - public boolean advanceExact(int doc) throws IOException { - docId = doc; - return doc >= iterator.docID() && iterator.docID() != DocIdSetIterator.NO_MORE_DOCS && iterator.advance(doc) == doc; - } - }; - } - - @Override - public boolean needsScores() { - return false; - } - - @Override - public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException { - return this; - } - - @Override - public void profile(QueryProfiler queryProfiler) { - queryProfiler.addVectorOpsCount(vectorOpsCount); - } - - @Override - public int hashCode() { - return Objects.hash(field, Arrays.hashCode(target), vectorSimilarityFunction); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - VectorSimilarityByteValueSource that = (VectorSimilarityByteValueSource) o; - return Objects.equals(field, that.field) - && Objects.deepEquals(target, that.target) - && vectorSimilarityFunction == that.vectorSimilarityFunction; - } - - @Override - public String toString() { - return "VectorSimilarityByteValueSource(" + field + ", " + Arrays.toString(target) + ", " + vectorSimilarityFunction + ")"; - } - - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return false; - } -} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java index 5c199f42093b1..935bb42e75bbe 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java @@ -35,8 +35,4 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { public void profile(QueryProfiler queryProfiler) { queryProfiler.addVectorOpsCount(vectorOpsCount); } - - public Integer kParam() { - return kParam; - } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index e754ba0c05688..b45135b6bf4f3 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -18,7 +18,6 @@ import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; -import org.elasticsearch.index.mapper.vectors.VectorSimilarityByteValueSource; import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; import org.elasticsearch.search.profile.query.QueryProfiler; @@ -31,7 +30,6 @@ */ public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvider { private final String fieldName; - private final byte[] byteTarget; private final float[] floatTarget; private final VectorSimilarityFunction vectorSimilarityFunction; private final Integer k; @@ -39,16 +37,6 @@ public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvide private QueryProfilerProvider vectorProfiling; - public RescoreKnnVectorQuery( - String fieldName, - byte[] byteTarget, - VectorSimilarityFunction vectorSimilarityFunction, - Integer k, - Query innerQuery - ) { - this(fieldName, byteTarget, null, vectorSimilarityFunction, k, innerQuery, null); - } - public RescoreKnnVectorQuery( String fieldName, float[] floatTarget, @@ -56,41 +44,18 @@ public RescoreKnnVectorQuery( Integer k, Query innerQuery ) { - this(fieldName, null, floatTarget, vectorSimilarityFunction, k, innerQuery, null); - } - - private RescoreKnnVectorQuery( - String fieldName, - byte[] byteTarget, - float[] floatTarget, - VectorSimilarityFunction vectorSimilarityFunction, - Integer k, - Query innerQuery, - QueryProfilerProvider queryProfilerProvider - ) { - if ((byteTarget == null ^ floatTarget == null) == false) { - throw new IllegalArgumentException("Either byteTarget or floatTarget must be set"); - } - this.fieldName = fieldName; - this.byteTarget = byteTarget; this.floatTarget = floatTarget; this.vectorSimilarityFunction = vectorSimilarityFunction; this.k = k; this.innerQuery = innerQuery; - this.vectorProfiling = queryProfilerProvider; } @Override public Query rewrite(IndexSearcher searcher) throws IOException { - final DoubleValuesSource valueSource; - if (byteTarget != null) { - valueSource = new VectorSimilarityByteValueSource(fieldName, byteTarget, vectorSimilarityFunction); - } else { - valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); - } - // Vector similarity DoubleValueSource keep track of the compared vectors - we need that in case we don't need - // to calculate top k and return directly the query + DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction); + // Vector similarity VectorSimilarityFloatValueSource keep track of the compared vectors - we need that in case we don't need + // to calculate top k and return directly the query to understand how many comparisons were done vectorProfiling = (QueryProfilerProvider) valueSource; FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource); Query query = searcher.rewrite(functionScoreQuery); @@ -144,7 +109,6 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; RescoreKnnVectorQuery that = (RescoreKnnVectorQuery) o; return Objects.equals(fieldName, that.fieldName) - && Objects.deepEquals(byteTarget, that.byteTarget) && Objects.deepEquals(floatTarget, that.floatTarget) && vectorSimilarityFunction == that.vectorSimilarityFunction && Objects.equals(k, that.k) @@ -153,18 +117,14 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(fieldName, Arrays.hashCode(byteTarget), Arrays.hashCode(floatTarget), vectorSimilarityFunction, k, innerQuery); + return Objects.hash(fieldName, Arrays.hashCode(floatTarget), vectorSimilarityFunction, k, innerQuery); } @Override public String toString(String field) { final StringBuilder sb = new StringBuilder("KnnRescoreVectorQuery{"); sb.append("fieldName='").append(fieldName).append('\''); - if (byteTarget != null) { - sb.append(", byteTarget=").append(byteTarget[0]).append("..."); - } else { - sb.append(", floatTarget=").append(floatTarget[0]).append("..."); - } + sb.append(", floatTarget=").append(floatTarget[0]).append("..."); sb.append(", vectorSimilarityFunction=").append(vectorSimilarityFunction); sb.append(", k=").append(k); sb.append(", vectorQuery=").append(innerQuery); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 714ed282d06a3..5a92e9b920b41 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -20,11 +20,9 @@ import org.elasticsearch.index.mapper.FieldTypeTestCase; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorSimilarity; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.vectors.DenseVectorQuery; -import org.elasticsearch.search.vectors.ESKnnByteVectorQuery; import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery; import org.elasticsearch.search.vectors.RescoreKnnVectorQuery; import org.elasticsearch.search.vectors.VectorData; @@ -411,11 +409,10 @@ public void testByteCreateKnnQuery() { } public void testRescoreOversampleUsedWithoutQuantization() { - ElementType elementType = randomFrom(BYTE, FLOAT); DenseVectorFieldType nonQuantizedField = new DenseVectorFieldType( "f", IndexVersion.current(), - elementType, + FLOAT, 3, true, VectorSimilarity.COSINE, @@ -433,15 +430,9 @@ public void testRescoreOversampleUsedWithoutQuantization() { null ); - if (elementType == BYTE) { - ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery) knnQuery; - assertThat(esKnnQuery.getK(), is(100)); - assertThat(esKnnQuery.kParam(), is(10)); - } else { - ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) knnQuery; - assertThat(esKnnQuery.getK(), is(100)); - assertThat(esKnnQuery.kParam(), is(10)); - } + ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) knnQuery; + assertThat(esKnnQuery.getK(), is(100)); + assertThat(esKnnQuery.kParam(), is(10)); } public void testRescoreOversampleModifiesKnnParams() { @@ -475,18 +466,10 @@ private static void checkRescoreQueryParameters( int expectedCandidates ) { Query query = fieldType.createKnnQuery(new VectorData(null, new byte[] { 1, 4, 10 }), k, candidates, oversample, null, null, null); - RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query; - if (fieldType.getElementType() == BYTE) { - ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery) rescoreQuery.innerQuery(); - assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults)); - assertThat("Unexpected k parameter", esKnnQuery.kParam(), equalTo(expectedK)); - assertThat("Unexpected candidates", esKnnQuery.getK(), equalTo(expectedCandidates)); - } else { - ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) rescoreQuery.innerQuery(); - assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults)); - assertThat("Unexpected k parameter", esKnnQuery.kParam(), equalTo(expectedK)); - assertThat("Unexpected candidates", esKnnQuery.getK(), equalTo(expectedCandidates)); - } + ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) rescoreQuery.innerQuery(); + assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults)); + assertThat("Unexpected k parameter", esKnnQuery.kParam(), equalTo(expectedK)); + assertThat("Unexpected candidates", esKnnQuery.getK(), equalTo(expectedCandidates)); } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index cbaa8eb81bcf4..7bbe7dcc155c5 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -12,15 +12,12 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import org.apache.lucene.document.Document; -import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.KnnVectorValues; -import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; @@ -53,13 +50,10 @@ public class RescoreKnnVectorQueryTests extends ESTestCase { public static final String FIELD_NAME = "float_vector"; private final int numDocs; - private final VectorProvider vectorProvider; private final Integer k; - public RescoreKnnVectorQueryTests(VectorProvider vectorProvider, boolean useK) { - this.vectorProvider = vectorProvider; + public RescoreKnnVectorQueryTests(boolean useK) { this.numDocs = randomIntBetween(10, 100); - ; this.k = useK ? randomIntBetween(1, numDocs - 1) : null; } @@ -72,15 +66,17 @@ public void testRescoreDocs() throws Exception { } try (Directory d = newDirectory()) { - addRandomDocuments(numDocs, d, numDims, vectorProvider); + addRandomDocuments(numDocs, d, numDims); try (IndexReader reader = DirectoryReader.open(d)) { // Use a RescoreKnnVectorQuery with a match all query, to ensure we get scoring of 1 from the inner query // and thus we're rescoring the top k docs. - VectorData queryVector = vectorProvider.randomVector(numDims); - RescoreKnnVectorQuery rescoreKnnVectorQuery = vectorProvider.createRescoreQuery( + float[] queryVector = randomVector(numDims); + RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( + FIELD_NAME, queryVector, + VectorSimilarityFunction.COSINE, adjustedK, new MatchAllDocsQuery() ); @@ -98,11 +94,11 @@ public void testRescoreDocs() throws Exception { PriorityQueue topK = new PriorityQueue<>((o1, o2) -> Float.compare(o2, o1)); for (LeafReaderContext leafReaderContext : reader.leaves()) { - KnnVectorValues vectorValues = vectorProvider.vectorValues(leafReaderContext.reader()); + FloatVectorValues vectorValues = leafReaderContext.reader().getFloatVectorValues(FIELD_NAME); KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); while (iterator.nextDoc() != NO_MORE_DOCS) { - VectorData vectorData = vectorProvider.dataVectorForDoc(vectorValues, iterator.docID()); - float score = vectorProvider.score(queryVector, vectorData); + float[] vectorData = vectorValues.vectorValue(iterator.docID()); + float score = VectorSimilarityFunction.COSINE.compare(queryVector, vectorData); topK.add(score); int docId = iterator.docID(); // If the doc has been retrieved from the RescoreKnnVectorQuery, check the score is the same and remove it @@ -131,10 +127,10 @@ public void testProfiling() throws Exception { int numDims = randomIntBetween(5, 100); try (Directory d = newDirectory()) { - addRandomDocuments(numDocs, d, numDims, vectorProvider); + addRandomDocuments(numDocs, d, numDims); try (IndexReader reader = DirectoryReader.open(d)) { - VectorData queryVector = vectorProvider.randomVector(numDims); + float[] queryVector = randomVector(numDims); checkProfiling(queryVector, reader, new MatchAllDocsQuery()); checkProfiling(queryVector, reader, new MockQueryProfilerProvider(randomIntBetween(1, 100))); @@ -142,8 +138,14 @@ public void testProfiling() throws Exception { } } - private void checkProfiling(VectorData queryVector, IndexReader reader, Query innerQuery) throws IOException { - RescoreKnnVectorQuery rescoreKnnVectorQuery = vectorProvider.createRescoreQuery(queryVector, k, innerQuery); + private void checkProfiling(float[] queryVector, IndexReader reader, Query innerQuery) throws IOException { + RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery( + FIELD_NAME, + queryVector, + VectorSimilarityFunction.COSINE, + k, + innerQuery + ); IndexSearcher searcher = newSearcher(reader, true, false); searcher.search(rescoreKnnVectorQuery, numDocs); @@ -161,6 +163,14 @@ private void checkProfiling(VectorData queryVector, IndexReader reader, Query in assertThat(queryProfiler.getVectorOpsCount(), equalTo(expectedVectorOpsCount)); } + private static float[] randomVector(int numDimensions) { + float[] vector = new float[numDimensions]; + for (int j = 0; j < numDimensions; j++) { + vector[j] = randomFloatBetween(0, 1, true); + } + return vector; + } + /** * A mock query that is used to test profiling */ @@ -206,103 +216,13 @@ public void profile(QueryProfiler queryProfiler) { } } - /** - * Vector operations depend on the type of vector field used. This interface abstracts the operations needed to perform the tests - */ - private interface VectorProvider { - VectorData randomVector(int numDimensions); - - RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k, Query innerQuery); - - KnnVectorValues vectorValues(LeafReader leafReader) throws IOException; - - void addVectorField(Document document, VectorData vector); - - VectorData dataVectorForDoc(KnnVectorValues vectorValues, int docId) throws IOException; - - float score(VectorData queryVector, VectorData dataVector); - } - - private static class FloatVectorProvider implements VectorProvider { - @Override - public VectorData randomVector(int numDimensions) { - float[] vector = new float[numDimensions]; - for (int j = 0; j < numDimensions; j++) { - vector[j] = randomFloatBetween(0, 1, true); - } - return VectorData.fromFloats(vector); - } - - @Override - public RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k, Query innerQuery) { - return new RescoreKnnVectorQuery(FIELD_NAME, queryVector.floatVector(), VectorSimilarityFunction.COSINE, k, innerQuery); - } - - @Override - public KnnVectorValues vectorValues(LeafReader leafReader) throws IOException { - return leafReader.getFloatVectorValues(FIELD_NAME); - } - - @Override - public void addVectorField(Document document, VectorData vector) { - KnnFloatVectorField vectorField = new KnnFloatVectorField(FIELD_NAME, vector.floatVector()); - document.add(vectorField); - } - - @Override - public VectorData dataVectorForDoc(KnnVectorValues vectorValues, int docId) throws IOException { - return VectorData.fromFloats(((FloatVectorValues) vectorValues).vectorValue(docId)); - } - - @Override - public float score(VectorData queryVector, VectorData dataVector) { - return VectorSimilarityFunction.COSINE.compare(queryVector.floatVector(), dataVector.floatVector()); - } - } - - private static class ByteVectorProvider implements VectorProvider { - @Override - public VectorData randomVector(int numDimensions) { - byte[] vector = new byte[numDimensions]; - for (int j = 0; j < numDimensions; j++) { - vector[j] = randomByte(); - } - return VectorData.fromBytes(vector); - } - - @Override - public RescoreKnnVectorQuery createRescoreQuery(VectorData queryVector, Integer k, Query innerQuery) { - return new RescoreKnnVectorQuery(FIELD_NAME, queryVector.byteVector(), VectorSimilarityFunction.COSINE, k, innerQuery); - } - - @Override - public KnnVectorValues vectorValues(LeafReader leafReader) throws IOException { - return leafReader.getByteVectorValues(FIELD_NAME); - } - - @Override - public void addVectorField(Document document, VectorData vector) { - KnnByteVectorField vectorField = new KnnByteVectorField(FIELD_NAME, vector.byteVector()); - document.add(vectorField); - } - - @Override - public VectorData dataVectorForDoc(KnnVectorValues vectorValues, int docId) throws IOException { - return VectorData.fromBytes(((ByteVectorValues) vectorValues).vectorValue(docId)); - } - - @Override - public float score(VectorData queryVector, VectorData dataVector) { - return VectorSimilarityFunction.COSINE.compare(queryVector.byteVector(), dataVector.byteVector()); - } - } - - private static void addRandomDocuments(int numDocs, Directory d, int numDims, VectorProvider vectorProvider) throws IOException { + private static void addRandomDocuments(int numDocs, Directory d, int numDims) throws IOException { try (IndexWriter w = new IndexWriter(d, newIndexWriterConfig())) { for (int i = 0; i < numDocs; i++) { Document document = new Document(); - VectorData vector = vectorProvider.randomVector(numDims); - vectorProvider.addVectorField(document, vector); + float[] vector = randomVector(numDims); + KnnFloatVectorField vectorField = new KnnFloatVectorField(FIELD_NAME, vector); + document.add(vectorField); w.addDocument(document); } w.commit(); @@ -313,10 +233,8 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims, Ve @ParametersFactory public static Iterable parameters() { List params = new ArrayList<>(); - params.add(new Object[] { new FloatVectorProvider(), true }); - params.add(new Object[] { new FloatVectorProvider(), false }); - params.add(new Object[] { new ByteVectorProvider(), true }); - params.add(new Object[] { new ByteVectorProvider(), false }); + params.add(new Object[] { true }); + params.add(new Object[] { false }); return params; } From b0c2221e315b4daa513835fbd8f527fdf13517d4 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 4 Dec 2024 19:36:06 +0100 Subject: [PATCH 34/56] Add assertion to advanceExact() --- .../index/mapper/vectors/VectorSimilarityFloatValueSource.java | 1 + 1 file changed, 1 insertion(+) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java index c8e128f7fe3fe..540533a1b2b7d 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java @@ -60,6 +60,7 @@ public double doubleValue() throws IOException { @Override public boolean advanceExact(int doc) throws IOException { + assert doc > iterator.docID(); docId = doc; return doc >= iterator.docID() && iterator.docID() != DocIdSetIterator.NO_MORE_DOCS && iterator.advance(doc) == doc; } From 9f4cebd296a2c8c829a0c30cb088cd8e520af150 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 4 Dec 2024 19:38:45 +0100 Subject: [PATCH 35/56] Parsing / toXContent improvements --- .../elasticsearch/search/retriever/KnnRetrieverBuilder.java | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java index 63422b60534bd..e8db0f270a10c 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java @@ -97,7 +97,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { optionalConstructorArg(), (p, c) -> RescoreVectorBuilder.fromXContent(p), RESCORE_FIELD, - ObjectParser.ValueType.OBJECT_OR_NULL + ObjectParser.ValueType.OBJECT ); RetrieverBuilder.declareBaseParserFields(NAME, PARSER); } @@ -282,9 +282,7 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept } if (rescoreVectorBuilder != null) { - builder.startObject(RESCORE_FIELD.getPreferredName()); - rescoreVectorBuilder.toXContent(builder, params); - builder.endObject(); + builder.field(RESCORE_FIELD.getPreferredName(), rescoreVectorBuilder); } } From fab7395b54f368e3a0c9513ea584eceadd945999 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 4 Dec 2024 20:11:14 +0100 Subject: [PATCH 36/56] private access for field, use getter instead --- .../search/vectors/KnnSearchBuilder.java | 4 ++-- .../search/vectors/KnnSearchBuilderTests.java | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java index 81646f4a84d10..9f5466a2e4875 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java @@ -90,7 +90,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea optionalConstructorArg(), (p, c) -> RescoreVectorBuilder.fromXContent(p), RESCORE_FIELD, - ObjectParser.ValueType.OBJECT_OR_NULL + ObjectParser.ValueType.OBJECT ); PARSER.declareFieldArray( KnnSearchBuilder.Builder::addFilterQueries, @@ -123,7 +123,7 @@ public static KnnSearchBuilder.Builder fromXContent(XContentParser parser) throw String queryName; float boost = DEFAULT_BOOST; InnerHitBuilder innerHitBuilder; - final RescoreVectorBuilder rescoreVectorBuilder; + private final RescoreVectorBuilder rescoreVectorBuilder; /** * Defines a kNN search. diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java index 1a54d7c8420da..5e4d5f3b79eea 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java @@ -118,7 +118,7 @@ protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) { instance.queryVector, instance.k, instance.numCands, - instance.rescoreVectorBuilder, + instance.getRescoreVectorBuilder(), instance.similarity ).boost(instance.boost); case 1: @@ -128,7 +128,7 @@ protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) { newVector, instance.k, instance.numCands, - instance.rescoreVectorBuilder, + instance.getRescoreVectorBuilder(), instance.similarity ).boost(instance.boost); case 2: @@ -139,7 +139,7 @@ protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) { instance.queryVector, newK, instance.numCands, - instance.rescoreVectorBuilder, + instance.getRescoreVectorBuilder(), instance.similarity ).boost(instance.boost); case 3: @@ -149,7 +149,7 @@ protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) { instance.queryVector, instance.k, newNumCands, - instance.rescoreVectorBuilder, + instance.getRescoreVectorBuilder(), instance.similarity ).boost(instance.boost); case 4: @@ -158,7 +158,7 @@ protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) { instance.queryVector, instance.k, instance.numCands, - instance.rescoreVectorBuilder, + instance.getRescoreVectorBuilder(), instance.similarity ).addFilterQueries(instance.filterQueries) .addFilterQuery(QueryBuilders.termQuery("new_field", "new-value")) @@ -170,7 +170,7 @@ protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) { instance.queryVector, instance.k, instance.numCands, - instance.rescoreVectorBuilder, + instance.getRescoreVectorBuilder(), instance.similarity ).addFilterQueries(instance.filterQueries).boost(newBoost); case 6: @@ -179,7 +179,7 @@ protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) { instance.queryVector, instance.k, instance.numCands, - instance.rescoreVectorBuilder, + instance.getRescoreVectorBuilder(), randomValueOtherThan(instance.similarity, ESTestCase::randomFloat) ).addFilterQueries(instance.filterQueries).boost(instance.boost); case 7: @@ -189,7 +189,7 @@ protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) { instance.k, instance.numCands, randomValueOtherThan( - instance.rescoreVectorBuilder, + instance.getRescoreVectorBuilder(), () -> new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)) ), instance.similarity @@ -288,7 +288,7 @@ public void testRewrite() throws Exception { assertThat(rewritten.filterQueries, hasSize(1)); assertThat(rewritten.similarity, equalTo(1f)); assertThat(((RewriteableQuery) rewritten.filterQueries.get(0)).rewrites, equalTo(1)); - assertThat(rewritten.rescoreVectorBuilder, equalTo(expectedRescore)); + assertThat(rewritten.getRescoreVectorBuilder(), equalTo(expectedRescore)); } public static float[] randomVector(int dim) { From e03c8e96007879a9593abc1e12758ceaead42929 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 4 Dec 2024 20:12:54 +0100 Subject: [PATCH 37/56] toXContent improvements --- .../org/elasticsearch/search/vectors/KnnSearchBuilder.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java index 9f5466a2e4875..0f89dd44367f0 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java @@ -547,9 +547,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(NAME_FIELD.getPreferredName(), queryName); } if (rescoreVectorBuilder != null) { - builder.startObject(RESCORE_FIELD.getPreferredName()); - rescoreVectorBuilder.toXContent(builder, params); - builder.endObject(); + builder.field(RESCORE_FIELD.getPreferredName(), rescoreVectorBuilder); } return builder; From 69b545120e9017ca54923155f679ef6b0c626b13 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 4 Dec 2024 20:14:44 +0100 Subject: [PATCH 38/56] Fix toXContent / parsing --- .../elasticsearch/search/vectors/KnnVectorQueryBuilder.java | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index b250f6484bda6..e737e354d53a7 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -105,7 +105,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder RescoreVectorBuilder.fromXContent(p), RESCORE_FIELD, - ObjectParser.ValueType.OBJECT_OR_NULL + ObjectParser.ValueType.OBJECT ); PARSER.declareFieldArray( KnnVectorQueryBuilder::addFilterQueries, @@ -406,9 +406,7 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep builder.endArray(); } if (rescoreVectorBuilder != null) { - builder.startObject(RESCORE_FIELD.getPreferredName()); - rescoreVectorBuilder.toXContent(builder, params); - builder.endObject(); + builder.field(RESCORE_FIELD.getPreferredName(), rescoreVectorBuilder); } boostAndQueryNameToXContent(builder); builder.endObject(); From ddc60947ec76568e1ca2162b2cc956dc3528dc88 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 5 Dec 2024 09:25:32 +0100 Subject: [PATCH 39/56] Make rescore parameter mandatory --- .../search/vectors/RescoreVectorBuilder.java | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java index 2471b29dc1193..4273a85b584f4 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java @@ -31,12 +31,11 @@ public class RescoreVectorBuilder implements Writeable, ToXContentObject { ); static { - PARSER.declareFloat(ConstructingObjectParser.optionalConstructorArg(), OVERSAMPLE_FIELD); + PARSER.declareFloat(ConstructingObjectParser.constructorArg(), OVERSAMPLE_FIELD); } // Oversample is required as of now as it is the only field in the rescore vector - // that may change in the future, so we treat it as optional - private final Float oversample; + private final float oversample; public RescoreVectorBuilder(Float oversample) { Objects.requireNonNull(oversample, "[" + OVERSAMPLE_FIELD.getPreferredName() + "] must be set"); @@ -47,20 +46,19 @@ public RescoreVectorBuilder(Float oversample) { } public RescoreVectorBuilder(StreamInput in) throws IOException { - this.oversample = in.readOptionalFloat(); + this.oversample = in.readFloat(); } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalFloat(oversample); + out.writeFloat(oversample); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - if (oversample != null) { - builder.field(OVERSAMPLE_FIELD.getPreferredName(), oversample); - } - + builder.startObject(); + builder.field(OVERSAMPLE_FIELD.getPreferredName(), oversample); + builder.endObject(); return builder; } From 3ef07fa9ca2b5ef58583f2589baa76b0571296ab Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 5 Dec 2024 09:26:47 +0100 Subject: [PATCH 40/56] Add index types to vector query builder tests --- ...AbstractKnnVectorQueryBuilderTestCase.java | 50 +++++++++++++++---- .../KnnByteVectorQueryBuilderTests.java | 5 ++ .../KnnFloatVectorQueryBuilderTests.java | 5 ++ 3 files changed, 50 insertions(+), 10 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index c75058a977364..7ab1a93822094 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -41,6 +41,9 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; @@ -53,7 +56,19 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCase { private static final String VECTOR_FIELD = "vector"; private static final String VECTOR_ALIAS_FIELD = "vector_alias"; - static final int VECTOR_DIMENSION = 3; + protected final String indexType = indexType(); + protected final int VECTOR_DIMENSION = indexType.contains("bbq") ? 64 : 3; + protected static final Set QUANTIZED_INDEX_TYPES = Set.of( + "int8_hnsw", + "int4_hnsw", + "bbq_hnsw", + "int8_flat", + "int4_flat", + "bbq_flat" + ); + protected static final Set NON_QUANTIZED_INDEX_TYPES = Set.of("hnsw", "flat"); + protected static final Set ALL_INDEX_TYPES = Stream.concat(QUANTIZED_INDEX_TYPES.stream(), NON_QUANTIZED_INDEX_TYPES.stream()) + .collect(Collectors.toUnmodifiableSet()); abstract DenseVectorFieldMapper.ElementType elementType(); @@ -65,8 +80,15 @@ abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder( Float similarity ); + protected boolean isQuantizedElementType() { + return QUANTIZED_INDEX_TYPES.contains(indexType()); + } + + protected abstract String indexType(); + @Override protected void initializeAdditionalMappings(MapperService mapperService) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() .startObject("properties") @@ -76,6 +98,9 @@ protected void initializeAdditionalMappings(MapperService mapperService) throws .field("index", true) .field("similarity", "l2_norm") .field("element_type", elementType()) + .startObject("index_options") + .field("type", indexType) + .endObject() .endObject() .startObject(VECTOR_ALIAS_FIELD) .field("type", "alias") @@ -126,7 +151,7 @@ protected RescoreVectorBuilder randomRescoreVectorBuilder() { @Override protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query query, SearchExecutionContext context) throws IOException { - if (queryBuilder.rescoreVectorBuilder() != null) { + if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) { RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query; query = rescoreQuery.innerQuery(); } @@ -154,7 +179,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que // The field should always be resolved to the concrete field Integer k = queryBuilder.k(); Integer numCands = queryBuilder.numCands(); - if (queryBuilder.rescoreVectorBuilder() != null) { + if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) { Float rescoreOversample = queryBuilder.rescoreVectorBuilder().oversample(); k = k == null ? null : Integer.valueOf(Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample))); numCands = numCands == null ? null : Math.max(k == null ? 0 : k, numCands); @@ -330,19 +355,24 @@ public void testBWCVersionSerializationQuery() throws IOException { public void testBWCVersionSerializationRescoreVector() throws IOException { KnnVectorQueryBuilder query = createTestQueryBuilder(); + TransportVersion version = TransportVersionUtils.randomVersionBetween( + random(), + TransportVersions.V_8_8_1, + TransportVersionUtils.getPreviousVersion(TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE) + ); + VectorData vectorData = version.onOrAfter(TransportVersions.V_8_14_0) + ? query.queryVector() + : VectorData.fromFloats(query.queryVector().asFloatVector()); + Integer k = version.before(TransportVersions.V_8_15_0) ? null : query.k(); KnnVectorQueryBuilder queryNoRescoreVector = new KnnVectorQueryBuilder( query.getFieldName(), - query.queryVector(), - query.k(), + vectorData, + k, query.numCands(), null, query.getVectorSimilarity() ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries()); - assertBWCSerialization( - query, - queryNoRescoreVector, - TransportVersionUtils.randomVersionBetween(random(), TransportVersions.V_8_8_0, TransportVersions.KNN_QUERY_RESCORE_OVERSAMPLE) - ); + assertBWCSerialization(query, queryNoRescoreVector, version); } private void assertBWCSerialization(QueryBuilder newQuery, QueryBuilder bwcQuery, TransportVersion version) throws IOException { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java index 980e506c0ca35..9986ecaa96160 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java @@ -31,4 +31,9 @@ protected KnnVectorQueryBuilder createKnnVectorQueryBuilder( } return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, rescoreVectorBuilder, similarity); } + + @Override + protected String indexType() { + return randomFrom(NON_QUANTIZED_INDEX_TYPES); + } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java index 75b1f395c57e7..2e628557ed16a 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java @@ -31,4 +31,9 @@ KnnVectorQueryBuilder createKnnVectorQueryBuilder( } return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, rescoreVectorBuilder, similarity); } + + @Override + protected String indexType() { + return randomFrom(ALL_INDEX_TYPES); + } } From fd9188fd7476d179192ecc7b235092361fd80594 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 5 Dec 2024 10:01:28 +0100 Subject: [PATCH 41/56] Fix compilation on rrf plugin --- .../org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java index fb326d63a69cc..6854fc436038f 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java @@ -764,6 +764,7 @@ public void testRRFFiltersPropagatedToKnnQueryVectorBuilder() { new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(new float[] { 3 }), 10, 10, + null, null ); source.retriever( From 863b6e631afee3e9fa9ddd0f8d85a630b0c09301 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 5 Dec 2024 11:20:33 +0100 Subject: [PATCH 42/56] Fix tests for vector query builders to ensure multiple dimensions / index types can be used --- ...AbstractKnnVectorQueryBuilderTestCase.java | 34 +++++++++++++++---- .../KnnByteVectorQueryBuilderTests.java | 4 +-- .../KnnFloatVectorQueryBuilderTests.java | 4 +-- 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index 7ab1a93822094..83a66a94945a4 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -37,6 +37,7 @@ import org.elasticsearch.test.TransportVersionUtils; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; +import org.junit.Before; import java.io.IOException; import java.util.ArrayList; @@ -56,8 +57,6 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCase { private static final String VECTOR_FIELD = "vector"; private static final String VECTOR_ALIAS_FIELD = "vector_alias"; - protected final String indexType = indexType(); - protected final int VECTOR_DIMENSION = indexType.contains("bbq") ? 64 : 3; protected static final Set QUANTIZED_INDEX_TYPES = Set.of( "int8_hnsw", "int4_hnsw", @@ -69,6 +68,15 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa protected static final Set NON_QUANTIZED_INDEX_TYPES = Set.of("hnsw", "flat"); protected static final Set ALL_INDEX_TYPES = Stream.concat(QUANTIZED_INDEX_TYPES.stream(), NON_QUANTIZED_INDEX_TYPES.stream()) .collect(Collectors.toUnmodifiableSet()); + protected static String indexType; + protected static int vectorDimensions; + + @Before + private void checkIndexTypeAndDimensions() { + // Check that these are initialized - should be done as part of the createAdditionalMappings method + assertNotNull(indexType); + assertNotEquals(0, vectorDimensions); + } abstract DenseVectorFieldMapper.ElementType elementType(); @@ -81,20 +89,32 @@ abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder( ); protected boolean isQuantizedElementType() { - return QUANTIZED_INDEX_TYPES.contains(indexType()); + return QUANTIZED_INDEX_TYPES.contains(indexType); } - protected abstract String indexType(); + protected abstract String randomIndexType(); @Override protected void initializeAdditionalMappings(MapperService mapperService) throws IOException { + // These fields are initialized here, as mappings are initialized only once per test class. + // We want the subclasses to be able to override the index type and vector dimensions so we don't make this static / BeforeClass + // for initialization. + indexType = randomIndexType(); + if (indexType.contains("bbq")) { + vectorDimensions = 64; + } else if (indexType.contains("int4")) { + vectorDimensions = 4; + } else { + vectorDimensions = 3; + } + XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() .startObject("properties") .startObject(VECTOR_FIELD) .field("type", "dense_vector") - .field("dims", VECTOR_DIMENSION) + .field("dims", vectorDimensions) .field("index", true) .field("similarity", "l2_norm") .field("element_type", elementType()) @@ -201,7 +221,7 @@ public void testWrongDimension() { IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context)); assertThat( e.getMessage(), - containsString("The query vector has a different number of dimensions [2] than the document vectors [3]") + containsString("The query vector has a different number of dimensions [2] than the document vectors [" + vectorDimensions + "]") ); } @@ -286,7 +306,7 @@ public void testMustRewrite() throws IOException { KnnVectorQueryBuilder query = new KnnVectorQueryBuilder( VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, - VECTOR_DIMENSION, + vectorDimensions, null, null, null diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java index 9986ecaa96160..f6c2e754cec63 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java @@ -25,7 +25,7 @@ protected KnnVectorQueryBuilder createKnnVectorQueryBuilder( RescoreVectorBuilder rescoreVectorBuilder, Float similarity ) { - byte[] vector = new byte[VECTOR_DIMENSION]; + byte[] vector = new byte[vectorDimensions]; for (int i = 0; i < vector.length; i++) { vector[i] = randomByte(); } @@ -33,7 +33,7 @@ protected KnnVectorQueryBuilder createKnnVectorQueryBuilder( } @Override - protected String indexType() { + protected String randomIndexType() { return randomFrom(NON_QUANTIZED_INDEX_TYPES); } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java index 2e628557ed16a..6f67e4be29a06 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java @@ -25,7 +25,7 @@ KnnVectorQueryBuilder createKnnVectorQueryBuilder( RescoreVectorBuilder rescoreVectorBuilder, Float similarity ) { - float[] vector = new float[VECTOR_DIMENSION]; + float[] vector = new float[vectorDimensions]; for (int i = 0; i < vector.length; i++) { vector[i] = randomFloat(); } @@ -33,7 +33,7 @@ KnnVectorQueryBuilder createKnnVectorQueryBuilder( } @Override - protected String indexType() { + protected String randomIndexType() { return randomFrom(ALL_INDEX_TYPES); } } From 9c9773f75916f04f8b9bd9d8b6396a75e5847776 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 5 Dec 2024 12:03:56 +0100 Subject: [PATCH 43/56] Fix test to use just floats --- .../mapper/vectors/DenseVectorFieldTypeTests.java | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 5a92e9b920b41..3c39687c22cf6 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -412,7 +412,7 @@ public void testRescoreOversampleUsedWithoutQuantization() { DenseVectorFieldType nonQuantizedField = new DenseVectorFieldType( "f", IndexVersion.current(), - FLOAT, + randomFrom(FLOAT, BYTE), 3, true, VectorSimilarity.COSINE, @@ -439,7 +439,7 @@ public void testRescoreOversampleModifiesKnnParams() { DenseVectorFieldType fieldType = new DenseVectorFieldType( "f", IndexVersion.current(), - randomFrom(BYTE, FLOAT), + FLOAT, 3, true, VectorSimilarity.COSINE, @@ -465,7 +465,15 @@ private static void checkRescoreQueryParameters( int expectedK, int expectedCandidates ) { - Query query = fieldType.createKnnQuery(new VectorData(null, new byte[] { 1, 4, 10 }), k, candidates, oversample, null, null, null); + Query query = fieldType.createKnnQuery( + VectorData.fromFloats(new float[] { 1, 4, 10 }), + k, + candidates, + oversample, + null, + null, + null + ); RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query; ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) rescoreQuery.innerQuery(); assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults)); From 7d083b72b2f63146a3a7e2d886feb336d07b11e2 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 5 Dec 2024 13:40:19 +0100 Subject: [PATCH 44/56] Fix test, add coverage for byte element types --- .../search/vectors/ESKnnByteVectorQuery.java | 4 ++++ .../vectors/DenseVectorFieldTypeTests.java | 16 ++++++++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java index 935bb42e75bbe..5c199f42093b1 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java @@ -35,4 +35,8 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { public void profile(QueryProfiler queryProfiler) { queryProfiler.addVectorOpsCount(vectorOpsCount); } + + public Integer kParam() { + return kParam; + } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 3c39687c22cf6..a5c1e8201376f 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorSimilarity; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.vectors.DenseVectorQuery; +import org.elasticsearch.search.vectors.ESKnnByteVectorQuery; import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery; import org.elasticsearch.search.vectors.RescoreKnnVectorQuery; import org.elasticsearch.search.vectors.VectorData; @@ -409,10 +410,11 @@ public void testByteCreateKnnQuery() { } public void testRescoreOversampleUsedWithoutQuantization() { + DenseVectorFieldMapper.ElementType elementType = randomFrom(FLOAT, BYTE); DenseVectorFieldType nonQuantizedField = new DenseVectorFieldType( "f", IndexVersion.current(), - randomFrom(FLOAT, BYTE), + elementType, 3, true, VectorSimilarity.COSINE, @@ -430,9 +432,15 @@ public void testRescoreOversampleUsedWithoutQuantization() { null ); - ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) knnQuery; - assertThat(esKnnQuery.getK(), is(100)); - assertThat(esKnnQuery.kParam(), is(10)); + if (elementType == BYTE) { + ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery) knnQuery; + assertThat(esKnnQuery.getK(), is(100)); + assertThat(esKnnQuery.kParam(), is(10)); + } else { + ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) knnQuery; + assertThat(esKnnQuery.getK(), is(100)); + assertThat(esKnnQuery.kParam(), is(10)); + } } public void testRescoreOversampleModifiesKnnParams() { From a8633e523ae7e1bbce110247023f45c658cafb20 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 5 Dec 2024 14:56:27 +0100 Subject: [PATCH 45/56] Fix YAML test capabilities, add profile test for similarity --- .../search.vectors/210_knn_search_profile.yml | 45 ++++++++++++++----- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml index 5d9f14c3e353f..b642f2def53bf 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml @@ -1,7 +1,14 @@ setup: - requires: - cluster_features: "mapper.vectors.bbq" - reason: 'kNN float to better-binary quantization is required' + reason: 'Quantized vector rescoring is required' + test_runner_features: [ capabilities ] + capabilities: + - method: GET + path: /_search + capabilities: [ knn_quantized_vector_rescore ] + - skip: + features: "headers" + - do: indices.create: index: bbq_hnsw @@ -34,8 +41,8 @@ setup: id: "1" body: name: cow.jpg - vector: [300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0] - another_vector: [115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0] + vector: [ 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0 ] + another_vector: [ 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0 ] # Flush in order to provoke a merge later - do: indices.flush: @@ -47,8 +54,8 @@ setup: id: "2" body: name: moose.jpg - vector: [100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0] - another_vector: [50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120] + vector: [ 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0 ] + another_vector: [ 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120 ] # Flush in order to provoke a merge later - do: indices.flush: @@ -60,8 +67,8 @@ setup: id: "3" body: name: rabbit.jpg - vector: [111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0] - another_vector: [11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0] + vector: [ 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0 ] + another_vector: [ 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0 ] # Flush in order to provoke a merge later - do: indices.flush: @@ -81,7 +88,7 @@ setup: profile: true knn: field: vector - query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] + query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0 ] k: 3 num_candidates: 3 @@ -94,9 +101,27 @@ setup: profile: true knn: field: vector - query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] + query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0 ] + k: 3 + num_candidates: 3 + "rescore": + "oversample": 2.0 + + # We expect the knn search ops + rescoring num_cnaidates (for rescoring) per shard + - match: { profile.shards.0.dfs.knn.0.vector_operations_count: 6 } + + # Search with similarity to check number of operations are propagated correctly + - do: + search: + index: bbq_hnsw + body: + profile: true + knn: + field: vector + query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0 ] k: 3 num_candidates: 3 + similarity: 100000.0 "rescore": "oversample": 2.0 From ef8ac8cdf8a092d1ff6167458aac5daecd50e5be Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 5 Dec 2024 16:47:23 +0100 Subject: [PATCH 46/56] Rename "rescore" to "rescore_vector" --- .../test/search.vectors/210_knn_search_profile.yml | 4 ++-- .../rest-api-spec/test/search.vectors/40_knn_search.yml | 2 +- .../test/search.vectors/41_knn_search_bbq_hnsw.yml | 2 +- .../test/search.vectors/41_knn_search_byte_quantized.yml | 2 +- .../search.vectors/41_knn_search_half_byte_quantized.yml | 2 +- .../test/search.vectors/42_knn_search_bbq_flat.yml | 2 +- .../test/search.vectors/42_knn_search_flat.yml | 2 +- .../test/search.vectors/42_knn_search_int4_flat.yml | 2 +- .../test/search.vectors/42_knn_search_int8_flat.yml | 2 +- .../rest-api-spec/test/search.vectors/45_knn_search_bit.yml | 2 +- .../test/search.vectors/45_knn_search_bit_flat.yml | 2 +- .../test/search.vectors/45_knn_search_byte.yml | 2 +- .../elasticsearch/search/retriever/KnnRetrieverBuilder.java | 6 +++--- .../org/elasticsearch/search/vectors/KnnSearchBuilder.java | 6 +++--- .../elasticsearch/search/vectors/KnnVectorQueryBuilder.java | 6 +++--- .../elasticsearch/search/vectors/RescoreVectorBuilder.java | 2 +- 16 files changed, 23 insertions(+), 23 deletions(-) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml index b642f2def53bf..ee66e1aa92213 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml @@ -104,7 +104,7 @@ setup: query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0 ] k: 3 num_candidates: 3 - "rescore": + "rescore_vector": "oversample": 2.0 # We expect the knn search ops + rescoring num_cnaidates (for rescoring) per shard @@ -122,7 +122,7 @@ setup: k: 3 num_candidates: 3 similarity: 100000.0 - "rescore": + "rescore_vector": "oversample": 2.0 # We expect the knn search ops + rescoring num_cnaidates (for rescoring) per shard diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml index e3179a0065c2e..82b1d0740070b 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml @@ -588,7 +588,7 @@ setup: query_vector: [-0.5, 90.0, -10, 14.8, -156.0] k: 3 num_candidates: 3 - rescore: + rescore_vector: oversample: 1.5 # Compare scores as hit IDs may change depending on how things are distributed diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index 92b882892cee4..f4cdd8b50ac4b 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -112,7 +112,7 @@ setup: query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] k: 3 num_candidates: 3 - rescore: + rescore_vector: oversample: 1.5 # Get rescoring scores - hit ordering may change depending on how things are distributed diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml index 7ff20e1beb8a5..fbf6c7c5bfdca 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml @@ -397,7 +397,7 @@ setup: num_candidates: 3 field: vector query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] - rescore: + rescore_vector: oversample: 1.5 # Get rescoring scores - hit ordering may change depending on how things are distributed diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml index 867e47624873f..b170cd66ceefb 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml @@ -574,7 +574,7 @@ setup: query_vector: [-0.5, 90.0, -10, 14.8] k: 3 num_candidates: 3 - rescore: + rescore_vector: oversample: 1.5 # Get rescoring scores - hit ordering may change depending on how things are distributed diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml index 48560747365eb..f28ca559b423b 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml @@ -111,7 +111,7 @@ setup: query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] k: 3 num_candidates: 3 - rescore: + rescore_vector: oversample: 1.5 # Get rescoring scores - hit ordering may change depending on how things are distributed diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml index 9258ce99a31aa..f0529f4209288 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml @@ -303,7 +303,7 @@ setup: query_vector: [-0.5, 90.0, -10, 14.8, -156.0] k: 3 num_candidates: 3 - rescore: + rescore_vector: oversample: 1.5 # Compare scores as hit IDs may change depending on how things are distributed diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml index 3be3cf70cdc69..dbe751c87084b 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml @@ -370,7 +370,7 @@ setup: query_vector: [-0.5, 90.0, -10, 14.8] k: 3 num_candidates: 3 - rescore: + rescore_vector: oversample: 1.5 # Get rescoring scores - hit ordering may change depending on how things are distributed diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml index 282dd7df1f038..def124a8baab8 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml @@ -287,7 +287,7 @@ setup: query_vector: [-0.5, 90.0, -10, 14.8, -156.0] k: 3 num_candidates: 3 - rescore: + rescore_vector: oversample: 1.5 # Get rescoring scores - hit ordering may change depending on how things are distributed diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml index 342b206f42a8b..0a8f045ea2e77 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml @@ -453,7 +453,7 @@ setup: query_vector: [127.0, -128.0, 0.0, 1.0, -1.0] k: 3 num_candidates: 3 - rescore: + rescore_vector: oversample: 1.5 # Compare scores as hit IDs may change depending on how things are distributed diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml index dd8e544417fd4..65836416558b7 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml @@ -269,7 +269,7 @@ setup: query_vector: [127, 127, -128, -128, 127] k: 3 num_candidates: 3 - rescore: + rescore_vector: oversample: 1.5 # Compare scores as hit IDs may change depending on how things are distributed diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml index 9b105813f7ec6..2f637d9ca3e6d 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml @@ -302,7 +302,7 @@ setup: query_vector: [127, 127, -128, -128, 127] k: 3 num_candidates: 3 - rescore: + rescore_vector: oversample: 1.5 # Compare scores as hit IDs may change depending on how things are distributed diff --git a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java index 65a1b3b7e9218..b29546ded75cd 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java @@ -54,7 +54,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector"); public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder"); public static final ParseField VECTOR_SIMILARITY = new ParseField("similarity"); - public static final ParseField RESCORE_FIELD = new ParseField("rescore"); + public static final ParseField RESCORE_VECTOR_FIELD = new ParseField("rescore_vector"); @SuppressWarnings("unchecked") public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( @@ -96,7 +96,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder { PARSER.declareField( optionalConstructorArg(), (p, c) -> RescoreVectorBuilder.fromXContent(p), - RESCORE_FIELD, + RESCORE_VECTOR_FIELD, ObjectParser.ValueType.OBJECT ); RetrieverBuilder.declareBaseParserFields(NAME, PARSER); @@ -281,7 +281,7 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept } if (rescoreVectorBuilder != null) { - builder.field(RESCORE_FIELD.getPreferredName(), rescoreVectorBuilder); + builder.field(RESCORE_VECTOR_FIELD.getPreferredName(), rescoreVectorBuilder); } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java index 0f89dd44367f0..b18ce2dff65cb 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java @@ -56,7 +56,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea public static final ParseField NAME_FIELD = AbstractQueryBuilder.NAME_FIELD; public static final ParseField BOOST_FIELD = AbstractQueryBuilder.BOOST_FIELD; public static final ParseField INNER_HITS_FIELD = new ParseField("inner_hits"); - public static final ParseField RESCORE_FIELD = new ParseField("rescore"); + public static final ParseField RESCORE_VECTOR_FIELD = new ParseField("rescore_vector"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("knn", args -> { @@ -89,7 +89,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea PARSER.declareField( optionalConstructorArg(), (p, c) -> RescoreVectorBuilder.fromXContent(p), - RESCORE_FIELD, + RESCORE_VECTOR_FIELD, ObjectParser.ValueType.OBJECT ); PARSER.declareFieldArray( @@ -547,7 +547,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(NAME_FIELD.getPreferredName(), queryName); } if (rescoreVectorBuilder != null) { - builder.field(RESCORE_FIELD.getPreferredName(), rescoreVectorBuilder); + builder.field(RESCORE_VECTOR_FIELD.getPreferredName(), rescoreVectorBuilder); } return builder; diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index e737e354d53a7..b99dadf7f34bc 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -69,7 +69,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder PARSER = new ConstructingObjectParser<>( "knn", @@ -104,7 +104,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder RescoreVectorBuilder.fromXContent(p), - RESCORE_FIELD, + RESCORE_VECTOR_FIELD, ObjectParser.ValueType.OBJECT ); PARSER.declareFieldArray( @@ -406,7 +406,7 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep builder.endArray(); } if (rescoreVectorBuilder != null) { - builder.field(RESCORE_FIELD.getPreferredName(), rescoreVectorBuilder); + builder.field(RESCORE_VECTOR_FIELD.getPreferredName(), rescoreVectorBuilder); } boostAndQueryNameToXContent(builder); builder.endObject(); diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java index 4273a85b584f4..3619553a94674 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java @@ -26,7 +26,7 @@ public class RescoreVectorBuilder implements Writeable, ToXContentObject { public static final ParseField OVERSAMPLE_FIELD = new ParseField("oversample"); public static final float MIN_OVERSAMPLE = 1.0F; private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "rescore", + "rescore_vector", args -> new RescoreVectorBuilder((Float) args[0]) ); From 4c65c8affa0c670dd0db040a12509ce3df514530 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 5 Dec 2024 21:55:48 +0100 Subject: [PATCH 47/56] Fix sneaky bug on iterator --- .../mapper/vectors/VectorSimilarityFloatValueSource.java | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java index 540533a1b2b7d..04b2a906d6dad 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java @@ -50,18 +50,14 @@ public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws final KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); return new DoubleValues() { - private int docId = -1; - @Override public double doubleValue() throws IOException { vectorOpsCount++; - return vectorSimilarityFunction.compare(target, vectorValues.vectorValue(docId)); + return vectorSimilarityFunction.compare(target, vectorValues.vectorValue(iterator.index())); } @Override public boolean advanceExact(int doc) throws IOException { - assert doc > iterator.docID(); - docId = doc; return doc >= iterator.docID() && iterator.docID() != DocIdSetIterator.NO_MORE_DOCS && iterator.advance(doc) == doc; } }; From 9907aadec6871d5450787055e7127ea7011786fc Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 9 Dec 2024 17:46:30 +0100 Subject: [PATCH 48/56] Use 'num_candidates_factor' parameter and update num_candidates instead of k --- .../search.vectors/210_knn_search_profile.yml | 4 +-- .../test/search.vectors/40_knn_search.yml | 2 +- .../search.vectors/41_knn_search_bbq_hnsw.yml | 2 +- .../41_knn_search_byte_quantized.yml | 2 +- .../41_knn_search_half_byte_quantized.yml | 2 +- .../search.vectors/42_knn_search_bbq_flat.yml | 2 +- .../search.vectors/42_knn_search_flat.yml | 2 +- .../42_knn_search_int4_flat.yml | 2 +- .../42_knn_search_int8_flat.yml | 2 +- .../test/search.vectors/45_knn_search_bit.yml | 2 +- .../search.vectors/45_knn_search_bit_flat.yml | 2 +- .../search.vectors/45_knn_search_byte.yml | 2 +- .../vectors/DenseVectorFieldMapper.java | 24 +++++++------- .../search/vectors/KnnVectorQueryBuilder.java | 6 ++-- .../search/vectors/RescoreVectorBuilder.java | 32 +++++++++---------- ...AbstractKnnVectorQueryBuilderTestCase.java | 9 +++--- 16 files changed, 50 insertions(+), 47 deletions(-) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml index ee66e1aa92213..fd73b2783b8b6 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml @@ -105,7 +105,7 @@ setup: k: 3 num_candidates: 3 "rescore_vector": - "oversample": 2.0 + "num_candidates_factor": 2.0 # We expect the knn search ops + rescoring num_cnaidates (for rescoring) per shard - match: { profile.shards.0.dfs.knn.0.vector_operations_count: 6 } @@ -123,7 +123,7 @@ setup: num_candidates: 3 similarity: 100000.0 "rescore_vector": - "oversample": 2.0 + "num_candidates_factor": 2.0 # We expect the knn search ops + rescoring num_cnaidates (for rescoring) per shard - match: { profile.shards.0.dfs.knn.0.vector_operations_count: 6 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml index 82b1d0740070b..534db18b5eb9c 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml @@ -589,7 +589,7 @@ setup: k: 3 num_candidates: 3 rescore_vector: - oversample: 1.5 + num_candidates_factor: 1.5 # Compare scores as hit IDs may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index f4cdd8b50ac4b..0a1e672412c7b 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -113,7 +113,7 @@ setup: k: 3 num_candidates: 3 rescore_vector: - oversample: 1.5 + num_candidates_factor: 1.5 # Get rescoring scores - hit ordering may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml index fbf6c7c5bfdca..b1e35789e8737 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_byte_quantized.yml @@ -398,7 +398,7 @@ setup: field: vector query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] rescore_vector: - oversample: 1.5 + num_candidates_factor: 1.5 # Get rescoring scores - hit ordering may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml index b170cd66ceefb..54e9eadf42e0b 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml @@ -575,7 +575,7 @@ setup: k: 3 num_candidates: 3 rescore_vector: - oversample: 1.5 + num_candidates_factor: 1.5 # Get rescoring scores - hit ordering may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml index f28ca559b423b..ccacbe8f053bb 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml @@ -112,7 +112,7 @@ setup: k: 3 num_candidates: 3 rescore_vector: - oversample: 1.5 + num_candidates_factor: 1.5 # Get rescoring scores - hit ordering may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml index f0529f4209288..a59aedceff3d3 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_flat.yml @@ -304,7 +304,7 @@ setup: k: 3 num_candidates: 3 rescore_vector: - oversample: 1.5 + num_candidates_factor: 1.5 # Compare scores as hit IDs may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml index dbe751c87084b..6796a92122f9a 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int4_flat.yml @@ -371,7 +371,7 @@ setup: k: 3 num_candidates: 3 rescore_vector: - oversample: 1.5 + num_candidates_factor: 1.5 # Get rescoring scores - hit ordering may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml index def124a8baab8..d1d312449cb70 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_int8_flat.yml @@ -288,7 +288,7 @@ setup: k: 3 num_candidates: 3 rescore_vector: - oversample: 1.5 + num_candidates_factor: 1.5 # Get rescoring scores - hit ordering may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml index 0a8f045ea2e77..effa3fff61525 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml @@ -454,7 +454,7 @@ setup: k: 3 num_candidates: 3 rescore_vector: - oversample: 1.5 + num_candidates_factor: 1.5 # Compare scores as hit IDs may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml index 65836416558b7..cdc1d9c64763e 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml @@ -270,7 +270,7 @@ setup: k: 3 num_candidates: 3 rescore_vector: - oversample: 1.5 + num_candidates_factor: 1.5 # Compare scores as hit IDs may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml index 2f637d9ca3e6d..213b571a0b4be 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml @@ -303,7 +303,7 @@ setup: k: 3 num_candidates: 3 rescore_vector: - oversample: 1.5 + num_candidates_factor: 1.5 # Compare scores as hit IDs may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 0d0a68dee2570..523a4b257e527 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -124,7 +124,7 @@ public static boolean isNotUnitVector(float magnitude) { public static short MIN_DIMS_FOR_DYNAMIC_FLOAT_MAPPING = 128; // minimum number of dims for floats to be dynamically mapped to vector public static final int MAGNITUDE_BYTES = 4; - public static final int OVERSAMPLE_LIMIT = 10_000; // Max oversample allowed for k and num_candidates + public static final int NUM_CANDS_OVERSAMPLE_LIMIT = 10_000; // Max oversample allowed for k and num_candidates private static DenseVectorFieldMapper toType(FieldMapper in) { return (DenseVectorFieldMapper) in; @@ -2008,7 +2008,7 @@ public Query createKnnQuery( VectorData queryVector, Integer k, int numCands, - Float rescoreOversample, + Float numCandsFactor, Query filter, Float similarityThreshold, BitSetProducer parentFilter @@ -2024,7 +2024,7 @@ public Query createKnnQuery( queryVector.asFloatVector(), k, numCands, - rescoreOversample, + numCandsFactor, filter, similarityThreshold, parentFilter @@ -2090,7 +2090,7 @@ private Query createKnnFloatQuery( float[] queryVector, Integer k, int numCands, - Float rescoreOversample, + Float numCandsFactor, Query filter, Float similarityThreshold, BitSetProducer parentFilter @@ -2111,15 +2111,17 @@ && isNotUnitVector(squaredMagnitude)) { } } - Integer adjustedK = k; int adjustedNumCands = numCands; - if (needsRescore(rescoreOversample) && adjustedK != null) { - adjustedK = Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample)); - adjustedNumCands = Math.max(adjustedK, numCands); + if (needsRescore(numCandsFactor)) { + // We shouldn't have less than k candidates (or 1 in case k is not set) to rescore + int minCands = k == null ? 1 : k; + // k <= numCands * numCandsFactor <= NUM_CANDS_OVERSAMPLE_LIMIT. Adjust otherwise. + adjustedNumCands = Math.max(minCands, (int) Math.ceil(numCands * numCandsFactor)); + adjustedNumCands = Math.min(adjustedNumCands, NUM_CANDS_OVERSAMPLE_LIMIT); } Query knnQuery = parentFilter != null - ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter) - : new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, adjustedNumCands, filter); + ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, k, adjustedNumCands, parentFilter) + : new ESKnnFloatVectorQuery(name(), queryVector, k, adjustedNumCands, filter); if (similarityThreshold != null) { knnQuery = new VectorSimilarityQuery( knnQuery, @@ -2127,7 +2129,7 @@ && isNotUnitVector(squaredMagnitude)) { similarity.score(similarityThreshold, elementType, dims) ); } - if (needsRescore(rescoreOversample)) { + if (needsRescore(numCandsFactor)) { knnQuery = new RescoreKnnVectorQuery( name(), queryVector, diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index b99dadf7f34bc..c868274eb8a1b 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -526,7 +526,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) fieldType; String parentPath = context.nestedLookup().getNestedParent(fieldName); - Float rescoreOversample = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.oversample(); + Float numCandidatesFactor = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.numCandidatesFactor(); if (parentPath != null) { final BitSetProducer parentBitSet; @@ -563,13 +563,13 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { queryVector, k, adjustedNumCands, - rescoreOversample, + numCandidatesFactor, filterQuery, vectorSimilarity, parentBitSet ); } - return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, rescoreOversample, filterQuery, vectorSimilarity, null); + return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, numCandidatesFactor, filterQuery, vectorSimilarity, null); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java index 3619553a94674..cf8f52e0edeed 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java @@ -23,41 +23,41 @@ public class RescoreVectorBuilder implements Writeable, ToXContentObject { - public static final ParseField OVERSAMPLE_FIELD = new ParseField("oversample"); - public static final float MIN_OVERSAMPLE = 1.0F; + public static final ParseField NUM_CANDIDATES_FACTOR_FIELD = new ParseField("num_candidates_factor"); + public static final float MIN_OVERSAMPLE = 0.0F; private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "rescore_vector", args -> new RescoreVectorBuilder((Float) args[0]) ); static { - PARSER.declareFloat(ConstructingObjectParser.constructorArg(), OVERSAMPLE_FIELD); + PARSER.declareFloat(ConstructingObjectParser.constructorArg(), NUM_CANDIDATES_FACTOR_FIELD); } // Oversample is required as of now as it is the only field in the rescore vector - private final float oversample; + private final float numCandidatesFactor; - public RescoreVectorBuilder(Float oversample) { - Objects.requireNonNull(oversample, "[" + OVERSAMPLE_FIELD.getPreferredName() + "] must be set"); - if (oversample <= MIN_OVERSAMPLE) { - throw new IllegalArgumentException("[" + OVERSAMPLE_FIELD.getPreferredName() + "] must be > " + MIN_OVERSAMPLE); + public RescoreVectorBuilder(Float numCandidatesFactor) { + Objects.requireNonNull(numCandidatesFactor, "[" + NUM_CANDIDATES_FACTOR_FIELD.getPreferredName() + "] must be set"); + if (numCandidatesFactor <= MIN_OVERSAMPLE) { + throw new IllegalArgumentException("[" + NUM_CANDIDATES_FACTOR_FIELD.getPreferredName() + "] must be > " + MIN_OVERSAMPLE); } - this.oversample = oversample; + this.numCandidatesFactor = numCandidatesFactor; } public RescoreVectorBuilder(StreamInput in) throws IOException { - this.oversample = in.readFloat(); + this.numCandidatesFactor = in.readFloat(); } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeFloat(oversample); + out.writeFloat(numCandidatesFactor); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(OVERSAMPLE_FIELD.getPreferredName(), oversample); + builder.field(NUM_CANDIDATES_FACTOR_FIELD.getPreferredName(), numCandidatesFactor); builder.endObject(); return builder; } @@ -71,15 +71,15 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; RescoreVectorBuilder that = (RescoreVectorBuilder) o; - return Objects.equals(oversample, that.oversample); + return Objects.equals(numCandidatesFactor, that.numCandidatesFactor); } @Override public int hashCode() { - return Objects.hashCode(oversample); + return Objects.hashCode(numCandidatesFactor); } - public Float oversample() { - return oversample; + public Float numCandidatesFactor() { + return numCandidatesFactor; } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index 83a66a94945a4..7c976b603e6f0 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -46,7 +46,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.NUM_CANDS_OVERSAMPLE_LIMIT; import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -200,9 +200,10 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que Integer k = queryBuilder.k(); Integer numCands = queryBuilder.numCands(); if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) { - Float rescoreOversample = queryBuilder.rescoreVectorBuilder().oversample(); - k = k == null ? null : Integer.valueOf(Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample))); - numCands = numCands == null ? null : Math.max(k == null ? 0 : k, numCands); + Float numCandsFactor = queryBuilder.rescoreVectorBuilder().numCandidatesFactor(); + int minCands = k == null ? 1 : k; + numCands = Math.max(minCands, (int) Math.ceil(numCands * numCandsFactor)); + numCands = Math.min(numCands, NUM_CANDS_OVERSAMPLE_LIMIT); } Query knnVectorQueryBuilt = switch (elementType()) { From 978cff3072bb0e27fcd0300e582bd9f2d0f07925 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 9 Dec 2024 17:46:44 +0100 Subject: [PATCH 49/56] Add knn retriever YAML tests --- .../search.retrievers/20_knn_retriever.yml | 58 ++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml index d08a8e2a6d39c..e49f0634a4887 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/20_knn_retriever.yml @@ -18,7 +18,7 @@ setup: dims: 5 index: true index_options: - type: hnsw + type: int8_hnsw similarity: l2_norm - do: @@ -73,3 +73,59 @@ setup: - match: {hits.total.value: 1} - match: {hits.hits.0._id: "3"} - match: {hits.hits.0.fields.name.0: "rabbit.jpg"} + +--- +"Vector rescoring has no effect for non-quantized vectors and provides same results as non-rescored knn": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: index1 + body: + knn: + field: vector + query_vector: [2, 2, 2, 2, 3] + k: 3 + num_candidates: 3 + rescore_vector: + num_candidates_factor: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: index1 + body: + query: + script_score: + query: {match_all: {} } + script: + source: "1.0 / (1.0 + Math.pow(l2norm(params.query_vector, 'vector'), 2.0))" + params: + query_vector: [2, 2, 2, 2, 3] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } From 9412be088b58c8d5bf1356b0bb1a63eb98d64888 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 9 Dec 2024 20:08:19 +0100 Subject: [PATCH 50/56] Simplify logic for RescoreKnnVectorQuery now that k is not modifiable --- .../vectors/DenseVectorFieldMapper.java | 1 - .../search/vectors/RescoreKnnVectorQuery.java | 51 +++++-------------- .../vectors/DenseVectorFieldTypeTests.java | 28 +++++----- .../search/vectors/KnnSearchBuilderTests.java | 4 +- .../vectors/RescoreKnnVectorQueryTests.java | 2 - 5 files changed, 31 insertions(+), 55 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 523a4b257e527..e10163a876ab0 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2134,7 +2134,6 @@ && isNotUnitVector(squaredMagnitude)) { name(), queryVector, similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT), - k, knnQuery ); } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index b45135b6bf4f3..4fa9a7964a173 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -16,8 +16,6 @@ import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; import org.elasticsearch.search.profile.query.QueryProfiler; @@ -32,7 +30,6 @@ public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvide private final String fieldName; private final float[] floatTarget; private final VectorSimilarityFunction vectorSimilarityFunction; - private final Integer k; private final Query innerQuery; private QueryProfilerProvider vectorProfiling; @@ -41,13 +38,11 @@ public RescoreKnnVectorQuery( String fieldName, float[] floatTarget, VectorSimilarityFunction vectorSimilarityFunction, - Integer k, Query innerQuery ) { this.fieldName = fieldName; this.floatTarget = floatTarget; this.vectorSimilarityFunction = vectorSimilarityFunction; - this.k = k; this.innerQuery = innerQuery; } @@ -58,34 +53,13 @@ public Query rewrite(IndexSearcher searcher) throws IOException { // to calculate top k and return directly the query to understand how many comparisons were done vectorProfiling = (QueryProfilerProvider) valueSource; FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource); - Query query = searcher.rewrite(functionScoreQuery); - - if (k == null) { - // No need to calculate top k - let the request size limit the results. - return query; - } - - // Retrieve top k documents from the rescored query - TopDocs topDocs = searcher.search(query, k); - ScoreDoc[] scoreDocs = topDocs.scoreDocs; - int[] docIds = new int[scoreDocs.length]; - float[] scores = new float[scoreDocs.length]; - for (int i = 0; i < scoreDocs.length; i++) { - docIds[i] = scoreDocs[i].doc; - scores[i] = scoreDocs[i].score; - } - - return new KnnScoreDocQuery(docIds, scores, searcher.getIndexReader()); + return searcher.rewrite(functionScoreQuery); } public Query innerQuery() { return innerQuery; } - public Integer k() { - return k; - } - @Override public void profile(QueryProfiler queryProfiler) { if (innerQuery instanceof QueryProfilerProvider queryProfilerProvider) { @@ -111,24 +85,27 @@ public boolean equals(Object o) { return Objects.equals(fieldName, that.fieldName) && Objects.deepEquals(floatTarget, that.floatTarget) && vectorSimilarityFunction == that.vectorSimilarityFunction - && Objects.equals(k, that.k) && Objects.equals(innerQuery, that.innerQuery); } @Override public int hashCode() { - return Objects.hash(fieldName, Arrays.hashCode(floatTarget), vectorSimilarityFunction, k, innerQuery); + return Objects.hash(fieldName, Arrays.hashCode(floatTarget), vectorSimilarityFunction, innerQuery); } @Override public String toString(String field) { - final StringBuilder sb = new StringBuilder("KnnRescoreVectorQuery{"); - sb.append("fieldName='").append(fieldName).append('\''); - sb.append(", floatTarget=").append(floatTarget[0]).append("..."); - sb.append(", vectorSimilarityFunction=").append(vectorSimilarityFunction); - sb.append(", k=").append(k); - sb.append(", vectorQuery=").append(innerQuery); - sb.append('}'); - return sb.toString(); + return "KnnRescoreVectorQuery{" + + "fieldName='" + + fieldName + + '\'' + + ", floatTarget=" + + floatTarget[0] + + "..." + + ", vectorSimilarityFunction=" + + vectorSimilarityFunction + + ", vectorQuery=" + + innerQuery + + '}'; } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index a5c1e8201376f..b9841aeb24f62 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -443,7 +443,7 @@ public void testRescoreOversampleUsedWithoutQuantization() { } } - public void testRescoreOversampleModifiesKnnParams() { + public void testRescoreOversampleModifiesNumCandidates() { DenseVectorFieldType fieldType = new DenseVectorFieldType( "f", IndexVersion.current(), @@ -456,36 +456,38 @@ public void testRescoreOversampleModifiesKnnParams() { ); // Total results is k, internal k is multiplied by oversample - checkRescoreQueryParameters(fieldType, 10, 200, 2.5F, 10, 25, 200); + checkRescoreQueryParameters(fieldType, 10, 200, 2.5F, 10, 500); // If numCands < k, update numCands to k - checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, 10, 25, 25); - // Oversampling limit - checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, 1000, 10000, 10000); - checkRescoreQueryParameters(fieldType, 5000, 7500, 2.5F, 5000, 10000, 10000); + checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, 10, 50); + // Oversampling limits for num candidates + checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, 1000, 10000); + checkRescoreQueryParameters(fieldType, 5000, 7500, 2.5F, 5000, 10000); + // Oversampling is capped at k as a minimum + checkRescoreQueryParameters(fieldType, 10, 100, 0.01F, 10, 10); + // Oversampling is capped at 1 as a minimum if k is not specified + checkRescoreQueryParameters(fieldType, null, 100, 0.0001F, null, 1); } private static void checkRescoreQueryParameters( DenseVectorFieldType fieldType, - int k, + Integer k, int candidates, - float oversample, - int expectedResults, - int expectedK, + float numCandsFactor, + Integer expectedK, int expectedCandidates ) { Query query = fieldType.createKnnQuery( VectorData.fromFloats(new float[] { 1, 4, 10 }), k, candidates, - oversample, + numCandsFactor, null, null, null ); RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query; ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) rescoreQuery.innerQuery(); - assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults)); - assertThat("Unexpected k parameter", esKnnQuery.kParam(), equalTo(expectedK)); + assertThat("Unexpected total results", esKnnQuery.kParam(), equalTo(expectedK)); assertThat("Unexpected candidates", esKnnQuery.getK(), equalTo(expectedCandidates)); } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java index 5e4d5f3b79eea..cbd80b197372e 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java @@ -257,9 +257,9 @@ public void testInvalidK() { public void testInvalidRescoreVectorBuilder() { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new KnnSearchBuilder("field", randomVector(3), 0, 100, new RescoreVectorBuilder(1.0F), null) + () -> new KnnSearchBuilder("field", randomVector(3), 10, 100, new RescoreVectorBuilder(0.0F), null) ); - assertThat(e.getMessage(), containsString("[oversample] must be > 1.0")); + assertThat(e.getMessage(), containsString("[num_candidates_factor] must be > 0.0")); } public void testRewrite() throws Exception { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 7bbe7dcc155c5..713e48d8b291e 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -77,7 +77,6 @@ public void testRescoreDocs() throws Exception { FIELD_NAME, queryVector, VectorSimilarityFunction.COSINE, - adjustedK, new MatchAllDocsQuery() ); @@ -143,7 +142,6 @@ private void checkProfiling(float[] queryVector, IndexReader reader, Query inner FIELD_NAME, queryVector, VectorSimilarityFunction.COSINE, - k, innerQuery ); IndexSearcher searcher = newSearcher(reader, true, false); From 7045500f8462b28988cbdac0d1275527e026f8fe Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 9 Dec 2024 20:50:03 +0100 Subject: [PATCH 51/56] Limit for rescoring factor is 1.0, so we can't have less rescored docs than num_candidates --- .../index/mapper/vectors/DenseVectorFieldMapper.java | 5 +---- .../elasticsearch/search/vectors/RescoreVectorBuilder.java | 6 +++--- .../index/mapper/vectors/DenseVectorFieldTypeTests.java | 4 ---- .../elasticsearch/search/vectors/KnnSearchBuilderTests.java | 4 ++-- 4 files changed, 6 insertions(+), 13 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index e10163a876ab0..7526023b49567 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2113,11 +2113,8 @@ && isNotUnitVector(squaredMagnitude)) { int adjustedNumCands = numCands; if (needsRescore(numCandsFactor)) { - // We shouldn't have less than k candidates (or 1 in case k is not set) to rescore - int minCands = k == null ? 1 : k; // k <= numCands * numCandsFactor <= NUM_CANDS_OVERSAMPLE_LIMIT. Adjust otherwise. - adjustedNumCands = Math.max(minCands, (int) Math.ceil(numCands * numCandsFactor)); - adjustedNumCands = Math.min(adjustedNumCands, NUM_CANDS_OVERSAMPLE_LIMIT); + adjustedNumCands = Math.min((int) Math.ceil(numCands * numCandsFactor), NUM_CANDS_OVERSAMPLE_LIMIT); } Query knnQuery = parentFilter != null ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, k, adjustedNumCands, parentFilter) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java index cf8f52e0edeed..5c1abd15989ee 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java @@ -24,7 +24,7 @@ public class RescoreVectorBuilder implements Writeable, ToXContentObject { public static final ParseField NUM_CANDIDATES_FACTOR_FIELD = new ParseField("num_candidates_factor"); - public static final float MIN_OVERSAMPLE = 0.0F; + public static final float MIN_OVERSAMPLE = 1.0F; private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "rescore_vector", args -> new RescoreVectorBuilder((Float) args[0]) @@ -39,8 +39,8 @@ public class RescoreVectorBuilder implements Writeable, ToXContentObject { public RescoreVectorBuilder(Float numCandidatesFactor) { Objects.requireNonNull(numCandidatesFactor, "[" + NUM_CANDIDATES_FACTOR_FIELD.getPreferredName() + "] must be set"); - if (numCandidatesFactor <= MIN_OVERSAMPLE) { - throw new IllegalArgumentException("[" + NUM_CANDIDATES_FACTOR_FIELD.getPreferredName() + "] must be > " + MIN_OVERSAMPLE); + if (numCandidatesFactor < MIN_OVERSAMPLE) { + throw new IllegalArgumentException("[" + NUM_CANDIDATES_FACTOR_FIELD.getPreferredName() + "] must be >= " + MIN_OVERSAMPLE); } this.numCandidatesFactor = numCandidatesFactor; } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index b9841aeb24f62..1fc6dc70b8d95 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -462,10 +462,6 @@ public void testRescoreOversampleModifiesNumCandidates() { // Oversampling limits for num candidates checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, 1000, 10000); checkRescoreQueryParameters(fieldType, 5000, 7500, 2.5F, 5000, 10000); - // Oversampling is capped at k as a minimum - checkRescoreQueryParameters(fieldType, 10, 100, 0.01F, 10, 10); - // Oversampling is capped at 1 as a minimum if k is not specified - checkRescoreQueryParameters(fieldType, null, 100, 0.0001F, null, 1); } private static void checkRescoreQueryParameters( diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java index cbd80b197372e..a39438af5b72a 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java @@ -257,9 +257,9 @@ public void testInvalidK() { public void testInvalidRescoreVectorBuilder() { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> new KnnSearchBuilder("field", randomVector(3), 10, 100, new RescoreVectorBuilder(0.0F), null) + () -> new KnnSearchBuilder("field", randomVector(3), 10, 100, new RescoreVectorBuilder(0.99F), null) ); - assertThat(e.getMessage(), containsString("[num_candidates_factor] must be > 0.0")); + assertThat(e.getMessage(), containsString("[num_candidates_factor] must be >= 1.0")); } public void testRewrite() throws Exception { From 497d8e257ec5c964fc5f37f61f89d6e09996a0b4 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 10 Dec 2024 08:04:49 +0100 Subject: [PATCH 52/56] Fix test after merge --- .../xpack/inference/highlight/SemanticTextHighlighterTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java index 7dc4d99e06acc..78743409ca178 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java @@ -91,7 +91,7 @@ public void testDenseVector() throws Exception { Map queryMap = (Map) queries.get("dense_vector_1"); float[] vector = readDenseVector(queryMap.get("embeddings")); var fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) mapperService.mappingLookup().getFieldType(SEMANTIC_FIELD_E5); - KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder(fieldType.getEmbeddingsField().fullPath(), vector, 10, 10, null); + KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder(fieldType.getEmbeddingsField().fullPath(), vector, 10, 10, null, null); NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(fieldType.getChunksField().fullPath(), knnQuery, ScoreMode.Max); var shardRequest = createShardSearchRequest(nestedQueryBuilder); var sourceToParse = new SourceToParse("0", readSampleDoc("sample-doc.json.gz"), XContentType.JSON); From 83238a51037c9f79412ea611f7a1df26686cccc9 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 10 Dec 2024 12:11:00 +0100 Subject: [PATCH 53/56] Vector similarity needs to wrap the new rescoring query and not the other way round --- .../vectors/DenseVectorFieldMapper.java | 24 +++++---- .../search/vectors/RescoreKnnVectorQuery.java | 49 +++++++++++++------ .../search/vectors/VectorSimilarityQuery.java | 3 +- .../vectors/DenseVectorFieldTypeTests.java | 14 +++--- ...AbstractKnnVectorQueryBuilderTestCase.java | 21 +++----- .../vectors/RescoreKnnVectorQueryTests.java | 2 + 6 files changed, 69 insertions(+), 44 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 0956de8d26a60..8c6e874ff577f 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2111,29 +2111,33 @@ && isNotUnitVector(squaredMagnitude)) { } } + Integer adjustedK = k; int adjustedNumCands = numCands; if (needsRescore(numCandsFactor)) { - // k <= numCands * numCandsFactor <= NUM_CANDS_OVERSAMPLE_LIMIT. Adjust otherwise. + // Get all candidates, get top k as part of rescoring + adjustedK = null; + // numCands * numCandsFactor <= NUM_CANDS_OVERSAMPLE_LIMIT. Adjust otherwise. adjustedNumCands = Math.min((int) Math.ceil(numCands * numCandsFactor), NUM_CANDS_OVERSAMPLE_LIMIT); } Query knnQuery = parentFilter != null - ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, k, adjustedNumCands, parentFilter) - : new ESKnnFloatVectorQuery(name(), queryVector, k, adjustedNumCands, filter); - if (similarityThreshold != null) { - knnQuery = new VectorSimilarityQuery( - knnQuery, - similarityThreshold, - similarity.score(similarityThreshold, elementType, dims) - ); - } + ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter) + : new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, adjustedNumCands, filter); if (needsRescore(numCandsFactor)) { knnQuery = new RescoreKnnVectorQuery( name(), queryVector, similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT), + k, knnQuery ); } + if (similarityThreshold != null) { + knnQuery = new VectorSimilarityQuery( + knnQuery, + similarityThreshold, + similarity.score(similarityThreshold, elementType, dims) + ); + } return knnQuery; } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index 4fa9a7964a173..b1b2d55648585 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -16,6 +16,8 @@ import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; import org.elasticsearch.search.profile.query.QueryProfiler; @@ -30,6 +32,7 @@ public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvide private final String fieldName; private final float[] floatTarget; private final VectorSimilarityFunction vectorSimilarityFunction; + private final Integer k; private final Query innerQuery; private QueryProfilerProvider vectorProfiling; @@ -38,11 +41,13 @@ public RescoreKnnVectorQuery( String fieldName, float[] floatTarget, VectorSimilarityFunction vectorSimilarityFunction, + Integer k, Query innerQuery ) { this.fieldName = fieldName; this.floatTarget = floatTarget; this.vectorSimilarityFunction = vectorSimilarityFunction; + this.k = k; this.innerQuery = innerQuery; } @@ -53,13 +58,34 @@ public Query rewrite(IndexSearcher searcher) throws IOException { // to calculate top k and return directly the query to understand how many comparisons were done vectorProfiling = (QueryProfilerProvider) valueSource; FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource); - return searcher.rewrite(functionScoreQuery); + Query query = searcher.rewrite(functionScoreQuery); + + if (k == null) { + // No need to calculate top k - let the request size limit the results. + return query; + } + + // Retrieve top k documents from the rescored query + TopDocs topDocs = searcher.search(query, k); + ScoreDoc[] scoreDocs = topDocs.scoreDocs; + int[] docIds = new int[scoreDocs.length]; + float[] scores = new float[scoreDocs.length]; + for (int i = 0; i < scoreDocs.length; i++) { + docIds[i] = scoreDocs[i].doc; + scores[i] = scoreDocs[i].score; + } + + return new KnnScoreDocQuery(docIds, scores, searcher.getIndexReader()); } public Query innerQuery() { return innerQuery; } + public Integer k() { + return k; + } + @Override public void profile(QueryProfiler queryProfiler) { if (innerQuery instanceof QueryProfilerProvider queryProfilerProvider) { @@ -85,27 +111,22 @@ public boolean equals(Object o) { return Objects.equals(fieldName, that.fieldName) && Objects.deepEquals(floatTarget, that.floatTarget) && vectorSimilarityFunction == that.vectorSimilarityFunction + && Objects.equals(k, that.k) && Objects.equals(innerQuery, that.innerQuery); } @Override public int hashCode() { - return Objects.hash(fieldName, Arrays.hashCode(floatTarget), vectorSimilarityFunction, innerQuery); + return Objects.hash(fieldName, Arrays.hashCode(floatTarget), vectorSimilarityFunction, k, innerQuery); } @Override public String toString(String field) { - return "KnnRescoreVectorQuery{" - + "fieldName='" - + fieldName - + '\'' - + ", floatTarget=" - + floatTarget[0] - + "..." - + ", vectorSimilarityFunction=" - + vectorSimilarityFunction - + ", vectorQuery=" - + innerQuery - + '}'; + return "KnnRescoreVectorQuery{" + "fieldName='" + fieldName + '\'' + + ", floatTarget=" + floatTarget[0] + "..." + + ", vectorSimilarityFunction=" + vectorSimilarityFunction + + ", k=" + k + + ", vectorQuery=" + innerQuery + + '}'; } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java index d91994f843541..4cd267bd07186 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java @@ -29,7 +29,8 @@ import static org.elasticsearch.common.Strings.format; /** - * This query provides a simple post-filter for the provided Query. The query is assumed to be a Knn(Float|Byte)VectorQuery. + * This query provides a simple post-filter for the provided Query to limit the results of the inner query to those that have a similarity + * above a certain threshold */ public class VectorSimilarityQuery extends Query implements QueryProfilerProvider { private final float similarity; diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index 1fc6dc70b8d95..d37b4a4bacb4e 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -456,12 +456,12 @@ public void testRescoreOversampleModifiesNumCandidates() { ); // Total results is k, internal k is multiplied by oversample - checkRescoreQueryParameters(fieldType, 10, 200, 2.5F, 10, 500); + checkRescoreQueryParameters(fieldType, 10, 200, 2.5F, null, 500, 10); // If numCands < k, update numCands to k - checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, 10, 50); + checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, null, 50, 10); // Oversampling limits for num candidates - checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, 1000, 10000); - checkRescoreQueryParameters(fieldType, 5000, 7500, 2.5F, 5000, 10000); + checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, null, 10000, 1000); + checkRescoreQueryParameters(fieldType, 5000, 7500, 2.5F, null, 10000, 5000); } private static void checkRescoreQueryParameters( @@ -470,7 +470,8 @@ private static void checkRescoreQueryParameters( int candidates, float numCandsFactor, Integer expectedK, - int expectedCandidates + int expectedCandidates, + int expectedResults ) { Query query = fieldType.createKnnQuery( VectorData.fromFloats(new float[] { 1, 4, 10 }), @@ -483,7 +484,8 @@ private static void checkRescoreQueryParameters( ); RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query; ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) rescoreQuery.innerQuery(); - assertThat("Unexpected total results", esKnnQuery.kParam(), equalTo(expectedK)); + assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults)); + assertThat("Unexpected k parameter", esKnnQuery.kParam(), equalTo(expectedK)); assertThat("Unexpected candidates", esKnnQuery.getK(), equalTo(expectedCandidates)); } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index 7c976b603e6f0..375712ee60861 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -171,23 +171,18 @@ protected RescoreVectorBuilder randomRescoreVectorBuilder() { @Override protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query query, SearchExecutionContext context) throws IOException { + if (queryBuilder.getVectorSimilarity() != null) { + assertTrue(query instanceof VectorSimilarityQuery); + assertThat(((VectorSimilarityQuery) query).getSimilarity(), equalTo(queryBuilder.getVectorSimilarity())); + query = ((VectorSimilarityQuery) query).getInnerKnnQuery(); + } if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) { RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query; query = rescoreQuery.innerQuery(); } - if (queryBuilder.getVectorSimilarity() != null) { - assertTrue(query instanceof VectorSimilarityQuery); - Query knnQuery = ((VectorSimilarityQuery) query).getInnerKnnQuery(); - assertThat(((VectorSimilarityQuery) query).getSimilarity(), equalTo(queryBuilder.getVectorSimilarity())); - switch (elementType()) { - case FLOAT -> assertTrue(knnQuery instanceof ESKnnFloatVectorQuery); - case BYTE -> assertTrue(knnQuery instanceof ESKnnByteVectorQuery); - } - } else { - switch (elementType()) { - case FLOAT -> assertTrue(query instanceof ESKnnFloatVectorQuery); - case BYTE -> assertTrue(query instanceof ESKnnByteVectorQuery); - } + switch (elementType()) { + case FLOAT -> assertTrue(query instanceof ESKnnFloatVectorQuery); + case BYTE -> assertTrue(query instanceof ESKnnByteVectorQuery); } BooleanQuery.Builder builder = new BooleanQuery.Builder(); diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 713e48d8b291e..7bbe7dcc155c5 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -77,6 +77,7 @@ public void testRescoreDocs() throws Exception { FIELD_NAME, queryVector, VectorSimilarityFunction.COSINE, + adjustedK, new MatchAllDocsQuery() ); @@ -142,6 +143,7 @@ private void checkProfiling(float[] queryVector, IndexReader reader, Query inner FIELD_NAME, queryVector, VectorSimilarityFunction.COSINE, + k, innerQuery ); IndexSearcher searcher = newSearcher(reader, true, false); From fa975ded0d7d48c5ae168c778426768717beab97 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 10 Dec 2024 12:11:24 +0100 Subject: [PATCH 54/56] Change similarity to MIP in tests --- .../search.vectors/210_knn_search_profile.yml | 76 ++++++++++--------- .../search.vectors/41_knn_search_bbq_hnsw.yml | 20 ++++- .../search.vectors/42_knn_search_bbq_flat.yml | 23 ++++-- 3 files changed, 77 insertions(+), 42 deletions(-) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml index fd73b2783b8b6..d4bf5e7e9807f 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/210_knn_search_profile.yml @@ -18,20 +18,11 @@ setup: number_of_shards: 1 mappings: properties: - name: - type: keyword vector: type: dense_vector dims: 64 index: true - similarity: l2_norm - index_options: - type: bbq_hnsw - another_vector: - type: dense_vector - dims: 64 - index: true - similarity: l2_norm + similarity: max_inner_product index_options: type: bbq_hnsw @@ -40,9 +31,14 @@ setup: index: bbq_hnsw id: "1" body: - name: cow.jpg - vector: [ 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0 ] - another_vector: [ 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0 ] + vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] # Flush in order to provoke a merge later - do: indices.flush: @@ -53,9 +49,14 @@ setup: index: bbq_hnsw id: "2" body: - name: moose.jpg - vector: [ 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0 ] - another_vector: [ 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120 ] + vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] # Flush in order to provoke a merge later - do: indices.flush: @@ -67,8 +68,14 @@ setup: id: "3" body: name: rabbit.jpg - vector: [ 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0 ] - another_vector: [ 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0 ] + vector: [0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.657, 1.285, 0.2 , -0.062, 0.038, 0.089, -0.068, -0.058] # Flush in order to provoke a merge later - do: indices.flush: @@ -78,21 +85,8 @@ setup: indices.forcemerge: index: bbq_hnsw max_num_segments: 1 - --- "Profile rescored knn search": - - do: - search: - index: bbq_hnsw - body: - profile: true - knn: - field: vector - query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0 ] - k: 3 - num_candidates: 3 - - - match: { profile.shards.0.dfs.knn.0.vector_operations_count: 3 } - do: search: @@ -101,7 +95,14 @@ setup: profile: true knn: field: vector - query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0 ] + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] k: 3 num_candidates: 3 "rescore_vector": @@ -118,10 +119,17 @@ setup: profile: true knn: field: vector - query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0 ] + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] k: 3 num_candidates: 3 - similarity: 100000.0 + similarity: 100000 "rescore_vector": "num_candidates_factor": 2.0 diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index 30dd7d439f92d..2567a4ac597d9 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -129,7 +129,14 @@ setup: body: knn: field: vector - query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] k: 3 num_candidates: 3 rescore_vector: @@ -152,9 +159,16 @@ setup: script_score: query: {match_all: {} } script: - source: "1.0 / (1.0 + Math.pow(l2norm(params.query_vector, 'vector'), 2.0))" + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" params: - query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0 ] + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] # Compare scores as hit IDs may change depending on how things are distributed - match: { hits.total: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml index 516114269e588..a3cd624ef0ab8 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml @@ -128,7 +128,14 @@ setup: body: knn: field: vector - query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17, + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] k: 3 num_candidates: 3 rescore_vector: @@ -150,12 +157,18 @@ setup: body: query: script_score: - query: {match_all: {} } + query: { match_all: {} } script: - source: "1.0 / (1.0 + Math.pow(l2norm(params.query_vector, 'vector'), 2.0))" + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" params: - query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] - + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17, + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] # Compare scores as hit IDs may change depending on how things are distributed - match: { hits.total: 3 } - match: { hits.hits.0._score: $rescore_score0 } From 74a22f8351f5808b16036f39d3008fea50e7641f Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 10 Dec 2024 12:37:02 +0100 Subject: [PATCH 55/56] Spotless --- .../search/vectors/RescoreKnnVectorQuery.java | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index b1b2d55648585..014c3265c1833 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -122,11 +122,19 @@ public int hashCode() { @Override public String toString(String field) { - return "KnnRescoreVectorQuery{" + "fieldName='" + fieldName + '\'' + - ", floatTarget=" + floatTarget[0] + "..." + - ", vectorSimilarityFunction=" + vectorSimilarityFunction + - ", k=" + k + - ", vectorQuery=" + innerQuery + - '}'; + return "KnnRescoreVectorQuery{" + + "fieldName='" + + fieldName + + '\'' + + ", floatTarget=" + + floatTarget[0] + + "..." + + ", vectorSimilarityFunction=" + + vectorSimilarityFunction + + ", k=" + + k + + ", vectorQuery=" + + innerQuery + + '}'; } } From a256de9e2af66c072da8df19761e69bf3cb7bb70 Mon Sep 17 00:00:00 2001 From: Carlos Delgado <6339205+carlosdelest@users.noreply.github.com> Date: Wed, 11 Dec 2024 07:50:30 +0100 Subject: [PATCH 56/56] Apply suggestions from code review Co-authored-by: Benjamin Trent --- .../mapper/vectors/VectorSimilarityFloatValueSource.java | 4 ++-- .../elasticsearch/search/vectors/RescoreKnnVectorQuery.java | 2 +- .../elasticsearch/search/vectors/RescoreVectorBuilder.java | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java index 04b2a906d6dad..74a7dbe168e6b 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java @@ -89,13 +89,13 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; VectorSimilarityFloatValueSource that = (VectorSimilarityFloatValueSource) o; return Objects.equals(field, that.field) - && Objects.deepEquals(target, that.target) + && Arrays.equals(target, that.target) && vectorSimilarityFunction == that.vectorSimilarityFunction; } @Override public String toString() { - return "VectorSimilarityFloatValueSource(" + field + ", " + Arrays.toString(target) + ", " + vectorSimilarityFunction + ")"; + return "VectorSimilarityFloatValueSource(" + field + ", [" + target[0] + ",...], " + vectorSimilarityFunction + ")"; } @Override diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java index 014c3265c1833..a9c606b1f8618 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java @@ -109,7 +109,7 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; RescoreKnnVectorQuery that = (RescoreKnnVectorQuery) o; return Objects.equals(fieldName, that.fieldName) - && Objects.deepEquals(floatTarget, that.floatTarget) + && Arrays.equals(floatTarget, that.floatTarget) && vectorSimilarityFunction == that.vectorSimilarityFunction && Objects.equals(k, that.k) && Objects.equals(innerQuery, that.innerQuery); diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java index 5c1abd15989ee..4604d4f0ea325 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java @@ -37,7 +37,7 @@ public class RescoreVectorBuilder implements Writeable, ToXContentObject { // Oversample is required as of now as it is the only field in the rescore vector private final float numCandidatesFactor; - public RescoreVectorBuilder(Float numCandidatesFactor) { + public RescoreVectorBuilder(float numCandidatesFactor) { Objects.requireNonNull(numCandidatesFactor, "[" + NUM_CANDIDATES_FACTOR_FIELD.getPreferredName() + "] must be set"); if (numCandidatesFactor < MIN_OVERSAMPLE) { throw new IllegalArgumentException("[" + NUM_CANDIDATES_FACTOR_FIELD.getPreferredName() + "] must be >= " + MIN_OVERSAMPLE); @@ -79,7 +79,7 @@ public int hashCode() { return Objects.hashCode(numCandidatesFactor); } - public Float numCandidatesFactor() { + public float numCandidatesFactor() { return numCandidatesFactor; } }