Skip to content

Commit

Permalink
[ML] Allow users to specify similarity field (#106493)
Browse files Browse the repository at this point in the history
* Allow users to specify similarity

* Adding l2_norm and e5 fields

* Bumping minimum versions for services

* Cleaning up

* Fixing merge issue

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
  • Loading branch information
jonathan-buttner and elasticmachine committed Mar 27, 2024
1 parent c276b45 commit 8f28a7a
Show file tree
Hide file tree
Showing 25 changed files with 504 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ static TransportVersion def(int id) {
public static final TransportVersion USE_DATA_STREAM_GLOBAL_RETENTION = def(8_613_00_0);
public static final TransportVersion ML_COMPLETION_INFERENCE_SERVICE_ADDED = def(8_614_00_0);
public static final TransportVersion ML_INFERENCE_EMBEDDING_BYTE_ADDED = def(8_615_00_0);
public static final TransportVersion ML_INFERENCE_L2_NORM_SIMILARITY_ADDED = def(8_616_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,18 @@

package org.elasticsearch.inference;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;

import java.util.EnumSet;
import java.util.Locale;

public enum SimilarityMeasure {
COSINE,
DOT_PRODUCT;
DOT_PRODUCT,
L2_NORM;

private static final EnumSet<SimilarityMeasure> BEFORE_L2_NORM_ENUMS = EnumSet.range(COSINE, DOT_PRODUCT);

@Override
public String toString() {
Expand All @@ -22,4 +29,21 @@ public String toString() {
public static SimilarityMeasure fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}

/**
* Returns a similarity measure that is known based on the transport version provided. If the similarity enum was not yet
* introduced it will be defaulted to null.
*
* @param similarityMeasure the value to translate if necessary
* @param version the version that dictates the translation
* @return the similarity that is known to the version passed in
*/
public static SimilarityMeasure translateSimilarity(SimilarityMeasure similarityMeasure, TransportVersion version) {
if (version.before(TransportVersions.ML_INFERENCE_L2_NORM_SIMILARITY_ADDED)
&& BEFORE_L2_NORM_ENUMS.contains(similarityMeasure) == false) {
return null;
}

return similarityMeasure;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public CohereEmbeddingsAction(Sender sender, CohereEmbeddingsModel model) {
Objects.requireNonNull(model);
this.sender = Objects.requireNonNull(sender);
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
model.getServiceSettings().getCommonSettings().getUri(),
model.getServiceSettings().getCommonSettings().uri(),
"Cohere embeddings"
);
requestCreator = new CohereEmbeddingsExecutableRequestCreator(model);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ private static ResponseHandler createEmbeddingsHandler() {

public CohereEmbeddingsExecutableRequestCreator(CohereEmbeddingsModel model) {
this.model = Objects.requireNonNull(model);
account = new CohereAccount(this.model.getServiceSettings().getCommonSettings().getUri(), this.model.getSecretSettings().apiKey());
account = new CohereAccount(this.model.getServiceSettings().getCommonSettings().uri(), this.model.getSecretSettings().apiKey());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public CohereEmbeddingsRequest(CohereAccount account, List<String> input, Cohere
this.input = Objects.requireNonNull(input);
uri = buildUri(this.account.url(), "Cohere", CohereEmbeddingsRequest::buildDefaultUri);
taskSettings = embeddingsModel.getTaskSettings();
model = embeddingsModel.getServiceSettings().getCommonSettings().getModelId();
model = embeddingsModel.getServiceSettings().getCommonSettings().modelId();
embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType();
inferenceEntityId = embeddingsModel.getInferenceEntityId();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,17 +182,14 @@ public static SecureString extractRequiredSecureString(
}

public static SimilarityMeasure extractSimilarity(Map<String, Object> map, String scope, ValidationException validationException) {
String similarity = extractOptionalString(map, SIMILARITY, scope, validationException);

if (similarity != null) {
try {
return SimilarityMeasure.fromString(similarity);
} catch (IllegalArgumentException iae) {
validationException.addValidationError("[" + scope + "] Unknown similarity measure [" + similarity + "]");
}
}

return null;
return extractOptionalEnum(
map,
SIMILARITY,
scope,
SimilarityMeasure::fromString,
EnumSet.allOf(SimilarityMeasure.class),
validationException
);
}

public static String extractRequiredString(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,16 @@ public void checkModelConfig(Model model, ActionListener<Model> listener) {
}

private CohereEmbeddingsModel updateModelWithEmbeddingDetails(CohereEmbeddingsModel model, int embeddingSize) {
var similarityFromModel = model.getServiceSettings().similarity();
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;

CohereEmbeddingsServiceSettings serviceSettings = new CohereEmbeddingsServiceSettings(
new CohereServiceSettings(
model.getServiceSettings().getCommonSettings().getUri(),
SimilarityMeasure.DOT_PRODUCT,
model.getServiceSettings().getCommonSettings().uri(),
similarityToUse,
embeddingSize,
model.getServiceSettings().getCommonSettings().getMaxInputTokens(),
model.getServiceSettings().getCommonSettings().getModelId()
model.getServiceSettings().getCommonSettings().maxInputTokens(),
model.getServiceSettings().getCommonSettings().modelId()
),
model.getServiceSettings().getEmbeddingType()
);
Expand All @@ -232,6 +235,6 @@ private CohereEmbeddingsModel updateModelWithEmbeddingDetails(CohereEmbeddingsMo

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_EMBEDDING_BYTE_ADDED;
return TransportVersions.ML_INFERENCE_L2_NORM_SIMILARITY_ADDED;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ public static CohereServiceSettings fromMap(Map<String, Object> map, Configurati
throw validationException;
}

return new CohereServiceSettings(uri, similarity, dims, maxInputTokens, getModelId(oldModelId, modelId));
return new CohereServiceSettings(uri, similarity, dims, maxInputTokens, modelId(oldModelId, modelId));
}

private static String getModelId(@Nullable String model, @Nullable String modelId) {
private static String modelId(@Nullable String model, @Nullable String modelId) {
return modelId != null ? modelId : model;
}

Expand Down Expand Up @@ -110,23 +110,25 @@ public CohereServiceSettings(StreamInput in) throws IOException {
modelId = in.readOptionalString();
}

public URI getUri() {
public URI uri() {
return uri;
}

public SimilarityMeasure getSimilarity() {
@Override
public SimilarityMeasure similarity() {
return similarity;
}

public Integer getDimensions() {
@Override
public Integer dimensions() {
return dimensions;
}

public Integer getMaxInputTokens() {
public Integer maxInputTokens() {
return maxInputTokens;
}

public String getModelId() {
public String modelId() {
return modelId;
}

Expand Down Expand Up @@ -179,7 +181,7 @@ public TransportVersion getMinimalSupportedVersion() {
public void writeTo(StreamOutput out) throws IOException {
var uriToWrite = uri != null ? uri.toString() : null;
out.writeOptionalString(uriToWrite);
out.writeOptionalEnum(similarity);
out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion()));
out.writeOptionalVInt(dimensions);
out.writeOptionalVInt(maxInputTokens);
out.writeOptionalString(modelId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
Expand Down Expand Up @@ -96,6 +97,16 @@ public CohereServiceSettings getCommonSettings() {
return commonSettings;
}

@Override
public SimilarityMeasure similarity() {
return commonSettings.similarity();
}

@Override
public Integer dimensions() {
return commonSettings.dimensions();
}

public CohereEmbeddingType getEmbeddingType() {
return embeddingType;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ public boolean isInClusterService() {

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_TEXT_EMBEDDING_INFERENCE_SERVICE_ADDED;
return TransportVersions.ML_INFERENCE_L2_NORM_SIMILARITY_ADDED;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.settings.InternalServiceSettings;

Expand All @@ -26,6 +27,9 @@ public class MultilingualE5SmallInternalServiceSettings extends ElasticsearchInt

public static final String NAME = "multilingual_e5_small_service_settings";

static final int DIMENSIONS = 384;
static final SimilarityMeasure SIMILARITY = SimilarityMeasure.COSINE;

public MultilingualE5SmallInternalServiceSettings(int numAllocations, int numThreads, String modelId) {
super(numAllocations, numThreads, modelId);
}
Expand All @@ -45,6 +49,16 @@ public MultilingualE5SmallInternalServiceSettings(StreamInput in) throws IOExcep
*/
public static MultilingualE5SmallInternalServiceSettings.Builder fromMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException();
var requestFields = extractRequestFields(map, validationException);

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

return createBuilder(requestFields);
}

private static RequestFields extractRequestFields(Map<String, Object> map, ValidationException validationException) {
Integer numAllocations = ServiceUtils.removeAsType(map, NUM_ALLOCATIONS, Integer.class);
Integer numThreads = ServiceUtils.removeAsType(map, NUM_THREADS, Integer.class);

Expand All @@ -62,26 +76,23 @@ public static MultilingualE5SmallInternalServiceSettings.Builder fromMap(Map<Str
}
}

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
return new RequestFields(numAllocations, numThreads, modelId);
}

private static MultilingualE5SmallInternalServiceSettings.Builder createBuilder(RequestFields requestFields) {
var builder = new InternalServiceSettings.Builder() {
@Override
public MultilingualE5SmallInternalServiceSettings build() {
return new MultilingualE5SmallInternalServiceSettings(getNumAllocations(), getNumThreads(), getModelId());
}
};
builder.setNumAllocations(numAllocations);
builder.setNumThreads(numThreads);
builder.setModelId(modelId);
builder.setNumAllocations(requestFields.numAllocations);
builder.setNumThreads(requestFields.numThreads);
builder.setModelId(requestFields.modelId);
return builder;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return super.toXContent(builder, params);
}
private record RequestFields(@Nullable Integer numAllocations, @Nullable Integer numThreads, @Nullable String modelId) {}

@Override
public boolean isFragment() {
Expand All @@ -103,9 +114,14 @@ public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
}

@Override
public SimilarityMeasure similarity() {
return SIMILARITY;
}

@Override
public Integer dimensions() {
return 384;
return DIMENSIONS;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public void checkModelConfig(Model model, ActionListener<Model> listener) {
private static HuggingFaceEmbeddingsModel updateModelWithEmbeddingDetails(HuggingFaceEmbeddingsModel model, int embeddingSize) {
var serviceSettings = new HuggingFaceServiceSettings(
model.getServiceSettings().uri(),
null, // Similarity measure is unknown
model.getServiceSettings().similarity(), // we don't know the similarity but use whatever the user specified
embeddingSize,
model.getTokenLimit()
);
Expand All @@ -76,6 +76,6 @@ public String name() {

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.V_8_12_0;
return TransportVersions.ML_INFERENCE_L2_NORM_SIMILARITY_ADDED;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ public TransportVersion getMinimalSupportedVersion() {
public void writeTo(StreamOutput out) throws IOException {
out.writeString(uri.toString());
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) {
out.writeOptionalEnum(similarity);
out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion()));
out.writeOptionalVInt(dimensions);
out.writeOptionalVInt(maxInputTokens);
}
Expand All @@ -145,10 +145,12 @@ public URI uri() {
return uri;
}

@Override
public SimilarityMeasure similarity() {
return similarity;
}

@Override
public Integer dimensions() {
return dimensions;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,14 @@ private OpenAiEmbeddingsModel updateModelWithEmbeddingDetails(OpenAiEmbeddingsMo
);
}

var similarityFromModel = model.getServiceSettings().similarity();
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;

OpenAiEmbeddingsServiceSettings serviceSettings = new OpenAiEmbeddingsServiceSettings(
model.getServiceSettings().modelId(),
model.getServiceSettings().uri(),
model.getServiceSettings().organizationId(),
SimilarityMeasure.DOT_PRODUCT,
similarityToUse,
embeddingSize,
model.getServiceSettings().maxInputTokens(),
model.getServiceSettings().dimensionsSetByUser()
Expand All @@ -263,7 +266,7 @@ private OpenAiEmbeddingsModel updateModelWithEmbeddingDetails(OpenAiEmbeddingsMo

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_COMPLETION_INFERENCE_SERVICE_ADDED;
return TransportVersions.ML_INFERENCE_L2_NORM_SIMILARITY_ADDED;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,12 @@ public String organizationId() {
return organizationId;
}

@Override
public SimilarityMeasure similarity() {
return similarity;
}

@Override
public Integer dimensions() {
return dimensions;
}
Expand Down Expand Up @@ -277,8 +279,9 @@ public void writeTo(StreamOutput out) throws IOException {
var uriToWrite = uri != null ? uri.toString() : null;
out.writeOptionalString(uriToWrite);
out.writeOptionalString(organizationId);

if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) {
out.writeOptionalEnum(similarity);
out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion()));
out.writeOptionalVInt(dimensions);
out.writeOptionalVInt(maxInputTokens);
}
Expand Down

0 comments on commit 8f28a7a

Please sign in to comment.