Skip to content

Commit

Permalink
Support cosine similarity in kNN search (#79500)
Browse files Browse the repository at this point in the history
This PR adds support for `cosine` similarity:

```
"mappings": {
  "properties": {
    "my_vector": {
      "type": "dense_vector",
      "dims": 128,
      "index": true,
      "similarity": "cosine"
    }
  }
}
```

Unlike `dot_product`, which requires vectors to be of unit length, this
similarity can handle vectors with any magnitude.

This PR also adds validation around `dot_product` to help catch mistakes. When
indexing vectors, we double-check that each vector has unit length. We also
check that kNN query vectors have unit length.
  • Loading branch information
jtibshirani committed Oct 21, 2021
1 parent 7823e5a commit 7f01138
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ setup:
type: dense_vector
dims: 5
index: true
similarity: dot_product

similarity: cosine
- do:
index:
index: test-index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
type: dense_vector
dims: 3
index: true
similarity: dot_product
similarity: cosine

- do:
bulk:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,10 @@
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocValuesFieldExistsQuery;
import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.index.mapper.MappingParser;
import org.elasticsearch.index.mapper.PerFieldKnnVectorsFormatFieldMapper;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser.Token;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.index.mapper.ArraySourceValueFetcher;
Expand All @@ -31,14 +27,20 @@
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.MapperBuilderContext;
import org.elasticsearch.index.mapper.MapperParsingException;
import org.elasticsearch.index.mapper.MappingParser;
import org.elasticsearch.index.mapper.PerFieldKnnVectorsFormatFieldMapper;
import org.elasticsearch.index.mapper.SimpleMappedFieldType;
import org.elasticsearch.index.mapper.TextSearchInfo;
import org.elasticsearch.index.mapper.ValueFetcher;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.support.CoreValuesSourceType;
import org.elasticsearch.search.lookup.SearchLookup;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser.Token;
import org.elasticsearch.xpack.vectors.query.KnnVectorFieldExistsQuery;
import org.elasticsearch.xpack.vectors.query.KnnVectorQueryBuilder;
import org.elasticsearch.xpack.vectors.query.VectorIndexFieldData;

import java.io.IOException;
Expand Down Expand Up @@ -107,7 +109,7 @@ public DenseVectorFieldMapper build(MapperBuilderContext context) {
return new DenseVectorFieldMapper(
name,
new DenseVectorFieldType(context.buildFullName(name), indexVersionCreated,
dims.getValue(), indexed.getValue(), meta.getValue()),
dims.getValue(), indexed.getValue(), similarity.getValue(), meta.getValue()),
dims.getValue(),
indexed.getValue(),
similarity.getValue(),
Expand All @@ -120,6 +122,7 @@ public DenseVectorFieldMapper build(MapperBuilderContext context) {

enum VectorSimilarity {
l2_norm(VectorSimilarityFunction.EUCLIDEAN),
cosine(VectorSimilarityFunction.COSINE),
dot_product(VectorSimilarityFunction.DOT_PRODUCT);

public final VectorSimilarityFunction function;
Expand Down Expand Up @@ -195,19 +198,18 @@ public String toString() {
public static final class DenseVectorFieldType extends SimpleMappedFieldType {
private final int dims;
private final boolean indexed;
private final VectorSimilarity similarity;
private final Version indexVersionCreated;

public DenseVectorFieldType(String name, Version indexVersionCreated, int dims, boolean indexed, Map<String, String> meta) {
public DenseVectorFieldType(String name, Version indexVersionCreated, int dims,boolean indexed,
VectorSimilarity similarity, Map<String, String> meta) {
super(name, indexed, false, indexed == false, TextSearchInfo.NONE, meta);
this.dims = dims;
this.indexed = indexed;
this.similarity = similarity;
this.indexVersionCreated = indexVersionCreated;
}

public int dims() {
return dims;
}

@Override
public String typeName() {
return CONTENT_TYPE;
Expand Down Expand Up @@ -257,6 +259,48 @@ public Query termQuery(Object value, SearchExecutionContext context) {
throw new IllegalArgumentException(
"Field [" + name() + "] of type [" + typeName() + "] doesn't support queries");
}

public KnnVectorQuery createKnnQuery(float[] queryVector, int numCands) {
if (isSearchable() == false) {
throw new IllegalArgumentException("[" + KnnVectorQueryBuilder.NAME + "] " +
"queries are not supported if [index] is disabled");
}

if (queryVector.length != dims) {
throw new IllegalArgumentException("the query vector has a different dimension [" + queryVector.length + "] "
+ "than the index vectors [" + dims + "]");
}

if (similarity == VectorSimilarity.dot_product) {
double squaredMagnitude = 0.0;
for (float e : queryVector) {
squaredMagnitude += e * e;
}
checkVectorMagnitude(queryVector, squaredMagnitude);
}
return new KnnVectorQuery(name(), queryVector, numCands);
}

private void checkVectorMagnitude(float[] vector, double squaredMagnitude) {
if (Math.abs(squaredMagnitude - 1.0f) > 1e-4) {
// Include the first five elements of the invalid vector in the error message
StringBuilder sb = new StringBuilder("The [" + VectorSimilarity.dot_product.name() + "] similarity can " +
"only be used with unit-length vectors. Preview of invalid vector: ");
sb.append("[");
for (int i = 0; i < Math.min(5, vector.length); i++) {
if (i > 0) {
sb.append(", ");
}
sb.append(vector[i]);
}
if (vector.length >= 5) {
sb.append(", ...");
}
sb.append("]");

throw new IllegalArgumentException(sb.toString());
}
}
}

private final int dims;
Expand Down Expand Up @@ -301,13 +345,20 @@ public void parse(DocumentParserContext context) throws IOException {

private Field parseKnnVector(DocumentParserContext context) throws IOException {
float[] vector = new float[dims];
double squaredMagnitude = 0.0;
int index = 0;
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
checkDimensionExceeded(index, context);
ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser());
vector[index++] = context.parser().floatValue(true);

float value = context.parser().floatValue(true);
vector[index++] = value;
squaredMagnitude += value * value;
}
checkDimensionMatches(index, context);
if (similarity == VectorSimilarity.dot_product) {
fieldType().checkVectorMagnitude(vector, squaredMagnitude);
}
return new KnnVectorField(fieldType().name(), vector, similarity.function);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

package org.elasticsearch.xpack.vectors.query;

import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.Query;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -92,14 +91,7 @@ protected Query doToQuery(SearchExecutionContext context) {
}

DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) fieldType;
if (queryVector.length != vectorFieldType.dims()) {
throw new IllegalArgumentException("the query vector has a different dimension [" + queryVector.length + "] "
+ "than the index vectors [" + vectorFieldType.dims() + "]");
}
if (vectorFieldType.isSearchable() == false) {
throw new IllegalArgumentException("[" + NAME + "] queries are not supported if [index] is disabled");
}
return new KnnVectorQuery(fieldType.name(), queryVector, numCands);
return vectorFieldType.createKnnQuery(queryVector, numCands);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ protected void minimalMapping(XContentBuilder b) throws IOException {

@Override
protected Object getSampleValueForDocument() {
return List.of(1, 2, 3, 4);
return List.of(0.5, 0.5, 0.5, 0.5);
}

@Override
Expand Down Expand Up @@ -204,7 +204,7 @@ public void testIndexedVector() throws Exception {
.field("index", true)
.field("similarity", similarity.name())));

float[] vector = {-12.1f, 100.7f, -4};
float[] vector = {-0.5f, 0.5f, 0.7071f};
ParsedDocument doc1 = mapper.parse(source(b -> b.array("field", vector)));

IndexableField[] fields = doc1.rootDoc().getFields("field");
Expand All @@ -220,6 +220,30 @@ public void testIndexedVector() throws Exception {
assertEquals(similarity.function, vectorField.fieldType().vectorSimilarityFunction());
}

public void testDotProductWithInvalidNorm() throws Exception {
DocumentMapper mapper = createDocumentMapper(fieldMapping(b -> b
.field("type", "dense_vector")
.field("dims", 3)
.field("index", true)
.field("similarity", VectorSimilarity.dot_product)));
float[] vector = {-12.1f, 2.7f, -4};
MapperParsingException e = expectThrows(MapperParsingException.class, () -> mapper.parse(source(b -> b.array("field", vector))));
assertNotNull(e.getCause());
assertThat(e.getCause().getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors. " +
"Preview of invalid vector: [-12.1, 2.7, -4.0]"));

DocumentMapper mapperWithLargerDim = createDocumentMapper(fieldMapping(b -> b
.field("type", "dense_vector")
.field("dims", 6)
.field("index", true)
.field("similarity", VectorSimilarity.dot_product)));
float[] largerVector = {-12.1f, 2.7f, -4, 1.05f, 10.0f, 29.9f};
e = expectThrows(MapperParsingException.class, () -> mapperWithLargerDim.parse(source(b -> b.array("field", largerVector))));
assertNotNull(e.getCause());
assertThat(e.getCause().getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors. " +
"Preview of invalid vector: [-12.1, 2.7, -4.0, 1.05, 10.0, ...]"));
}

public void testInvalidParameters() {
MapperParsingException e = expectThrows(MapperParsingException.class,
() -> createDocumentMapper(fieldMapping(b -> b
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,52 +9,70 @@

import org.elasticsearch.Version;
import org.elasticsearch.index.mapper.FieldTypeTestCase;
import org.elasticsearch.xpack.vectors.mapper.DenseVectorFieldMapper.DenseVectorFieldType;
import org.elasticsearch.xpack.vectors.mapper.DenseVectorFieldMapper.VectorSimilarity;

import java.io.IOException;
import java.util.Collections;
import java.util.List;

import static org.hamcrest.Matchers.containsString;

public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
private final boolean indexed;

public DenseVectorFieldTypeTests() {
this.indexed = randomBoolean();
}

private DenseVectorFieldMapper.DenseVectorFieldType createFieldType() {
return new DenseVectorFieldMapper.DenseVectorFieldType("f", Version.CURRENT, 5, indexed, Collections.emptyMap());
private DenseVectorFieldType createFieldType() {
return new DenseVectorFieldType("f", Version.CURRENT, 5, indexed, VectorSimilarity.cosine, Collections.emptyMap());
}

public void testHasDocValues() {
DenseVectorFieldMapper.DenseVectorFieldType ft = createFieldType();
DenseVectorFieldType ft = createFieldType();
assertNotEquals(indexed, ft.hasDocValues());
}

public void testIsSearchable() {
DenseVectorFieldMapper.DenseVectorFieldType ft = createFieldType();
DenseVectorFieldType ft = createFieldType();
assertEquals(indexed, ft.isSearchable());
}

public void testIsAggregatable() {
DenseVectorFieldMapper.DenseVectorFieldType ft = createFieldType();
DenseVectorFieldType ft = createFieldType();
assertFalse(ft.isAggregatable());
}

public void testFielddataBuilder() {
DenseVectorFieldMapper.DenseVectorFieldType ft = createFieldType();
DenseVectorFieldType ft = createFieldType();
assertNotNull(ft.fielddataBuilder("index", () -> {
throw new UnsupportedOperationException();
}));
}

public void testDocValueFormat() {
DenseVectorFieldMapper.DenseVectorFieldType ft = createFieldType();
DenseVectorFieldType ft = createFieldType();
expectThrows(IllegalArgumentException.class, () -> ft.docValueFormat(null, null));
}

public void testFetchSourceValue() throws IOException {
DenseVectorFieldMapper.DenseVectorFieldType ft = createFieldType();
DenseVectorFieldType ft = createFieldType();
List<Double> vector = List.of(0.0, 1.0, 2.0, 3.0, 4.0);
assertEquals(vector, fetchSourceValue(ft, vector));
}

public void testCreateKnnQuery() {
DenseVectorFieldType unindexedField = new DenseVectorFieldType("f", Version.CURRENT,
3, false, VectorSimilarity.cosine, Collections.emptyMap());
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> unindexedField.createKnnQuery(
new float[]{0.3f, 0.1f, 1.0f}, 10));
assertThat(e.getMessage(), containsString("[knn] queries are not supported if [index] is disabled"));

DenseVectorFieldType dotProductField = new DenseVectorFieldType("f", Version.CURRENT,
3, true, VectorSimilarity.dot_product, Collections.emptyMap());
e = expectThrows(IllegalArgumentException.class, () -> dotProductField.createKnnQuery(
new float[]{0.3f, 0.1f, 1.0f}, 10));
assertThat(e.getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors."));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,6 @@ public void testWrongFieldType() {
assertThat(e.getMessage(), containsString("[knn] queries are only supported on [dense_vector] fields"));
}

public void testUnindexedField() {
SearchExecutionContext context = createSearchExecutionContext();
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(UNINDEXED_VECTOR_FIELD,
new float[]{1.0f, 1.0f, 1.0f}, 10);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context));
assertThat(e.getMessage(), containsString("[knn] queries are not supported if [index] is disabled"));
}

@Override
public void testValidOutput() {
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] {1.0f, 2.0f, 3.0f}, 10);
Expand Down

0 comments on commit 7f01138

Please sign in to comment.