Skip to content

Commit

Permalink
[ML] Adding elementType to service settings and persisting byte inste…
Browse files Browse the repository at this point in the history
…ad of int8 (#106700)

* Adding element type method and storing byte instead of int8

* Adding more tests and checking for null

* Converting between element type and cohere embedding type

* Update server/src/main/java/org/elasticsearch/inference/ServiceSettings.java

Co-authored-by: David Kyle <david.kyle@elastic.co>

* enum tests

---------

Co-authored-by: David Kyle <david.kyle@elastic.co>
  • Loading branch information
jonathan-buttner and davidkyle committed Mar 27, 2024
1 parent 720188e commit 2b6b7ae
Show file tree
Hide file tree
Showing 16 changed files with 308 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ static TransportVersion def(int id) {
public static final TransportVersion KNN_QUERY_VECTOR_BUILDER = def(8_612_00_0);
public static final TransportVersion USE_DATA_STREAM_GLOBAL_RETENTION = def(8_613_00_0);
public static final TransportVersion ML_COMPLETION_INFERENCE_SERVICE_ADDED = def(8_614_00_0);
public static final TransportVersion ML_INFERENCE_EMBEDDING_BYTE_ADDED = def(8_615_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,10 @@ static Function<StringBuilder, StringBuilder> errorByteElementsAppender(byte[] v
}

public abstract double computeDotProduct(VectorData vectorData);

public static ElementType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}
}

static final Map<String, ElementType> namesToElementType = Map.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.elasticsearch.inference;

import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.xcontent.ToXContentObject;

public interface ServiceSettings extends ToXContentObject, VersionedNamedWriteable {
Expand Down Expand Up @@ -36,4 +37,13 @@ default Integer dimensions() {
return null;
}

/**
* The data type for the embeddings this service works with. Defaults to null,
* Text Embedding models should return a non-null value
*
* @return the element type
*/
default DenseVectorFieldMapper.ElementType elementType() {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}

if (embeddingType != null) {
builder.field(EMBEDDING_TYPES_FIELD, List.of(embeddingType));
builder.field(EMBEDDING_TYPES_FIELD, List.of(embeddingType.toRequestString()));
}

if (taskSettings.getTruncation() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,6 @@ private CohereEmbeddingsModel updateModelWithEmbeddingDetails(CohereEmbeddingsMo

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_CLASS_CLUSTER_ADDED;
return TransportVersions.ML_INFERENCE_EMBEDDING_BYTE_ADDED;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@

package org.elasticsearch.xpack.inference.services.cohere.embeddings;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;

import java.util.Arrays;
import java.util.EnumSet;
import java.util.Locale;
import java.util.Map;

/**
* Defines the type of embedding that the cohere api should return for a request.
Expand All @@ -20,22 +28,94 @@ public enum CohereEmbeddingType {
/**
* Use this when you want to get back the default float embeddings. Valid for all models.
*/
FLOAT,
FLOAT(DenseVectorFieldMapper.ElementType.FLOAT, RequestConstants.FLOAT),
/**
* Use this when you want to get back signed int8 embeddings. Valid for only v3 models.
*/
INT8;
INT8(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8),
/**
* This is a synonym for INT8
*/
BYTE(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8);

private static final class RequestConstants {
private static final String FLOAT = "float";
private static final String INT8 = "int8";
}

private static final Map<DenseVectorFieldMapper.ElementType, CohereEmbeddingType> ELEMENT_TYPE_TO_COHERE_EMBEDDING = Map.of(
DenseVectorFieldMapper.ElementType.FLOAT,
FLOAT,
DenseVectorFieldMapper.ElementType.BYTE,
BYTE
);
static final EnumSet<DenseVectorFieldMapper.ElementType> SUPPORTED_ELEMENT_TYPES = EnumSet.copyOf(
ELEMENT_TYPE_TO_COHERE_EMBEDDING.keySet()
);

private final DenseVectorFieldMapper.ElementType elementType;
private final String requestString;

CohereEmbeddingType(DenseVectorFieldMapper.ElementType elementType, String requestString) {
this.elementType = elementType;
this.requestString = requestString;
}

@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
}

public String toRequestString() {
return requestString;
}

public static String toLowerCase(CohereEmbeddingType type) {
return type.toString().toLowerCase(Locale.ROOT);
}

public static CohereEmbeddingType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}

public static CohereEmbeddingType fromElementType(DenseVectorFieldMapper.ElementType elementType) {
var embedding = ELEMENT_TYPE_TO_COHERE_EMBEDDING.get(elementType);

if (embedding == null) {
var validElementTypes = SUPPORTED_ELEMENT_TYPES.stream()
.map(value -> value.toString().toLowerCase(Locale.ROOT))
.toArray(String[]::new);
Arrays.sort(validElementTypes);

throw new IllegalArgumentException(
Strings.format(
"Element type [%s] does not map to a Cohere embedding value, must be one of [%s]",
elementType,
String.join(", ", validElementTypes)
)
);
}

return embedding;
}

public DenseVectorFieldMapper.ElementType toElementType() {
return elementType;
}

/**
* Returns an embedding type that is known based on the transport version provided. If the embedding type enum was not yet
* introduced it will be defaulted INT8.
*
* @param embeddingType the value to translate if necessary
* @param version the version that dictates the translation
* @return the embedding type that is known to the version passed in
*/
public static CohereEmbeddingType translateToVersion(CohereEmbeddingType embeddingType, TransportVersion version) {
if (version.before(TransportVersions.ML_INFERENCE_EMBEDDING_BYTE_ADDED) && embeddingType == BYTE) {
return INT8;
}

return embeddingType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.ToXContentObject;
Expand All @@ -22,32 +22,21 @@

import java.io.IOException;
import java.util.EnumSet;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;

public class CohereEmbeddingsServiceSettings implements ServiceSettings {
public static final String NAME = "cohere_embeddings_service_settings";

static final String EMBEDDING_TYPE = "embedding_type";
static final String EMBEDDING_TYPE_BYTE = "byte";

public static CohereEmbeddingsServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
var commonServiceSettings = CohereServiceSettings.fromMap(map, context);
translateEmbeddingType(map, context);

CohereEmbeddingType embeddingTypes = extractOptionalEnum(
map,
EMBEDDING_TYPE,
ModelConfigurations.SERVICE_SETTINGS,
CohereEmbeddingType::fromString,
EnumSet.allOf(CohereEmbeddingType.class),
validationException
);

CohereEmbeddingType embeddingTypes = parseEmbeddingType(map, context, validationException);

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
Expand All @@ -56,37 +45,51 @@ public static CohereEmbeddingsServiceSettings fromMap(Map<String, Object> map, C
return new CohereEmbeddingsServiceSettings(commonServiceSettings, embeddingTypes);
}

private static void translateEmbeddingType(Map<String, Object> map, ConfigurationParseContext context) {
if (ConfigurationParseContext.isRequestContext(context) == false || map.containsKey(EMBEDDING_TYPE) == false) {
return;
private static CohereEmbeddingType parseEmbeddingType(
Map<String, Object> map,
ConfigurationParseContext context,
ValidationException validationException
) {
if (context == ConfigurationParseContext.REQUEST) {
return Objects.requireNonNullElse(
extractOptionalEnum(
map,
EMBEDDING_TYPE,
ModelConfigurations.SERVICE_SETTINGS,
CohereEmbeddingType::fromString,
EnumSet.allOf(CohereEmbeddingType.class),
validationException
),
CohereEmbeddingType.FLOAT
);
}

ValidationException validationException = new ValidationException();

String embeddingType = extractRequiredString(map, EMBEDDING_TYPE, ModelConfigurations.SERVICE_SETTINGS, validationException);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
DenseVectorFieldMapper.ElementType elementType = Objects.requireNonNullElse(
extractOptionalEnum(
map,
EMBEDDING_TYPE,
ModelConfigurations.SERVICE_SETTINGS,
DenseVectorFieldMapper.ElementType::fromString,
CohereEmbeddingType.SUPPORTED_ELEMENT_TYPES,
validationException
),
DenseVectorFieldMapper.ElementType.FLOAT
);

assert embeddingType != null;
if (embeddingType.toLowerCase(Locale.ROOT).equals(EMBEDDING_TYPE_BYTE)) {
map.put(EMBEDDING_TYPE, CohereEmbeddingType.INT8.toString());
} else {
map.put(EMBEDDING_TYPE, embeddingType);
}
return CohereEmbeddingType.fromElementType(elementType);
}

private final CohereServiceSettings commonSettings;
private final CohereEmbeddingType embeddingType;

public CohereEmbeddingsServiceSettings(CohereServiceSettings commonSettings, @Nullable CohereEmbeddingType embeddingType) {
public CohereEmbeddingsServiceSettings(CohereServiceSettings commonSettings, CohereEmbeddingType embeddingType) {
this.commonSettings = commonSettings;
this.embeddingType = embeddingType;
this.embeddingType = Objects.requireNonNull(embeddingType);
}

public CohereEmbeddingsServiceSettings(StreamInput in) throws IOException {
commonSettings = new CohereServiceSettings(in);
embeddingType = in.readOptionalEnum(CohereEmbeddingType.class);
embeddingType = Objects.requireNonNullElse(in.readOptionalEnum(CohereEmbeddingType.class), CohereEmbeddingType.FLOAT);
}

public CohereServiceSettings getCommonSettings() {
Expand All @@ -97,6 +100,11 @@ public CohereEmbeddingType getEmbeddingType() {
return embeddingType;
}

@Override
public DenseVectorFieldMapper.ElementType elementType() {
return embeddingType == null ? DenseVectorFieldMapper.ElementType.FLOAT : embeddingType.toElementType();
}

@Override
public String getWriteableName() {
return NAME;
Expand All @@ -107,7 +115,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.startObject();

commonSettings.toXContentFragment(builder);
builder.field(EMBEDDING_TYPE, embeddingType);
builder.field(EMBEDDING_TYPE, elementType());

builder.endObject();
return builder;
Expand All @@ -126,7 +134,7 @@ public TransportVersion getMinimalSupportedVersion() {
@Override
public void writeTo(StreamOutput out) throws IOException {
commonSettings.writeTo(out);
out.writeOptionalEnum(embeddingType);
out.writeOptionalEnum(CohereEmbeddingType.translateToVersion(embeddingType, out.getTransportVersion()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.settings.InternalServiceSettings;
Expand Down Expand Up @@ -106,4 +107,9 @@ public void writeTo(StreamOutput out) throws IOException {
public Integer dimensions() {
return 384;
}

@Override
public DenseVectorFieldMapper.ElementType elementType() {
return DenseVectorFieldMapper.ElementType.FLOAT;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SimilarityMeasure;
Expand Down Expand Up @@ -156,6 +157,11 @@ public Integer maxInputTokens() {
return maxInputTokens;
}

@Override
public DenseVectorFieldMapper.ElementType elementType() {
return DenseVectorFieldMapper.ElementType.FLOAT;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SimilarityMeasure;
Expand Down Expand Up @@ -211,6 +212,11 @@ public String modelId() {
return modelId;
}

@Override
public DenseVectorFieldMapper.ElementType elementType() {
return DenseVectorFieldMapper.ElementType.FLOAT;
}

@Override
public String getWriteableName() {
return NAME;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,22 @@ public void testXContent_InputTypeSearch_EmbeddingTypesInt8_TruncateNone() throw
{"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}"""));
}

public void testXContent_InputTypeSearch_EmbeddingTypesByte_TruncateNone() throws IOException {
var entity = new CohereEmbeddingsRequestEntity(
List.of("abc"),
new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE),
"model",
CohereEmbeddingType.BYTE
);

XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);

MatcherAssert.assertThat(xContentResult, is("""
{"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}"""));
}

public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException {
var entity = new CohereEmbeddingsRequestEntity(List.of("abc"), CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null);

Expand Down

0 comments on commit 2b6b7ae

Please sign in to comment.