Skip to content

Commit

Permalink
Add support for configuring HNSW parameters (#79193)
Browse files Browse the repository at this point in the history
This PR extends the dense_vector type to allow configure HNSW params in
`index_options`:
`m` – max number of connections for each  node,
`ef_construction` – number  of candidate neighbours to track while searching
the graph for each newly inserted node.

```
"mappings": {
  "properties": {
    "my_vector": {
      "type": "dense_vector",
      "dims": 128,
      "index": true,
      "similarity": "l2_norm",
      "index_options": {
        "type" : "hnsw",
        "m" : 15,
        "ef_construction" : 50
      }
    }
  }
}
```

`index_options` as an object is optional. If not provided, the default values from the
current codec will be used.
If `index_options` is provided,  that all parameters related to the specific type
must be provided. 

Relates to #78473
  • Loading branch information
mayya-sharipova committed Oct 18, 2021
1 parent d50c9e8 commit bdf8ca9
Show file tree
Hide file tree
Showing 10 changed files with 273 additions and 20 deletions.
Expand Up @@ -38,9 +38,9 @@ public CodecService(@Nullable MapperService mapperService) {
codecs.put(BEST_COMPRESSION_CODEC, new Lucene90Codec(Lucene90Codec.Mode.BEST_COMPRESSION));
} else {
codecs.put(DEFAULT_CODEC,
new PerFieldMappingPostingFormatCodec(Lucene90Codec.Mode.BEST_SPEED, mapperService));
new PerFieldMapperCodec(Lucene90Codec.Mode.BEST_SPEED, mapperService));
codecs.put(BEST_COMPRESSION_CODEC,
new PerFieldMappingPostingFormatCodec(Lucene90Codec.Mode.BEST_COMPRESSION, mapperService));
new PerFieldMapperCodec(Lucene90Codec.Mode.BEST_COMPRESSION, mapperService));
}
codecs.put(LUCENE_DEFAULT_CODEC, Codec.getDefault());
for (String codec : Codec.availableCodecs()) {
Expand Down
Expand Up @@ -10,31 +10,32 @@

import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.PostingsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90Codec;
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.index.mapper.MapperService;

/**
* {@link PerFieldMappingPostingFormatCodec This postings format} is the default
* {@link PostingsFormat} for Elasticsearch. It utilizes the
* {@link MapperService} to lookup a {@link PostingsFormat} per field. This
* allows users to change the low level postings format for individual fields
* per index in real time via the mapping API. If no specific postings format is
* configured for a specific field the default postings format is used.
* {@link PerFieldMapperCodec This Lucene codec} provides the default
* {@link PostingsFormat} and {@link KnnVectorsFormat} for Elasticsearch. It utilizes the
* {@link MapperService} to lookup a {@link PostingsFormat} and {@link KnnVectorsFormat} per field. This
* allows users to change the low level postings format and vectors format for individual fields
* per index in real time via the mapping API. If no specific postings format or vector format is
* configured for a specific field the default postings or vector format is used.
*/
public class PerFieldMappingPostingFormatCodec extends Lucene90Codec {
public class PerFieldMapperCodec extends Lucene90Codec {
private final MapperService mapperService;

private final DocValuesFormat docValuesFormat = new Lucene90DocValuesFormat();

static {
assert Codec.forName(Lucene.LATEST_CODEC).getClass().isAssignableFrom(PerFieldMappingPostingFormatCodec.class) :
"PerFieldMappingPostingFormatCodec must subclass the latest " + "lucene codec: " + Lucene.LATEST_CODEC;
assert Codec.forName(Lucene.LATEST_CODEC).getClass().isAssignableFrom(PerFieldMapperCodec.class) :
"PerFieldMapperCodec must subclass the latest " + "lucene codec: " + Lucene.LATEST_CODEC;
}

public PerFieldMappingPostingFormatCodec(Mode compressionMode, MapperService mapperService) {
public PerFieldMapperCodec(Mode compressionMode, MapperService mapperService) {
super(compressionMode);
this.mapperService = mapperService;
}
Expand All @@ -48,6 +49,15 @@ public PostingsFormat getPostingsFormatForField(String field) {
return format;
}

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
KnnVectorsFormat format = mapperService.mappingLookup().getKnnVectorsFormatForField(field);
if (format == null) {
return super.getKnnVectorsFormatForField(field);
}
return format;
}

@Override
public DocValuesFormat getDocValuesFormatForField(String field) {
return docValuesFormat;
Expand Down
Expand Up @@ -8,6 +8,7 @@

package org.elasticsearch.index.mapper;

import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.PostingsFormat;
import org.elasticsearch.cluster.metadata.DataStream;
import org.elasticsearch.index.IndexSettings;
Expand Down Expand Up @@ -228,6 +229,20 @@ public PostingsFormat getPostingsFormat(String field) {
return completionFields.contains(field) ? CompletionFieldMapper.postingsFormat() : null;
}

/**
* Returns the knn vectors format for a particular field
* @param field the field to retrieve a knn vectors format for
* @return the knn vectors format for the field, or {@code null} if the default format should be used
*/
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
Mapper fieldMapper = fieldMappers.get(field);
if (fieldMapper instanceof PerFieldKnnVectorsFormatFieldMapper) {
return ((PerFieldKnnVectorsFormatFieldMapper) fieldMapper).getKnnVectorsFormatForField();
} else {
return null;
}
}

void checkLimits(IndexSettings settings) {
checkFieldLimit(settings.getMappingTotalFieldsLimit());
checkObjectDepthLimit(settings.getMappingDepthLimit());
Expand Down
@@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.index.mapper;

import org.apache.lucene.codecs.KnnVectorsFormat;

/**
* Field mapper used for the only purpose to provide a custom knn vectors format.
* For internal use only.
*/

public interface PerFieldKnnVectorsFormatFieldMapper {

/**
* Returns the knn vectors format that is customly set up for this field
* or {@code null} if the format is not set up.
* @return the knn vectors format for the field, or {@code null} if the default format should be used
*/
KnnVectorsFormat getKnnVectorsFormatForField();
}
Expand Up @@ -40,7 +40,7 @@ public class CodecTests extends ESTestCase {

public void testResolveDefaultCodecs() throws Exception {
CodecService codecService = createCodecService();
assertThat(codecService.codec("default"), instanceOf(PerFieldMappingPostingFormatCodec.class));
assertThat(codecService.codec("default"), instanceOf(PerFieldMapperCodec.class));
assertThat(codecService.codec("default"), instanceOf(Lucene90Codec.class));
}

Expand Down
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.unit.Fuzziness;
import org.elasticsearch.index.codec.PerFieldMapperCodec;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
Expand All @@ -38,7 +39,6 @@
import org.elasticsearch.index.analysis.IndexAnalyzers;
import org.elasticsearch.index.analysis.NamedAnalyzer;
import org.elasticsearch.index.codec.CodecService;
import org.elasticsearch.index.codec.PerFieldMappingPostingFormatCodec;
import org.hamcrest.FeatureMatcher;
import org.hamcrest.Matcher;
import org.hamcrest.Matchers;
Expand Down Expand Up @@ -122,8 +122,8 @@ public void testPostingsFormat() throws IOException {
MapperService mapperService = createMapperService(fieldMapping(this::minimalMapping));
CodecService codecService = new CodecService(mapperService);
Codec codec = codecService.codec("default");
assertThat(codec, instanceOf(PerFieldMappingPostingFormatCodec.class));
PerFieldMappingPostingFormatCodec perFieldCodec = (PerFieldMappingPostingFormatCodec) codec;
assertThat(codec, instanceOf(PerFieldMapperCodec.class));
PerFieldMapperCodec perFieldCodec = (PerFieldMapperCodec) codec;
assertThat(perFieldCodec.getPostingsFormatForField("field"), instanceOf(Completion90PostingsFormat.class));
}

Expand Down
Expand Up @@ -18,6 +18,7 @@ setup:
dims: 5
index: true
similarity: dot_product

- do:
index:
index: test-index
Expand Down
Expand Up @@ -19,6 +19,10 @@ setup:
dims: 3
index: true
similarity: l2_norm
index_options:
type: hnsw
m: 15
ef_construction: 50

---
"Indexing of Dense vectors should error when dims don't match defined in the mapping":
Expand Down
Expand Up @@ -8,6 +8,8 @@

package org.elasticsearch.xpack.vectors.mapper;

import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsFormat;
import org.apache.lucene.document.BinaryDocValuesField;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnVectorField;
Expand All @@ -16,6 +18,10 @@
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.index.mapper.MappingParser;
import org.elasticsearch.index.mapper.PerFieldKnnVectorsFormatFieldMapper;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser.Token;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.index.fielddata.IndexFieldData;
Expand All @@ -40,14 +46,15 @@
import java.time.ZoneId;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;

import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;

/**
* A {@link FieldMapper} for indexing a dense vector of floats.
*/
public class DenseVectorFieldMapper extends FieldMapper {
public class DenseVectorFieldMapper extends FieldMapper implements PerFieldKnnVectorsFormatFieldMapper {

public static final String CONTENT_TYPE = "dense_vector";
public static short MAX_DIMS_COUNT = 2048; //maximum allowed number of dimensions
Expand All @@ -73,6 +80,8 @@ public static class Builder extends FieldMapper.Builder {
private final Parameter<Boolean> indexed = Parameter.indexParam(m -> toType(m).indexed, false);
private final Parameter<VectorSimilarity> similarity = Parameter.enumParam(
"similarity", false, m -> toType(m).similarity, null, VectorSimilarity.class);
private final Parameter<IndexOptions> indexOptions = new Parameter<>("index_options", false, () -> null,
(n, c, o) -> o == null ? null : parseIndexOptions(n, o), m -> toType(m).indexOptions);
private final Parameter<Map<String, String>> meta = Parameter.metaParam();

final Version indexVersionCreated;
Expand All @@ -84,11 +93,13 @@ public Builder(String name, Version indexVersionCreated) {
this.indexed.requiresParameters(similarity);
this.similarity.setSerializerCheck((id, ic, v) -> v != null);
this.similarity.requiresParameters(indexed);
this.indexOptions.requiresParameters(indexed);
this.indexOptions.setSerializerCheck((id, ic, v) -> v != null);
}

@Override
protected List<Parameter<?>> getParameters() {
return List.of(dims, indexed, similarity, meta);
return List.of(dims, indexed, similarity, indexOptions, meta);
}

@Override
Expand All @@ -100,6 +111,7 @@ public DenseVectorFieldMapper build(MapperBuilderContext context) {
dims.getValue(),
indexed.getValue(),
similarity.getValue(),
indexOptions.getValue(),
indexVersionCreated,
multiFieldsBuilder.build(this, context),
copyTo.build());
Expand All @@ -116,6 +128,67 @@ enum VectorSimilarity {
}
}

private abstract static class IndexOptions implements ToXContent {
final String type;
IndexOptions(String type) {
this.type = type;
}
}

private static class HnswIndexOptions extends IndexOptions {
private final int m;
private final int efConstruction;

static IndexOptions parseIndexOptions(String fieldName, Map<String, ?> indexOptionsMap) {
Object mNode = indexOptionsMap.remove("m");
Object efConstructionNode = indexOptionsMap.remove("ef_construction");
if (mNode == null) {
throw new MapperParsingException("[index_options] of type [hnsw] requires field [m] to be configured");
}
if (efConstructionNode == null) {
throw new MapperParsingException("[index_options] of type [hnsw] requires field [ef_construction] to be configured");
}
int m = XContentMapValues.nodeIntegerValue(mNode);
int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode);
MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap);
return new HnswIndexOptions(m, efConstruction);
}

private HnswIndexOptions(int m, int efConstruction) {
super("hnsw");
this.m = m;
this.efConstruction = efConstruction;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field("type", type);
builder.field("m", m);
builder.field("ef_construction", efConstruction);
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
HnswIndexOptions that = (HnswIndexOptions) o;
return m == that.m && efConstruction == that.efConstruction;
}

@Override
public int hashCode() {
return Objects.hash(type, m, efConstruction);
}

@Override
public String toString() {
return "{type=" + type + ", m=" + m + ", ef_construction=" + efConstruction + " }";
}
}

public static final TypeParser PARSER
= new TypeParser((n, c) -> new Builder(n, c.indexVersionCreated()), notInMultiFields(CONTENT_TYPE));

Expand Down Expand Up @@ -185,15 +258,17 @@ public Query termQuery(Object value, SearchExecutionContext context) {
private final int dims;
private final boolean indexed;
private final VectorSimilarity similarity;
private final IndexOptions indexOptions;
private final Version indexCreatedVersion;

private DenseVectorFieldMapper(String simpleName, MappedFieldType mappedFieldType, int dims,
boolean indexed, VectorSimilarity similarity,
private DenseVectorFieldMapper(String simpleName, MappedFieldType mappedFieldType, int dims, boolean indexed,
VectorSimilarity similarity, IndexOptions indexOptions,
Version indexCreatedVersion, MultiFields multiFields, CopyTo copyTo) {
super(simpleName, mappedFieldType, multiFields, copyTo);
this.dims = dims;
this.indexed = indexed;
this.similarity = similarity;
this.indexOptions = indexOptions;
this.indexCreatedVersion = indexCreatedVersion;
}

Expand Down Expand Up @@ -289,4 +364,29 @@ protected String contentType() {
public FieldMapper.Builder getMergeBuilder() {
return new Builder(simpleName(), indexCreatedVersion).init(this);
}

private static IndexOptions parseIndexOptions(String fieldName, Object propNode) {
@SuppressWarnings("unchecked")
Map<String, ?> indexOptionsMap = (Map<String, ?>) propNode;
Object typeNode = indexOptionsMap.remove("type");
if (typeNode == null) {
throw new MapperParsingException("[index_options] requires field [type] to be configured");
}
String type = XContentMapValues.nodeStringValue(typeNode);
if (type.equals("hnsw")) {
return HnswIndexOptions.parseIndexOptions(fieldName, indexOptionsMap);
} else {
throw new MapperParsingException("Unknown vector index options type [" + type + "] for field [" + fieldName + "]");
}
}

@Override
public KnnVectorsFormat getKnnVectorsFormatForField() {
if (indexOptions == null) {
return null; // use default format
} else {
HnswIndexOptions hnswIndexOptions = (HnswIndexOptions) indexOptions;
return new Lucene90HnswVectorsFormat(hnswIndexOptions.m, hnswIndexOptions.efConstruction);
}
}
}

0 comments on commit bdf8ca9

Please sign in to comment.