From aa60000111c39a66d5eb748651e7a75ee0378df4 Mon Sep 17 00:00:00 2001 From: demoncoder-crypto Date: Wed, 9 Apr 2025 20:46:46 +0530 Subject: [PATCH 1/2] feat: [ML] Support binary embeddings from Amazon Bedrock Titan (#125378) --- .../amazonbedrock/AmazonBedrockConstants.java | 4 +- .../AmazonBedrockServiceSettings.java | 56 +++- .../AmazonBedrockEmbeddingsEntityFactory.java | 2 +- ...onBedrockTitanEmbeddingsRequestEntity.java | 9 +- .../AmazonBedrockEmbeddingsResponse.java | 64 +++-- .../AmazonBedrockServiceTests.java | 246 ++++++++++++------ ...BedrockEmbeddingsServiceSettingsTests.java | 115 +++++++- ...rockTitanEmbeddingsRequestEntityTests.java | 12 +- 8 files changed, 388 insertions(+), 120 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java index 1755dac2ac13f..fcc525fd7d050 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java @@ -11,8 +11,9 @@ public class AmazonBedrockConstants { public static final String ACCESS_KEY_FIELD = "access_key"; public static final String SECRET_KEY_FIELD = "secret_key"; public static final String REGION_FIELD = "region"; - public static final String MODEL_FIELD = "model"; + public static final String MODEL_FIELD = "model_id"; public static final String PROVIDER_FIELD = "provider"; + public static final String EMBEDDING_TYPE_FIELD = "embedding_type"; public static final String TEMPERATURE_FIELD = "temperature"; public static final String TOP_P_FIELD = "top_p"; @@ -24,4 +25,5 @@ public class AmazonBedrockConstants { public static final int DEFAULT_MAX_CHUNK_SIZE = 2048; + private AmazonBedrockConstants() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceSettings.java index f11f5818f635c..377893e826e7a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceSettings.java @@ -24,21 +24,45 @@ import java.util.EnumSet; import java.util.Map; import java.util.Objects; +import java.util.Optional; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredEnum; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MODEL_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.PROVIDER_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.REGION_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.EMBEDDING_TYPE_FIELD; public abstract class AmazonBedrockServiceSettings extends FilteredXContentObject implements ServiceSettings { protected static final String AMAZON_BEDROCK_BASE_NAME = "amazon_bedrock"; + public enum AmazonBedrockEmbeddingType { + FLOAT, + BINARY; + + public static AmazonBedrockEmbeddingType fromString(String value) { + return switch (value.toLowerCase()) { + case "float" -> FLOAT; + case "binary" -> BINARY; + default -> throw new IllegalArgumentException("unknown value for embedding type: " + value); + }; + } + + @Override + public String toString() { + return name().toLowerCase(); + } + } + + protected static final AmazonBedrockEmbeddingType DEFAULT_EMBEDDING_TYPE = AmazonBedrockEmbeddingType.FLOAT; + protected final String region; protected final String model; protected final AmazonBedrockProvider provider; protected final RateLimitSettings rateLimitSettings; + protected final AmazonBedrockEmbeddingType embeddingType; // the default requests per minute are defined as per-model in the "Runtime quotas" on AWS // see: https://docs.aws.amazon.com/bedrock/latest/userguide/quotas.html @@ -69,15 +93,24 @@ protected static AmazonBedrockServiceSettings.BaseAmazonBedrockCommonSettings fr AMAZON_BEDROCK_BASE_NAME, context ); + AmazonBedrockEmbeddingType embeddingType = extractOptionalEnum( + map, + EMBEDDING_TYPE_FIELD, + ModelConfigurations.SERVICE_SETTINGS, + AmazonBedrockEmbeddingType::fromString, + EnumSet.allOf(AmazonBedrockEmbeddingType.class), + validationException + ).orElse(DEFAULT_EMBEDDING_TYPE); - return new BaseAmazonBedrockCommonSettings(region, model, provider, rateLimitSettings); + return new BaseAmazonBedrockCommonSettings(region, model, provider, rateLimitSettings, embeddingType); } protected record BaseAmazonBedrockCommonSettings( String region, String model, AmazonBedrockProvider provider, - @Nullable RateLimitSettings rateLimitSettings + @Nullable RateLimitSettings rateLimitSettings, + AmazonBedrockEmbeddingType embeddingType ) {} protected AmazonBedrockServiceSettings(StreamInput in) throws IOException { @@ -85,18 +118,25 @@ protected AmazonBedrockServiceSettings(StreamInput in) throws IOException { this.model = in.readString(); this.provider = in.readEnum(AmazonBedrockProvider.class); this.rateLimitSettings = new RateLimitSettings(in); + if (in.getTransportVersion().onOrAfter(TransportVersions.V_9_0_0)) { // Version set for BWC + this.embeddingType = in.readEnum(AmazonBedrockEmbeddingType.class); + } else { + this.embeddingType = DEFAULT_EMBEDDING_TYPE; + } } protected AmazonBedrockServiceSettings( String region, String model, AmazonBedrockProvider provider, - @Nullable RateLimitSettings rateLimitSettings + @Nullable RateLimitSettings rateLimitSettings, + AmazonBedrockEmbeddingType embeddingType ) { this.region = Objects.requireNonNull(region); this.model = Objects.requireNonNull(model); this.provider = Objects.requireNonNull(provider); this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + this.embeddingType = Objects.requireNonNullElse(embeddingType, DEFAULT_EMBEDDING_TYPE); } @Override @@ -121,12 +161,19 @@ public RateLimitSettings rateLimitSettings() { return rateLimitSettings; } + public AmazonBedrockEmbeddingType embeddingType() { + return embeddingType; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(region); out.writeString(model); out.writeEnum(provider); rateLimitSettings.writeTo(out); + if (out.getTransportVersion().onOrAfter(TransportVersions.V_9_0_0)) { // Version set for BWC + out.writeEnum(embeddingType); + } } public void addBaseXContent(XContentBuilder builder, Params params) throws IOException { @@ -137,6 +184,9 @@ protected void addXContentFragmentOfExposedFields(XContentBuilder builder, Param builder.field(REGION_FIELD, region); builder.field(MODEL_FIELD, model); builder.field(PROVIDER_FIELD, provider.name()); + if (embeddingType != DEFAULT_EMBEDDING_TYPE) { + builder.field(EMBEDDING_TYPE_FIELD, embeddingType.toString()); + } rateLimitSettings.toXContent(builder, params); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockEmbeddingsEntityFactory.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockEmbeddingsEntityFactory.java index 0bd1b191f050f..e5536f7b73895 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockEmbeddingsEntityFactory.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockEmbeddingsEntityFactory.java @@ -36,7 +36,7 @@ public static ToXContent createEntity( if (truncatedInput.size() > 1) { throw new ElasticsearchException("[input] cannot contain more than one string"); } - return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0)); + return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0), serviceSettings.embeddingType()); } case COHERE -> { return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockTitanEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockTitanEmbeddingsRequestEntity.java index 3ce0a433502bf..13637795d3b7a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockTitanEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockTitanEmbeddingsRequestEntity.java @@ -12,19 +12,26 @@ import java.io.IOException; import java.util.Objects; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockServiceSettings.AmazonBedrockEmbeddingType; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.EMBEDDING_TYPE_FIELD; -public record AmazonBedrockTitanEmbeddingsRequestEntity(String inputText) implements ToXContentObject { +public record AmazonBedrockTitanEmbeddingsRequestEntity(String inputText, AmazonBedrockEmbeddingType embeddingType) + implements ToXContentObject { private static final String INPUT_TEXT_FIELD = "inputText"; public AmazonBedrockTitanEmbeddingsRequestEntity { Objects.requireNonNull(inputText); + Objects.requireNonNull(embeddingType); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(INPUT_TEXT_FIELD, inputText); + if (embeddingType == AmazonBedrockEmbeddingType.BINARY) { + builder.field(EMBEDDING_TYPE_FIELD, embeddingType.toString()); + } builder.endObject(); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java index 831bf9938c211..30f2728f40f3f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java @@ -16,16 +16,19 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBytesResults; import org.elasticsearch.xpack.inference.external.response.XContentUtils; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; import org.elasticsearch.xpack.inference.services.amazonbedrock.request.AmazonBedrockRequest; import org.elasticsearch.xpack.inference.services.amazonbedrock.request.embeddings.AmazonBedrockEmbeddingsRequest; import org.elasticsearch.xpack.inference.services.amazonbedrock.response.AmazonBedrockResponse; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockServiceSettings.AmazonBedrockEmbeddingType; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; +import java.util.Base64; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; @@ -42,13 +45,13 @@ public AmazonBedrockEmbeddingsResponse(InvokeModelResponse invokeModelResult) { @Override public InferenceServiceResults accept(AmazonBedrockRequest request) { if (request instanceof AmazonBedrockEmbeddingsRequest asEmbeddingsRequest) { - return fromResponse(result, asEmbeddingsRequest.provider()); + return fromResponse(result, asEmbeddingsRequest); } throw new ElasticsearchException("unexpected request type [" + request.getClass() + "]"); } - public static TextEmbeddingFloatResults fromResponse(InvokeModelResponse response, AmazonBedrockProvider provider) { + public static TextEmbeddingResults fromResponse(InvokeModelResponse response, AmazonBedrockEmbeddingsRequest request) { var charset = StandardCharsets.UTF_8; var bodyText = String.valueOf(charset.decode(response.body().asByteBuffer())); @@ -61,28 +64,33 @@ public static TextEmbeddingFloatResults fromResponse(InvokeModelResponse respons XContentParser.Token token = jsonParser.currentToken(); ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); - var embeddingList = parseEmbeddings(jsonParser, provider); - - return new TextEmbeddingFloatResults(embeddingList); + var embeddingType = request.getServiceSettings().embeddingType(); + if (embeddingType == AmazonBedrockEmbeddingType.BINARY) { + var embeddingList = parseBinaryEmbeddings(jsonParser, request.provider()); + return new TextEmbeddingBytesResults(embeddingList); + } else { + var embeddingList = parseFloatEmbeddings(jsonParser, request.provider()); + return new TextEmbeddingFloatResults(embeddingList); + } } catch (IOException e) { throw new ElasticsearchException(e); } } - private static List parseEmbeddings(XContentParser jsonParser, AmazonBedrockProvider provider) + private static List parseFloatEmbeddings(XContentParser jsonParser, AmazonBedrockProvider provider) throws IOException { switch (provider) { case AMAZONTITAN -> { - return parseTitanEmbeddings(jsonParser); + return parseTitanFloatEmbeddings(jsonParser); } case COHERE -> { - return parseCohereEmbeddings(jsonParser); + return parseCohereFloatEmbeddings(jsonParser); } default -> throw new IOException("Unsupported provider [" + provider + "]"); } } - private static List parseTitanEmbeddings(XContentParser parser) throws IOException { + private static List parseTitanFloatEmbeddings(XContentParser parser) throws IOException { /* Titan response: { @@ -92,11 +100,11 @@ private static List parseTitanEmbeddings(XC */ positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); List embeddingValuesList = parseList(parser, XContentUtils::parseFloat); - var embeddingValues = TextEmbeddingFloatResults.Embedding.of(embeddingValuesList); + TextEmbeddingResults.InferredValue embeddingValues = TextEmbeddingFloatResults.Embedding.of(embeddingValuesList); return List.of(embeddingValues); } - private static List parseCohereEmbeddings(XContentParser parser) throws IOException { + private static List parseCohereFloatEmbeddings(XContentParser parser) throws IOException { /* Cohere response: { @@ -111,17 +119,43 @@ private static List parseCohereEmbeddings(X */ positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE); - List embeddingList = parseList( + List embeddingList = parseList( parser, - AmazonBedrockEmbeddingsResponse::parseCohereEmbeddingsListItem + AmazonBedrockEmbeddingsResponse::parseCohereFloatEmbeddingsListItem ); return embeddingList; } - private static TextEmbeddingFloatResults.Embedding parseCohereEmbeddingsListItem(XContentParser parser) throws IOException { + private static TextEmbeddingResults.InferredValue parseCohereFloatEmbeddingsListItem(XContentParser parser) throws IOException { List embeddingValuesList = parseList(parser, XContentUtils::parseFloat); return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList); } + private static List parseBinaryEmbeddings(XContentParser jsonParser, AmazonBedrockProvider provider) + throws IOException { + switch (provider) { + case AMAZONTITAN -> { + return parseTitanBinaryEmbeddings(jsonParser); + } + default -> throw new IOException("Binary embeddings not supported for provider [" + provider + "]"); + } + } + + private static List parseTitanBinaryEmbeddings(XContentParser parser) throws IOException { + /* + Titan Binary response (structure assumed based on float version): + { + "embedding": "", + "inputTextTokenCount": int + } + */ + positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); + String base64Embedding = parser.text(); + byte[] embeddingBytes = Base64.getDecoder().decode(base64Embedding); + + TextEmbeddingResults.InferredValue embeddingValue = TextEmbeddingBytesResults.Embedding.of(embeddingBytes); + return List.of(embeddingValue); + } + } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 688bd3d4afc56..7a16b35f83e3a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -38,6 +38,8 @@ import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBytesResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.Utils; import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -58,11 +60,13 @@ import org.junit.Before; import java.io.IOException; +import java.util.Base64; import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; @@ -84,6 +88,7 @@ import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -1005,92 +1010,138 @@ public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForProviderTh } public void testInfer_SendsRequest_ForEmbeddingsModel() throws IOException { - var sender = mock(Sender.class); - var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender()).thenReturn(sender); - - var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( - ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), - mockClusterServiceEmpty() - ); + var mockSenderFactory = mock(AmazonBedrockMockRequestSender.Factory.class); + var mockSender = mock(AmazonBedrockMockRequestSender.class); + when(mockSenderFactory.createSender()).thenReturn(mockSender); + + try (var service = createAmazonBedrockService(mockSenderFactory)) { + // Test Titan Float (Default) + var modelTitanFloat = AmazonBedrockEmbeddingsModelTests.createRandomTitanModelWithEmbeddingType( + AmazonBedrockEmbeddingType.FLOAT + ); + List titanFloatEmbeddings = List.of(1.0f, 2.0f, 3.0f); + String titanFloatResponseBody = createTitanResponseBody(titanFloatEmbeddings, null); + when(mockSender.sendEmbeddingsRequest(modelTitanFloat, List.of("input1"))).thenReturn( + new BytesArray(titanFloatResponseBody) + ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { - try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { - var results = new TextEmbeddingFloatResults( - List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })) - ); - requestSender.enqueue(results); + var futureTitanFloat = new PlainActionFuture(); + service.infer( + modelTitanFloat, + null, + null, + null, + List.of("input1"), + false, + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + futureTitanFloat + ); + var resultTitanFloat = futureTitanFloat.actionGet(TIMEOUT); + assertThat(resultTitanFloat, instanceOf(TextEmbeddingFloatResults.class)); + assertThat( + ((TextEmbeddingFloatResults) resultTitanFloat).getResults(), + Matchers.contains( + buildExpectationFloat(titanFloatEmbeddings) + ) + ); - var model = AmazonBedrockEmbeddingsModelTests.createModel( - "id", - "region", - "model", - AmazonBedrockProvider.COHERE, - "access", - "secret" - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.infer( - model, - null, - null, - null, - List.of("abc"), - false, - new HashMap<>(), - InputType.CLASSIFICATION, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + // Test Titan Binary + var modelTitanBinary = AmazonBedrockEmbeddingsModelTests.createRandomTitanModelWithEmbeddingType( + AmazonBedrockEmbeddingType.BINARY + ); + byte[] binaryEmbeddingBytes = new byte[]{1, 2, 3, 4}; + String binaryEmbeddingBase64 = Base64.getEncoder().encodeToString(binaryEmbeddingBytes); + String titanBinaryResponseBody = createTitanResponseBody(null, binaryEmbeddingBase64); + when(mockSender.sendEmbeddingsRequest(modelTitanBinary, List.of("input1"))).thenReturn( + new BytesArray(titanBinaryResponseBody) + ); - var result = listener.actionGet(TIMEOUT); + var futureTitanBinary = new PlainActionFuture(); + service.infer( + modelTitanBinary, + new InferenceAction.Request(null, List.of("input1")), + Map.of(), + TIMEOUT, + futureTitanBinary + ); + var resultTitanBinary = futureTitanBinary.actionGet(TIMEOUT); + assertThat(resultTitanBinary, instanceOf(TextEmbeddingBytesResults.class)); + var byteResults = ((TextEmbeddingBytesResults) resultTitanBinary).getResults(); + assertThat(byteResults, hasSize(1)); + assertArrayEquals(binaryEmbeddingBytes, byteResults.get(0).getEmbedding()); + + // Test Cohere + var modelCohere = AmazonBedrockEmbeddingsModelTests.createRandomCohereModel(); + List cohereEmbeddings1 = List.of(4.0f, 5.0f); + String cohereResponseBody = createCohereResponseBody(List.of(cohereEmbeddings1)); + when(mockSender.sendEmbeddingsRequest(modelCohere, List.of("input1"))).thenReturn( + new BytesArray(cohereResponseBody) + ); + var futureCohere = new PlainActionFuture(); + service.infer( + modelCohere, + null, + null, + null, + List.of("input1"), + false, + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + futureCohere + ); + var resultCohere = futureCohere.actionGet(TIMEOUT); + assertThat(resultCohere, instanceOf(TextEmbeddingFloatResults.class)); + assertThat( + ((TextEmbeddingFloatResults) resultCohere).getResults(), + Matchers.contains(buildExpectationFloat(cohereEmbeddings1)) + ); - assertThat(result.asMap(), Matchers.is(buildExpectationFloat(List.of(new float[] { 0.123F, 0.678F })))); - } + // Verify interactions + verify(mockSenderFactory, times(1)).createSender(); + verifyNoMoreInteractions(mockSenderFactory); } } public void testInfer_SendsRequest_ForChatCompletionModel() throws IOException { - var sender = mock(Sender.class); - var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender()).thenReturn(sender); + var mockSenderFactory = mock(AmazonBedrockMockRequestSender.Factory.class); + var mockSender = spy(new AmazonBedrockMockRequestSender()); + when(mockSenderFactory.createSender()).thenReturn(mockSender); - var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( - ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), - mockClusterServiceEmpty() - ); + try (var service = createAmazonBedrockService(mockSenderFactory)) { + var mockResults = new ChatCompletionResults(List.of(new ChatCompletionResults.Result("test result"))); + mockSender.enqueue(mockResults); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { - try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { - var mockResults = new ChatCompletionResults(List.of(new ChatCompletionResults.Result("test result"))); - requestSender.enqueue(mockResults); + var model = AmazonBedrockChatCompletionModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + "access", + "secret" + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + null, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); - var model = AmazonBedrockChatCompletionModelTests.createModel( - "id", - "region", - "model", - AmazonBedrockProvider.AMAZONTITAN, - "access", - "secret" - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.infer( - model, - null, - null, - null, - List.of("abc"), - false, - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + var result = listener.actionGet(TIMEOUT); - var result = listener.actionGet(TIMEOUT); + assertThat(result.asMap(), Matchers.is(buildExpectationCompletion(List.of("test result")))); - assertThat(result.asMap(), Matchers.is(buildExpectationCompletion(List.of("test result")))); - } + verify(mockSenderFactory, times(1)).createSender(); + verifyNoMoreInteractions(mockSenderFactory); } } @@ -1247,17 +1298,12 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { } private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOException { - var sender = mock(Sender.class); - var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender()).thenReturn(sender); + var mockSenderFactory = mock(AmazonBedrockMockRequestSender.Factory.class); + var mockSender = spy(new AmazonBedrockMockRequestSender()); + when(mockSenderFactory.createSender()).thenReturn(mockSender); - var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( - ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), - mockClusterServiceEmpty() - ); - - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { - try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { + try (var service = createAmazonBedrockService(mockSenderFactory)) { + try (var requestSender = (AmazonBedrockMockRequestSender) mockSender) { { var mockResults1 = new TextEmbeddingFloatResults( List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })) @@ -1309,6 +1355,9 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep ); } } + + verify(mockSenderFactory, times(1)).createSender(); + verifyNoMoreInteractions(mockSenderFactory); } } @@ -1369,4 +1418,35 @@ private Utils.PersistedConfig getPersistedConfigMap( new HashMap<>(Map.of(ModelSecrets.SECRET_SETTINGS, secretSettings)) ); } + + // Helper to create a Titan response body + private String createTitanResponseBody(@Nullable List embeddingFloats, @Nullable String embeddingBase64) { + StringBuilder sb = new StringBuilder("{"); + if (embeddingFloats != null) { + sb.append("\"embedding\": ["); + sb.append(embeddingFloats.stream().map(String::valueOf).collect(Collectors.joining(","))); + sb.append("]"); + } else if (embeddingBase64 != null) { + sb.append("\"embedding\": \"").append(embeddingBase64).append("\""); + } else { + fail("Either float or base64 embedding must be provided for Titan response"); + } + sb.append(", \"inputTextTokenCount\": 10}"); + return sb.toString(); + } + + // Helper to create a Cohere response body + private String createCohereResponseBody(List> embeddings) { + StringBuilder sb = new StringBuilder("{\n \"embeddings\": ["); + for (int i = 0; i < embeddings.size(); i++) { + sb.append("["); + sb.append(embeddings.get(i).stream().map(String::valueOf).collect(Collectors.joining(","))); + sb.append("]"); + if (i < embeddings.size() - 1) { + sb.append(","); + } + } + sb.append("], \"id\": \"test-id\", \"response_type\" : \"embeddings_floats\", \"texts\": [\"input1\"]}"); + return sb.toString(); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettingsTests.java index a100b89e1db6e..3acff99925397 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettingsTests.java @@ -35,9 +35,11 @@ import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MODEL_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.PROVIDER_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.REGION_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.EMBEDDING_TYPE_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockServiceSettings.AmazonBedrockEmbeddingType; public class AmazonBedrockEmbeddingsServiceSettingsTests extends AbstractBWCWireSerializationTestCase< AmazonBedrockEmbeddingsServiceSettings> { @@ -47,10 +49,53 @@ public void testFromMap_Request_CreatesSettingsCorrectly() { var model = "model-id"; var provider = "amazontitan"; var maxInputTokens = 512; - var serviceSettings = AmazonBedrockEmbeddingsServiceSettings.fromMap( - createEmbeddingsRequestSettingsMap(region, model, provider, null, null, maxInputTokens, SimilarityMeasure.COSINE), - ConfigurationParseContext.REQUEST + var map = createEmbeddingsRequestSettingsMap( + region, + model, + provider, + null, + null, + maxInputTokens, + SimilarityMeasure.COSINE, + null // Default embedding type (float) + ); + var serviceSettings = AmazonBedrockEmbeddingsServiceSettings.fromMap(map, ConfigurationParseContext.REQUEST); + + assertThat( + serviceSettings, + is( + new AmazonBedrockEmbeddingsServiceSettings( + region, + model, + AmazonBedrockProvider.AMAZONTITAN, + null, + false, + maxInputTokens, + SimilarityMeasure.COSINE, + null, + AmazonBedrockEmbeddingType.FLOAT + ) + ) + ); + } + + public void testFromMap_Request_WithBinaryEmbeddingType_CreatesSettingsCorrectly() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var maxInputTokens = 512; + + var map = createEmbeddingsRequestSettingsMap( + region, + model, + provider, + null, + null, + maxInputTokens, + SimilarityMeasure.COSINE, + "binary" ); + var serviceSettings = AmazonBedrockEmbeddingsServiceSettings.fromMap(map, ConfigurationParseContext.REQUEST); assertThat( serviceSettings, @@ -63,12 +108,41 @@ public void testFromMap_Request_CreatesSettingsCorrectly() { false, maxInputTokens, SimilarityMeasure.COSINE, - null + null, + AmazonBedrockEmbeddingType.BINARY ) ) ); } + public void testFromMap_Request_WithInvalidEmbeddingType_Throws() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var maxInputTokens = 512; + + var map = createEmbeddingsRequestSettingsMap( + region, + model, + provider, + null, + null, + maxInputTokens, + SimilarityMeasure.COSINE, + "invalid_type" // Invalid embedding type + ); + + var thrownException = expectThrows( + ValidationException.class, + () -> AmazonBedrockEmbeddingsServiceSettings.fromMap(map, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not support value [invalid_type] for setting [embedding_type]") + ); + } + public void testFromMap_RequestWithRateLimit_CreatesSettingsCorrectly() { var region = "region"; var model = "model-id"; @@ -90,7 +164,8 @@ public void testFromMap_RequestWithRateLimit_CreatesSettingsCorrectly() { false, maxInputTokens, SimilarityMeasure.COSINE, - new RateLimitSettings(3) + new RateLimitSettings(3), + AmazonBedrockEmbeddingType.FLOAT ) ) ); @@ -115,7 +190,8 @@ public void testFromMap_Request_DimensionsSetByUser_IsFalse_WhenDimensionsAreNot false, maxInputTokens, SimilarityMeasure.COSINE, - null + null, + AmazonBedrockEmbeddingType.FLOAT ) ) ); @@ -209,7 +285,8 @@ public void testFromMap_Persistent_CreatesSettingsCorrectly() { false, maxInputTokens, SimilarityMeasure.COSINE, - null + null, + AmazonBedrockEmbeddingType.FLOAT ) ) ); @@ -225,7 +302,7 @@ public void testFromMap_PersistentContext_DoesNotThrowException_WhenDimensionsIs assertThat( serviceSettings, - is(new AmazonBedrockEmbeddingsServiceSettings(region, model, AmazonBedrockProvider.AMAZONTITAN, null, true, null, null, null)) + is(new AmazonBedrockEmbeddingsServiceSettings(region, model, AmazonBedrockProvider.AMAZONTITAN, null, true, null, null, null, AmazonBedrockEmbeddingType.FLOAT)) ); } @@ -248,7 +325,8 @@ public void testFromMap_PersistentContext_DoesNotThrowException_WhenSimilarityIs true, null, SimilarityMeasure.DOT_PRODUCT, - null + null, + AmazonBedrockEmbeddingType.FLOAT ) ) ); @@ -281,7 +359,8 @@ public void testToXContent_WritesDimensionsSetByUserTrue() throws IOException { true, null, null, - new RateLimitSettings(2) + new RateLimitSettings(2), + AmazonBedrockEmbeddingType.FLOAT ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -302,7 +381,8 @@ public void testToXContent_WritesAllValues() throws IOException { false, 512, null, - new RateLimitSettings(3) + new RateLimitSettings(3), + AmazonBedrockEmbeddingType.FLOAT ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -323,7 +403,8 @@ public void testToFilteredXContent_WritesAllValues_ExceptDimensionsSetByUser() t false, 512, null, - new RateLimitSettings(3) + new RateLimitSettings(3), + AmazonBedrockEmbeddingType.FLOAT ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -343,7 +424,8 @@ public static HashMap createEmbeddingsRequestSettingsMap( @Nullable Integer dimensions, @Nullable Boolean dimensionsSetByUser, @Nullable Integer maxTokens, - @Nullable SimilarityMeasure similarityMeasure + @Nullable SimilarityMeasure similarityMeasure, + @Nullable String embeddingType ) { var map = new HashMap(Map.of(REGION_FIELD, region, MODEL_FIELD, model, PROVIDER_FIELD, provider)); @@ -363,6 +445,10 @@ public static HashMap createEmbeddingsRequestSettingsMap( map.put(SIMILARITY, similarityMeasure.toString()); } + if (embeddingType != null) { + map.put(EMBEDDING_TYPE_FIELD, embeddingType); + } + return map; } @@ -398,7 +484,8 @@ private static AmazonBedrockEmbeddingsServiceSettings createRandom() { randomBoolean(), randomFrom(new Integer[] { null, randomNonNegativeInt() }), randomFrom(new SimilarityMeasure[] { null, randomFrom(SimilarityMeasure.values()) }), - RateLimitSettingsTests.createRandom() + RateLimitSettingsTests.createRandom(), + randomFrom(AmazonBedrockEmbeddingType.values()) ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/AmazonBedrockTitanEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/AmazonBedrockTitanEmbeddingsRequestEntityTests.java index e25be112307b8..ffe1cb3a23eb8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/AmazonBedrockTitanEmbeddingsRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/AmazonBedrockTitanEmbeddingsRequestEntityTests.java @@ -9,16 +9,24 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.amazonbedrock.request.embeddings.AmazonBedrockTitanEmbeddingsRequestEntity; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockServiceSettings.AmazonBedrockEmbeddingType; import java.io.IOException; import static org.hamcrest.Matchers.is; public class AmazonBedrockTitanEmbeddingsRequestEntityTests extends ESTestCase { - public void testRequestEntity_GeneratesExpectedJsonBody() throws IOException { - var entity = new AmazonBedrockTitanEmbeddingsRequestEntity("test input"); + public void testRequestEntity_WithFloatType_GeneratesExpectedJsonBody() throws IOException { + var entity = new AmazonBedrockTitanEmbeddingsRequestEntity("test input", AmazonBedrockEmbeddingType.FLOAT); var builder = new AmazonBedrockJsonBuilder(entity); var result = builder.getStringContent(); assertThat(result, is("{\"inputText\":\"test input\"}")); } + + public void testRequestEntity_WithBinaryType_GeneratesExpectedJsonBody() throws IOException { + var entity = new AmazonBedrockTitanEmbeddingsRequestEntity("test input", AmazonBedrockEmbeddingType.BINARY); + var builder = new AmazonBedrockJsonBuilder(entity); + var result = builder.getStringContent(); + assertThat(result, is("{\"inputText\":\"test input\",\"embedding_type\":\"binary\"}")); + } } From 0208e59ea8e2633cd7b7bf902e4d10d6ade4032c Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Fri, 11 Apr 2025 08:30:54 +0000 Subject: [PATCH 2/2] [CI] Auto commit changes from spotless --- .../AmazonBedrockServiceSettings.java | 5 ++-- ...onBedrockTitanEmbeddingsRequestEntity.java | 6 ++-- .../AmazonBedrockEmbeddingsResponse.java | 6 ++-- .../AmazonBedrockServiceTests.java | 29 +++++-------------- ...BedrockEmbeddingsServiceSettingsTests.java | 18 ++++++++++-- ...rockTitanEmbeddingsRequestEntityTests.java | 2 +- 6 files changed, 32 insertions(+), 34 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceSettings.java index 377893e826e7a..fc3060f271586 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceSettings.java @@ -24,15 +24,14 @@ import java.util.EnumSet; import java.util.Map; import java.util.Objects; -import java.util.Optional; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredEnum; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.EMBEDDING_TYPE_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MODEL_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.PROVIDER_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.REGION_FIELD; -import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.EMBEDDING_TYPE_FIELD; public abstract class AmazonBedrockServiceSettings extends FilteredXContentObject implements ServiceSettings { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockTitanEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockTitanEmbeddingsRequestEntity.java index 13637795d3b7a..9007cb7c5bfa1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockTitanEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockTitanEmbeddingsRequestEntity.java @@ -9,14 +9,16 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockServiceSettings.AmazonBedrockEmbeddingType; import java.io.IOException; import java.util.Objects; -import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockServiceSettings.AmazonBedrockEmbeddingType; + import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.EMBEDDING_TYPE_FIELD; public record AmazonBedrockTitanEmbeddingsRequestEntity(String inputText, AmazonBedrockEmbeddingType embeddingType) - implements ToXContentObject { + implements + ToXContentObject { private static final String INPUT_TEXT_FIELD = "inputText"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java index 30f2728f40f3f..043df40bbf5bb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/embeddings/AmazonBedrockEmbeddingsResponse.java @@ -16,19 +16,19 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBytesResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.inference.external.response.XContentUtils; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockServiceSettings.AmazonBedrockEmbeddingType; import org.elasticsearch.xpack.inference.services.amazonbedrock.request.AmazonBedrockRequest; import org.elasticsearch.xpack.inference.services.amazonbedrock.request.embeddings.AmazonBedrockEmbeddingsRequest; import org.elasticsearch.xpack.inference.services.amazonbedrock.response.AmazonBedrockResponse; -import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockServiceSettings.AmazonBedrockEmbeddingType; import java.io.IOException; import java.nio.charset.StandardCharsets; -import java.util.List; import java.util.Base64; +import java.util.List; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index dbaf5cdb18498..80e406c76d741 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -38,9 +38,8 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBytesResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.Utils; import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -1053,9 +1052,7 @@ public void testInfer_SendsRequest_ForEmbeddingsModel() throws IOException { ); List titanFloatEmbeddings = List.of(1.0f, 2.0f, 3.0f); String titanFloatResponseBody = createTitanResponseBody(titanFloatEmbeddings, null); - when(mockSender.sendEmbeddingsRequest(modelTitanFloat, List.of("input1"))).thenReturn( - new BytesArray(titanFloatResponseBody) - ); + when(mockSender.sendEmbeddingsRequest(modelTitanFloat, List.of("input1"))).thenReturn(new BytesArray(titanFloatResponseBody)); var futureTitanFloat = new PlainActionFuture(); service.infer( @@ -1074,30 +1071,20 @@ public void testInfer_SendsRequest_ForEmbeddingsModel() throws IOException { assertThat(resultTitanFloat, instanceOf(TextEmbeddingFloatResults.class)); assertThat( ((TextEmbeddingFloatResults) resultTitanFloat).getResults(), - Matchers.contains( - buildExpectationFloat(titanFloatEmbeddings) - ) + Matchers.contains(buildExpectationFloat(titanFloatEmbeddings)) ); // Test Titan Binary var modelTitanBinary = AmazonBedrockEmbeddingsModelTests.createRandomTitanModelWithEmbeddingType( AmazonBedrockEmbeddingType.BINARY ); - byte[] binaryEmbeddingBytes = new byte[]{1, 2, 3, 4}; + byte[] binaryEmbeddingBytes = new byte[] { 1, 2, 3, 4 }; String binaryEmbeddingBase64 = Base64.getEncoder().encodeToString(binaryEmbeddingBytes); String titanBinaryResponseBody = createTitanResponseBody(null, binaryEmbeddingBase64); - when(mockSender.sendEmbeddingsRequest(modelTitanBinary, List.of("input1"))).thenReturn( - new BytesArray(titanBinaryResponseBody) - ); + when(mockSender.sendEmbeddingsRequest(modelTitanBinary, List.of("input1"))).thenReturn(new BytesArray(titanBinaryResponseBody)); var futureTitanBinary = new PlainActionFuture(); - service.infer( - modelTitanBinary, - new InferenceAction.Request(null, List.of("input1")), - Map.of(), - TIMEOUT, - futureTitanBinary - ); + service.infer(modelTitanBinary, new InferenceAction.Request(null, List.of("input1")), Map.of(), TIMEOUT, futureTitanBinary); var resultTitanBinary = futureTitanBinary.actionGet(TIMEOUT); assertThat(resultTitanBinary, instanceOf(TextEmbeddingBytesResults.class)); var byteResults = ((TextEmbeddingBytesResults) resultTitanBinary).getResults(); @@ -1108,9 +1095,7 @@ public void testInfer_SendsRequest_ForEmbeddingsModel() throws IOException { var modelCohere = AmazonBedrockEmbeddingsModelTests.createRandomCohereModel(); List cohereEmbeddings1 = List.of(4.0f, 5.0f); String cohereResponseBody = createCohereResponseBody(List.of(cohereEmbeddings1)); - when(mockSender.sendEmbeddingsRequest(modelCohere, List.of("input1"))).thenReturn( - new BytesArray(cohereResponseBody) - ); + when(mockSender.sendEmbeddingsRequest(modelCohere, List.of("input1"))).thenReturn(new BytesArray(cohereResponseBody)); var futureCohere = new PlainActionFuture(); service.infer( modelCohere, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettingsTests.java index 3acff99925397..3207b6df71ec9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettingsTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockServiceSettings.AmazonBedrockEmbeddingType; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import org.hamcrest.CoreMatchers; @@ -32,14 +33,13 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.EMBEDDING_TYPE_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MODEL_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.PROVIDER_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.REGION_FIELD; -import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.EMBEDDING_TYPE_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; -import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockServiceSettings.AmazonBedrockEmbeddingType; public class AmazonBedrockEmbeddingsServiceSettingsTests extends AbstractBWCWireSerializationTestCase< AmazonBedrockEmbeddingsServiceSettings> { @@ -302,7 +302,19 @@ public void testFromMap_PersistentContext_DoesNotThrowException_WhenDimensionsIs assertThat( serviceSettings, - is(new AmazonBedrockEmbeddingsServiceSettings(region, model, AmazonBedrockProvider.AMAZONTITAN, null, true, null, null, null, AmazonBedrockEmbeddingType.FLOAT)) + is( + new AmazonBedrockEmbeddingsServiceSettings( + region, + model, + AmazonBedrockProvider.AMAZONTITAN, + null, + true, + null, + null, + null, + AmazonBedrockEmbeddingType.FLOAT + ) + ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/AmazonBedrockTitanEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/AmazonBedrockTitanEmbeddingsRequestEntityTests.java index ffe1cb3a23eb8..3e0f6926c639c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/AmazonBedrockTitanEmbeddingsRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/AmazonBedrockTitanEmbeddingsRequestEntityTests.java @@ -8,8 +8,8 @@ package org.elasticsearch.xpack.inference.services.amazonbedrock.request; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.services.amazonbedrock.request.embeddings.AmazonBedrockTitanEmbeddingsRequestEntity; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockServiceSettings.AmazonBedrockEmbeddingType; +import org.elasticsearch.xpack.inference.services.amazonbedrock.request.embeddings.AmazonBedrockTitanEmbeddingsRequestEntity; import java.io.IOException;