Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for configuring HNSW parameters #79193

Merged
merged 4 commits into from Oct 18, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -38,9 +38,9 @@ public CodecService(@Nullable MapperService mapperService) {
codecs.put(BEST_COMPRESSION_CODEC, new Lucene90Codec(Lucene90Codec.Mode.BEST_COMPRESSION));
} else {
codecs.put(DEFAULT_CODEC,
new PerFieldMappingPostingFormatCodec(Lucene90Codec.Mode.BEST_SPEED, mapperService));
new PerFieldMappingCodec(Lucene90Codec.Mode.BEST_SPEED, mapperService));
codecs.put(BEST_COMPRESSION_CODEC,
new PerFieldMappingPostingFormatCodec(Lucene90Codec.Mode.BEST_COMPRESSION, mapperService));
new PerFieldMappingCodec(Lucene90Codec.Mode.BEST_COMPRESSION, mapperService));
}
codecs.put(LUCENE_DEFAULT_CODEC, Codec.getDefault());
for (String codec : Codec.availableCodecs()) {
Expand Down
Expand Up @@ -10,31 +10,32 @@

import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.PostingsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90Codec;
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.index.mapper.MapperService;

/**
* {@link PerFieldMappingPostingFormatCodec This postings format} is the default
* {@link PostingsFormat} for Elasticsearch. It utilizes the
* {@link MapperService} to lookup a {@link PostingsFormat} per field. This
* allows users to change the low level postings format for individual fields
* per index in real time via the mapping API. If no specific postings format is
* configured for a specific field the default postings format is used.
* {@link PerFieldMappingCodec This Lucene codec} provides the default
* {@link PostingsFormat} and {@link KnnVectorsFormat} for Elasticsearch. It utilizes the
* {@link MapperService} to lookup a {@link PostingsFormat} and {@link KnnVectorsFormat} per field. This
* allows users to change the low level postings format and vectors format for individual fields
* per index in real time via the mapping API. If no specific postings format or vector format is
* configured for a specific field the default postings or vector format is used.
*/
public class PerFieldMappingPostingFormatCodec extends Lucene90Codec {
public class PerFieldMappingCodec extends Lucene90Codec {
mayya-sharipova marked this conversation as resolved.
Show resolved Hide resolved
private final MapperService mapperService;

private final DocValuesFormat docValuesFormat = new Lucene90DocValuesFormat();

static {
assert Codec.forName(Lucene.LATEST_CODEC).getClass().isAssignableFrom(PerFieldMappingPostingFormatCodec.class) :
"PerFieldMappingPostingFormatCodec must subclass the latest " + "lucene codec: " + Lucene.LATEST_CODEC;
assert Codec.forName(Lucene.LATEST_CODEC).getClass().isAssignableFrom(PerFieldMappingCodec.class) :
"PerFieldMappingCodec must subclass the latest " + "lucene codec: " + Lucene.LATEST_CODEC;
}

public PerFieldMappingPostingFormatCodec(Mode compressionMode, MapperService mapperService) {
public PerFieldMappingCodec(Mode compressionMode, MapperService mapperService) {
super(compressionMode);
this.mapperService = mapperService;
}
Expand All @@ -48,6 +49,15 @@ public PostingsFormat getPostingsFormatForField(String field) {
return format;
}

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
KnnVectorsFormat format = mapperService.mappingLookup().getKnnVectorsFormatForField(field);
if (format == null) {
return super.getKnnVectorsFormatForField(field);
}
return format;
}

@Override
public DocValuesFormat getDocValuesFormatForField(String field) {
return docValuesFormat;
Expand Down
Expand Up @@ -8,6 +8,7 @@

package org.elasticsearch.index.mapper;

import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.PostingsFormat;
import org.elasticsearch.cluster.metadata.DataStream;
import org.elasticsearch.index.IndexSettings;
Expand Down Expand Up @@ -228,6 +229,20 @@ public PostingsFormat getPostingsFormat(String field) {
return completionFields.contains(field) ? CompletionFieldMapper.postingsFormat() : null;
}

/**
* Returns the knn vectors format for a particular field
* @param field the field to retrieve a knn vectors format for
* @return the knn vectors format for the field, or {@code null} if the default format should be used
*/
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
Mapper fieldMapper = fieldMappers.get(field);
if (fieldMapper instanceof PerFieldKnnVectorsFormatFieldMapper) {
return ((PerFieldKnnVectorsFormatFieldMapper) fieldMapper).getKnnVectorsFormatForField();
} else {
return null;
}
}

void checkLimits(IndexSettings settings) {
checkFieldLimit(settings.getMappingTotalFieldsLimit());
checkObjectDepthLimit(settings.getMappingDepthLimit());
Expand Down
@@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.index.mapper;

import org.apache.lucene.codecs.KnnVectorsFormat;

/**
* Field mapper used for the only purpose to provide a custom knn vectors format.
* For internal use only.
*/

public interface PerFieldKnnVectorsFormatFieldMapper {

/**
* Returns the knn vectors format that is customly set up for this field or {@code null} if
* the format is not set up or if the set up format matches the default format.
* @return the knn vectors format for the field, or {@code null} if the default format should be used
*/
KnnVectorsFormat getKnnVectorsFormatForField();
}
Expand Up @@ -40,7 +40,7 @@ public class CodecTests extends ESTestCase {

public void testResolveDefaultCodecs() throws Exception {
CodecService codecService = createCodecService();
assertThat(codecService.codec("default"), instanceOf(PerFieldMappingPostingFormatCodec.class));
assertThat(codecService.codec("default"), instanceOf(PerFieldMappingCodec.class));
assertThat(codecService.codec("default"), instanceOf(Lucene90Codec.class));
}

Expand Down
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.unit.Fuzziness;
import org.elasticsearch.index.codec.PerFieldMappingCodec;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
Expand All @@ -38,7 +39,6 @@
import org.elasticsearch.index.analysis.IndexAnalyzers;
import org.elasticsearch.index.analysis.NamedAnalyzer;
import org.elasticsearch.index.codec.CodecService;
import org.elasticsearch.index.codec.PerFieldMappingPostingFormatCodec;
import org.hamcrest.FeatureMatcher;
import org.hamcrest.Matcher;
import org.hamcrest.Matchers;
Expand Down Expand Up @@ -122,8 +122,8 @@ public void testPostingsFormat() throws IOException {
MapperService mapperService = createMapperService(fieldMapping(this::minimalMapping));
CodecService codecService = new CodecService(mapperService);
Codec codec = codecService.codec("default");
assertThat(codec, instanceOf(PerFieldMappingPostingFormatCodec.class));
PerFieldMappingPostingFormatCodec perFieldCodec = (PerFieldMappingPostingFormatCodec) codec;
assertThat(codec, instanceOf(PerFieldMappingCodec.class));
PerFieldMappingCodec perFieldCodec = (PerFieldMappingCodec) codec;
assertThat(perFieldCodec.getPostingsFormatForField("field"), instanceOf(Completion90PostingsFormat.class));
}

Expand Down
Expand Up @@ -18,6 +18,7 @@ setup:
dims: 5
index: true
similarity: dot_product

- do:
index:
index: test-index
Expand Down
Expand Up @@ -19,6 +19,10 @@ setup:
dims: 3
index: true
similarity: l2_norm
index_options:
type: hnsw
m: 15
ef_construction: 50

---
"Indexing of Dense vectors should error when dims don't match defined in the mapping":
Expand Down
Expand Up @@ -8,6 +8,8 @@

package org.elasticsearch.xpack.vectors.mapper;

import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsFormat;
import org.apache.lucene.document.BinaryDocValuesField;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnVectorField;
Expand All @@ -16,6 +18,10 @@
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.index.mapper.MappingParser;
import org.elasticsearch.index.mapper.PerFieldKnnVectorsFormatFieldMapper;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser.Token;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.index.fielddata.IndexFieldData;
Expand All @@ -40,17 +46,21 @@
import java.time.ZoneId;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;

import static org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene90.Lucene90HnswVectorsFormat.DEFAULT_MAX_CONN;
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;

/**
* A {@link FieldMapper} for indexing a dense vector of floats.
*/
public class DenseVectorFieldMapper extends FieldMapper {
public class DenseVectorFieldMapper extends FieldMapper implements PerFieldKnnVectorsFormatFieldMapper {

public static final String CONTENT_TYPE = "dense_vector";
public static short MAX_DIMS_COUNT = 2048; //maximum allowed number of dimensions
public static final IndexOptions DEFAULT_INDEX_OPTIONS = new HnswIndexOptions(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH);
mayya-sharipova marked this conversation as resolved.
Show resolved Hide resolved
private static final byte INT_BYTES = 4;

private static DenseVectorFieldMapper toType(FieldMapper in) {
Expand All @@ -73,6 +83,8 @@ public static class Builder extends FieldMapper.Builder {
private final Parameter<Boolean> indexed = Parameter.indexParam(m -> toType(m).indexed, false);
private final Parameter<VectorSimilarity> similarity = Parameter.enumParam(
"similarity", false, m -> toType(m).similarity, null, VectorSimilarity.class);
private final Parameter<IndexOptions> indexOptions = new Parameter<>("index_options", false, () -> null,
(n, c, o) -> o == null ? null : parseIndexOptions(n, o), m -> toType(m).indexOptions);
private final Parameter<Map<String, String>> meta = Parameter.metaParam();

final Version indexVersionCreated;
Expand All @@ -84,11 +96,13 @@ public Builder(String name, Version indexVersionCreated) {
this.indexed.requiresParameters(similarity);
this.similarity.setSerializerCheck((id, ic, v) -> v != null);
this.similarity.requiresParameters(indexed);
this.indexOptions.requiresParameters(indexed);
this.indexOptions.setSerializerCheck((id, ic, v) -> v != null);
}

@Override
protected List<Parameter<?>> getParameters() {
return List.of(dims, indexed, similarity, meta);
return List.of(dims, indexed, similarity, indexOptions, meta);
}

@Override
Expand All @@ -100,6 +114,7 @@ public DenseVectorFieldMapper build(MapperBuilderContext context) {
dims.getValue(),
indexed.getValue(),
similarity.getValue(),
indexOptions.getValue(),
indexVersionCreated,
multiFieldsBuilder.build(this, context),
copyTo.build());
Expand All @@ -116,6 +131,71 @@ enum VectorSimilarity {
}
}

abstract static class IndexOptions implements ToXContent {
mayya-sharipova marked this conversation as resolved.
Show resolved Hide resolved
final String type;
IndexOptions(String type) {
this.type = type;
}
}

static class HnswIndexOptions extends IndexOptions {
private final int m;
private final int efConstruction;

static IndexOptions parseIndexOptions(String fieldName, Map<String, ?> indexOptionsMap) {
Object mNode = indexOptionsMap.remove("m");
Object efConstructionNode = indexOptionsMap.remove("ef_construction");
if (mNode == null) {
throw new MapperParsingException("[index_options] of type [hnsw] requires field [m] to be configured");
}
if (efConstructionNode == null) {
throw new MapperParsingException("[index_options] of type [hnsw] requires field [ef_construction] to be configured");
}
int m = XContentMapValues.nodeIntegerValue(mNode);
int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode);
MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap);
if (m == DEFAULT_MAX_CONN && efConstruction == DEFAULT_BEAM_WIDTH) {
return DEFAULT_INDEX_OPTIONS;
} else {
return new HnswIndexOptions(m, efConstruction);
}
}

private HnswIndexOptions(int m, int efConstruction) {
super("hnsw");
this.m = m;
this.efConstruction = efConstruction;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field("type", type);
builder.field("m", m);
builder.field("ef_construction", efConstruction);
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
HnswIndexOptions that = (HnswIndexOptions) o;
return m == that.m && efConstruction == that.efConstruction;
}

@Override
public int hashCode() {
return Objects.hash(type, m, efConstruction);
}

@Override
public String toString() {
return "{type=" + type + ", m=" + m + ", ef_construction=" + efConstruction + " }";
}
}

public static final TypeParser PARSER
= new TypeParser((n, c) -> new Builder(n, c.indexVersionCreated()), notInMultiFields(CONTENT_TYPE));

Expand Down Expand Up @@ -185,15 +265,17 @@ public Query termQuery(Object value, SearchExecutionContext context) {
private final int dims;
private final boolean indexed;
private final VectorSimilarity similarity;
private final IndexOptions indexOptions;
private final Version indexCreatedVersion;

private DenseVectorFieldMapper(String simpleName, MappedFieldType mappedFieldType, int dims,
boolean indexed, VectorSimilarity similarity,
private DenseVectorFieldMapper(String simpleName, MappedFieldType mappedFieldType, int dims, boolean indexed,
VectorSimilarity similarity, IndexOptions indexOptions,
Version indexCreatedVersion, MultiFields multiFields, CopyTo copyTo) {
super(simpleName, mappedFieldType, multiFields, copyTo);
this.dims = dims;
this.indexed = indexed;
this.similarity = similarity;
this.indexOptions = indexOptions;
this.indexCreatedVersion = indexCreatedVersion;
}

Expand Down Expand Up @@ -289,4 +371,29 @@ protected String contentType() {
public FieldMapper.Builder getMergeBuilder() {
return new Builder(simpleName(), indexCreatedVersion).init(this);
}

public static IndexOptions parseIndexOptions(String fieldName, Object propNode) {
mayya-sharipova marked this conversation as resolved.
Show resolved Hide resolved
@SuppressWarnings("unchecked")
Map<String, ?> indexOptionsMap = (Map<String, ?>) propNode;
Object typeNode = indexOptionsMap.remove("type");
if (typeNode == null) {
throw new MapperParsingException("[index_options] requires field [type] to be configured");
}
String type = XContentMapValues.nodeStringValue(typeNode);
if (type.equals("hnsw")) {
return HnswIndexOptions.parseIndexOptions(fieldName, indexOptionsMap);
} else {
throw new MapperParsingException("Unknown vector index options type [" + type + "] for field [" + fieldName + "]");
}
}

@Override
public KnnVectorsFormat getKnnVectorsFormatForField() {
if (indexOptions == null || indexOptions == DEFAULT_INDEX_OPTIONS) {
return null; // use default format
} else {
HnswIndexOptions hnswIndexOptions = (HnswIndexOptions) indexOptions;
return new Lucene90HnswVectorsFormat(hnswIndexOptions.m, hnswIndexOptions.efConstruction);
}
}
}