From ec0338d5ce5d85bd963613acb4ed3ec338cff29d Mon Sep 17 00:00:00 2001 From: Joan Fontanals Date: Wed, 8 Jan 2025 12:08:45 +0100 Subject: [PATCH] [Inference API] Add Jina AI API to do inference for Embedding and Rerank models (#118652) # Conflicts: # x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java --- docs/changelog/118652.yaml | 5 + .../org/elasticsearch/TransportVersions.java | 1 + .../InferenceNamedWriteablesProvider.java | 28 + .../xpack/inference/InferencePlugin.java | 2 + .../action/jinaai/JinaAIActionCreator.java | 58 + .../action/jinaai/JinaAIActionVisitor.java | 21 + .../inference/external/http/HttpResult.java | 5 + .../JinaAIEmbeddingsRequestManager.java | 57 + .../http/sender/JinaAIRequestManager.java | 28 + .../sender/JinaAIRerankRequestManager.java | 56 + .../external/jinaai/JinaAIAccount.java | 32 + .../jinaai/JinaAIResponseHandler.java | 62 + .../jinaai/JinaAIEmbeddingsRequest.java | 84 + .../jinaai/JinaAIEmbeddingsRequestEntity.java | 67 + .../request/jinaai/JinaAIRequest.java | 26 + .../request/jinaai/JinaAIRerankRequest.java | 86 + .../jinaai/JinaAIRerankRequestEntity.java | 58 + .../external/request/jinaai/JinaAIUtils.java | 26 + .../JinaAIEmbeddingsResponseEntity.java | 110 + .../jinaai/JinaAIErrorResponseEntity.java | 46 + .../jinaai/JinaAIRerankResponseEntity.java | 158 ++ .../services/jinaai/JinaAIModel.java | 68 + .../JinaAIRateLimitServiceSettings.java | 15 + .../services/jinaai/JinaAIService.java | 358 +++ .../services/jinaai/JinaAIServiceFields.java | 13 + .../jinaai/JinaAIServiceSettings.java | 159 ++ .../embeddings/JinaAIEmbeddingsModel.java | 140 ++ .../JinaAIEmbeddingsServiceSettings.java | 162 ++ .../JinaAIEmbeddingsTaskSettings.java | 183 ++ .../jinaai/rerank/JinaAIRerankModel.java | 148 ++ .../rerank/JinaAIRerankServiceSettings.java | 113 + .../rerank/JinaAIRerankTaskSettings.java | 166 ++ .../jinaai/JinaAIResponseHandlerTests.java | 138 ++ .../JinaAIEmbeddingsRequestEntityTests.java | 54 + .../jinaai/JinaAIEmbeddingsRequestTests.java | 101 + .../request/jinaai/JinaAIRequestTests.java | 36 + .../JinaAIRerankRequestEntityTests.java | 140 ++ .../jinaai/JinaAIRerankRequestTests.java | 110 + .../request/jinaai/JinaAIUtilsTests.java | 23 + .../JinaAIEmbeddingsResponseEntityTests.java | 397 ++++ .../JinaAIErrorResponseEntityTests.java | 51 + .../JinaAIRerankResponseEntityTests.java | 180 ++ .../jinaai/JinaAIServiceSettingsTests.java | 174 ++ .../services/jinaai/JinaAIServiceTests.java | 2003 +++++++++++++++++ .../JinaAIEmbeddingsModelTests.java | 168 ++ .../JinaAIEmbeddingsServiceSettingsTests.java | 187 ++ .../JinaAIEmbeddingsTaskSettingsTests.java | 193 ++ .../jinaai/rerank/JinaAIRerankModelTests.java | 74 + .../JinaAIRerankServiceSettingsTests.java | 83 + .../rerank/JinaAIRerankTaskSettingsTests.java | 132 ++ 50 files changed, 6785 insertions(+) create mode 100644 docs/changelog/118652.yaml create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/jinaai/JinaAIActionCreator.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/jinaai/JinaAIActionVisitor.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIEmbeddingsRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIRerankRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/jinaai/JinaAIAccount.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/jinaai/JinaAIResponseHandler.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIUtils.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIErrorResponseEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIRerankResponseEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIRateLimitServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceFields.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankTaskSettings.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/jinaai/JinaAIResponseHandlerTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRequestTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIUtilsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIErrorResponseEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIRerankResponseEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankModelTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankTaskSettingsTests.java diff --git a/docs/changelog/118652.yaml b/docs/changelog/118652.yaml new file mode 100644 index 0000000000000..0b08686230405 --- /dev/null +++ b/docs/changelog/118652.yaml @@ -0,0 +1,5 @@ +pr: 118652 +summary: Add Jina AI API to do inference for Embedding and Rerank models +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index ef410e6181c3b..f21840cafe7ac 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -156,6 +156,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_CCS_TELEMETRY_STATS = def(8_816_00_0); public static final TransportVersion TEXT_EMBEDDING_QUERY_VECTOR_BUILDER_INFER_MODEL_ID = def(8_817_00_0); public static final TransportVersion ESQL_ENABLE_NODE_LEVEL_REDUCTION = def(8_818_00_0); + public static final TransportVersion JINA_AI_INTEGRATION_ADDED = def(8_819_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 71fbcf6d8ef49..6fc9870034018 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -75,6 +75,11 @@ import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings; import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings; @@ -132,6 +137,7 @@ public static List getNamedWriteables() { addAmazonBedrockNamedWriteables(namedWriteables); addEisNamedWriteables(namedWriteables); addAlibabaCloudSearchNamedWriteables(namedWriteables); + addJinaAINamedWriteables(namedWriteables); addUnifiedNamedWriteables(namedWriteables); @@ -569,6 +575,28 @@ private static void addAlibabaCloudSearchNamedWriteables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry(ServiceSettings.class, JinaAIServiceSettings.NAME, JinaAIServiceSettings::new) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + JinaAIEmbeddingsServiceSettings.NAME, + JinaAIEmbeddingsServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry(TaskSettings.class, JinaAIEmbeddingsTaskSettings.NAME, JinaAIEmbeddingsTaskSettings::new) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry(ServiceSettings.class, JinaAIRerankServiceSettings.NAME, JinaAIRerankServiceSettings::new) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry(TaskSettings.class, JinaAIRerankTaskSettings.NAME, JinaAIRerankTaskSettings::new) + ); + } + private static void addEisNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 3e3960a8475be..22b16e886e405 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -111,6 +111,7 @@ import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService; import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService; import org.elasticsearch.xpack.inference.services.mistral.MistralService; import org.elasticsearch.xpack.inference.services.openai.OpenAiService; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; @@ -289,6 +290,7 @@ public List getInferenceServiceFactories() { context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()), context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()), context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()), + context -> new JinaAIService(httpFactory.get(), serviceComponents.get()), ElasticsearchInternalService::new ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/jinaai/JinaAIActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/jinaai/JinaAIActionCreator.java new file mode 100644 index 0000000000000..4d5827a3bf266 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/jinaai/JinaAIActionCreator.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.jinaai; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.JinaAIEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.JinaAIRerankRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel; + +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; + +/** + * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the jinaai model type. + */ +public class JinaAIActionCreator implements JinaAIActionVisitor { + private final Sender sender; + private final ServiceComponents serviceComponents; + + public JinaAIActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(JinaAIEmbeddingsModel model, Map taskSettings, InputType inputType) { + var overriddenModel = JinaAIEmbeddingsModel.of(model, taskSettings, inputType); + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage( + overriddenModel.getServiceSettings().getCommonSettings().uri(), + "JinaAI embeddings" + ); + var requestCreator = JinaAIEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool()); + return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); + } + + @Override + public ExecutableAction create(JinaAIRerankModel model, Map taskSettings) { + var overriddenModel = JinaAIRerankModel.of(model, taskSettings); + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage( + overriddenModel.getServiceSettings().getCommonSettings().uri(), + "JinaAI rerank" + ); + var requestCreator = JinaAIRerankRequestManager.of(overriddenModel, serviceComponents.threadPool()); + return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/jinaai/JinaAIActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/jinaai/JinaAIActionVisitor.java new file mode 100644 index 0000000000000..c585e68e3d731 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/jinaai/JinaAIActionVisitor.java @@ -0,0 +1,21 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.jinaai; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel; + +import java.util.Map; + +public interface JinaAIActionVisitor { + ExecutableAction create(JinaAIEmbeddingsModel model, Map taskSettings, InputType inputType); + + ExecutableAction create(JinaAIRerankModel model, Map taskSettings); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpResult.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpResult.java index 6c79daa2dedc0..04be94f4049ab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpResult.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpResult.java @@ -47,4 +47,9 @@ private static byte[] limitBody(ByteSizeValue maxResponseSize, HttpResponse resp public boolean isBodyEmpty() { return body().length == 0; } + + public boolean isSuccessfulResponse() { + var code = response.getStatusLine().getStatusCode(); + return code >= 200 && code < 300; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIEmbeddingsRequestManager.java new file mode 100644 index 0000000000000..c0828224cd1a9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIEmbeddingsRequestManager.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.jinaai.JinaAIResponseHandler; +import org.elasticsearch.xpack.inference.external.request.jinaai.JinaAIEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.jinaai.JinaAIEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel; + +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +public class JinaAIEmbeddingsRequestManager extends JinaAIRequestManager { + private static final Logger logger = LogManager.getLogger(JinaAIEmbeddingsRequestManager.class); + private static final ResponseHandler HANDLER = createEmbeddingsHandler(); + + private static ResponseHandler createEmbeddingsHandler() { + return new JinaAIResponseHandler("jinaai text embedding", JinaAIEmbeddingsResponseEntity::fromResponse); + } + + public static JinaAIEmbeddingsRequestManager of(JinaAIEmbeddingsModel model, ThreadPool threadPool) { + return new JinaAIEmbeddingsRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); + } + + private final JinaAIEmbeddingsModel model; + + private JinaAIEmbeddingsRequestManager(JinaAIEmbeddingsModel model, ThreadPool threadPool) { + super(threadPool, model); + this.model = Objects.requireNonNull(model); + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + JinaAIEmbeddingsRequest request = new JinaAIEmbeddingsRequest(docsInput, model); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIRequestManager.java new file mode 100644 index 0000000000000..3a0d6e4e17f5b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIRequestManager.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIModel; + +import java.util.Objects; + +abstract class JinaAIRequestManager extends BaseRequestManager { + + protected JinaAIRequestManager(ThreadPool threadPool, JinaAIModel model) { + super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); + } + + record RateLimitGrouping(int apiKeyHash) { + public static RateLimitGrouping of(JinaAIModel model) { + Objects.requireNonNull(model); + + return new RateLimitGrouping(model.apiKey().hashCode()); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIRerankRequestManager.java new file mode 100644 index 0000000000000..26f134873bca0 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIRerankRequestManager.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.jinaai.JinaAIResponseHandler; +import org.elasticsearch.xpack.inference.external.request.jinaai.JinaAIRerankRequest; +import org.elasticsearch.xpack.inference.external.response.jinaai.JinaAIRerankResponseEntity; +import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel; + +import java.util.Objects; +import java.util.function.Supplier; + +public class JinaAIRerankRequestManager extends JinaAIRequestManager { + private static final Logger logger = LogManager.getLogger(JinaAIRerankRequestManager.class); + private static final ResponseHandler HANDLER = createJinaAIResponseHandler(); + + private static ResponseHandler createJinaAIResponseHandler() { + return new JinaAIResponseHandler("jinaai rerank", (request, response) -> JinaAIRerankResponseEntity.fromResponse(response)); + } + + public static JinaAIRerankRequestManager of(JinaAIRerankModel model, ThreadPool threadPool) { + return new JinaAIRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); + } + + private final JinaAIRerankModel model; + + private JinaAIRerankRequestManager(JinaAIRerankModel model, ThreadPool threadPool) { + super(threadPool, model); + this.model = model; + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + var rerankInput = QueryAndDocsInputs.of(inferenceInputs); + JinaAIRerankRequest request = new JinaAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/jinaai/JinaAIAccount.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/jinaai/JinaAIAccount.java new file mode 100644 index 0000000000000..722a785db4795 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/jinaai/JinaAIAccount.java @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.jinaai; + +import org.elasticsearch.common.CheckedSupplier; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIModel; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri; + +public record JinaAIAccount(URI uri, SecureString apiKey) { + + public static JinaAIAccount of(JinaAIModel model, CheckedSupplier uriBuilder) { + var uri = buildUri(model.uri(), "JinaAI", uriBuilder); + + return new JinaAIAccount(uri, model.apiKey()); + } + + public JinaAIAccount { + Objects.requireNonNull(uri); + Objects.requireNonNull(apiKey); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/jinaai/JinaAIResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/jinaai/JinaAIResponseHandler.java new file mode 100644 index 0000000000000..66dc85b3bdb6a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/jinaai/JinaAIResponseHandler.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.jinaai; + +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.jinaai.JinaAIErrorResponseEntity; + +/** + * Defines how to handle various errors returned from the JinaAI integration. + * + */ +public class JinaAIResponseHandler extends BaseResponseHandler { + static final String VALIDATION_ERROR_MESSAGE = "Received an input validation error response"; + static final String PAYMENT_ERROR_MESSAGE = "Payment required"; + + public JinaAIResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, JinaAIErrorResponseEntity::fromResponse); + } + + /** + * Validates the status code throws an RetryException if not in the range [200, 300). + * + * @param request The http request + * @param result The http response and body + * @throws RetryException Throws if status code is {@code >= 300 or < 200 } + */ + @Override + protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException { + if (result.isSuccessfulResponse()) { + return; + } + + // handle error codes + int statusCode = result.response().getStatusLine().getStatusCode(); + if (statusCode == 500) { + throw new RetryException(true, buildError(SERVER_ERROR, request, result)); + } else if (statusCode > 500) { + throw new RetryException(false, buildError(SERVER_ERROR, request, result)); + } else if (statusCode == 429) { + throw new RetryException(true, buildError(RATE_LIMIT, request, result)); + } else if (statusCode == 400 || statusCode == 422) { + throw new RetryException(false, buildError(VALIDATION_ERROR_MESSAGE, request, result)); + } else if (statusCode == 401) { + throw new RetryException(false, buildError(AUTHENTICATION, request, result)); + } else if (statusCode == 402) { + throw new RetryException(false, buildError(PAYMENT_ERROR_MESSAGE, request, result)); + } else if (statusCode >= 300 && statusCode < 400) { + throw new RetryException(false, buildError(REDIRECTION, request, result)); + } else { + throw new RetryException(false, buildError(UNSUCCESSFUL, request, result)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequest.java new file mode 100644 index 0000000000000..d99f15a7703ae --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequest.java @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.jinaai; + +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xpack.inference.external.jinaai.JinaAIAccount; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +public class JinaAIEmbeddingsRequest extends JinaAIRequest { + + private final JinaAIAccount account; + private final List input; + private final JinaAIEmbeddingsTaskSettings taskSettings; + private final String model; + private final String inferenceEntityId; + + public JinaAIEmbeddingsRequest(List input, JinaAIEmbeddingsModel embeddingsModel) { + Objects.requireNonNull(embeddingsModel); + + account = JinaAIAccount.of(embeddingsModel, JinaAIEmbeddingsRequest::buildDefaultUri); + this.input = Objects.requireNonNull(input); + taskSettings = embeddingsModel.getTaskSettings(); + model = embeddingsModel.getServiceSettings().getCommonSettings().modelId(); + inferenceEntityId = embeddingsModel.getInferenceEntityId(); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(account.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new JinaAIEmbeddingsRequestEntity(input, taskSettings, model)).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + decorateWithAuthHeader(httpPost, account); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } + + @Override + public URI getURI() { + return account.uri(); + } + + @Override + public Request truncate() { + return this; + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } + + public static URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(JinaAIUtils.HOST) + .setPathSegments(JinaAIUtils.VERSION_1, JinaAIUtils.EMBEDDINGS_PATH) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..d4f98f1eb52ca --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntity.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.jinaai; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings.invalidInputTypeMessage; + +public record JinaAIEmbeddingsRequestEntity(List input, JinaAIEmbeddingsTaskSettings taskSettings, @Nullable String model) + implements + ToXContentObject { + + private static final String SEARCH_DOCUMENT = "retrieval.passage"; + private static final String SEARCH_QUERY = "retrieval.query"; + private static final String CLUSTERING = "separation"; + private static final String CLASSIFICATION = "classification"; + private static final String INPUT_FIELD = "input"; + private static final String MODEL_FIELD = "model"; + public static final String TASK_TYPE_FIELD = "task"; + + public JinaAIEmbeddingsRequestEntity { + Objects.requireNonNull(input); + Objects.requireNonNull(taskSettings); + Objects.requireNonNull(model); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INPUT_FIELD, input); + builder.field(MODEL_FIELD, model); + + if (taskSettings.getInputType() != null) { + builder.field(TASK_TYPE_FIELD, convertToString(taskSettings.getInputType())); + } + + builder.endObject(); + return builder; + } + + // default for testing + static String convertToString(InputType inputType) { + return switch (inputType) { + case INGEST -> SEARCH_DOCUMENT; + case SEARCH -> SEARCH_QUERY; + case CLASSIFICATION -> CLASSIFICATION; + case CLUSTERING -> CLUSTERING; + default -> { + assert false : invalidInputTypeMessage(inputType); + yield null; + } + }; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRequest.java new file mode 100644 index 0000000000000..8b1e26a36238b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRequest.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.jinaai; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.jinaai.JinaAIAccount; +import org.elasticsearch.xpack.inference.external.request.Request; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +public abstract class JinaAIRequest implements Request { + + public static void decorateWithAuthHeader(HttpPost request, JinaAIAccount account) { + request.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + request.setHeader(createAuthBearerHeader(account.apiKey())); + request.setHeader(JinaAIUtils.createRequestSourceHeader()); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequest.java new file mode 100644 index 0000000000000..93d4ab830c604 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequest.java @@ -0,0 +1,86 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.jinaai; + +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xpack.inference.external.jinaai.JinaAIAccount; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel; +import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +public class JinaAIRerankRequest extends JinaAIRequest { + + private final JinaAIAccount account; + private final String query; + private final List input; + private final JinaAIRerankTaskSettings taskSettings; + private final String model; + private final String inferenceEntityId; + + public JinaAIRerankRequest(String query, List input, JinaAIRerankModel model) { + Objects.requireNonNull(model); + + this.account = JinaAIAccount.of(model, JinaAIRerankRequest::buildDefaultUri); + this.input = Objects.requireNonNull(input); + this.query = Objects.requireNonNull(query); + taskSettings = model.getTaskSettings(); + this.model = model.getServiceSettings().modelId(); + inferenceEntityId = model.getInferenceEntityId(); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(account.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new JinaAIRerankRequestEntity(query, input, taskSettings, model)).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + decorateWithAuthHeader(httpPost, account); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } + + @Override + public URI getURI() { + return account.uri(); + } + + @Override + public Request truncate() { + return this; + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } + + public static URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(JinaAIUtils.HOST) + .setPathSegments(JinaAIUtils.VERSION_1, JinaAIUtils.RERANK_PATH) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntity.java new file mode 100644 index 0000000000000..7f470d5fa91f5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntity.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.jinaai; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record JinaAIRerankRequestEntity(String model, String query, List documents, JinaAIRerankTaskSettings taskSettings) + implements + ToXContentObject { + + private static final String DOCUMENTS_FIELD = "documents"; + private static final String QUERY_FIELD = "query"; + private static final String MODEL_FIELD = "model"; + + public JinaAIRerankRequestEntity { + Objects.requireNonNull(query); + Objects.requireNonNull(documents); + Objects.requireNonNull(model); + Objects.requireNonNull(taskSettings); + } + + public JinaAIRerankRequestEntity(String query, List input, JinaAIRerankTaskSettings taskSettings, String model) { + this(model, query, input, taskSettings != null ? taskSettings : JinaAIRerankTaskSettings.EMPTY_SETTINGS); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(MODEL_FIELD, model); + builder.field(QUERY_FIELD, query); + builder.field(DOCUMENTS_FIELD, documents); + + if (taskSettings.getTopNDocumentsOnly() != null) { + builder.field(JinaAIRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly()); + } + + var return_documents = taskSettings.getDoesReturnDocuments(); + if (return_documents != null) { + builder.field(JinaAIRerankTaskSettings.RETURN_DOCUMENTS, return_documents); + } + + builder.endObject(); + return builder; + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIUtils.java new file mode 100644 index 0000000000000..fccbd1e230556 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIUtils.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.jinaai; + +import org.apache.http.Header; +import org.apache.http.message.BasicHeader; + +public class JinaAIUtils { + public static final String HOST = "api.jina.ai"; + public static final String VERSION_1 = "v1"; + public static final String EMBEDDINGS_PATH = "embeddings"; + public static final String RERANK_PATH = "rerank"; + public static final String REQUEST_SOURCE_HEADER = "Request-Source"; + public static final String ELASTIC_REQUEST_SOURCE = "unspecified:elasticsearch"; + + public static Header createRequestSourceHeader() { + return new BasicHeader(REQUEST_SOURCE_HEADER, ELASTIC_REQUEST_SOURCE); + } + + private JinaAIUtils() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntity.java new file mode 100644 index 0000000000000..26bde5f5f48ad --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntity.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file was contributed to by a generative AI + */ + +package org.elasticsearch.xpack.inference.external.response.jinaai; + +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.XContentUtils; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.consumeUntilObjectEnd; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +public class JinaAIEmbeddingsResponseEntity { + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in JinaAI embeddings response"; + + /** + * Parses the JinaAI json response. + * For a request like: + * + *
+     *     
+     *        {
+     *            "inputs": ["hello this is my name", "I wish I was there!"]
+     *        }
+     *     
+     * 
+ * + * The response would look like: + * + *
+     * 
+     * {
+     *  "object": "list",
+     *  "data": [
+     *      {
+     *          "object": "embedding",
+     *          "embedding": [
+     *              -0.009327292,
+     *              -0.0028842222,
+     *          ],
+     *          "index": 0
+     *      },
+     *      {
+     *          "object": "embedding",
+     *          "embedding": [ ... ],
+     *          "index": 1
+     *      }
+     *  ],
+     *  "model": "jina-embeddings-v3",
+     *  "usage": {
+     *      "prompt_tokens": 8,
+     *      "total_tokens": 8
+     *  }
+     * }
+     * 
+     * 
+ */ + public static InferenceTextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + moveToFirstToken(jsonParser); + + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE); + + List embeddingList = parseList( + jsonParser, + JinaAIEmbeddingsResponseEntity::parseEmbeddingObject + ); + + return new InferenceTextEmbeddingFloatResults(embeddingList); + } + } + + private static InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding parseEmbeddingObject(XContentParser parser) + throws IOException { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + + positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); + + List embeddingValuesList = parseList(parser, XContentUtils::parseFloat); + // parse and discard the rest of the object + consumeUntilObjectEnd(parser); + + return InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(embeddingValuesList); + } + + private JinaAIEmbeddingsResponseEntity() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIErrorResponseEntity.java new file mode 100644 index 0000000000000..99d29b26e4a04 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIErrorResponseEntity.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.jinaai; + +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; + +public class JinaAIErrorResponseEntity extends ErrorResponse { + + private JinaAIErrorResponseEntity(String errorMessage) { + super(errorMessage); + } + + /** + * Parse an HTTP response into a JinaAIErrorResponseEntity + * + * @param response The error response + * @return An error entity if the response is JSON with a `detail` field containing the error message + * or null if the response does not contain the message field + */ + public static ErrorResponse fromResponse(HttpResult response) { + try ( + XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response.body()) + ) { + var responseMap = jsonParser.map(); + var message = (String) responseMap.get("detail"); + if (message != null) { + return new JinaAIErrorResponseEntity(message); + } + } catch (Exception e) { + // swallow the error + } + + return ErrorResponse.UNDEFINED_ERROR; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIRerankResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIRerankResponseEntity.java new file mode 100644 index 0000000000000..d22bc875041e0 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIRerankResponseEntity.java @@ -0,0 +1,158 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file was contributed to by a generative AI + */ + +package org.elasticsearch.xpack.inference.external.response.jinaai; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.io.IOException; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; +import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownField; +import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +public class JinaAIRerankResponseEntity { + + private static final Logger logger = LogManager.getLogger(JinaAIRerankResponseEntity.class); + + /** + * Parses the JinaAI ranked response. + * + * For a request like: + * "model": "jina-reranker-v2-base-multilingual", + * "query": "What is the capital of the United States?", + * "top_n": 3, + * "documents": ["Carson City is the capital city of the American state of Nevada.", + * "The Commonwealth of the Northern Mariana ... Its capital is Saipan.", + * "Washington, D.C. (also known as simply Washington or D.C., ... It is a federal district.", + * "Capital punishment (the death penalty) ... As of 2017, capital punishment is legal in 30 of the 50 states."] + *

+ * The response will look like (without whitespace): + * { + * "id": "1983d114-a6e8-4940-b121-eb4ac3f6f703", + * "results": [ + * { + * "document": { + * "text": "Washington, D.C. is the capital of the United States. It is a federal district." + * }, + * "index": 2, + * "relevance_score": 0.98005307 + * }, + * { + * "document": { + * "text": "Capital punishment (the death penalty) As of 2017, capital punishment is legal in 30 of the 50 states." + * }, + * "index": 3, + * "relevance_score": 0.27904198 + * }, + * { + * "document": { + * "text": "Carson City is the capital city of the American state of Nevada." + * }, + * "index": 0, + * "relevance_score": 0.10194652 + * } + * ], + * "usage": {"total_tokens": 15} + * } + * + * @param response the http response from JinaAI + * @return the parsed response + * @throws IOException if there is an error parsing the response + */ + public static InferenceServiceResults fromResponse(HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + moveToFirstToken(jsonParser); + + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + positionParserAtTokenAfterField(jsonParser, "results", FAILED_TO_FIND_FIELD_TEMPLATE); + + token = jsonParser.currentToken(); + if (token == XContentParser.Token.START_ARRAY) { + return new RankedDocsResults(parseList(jsonParser, JinaAIRerankResponseEntity::parseRankedDocObject)); + } else { + throwUnknownToken(token, jsonParser); + } + + // This should never be reached. The above code should either return successfully or hit the throwUnknownToken + // or throw a parsing exception + throw new IllegalStateException("Reached an invalid state while parsing the JinaAI response"); + } + } + + private static RankedDocsResults.RankedDoc parseRankedDocObject(XContentParser parser) throws IOException { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + int index = -1; + float relevanceScore = -1; + String documentText = null; + parser.nextToken(); + while (parser.currentToken() != XContentParser.Token.END_OBJECT) { + if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { + switch (parser.currentName()) { + case "index": + parser.nextToken(); // move to VALUE_NUMBER + index = parser.intValue(); + parser.nextToken(); // move to next FIELD_NAME or END_OBJECT + break; + case "relevance_score": + parser.nextToken(); // move to VALUE_NUMBER + relevanceScore = parser.floatValue(); + parser.nextToken(); // move to next FIELD_NAME or END_OBJECT + break; + case "document": + parser.nextToken(); // move to START_OBJECT; document text is wrapped in an object + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + do { + if (parser.currentToken() == XContentParser.Token.FIELD_NAME && parser.currentName().equals("text")) { + parser.nextToken(); // move to VALUE_STRING + documentText = parser.text(); + } + } while (parser.nextToken() != XContentParser.Token.END_OBJECT); + parser.nextToken();// move past END_OBJECT + // parser should now be at the next FIELD_NAME or END_OBJECT + break; + default: + throwUnknownField(parser.currentName(), parser); + } + } else { + parser.nextToken(); + } + } + + if (index == -1) { + logger.warn("Failed to find required field [index] in JinaAI rerank response"); + } + if (relevanceScore == -1) { + logger.warn("Failed to find required field [relevance_score] in JinaAI rerank response"); + } + // documentText may or may not be present depending on the request parameter + + return new RankedDocsResults.RankedDoc(index, relevanceScore, documentText); + } + + private JinaAIRerankResponseEntity() {} + + static String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in JinaAI rerank response"; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIModel.java new file mode 100644 index 0000000000000..bfd8235e3da48 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIModel.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.jinaai.JinaAIActionVisitor; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; + +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +public abstract class JinaAIModel extends Model { + private final SecureString apiKey; + private final JinaAIRateLimitServiceSettings rateLimitServiceSettings; + + public JinaAIModel( + ModelConfigurations configurations, + ModelSecrets secrets, + @Nullable ApiKeySecrets apiKeySecrets, + JinaAIRateLimitServiceSettings rateLimitServiceSettings + ) { + super(configurations, secrets); + + this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings); + apiKey = ServiceUtils.apiKey(apiKeySecrets); + } + + protected JinaAIModel(JinaAIModel model, TaskSettings taskSettings) { + super(model, taskSettings); + + rateLimitServiceSettings = model.rateLimitServiceSettings(); + apiKey = model.apiKey(); + } + + protected JinaAIModel(JinaAIModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + + rateLimitServiceSettings = model.rateLimitServiceSettings(); + apiKey = model.apiKey(); + } + + public SecureString apiKey() { + return apiKey; + } + + public JinaAIRateLimitServiceSettings rateLimitServiceSettings() { + return rateLimitServiceSettings; + } + + public abstract ExecutableAction accept(JinaAIActionVisitor creator, Map taskSettings, InputType inputType); + + public abstract URI uri(); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIRateLimitServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIRateLimitServiceSettings.java new file mode 100644 index 0000000000000..ac65ad1c9d714 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIRateLimitServiceSettings.java @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai; + +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +public interface JinaAIRateLimitServiceSettings { + RateLimitSettings rateLimitSettings(); + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java new file mode 100644 index 0000000000000..11a72f811e8d3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -0,0 +1,358 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.EmptySettingsConfiguration; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskSettingsConfiguration; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.jinaai.JinaAIActionCreator; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; +import static org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceFields.EMBEDDING_MAX_BATCH_SIZE; + +public class JinaAIService extends SenderService { + public static final String NAME = "jinaai"; + + private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK); + + public JinaAIService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + super(factory, serviceComponents); + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String inferenceEntityId, + TaskType taskType, + Map config, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + JinaAIModel model = createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + private static JinaAIModel createModelFromPersistent( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + ChunkingSettings chunkingSettings, + @Nullable Map secretSettings, + String failureMessage + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + taskSettings, + chunkingSettings, + secretSettings, + failureMessage, + ConfigurationParseContext.PERSISTENT + ); + } + + private static JinaAIModel createModel( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + ChunkingSettings chunkingSettings, + @Nullable Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + return switch (taskType) { + case TEXT_EMBEDDING -> new JinaAIEmbeddingsModel( + inferenceEntityId, + NAME, + serviceSettings, + taskSettings, + chunkingSettings, + secretSettings, + context + ); + case RERANK -> new JinaAIRerankModel(inferenceEntityId, NAME, serviceSettings, taskSettings, secretSettings, context); + default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + }; + } + + @Override + public JinaAIModel parsePersistedConfigWithSecrets( + String inferenceEntityId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + secretSettingsMap, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + public JinaAIModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + null, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return Configuration.get(); + } + + @Override + public EnumSet supportedTaskTypes() { + return supportedTaskTypes; + } + + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + + @Override + public void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof JinaAIModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + JinaAIModel jinaaiModel = (JinaAIModel) model; + var actionCreator = new JinaAIActionCreator(getSender(), getServiceComponents()); + + var action = jinaaiModel.accept(actionCreator, taskSettings, inputType); + action.execute(inputs, timeout, listener); + } + + @Override + protected void doChunkedInfer( + Model model, + DocumentsOnlyInput inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + if (model instanceof JinaAIModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + JinaAIModel jinaaiModel = (JinaAIModel) model; + var actionCreator = new JinaAIActionCreator(getSender(), getServiceComponents()); + + List batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.fromDenseVectorElementType(model.getServiceSettings().elementType()), + jinaaiModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + + for (var request : batchedRequests) { + var action = jinaaiModel.accept(actionCreator, taskSettings, inputType); + action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); + } + } + + /** + * For text embedding models get the embedding size and + * update the service settings. + * + * @param model The new model + * @param listener The listener + */ + @Override + public void checkModelConfig(Model model, ActionListener listener) { + ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); + } + + @Override + public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { + if (model instanceof JinaAIEmbeddingsModel embeddingsModel) { + var serviceSettings = embeddingsModel.getServiceSettings(); + var similarityFromModel = serviceSettings.similarity(); + var similarityToUse = similarityFromModel == null ? defaultSimilarity() : similarityFromModel; + var maxInputTokens = serviceSettings.maxInputTokens(); + + var updatedServiceSettings = new JinaAIEmbeddingsServiceSettings( + new JinaAIServiceSettings( + serviceSettings.getCommonSettings().uri(), + serviceSettings.getCommonSettings().modelId(), + serviceSettings.getCommonSettings().rateLimitSettings() + ), + similarityToUse, + embeddingSize, + maxInputTokens + ); + + return new JinaAIEmbeddingsModel(embeddingsModel, updatedServiceSettings); + } else { + throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass()); + } + } + + /** + * Return the default similarity measure for the embedding type. + * JinaAI embeddings are normalized to unit vectors therefore Dot + * Product similarity can be used and is the default for all JinaAI + * models. + * + * @return The default similarity. + */ + static SimilarityMeasure defaultSimilarity() { + return SimilarityMeasure.DOT_PRODUCT; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.JINA_AI_INTEGRATION_ADDED; + } + + public static class Configuration { + public static InferenceServiceConfiguration get() { + return configuration.getOrCompute(); + } + + private static final LazyInitializable configuration = new LazyInitializable<>( + () -> { + var configurationMap = new HashMap(); + + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration()); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration()); + + return new InferenceServiceConfiguration.Builder().setProvider(NAME).setTaskTypes(supportedTaskTypes.stream().map(t -> { + Map taskSettingsConfig; + switch (t) { + case TEXT_EMBEDDING -> taskSettingsConfig = JinaAIEmbeddingsModel.Configuration.get(); + case RERANK -> taskSettingsConfig = JinaAIRerankModel.Configuration.get(); + default -> taskSettingsConfig = EmptySettingsConfiguration.get(); + } + return new TaskSettingsConfiguration.Builder().setTaskType(t).setConfiguration(taskSettingsConfig).build(); + }).toList()).setConfiguration(configurationMap).build(); + } + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceFields.java new file mode 100644 index 0000000000000..2df8f1440e471 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceFields.java @@ -0,0 +1,13 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai; + +public class JinaAIServiceFields { + + static final int EMBEDDING_MAX_BATCH_SIZE = 2048; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceSettings.java new file mode 100644 index 0000000000000..66c6193f653f1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceSettings.java @@ -0,0 +1,159 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; + +public class JinaAIServiceSettings extends FilteredXContentObject implements ServiceSettings, JinaAIRateLimitServiceSettings { + + public static final String NAME = "jinaai_service_settings"; + public static final String MODEL_ID = "model_id"; + private static final Logger logger = LogManager.getLogger(JinaAIServiceSettings.class); + // See https://jina.ai/contact-sales/#rate-limit + public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(2_000); + + public static JinaAIServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); + URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + JinaAIService.NAME, + context + ); + + String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new JinaAIServiceSettings(uri, modelId, rateLimitSettings); + } + + private final URI uri; + private final String modelId; + private final RateLimitSettings rateLimitSettings; + + public JinaAIServiceSettings(@Nullable URI uri, String modelId, @Nullable RateLimitSettings rateLimitSettings) { + this.uri = uri; + this.modelId = Objects.requireNonNull(modelId); + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public JinaAIServiceSettings(@Nullable String url, String modelId, @Nullable RateLimitSettings rateLimitSettings) { + this(createOptionalUri(url), modelId, rateLimitSettings); + } + + public JinaAIServiceSettings(StreamInput in) throws IOException { + uri = createOptionalUri(in.readOptionalString()); + modelId = in.readOptionalString(); + rateLimitSettings = new RateLimitSettings(in); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + public URI uri() { + return uri; + } + + @Override + public String modelId() { + return modelId; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + toXContentFragment(builder, params); + + builder.endObject(); + return builder; + } + + public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException { + return toXContentFragmentOfExposedFields(builder, params); + } + + @Override + public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + if (uri != null) { + builder.field(URL, uri.toString()); + } + if (modelId != null) { + builder.field(MODEL_ID, modelId); + } + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.JINA_AI_INTEGRATION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + var uriToWrite = uri != null ? uri.toString() : null; + out.writeOptionalString(uriToWrite); + out.writeOptionalString(modelId); + rateLimitSettings.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + JinaAIServiceSettings that = (JinaAIServiceSettings) o; + return Objects.equals(uri, that.uri) + && Objects.equals(modelId, that.modelId) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(uri, modelId, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModel.java new file mode 100644 index 0000000000000..dd479802cdf13 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModel.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai.embeddings; + +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.inference.configuration.SettingsConfigurationSelectOption; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.jinaai.JinaAIActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.net.URI; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Stream; + +import static org.elasticsearch.xpack.inference.external.request.jinaai.JinaAIEmbeddingsRequestEntity.TASK_TYPE_FIELD; + +public class JinaAIEmbeddingsModel extends JinaAIModel { + public static JinaAIEmbeddingsModel of(JinaAIEmbeddingsModel model, Map taskSettings, InputType inputType) { + var requestTaskSettings = JinaAIEmbeddingsTaskSettings.fromMap(taskSettings); + return new JinaAIEmbeddingsModel(model, JinaAIEmbeddingsTaskSettings.of(model.getTaskSettings(), requestTaskSettings, inputType)); + } + + public JinaAIEmbeddingsModel( + String inferenceId, + String service, + Map serviceSettings, + Map taskSettings, + ChunkingSettings chunkingSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceId, + service, + JinaAIEmbeddingsServiceSettings.fromMap(serviceSettings, context), + JinaAIEmbeddingsTaskSettings.fromMap(taskSettings), + chunkingSettings, + DefaultSecretSettings.fromMap(secrets) + ); + } + + // should only be used for testing + JinaAIEmbeddingsModel( + String modelId, + String service, + JinaAIEmbeddingsServiceSettings serviceSettings, + JinaAIEmbeddingsTaskSettings taskSettings, + ChunkingSettings chunkingSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + super( + new ModelConfigurations(modelId, TaskType.TEXT_EMBEDDING, service, serviceSettings, taskSettings, chunkingSettings), + new ModelSecrets(secretSettings), + secretSettings, + serviceSettings.getCommonSettings() + ); + } + + private JinaAIEmbeddingsModel(JinaAIEmbeddingsModel model, JinaAIEmbeddingsTaskSettings taskSettings) { + super(model, taskSettings); + } + + public JinaAIEmbeddingsModel(JinaAIEmbeddingsModel model, JinaAIEmbeddingsServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + @Override + public JinaAIEmbeddingsServiceSettings getServiceSettings() { + return (JinaAIEmbeddingsServiceSettings) super.getServiceSettings(); + } + + @Override + public JinaAIEmbeddingsTaskSettings getTaskSettings() { + return (JinaAIEmbeddingsTaskSettings) super.getTaskSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + @Override + public ExecutableAction accept(JinaAIActionVisitor visitor, Map taskSettings, InputType inputType) { + return visitor.create(this, taskSettings, inputType); + } + + @Override + public URI uri() { + return getServiceSettings().getCommonSettings().uri(); + } + + public static class Configuration { + public static Map get() { + return configuration.getOrCompute(); + } + + private static final LazyInitializable, RuntimeException> configuration = + new LazyInitializable<>(() -> { + var configurationMap = new HashMap(); + + configurationMap.put( + TASK_TYPE_FIELD, + new SettingsConfiguration.Builder().setDisplay(SettingsConfigurationDisplayType.DROPDOWN) + .setLabel("Task") + .setOrder(1) + .setRequired(false) + .setSensitive(false) + .setTooltip("Specifies the task type passed to the model.") + .setType(SettingsConfigurationFieldType.STRING) + .setOptions( + Stream.of("retrieval.query", "retrieval.passage", "classification", "separation") + .map(v -> new SettingsConfigurationSelectOption.Builder().setLabelAndValue(v).build()) + .toList() + ) + .setValue("") + .build() + ); + + return Collections.unmodifiableMap(configurationMap); + }); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettings.java new file mode 100644 index 0000000000000..449da72674be4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettings.java @@ -0,0 +1,162 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; + +public class JinaAIEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings { + public static final String NAME = "jinaai_embeddings_service_settings"; + + public static JinaAIEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + var commonServiceSettings = JinaAIServiceSettings.fromMap(map, context); + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new JinaAIEmbeddingsServiceSettings(commonServiceSettings, similarity, dims, maxInputTokens); + } + + private final JinaAIServiceSettings commonSettings; + private final SimilarityMeasure similarity; + private final Integer dimensions; + private final Integer maxInputTokens; + + public JinaAIEmbeddingsServiceSettings( + JinaAIServiceSettings commonSettings, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens + ) { + this.commonSettings = commonSettings; + this.similarity = similarity; + this.dimensions = dimensions; + this.maxInputTokens = maxInputTokens; + } + + public JinaAIEmbeddingsServiceSettings(StreamInput in) throws IOException { + this.commonSettings = new JinaAIServiceSettings(in); + this.similarity = in.readOptionalEnum(SimilarityMeasure.class); + this.dimensions = in.readOptionalVInt(); + this.maxInputTokens = in.readOptionalVInt(); + } + + public JinaAIServiceSettings getCommonSettings() { + return commonSettings; + } + + @Override + public SimilarityMeasure similarity() { + return similarity; + } + + @Override + public Integer dimensions() { + return dimensions; + } + + public Integer maxInputTokens() { + return maxInputTokens; + } + + @Override + public String modelId() { + return commonSettings.modelId(); + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder = commonSettings.toXContentFragment(builder, params); + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + commonSettings.toXContentFragmentOfExposedFields(builder, params); + + return builder; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.JINA_AI_INTEGRATION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + commonSettings.writeTo(out); + out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); + out.writeOptionalVInt(dimensions); + out.writeOptionalVInt(maxInputTokens); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + JinaAIEmbeddingsServiceSettings that = (JinaAIEmbeddingsServiceSettings) o; + return Objects.equals(commonSettings, that.commonSettings) + && Objects.equals(similarity, that.similarity) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(maxInputTokens, that.maxInputTokens); + } + + @Override + public int hashCode() { + return Objects.hash(commonSettings, similarity, dimensions, maxInputTokens); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettings.java new file mode 100644 index 0000000000000..77150b5097aa6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettings.java @@ -0,0 +1,183 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; + +/** + * Defines the task settings for the JinaAI text embeddings service. + * + */ +public class JinaAIEmbeddingsTaskSettings implements TaskSettings { + + public static final String NAME = "jinaai_embeddings_task_settings"; + public static final JinaAIEmbeddingsTaskSettings EMPTY_SETTINGS = new JinaAIEmbeddingsTaskSettings((InputType) null); + static final String INPUT_TYPE = "input_type"; + static final EnumSet VALID_REQUEST_VALUES = EnumSet.of( + InputType.INGEST, + InputType.SEARCH, + InputType.CLASSIFICATION, + InputType.CLUSTERING + ); + + public static JinaAIEmbeddingsTaskSettings fromMap(Map map) { + if (map == null || map.isEmpty()) { + return EMPTY_SETTINGS; + } + + ValidationException validationException = new ValidationException(); + + InputType inputType = extractOptionalEnum( + map, + INPUT_TYPE, + ModelConfigurations.TASK_SETTINGS, + InputType::fromString, + VALID_REQUEST_VALUES, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new JinaAIEmbeddingsTaskSettings(inputType); + } + + /** + * Creates a new {@link JinaAIEmbeddingsTaskSettings} by preferring non-null fields from the provided parameters. + * For the input type, preference is given to requestInputType if it is not null and not UNSPECIFIED. + * Then preference is given to the requestTaskSettings and finally to originalSettings even if the value is null. + * + * Similarly, for the truncation field preference is given to requestTaskSettings if it is not null and then to + * originalSettings. + * @param originalSettings the settings stored as part of the inference entity configuration + * @param requestTaskSettings the settings passed in within the task_settings field of the request + * @param requestInputType the input type passed in the request parameters + * @return a constructed {@link JinaAIEmbeddingsTaskSettings} + */ + public static JinaAIEmbeddingsTaskSettings of( + JinaAIEmbeddingsTaskSettings originalSettings, + JinaAIEmbeddingsTaskSettings requestTaskSettings, + InputType requestInputType + ) { + var inputTypeToUse = getValidInputType(originalSettings, requestTaskSettings, requestInputType); + + return new JinaAIEmbeddingsTaskSettings(inputTypeToUse); + } + + private static InputType getValidInputType( + JinaAIEmbeddingsTaskSettings originalSettings, + JinaAIEmbeddingsTaskSettings requestTaskSettings, + InputType requestInputType + ) { + InputType inputTypeToUse = originalSettings.inputType; + + if (VALID_REQUEST_VALUES.contains(requestInputType)) { + inputTypeToUse = requestInputType; + } else if (requestTaskSettings.inputType != null) { + inputTypeToUse = requestTaskSettings.inputType; + } + + return inputTypeToUse; + } + + private final InputType inputType; + + public JinaAIEmbeddingsTaskSettings(StreamInput in) throws IOException { + this(in.readOptionalEnum(InputType.class)); + } + + public JinaAIEmbeddingsTaskSettings(@Nullable InputType inputType) { + validateInputType(inputType); + this.inputType = inputType; + } + + private static void validateInputType(InputType inputType) { + if (inputType == null) { + return; + } + + assert VALID_REQUEST_VALUES.contains(inputType) : invalidInputTypeMessage(inputType); + } + + @Override + public boolean isEmpty() { + return inputType == null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (inputType != null) { + builder.field(INPUT_TYPE, inputType); + } + + builder.endObject(); + return builder; + } + + public InputType getInputType() { + return inputType; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.JINA_AI_INTEGRATION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalEnum(inputType); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + JinaAIEmbeddingsTaskSettings that = (JinaAIEmbeddingsTaskSettings) o; + return Objects.equals(inputType, that.inputType); + } + + @Override + public int hashCode() { + return Objects.hash(inputType); + } + + public static String invalidInputTypeMessage(InputType inputType) { + return Strings.format("received invalid input type value [%s]", inputType.toString()); + } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + JinaAIEmbeddingsTaskSettings updatedSettings = JinaAIEmbeddingsTaskSettings.fromMap(new HashMap<>(newSettings)); + return of(this, updatedSettings, updatedSettings.inputType != null ? updatedSettings.inputType : this.inputType); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankModel.java new file mode 100644 index 0000000000000..2fb9228d3b652 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankModel.java @@ -0,0 +1,148 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai.rerank; + +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.jinaai.JinaAIActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.net.URI; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings.RETURN_DOCUMENTS; +import static org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings.TOP_N_DOCS_ONLY; + +public class JinaAIRerankModel extends JinaAIModel { + public static JinaAIRerankModel of(JinaAIRerankModel model, Map taskSettings) { + var requestTaskSettings = JinaAIRerankTaskSettings.fromMap(taskSettings); + return new JinaAIRerankModel(model, JinaAIRerankTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); + } + + public JinaAIRerankModel( + String inferenceId, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceId, + service, + JinaAIRerankServiceSettings.fromMap(serviceSettings, context), + JinaAIRerankTaskSettings.fromMap(taskSettings), + DefaultSecretSettings.fromMap(secrets) + ); + } + + // should only be used for testing + JinaAIRerankModel( + String modelId, + String service, + JinaAIRerankServiceSettings serviceSettings, + JinaAIRerankTaskSettings taskSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + super( + new ModelConfigurations(modelId, TaskType.RERANK, service, serviceSettings, taskSettings), + new ModelSecrets(secretSettings), + secretSettings, + serviceSettings.getCommonSettings() + ); + } + + private JinaAIRerankModel(JinaAIRerankModel model, JinaAIRerankTaskSettings taskSettings) { + super(model, taskSettings); + } + + public JinaAIRerankModel(JinaAIRerankModel model, JinaAIRerankServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + @Override + public JinaAIRerankServiceSettings getServiceSettings() { + return (JinaAIRerankServiceSettings) super.getServiceSettings(); + } + + @Override + public JinaAIRerankTaskSettings getTaskSettings() { + return (JinaAIRerankTaskSettings) super.getTaskSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + /** + * Accepts a visitor to create an executable action. The returned action will not return documents in the response. + * @param visitor _ + * @param taskSettings _ + * @param inputType ignored for rerank task + * @return the rerank action + */ + @Override + public ExecutableAction accept(JinaAIActionVisitor visitor, Map taskSettings, InputType inputType) { + return visitor.create(this, taskSettings); + } + + @Override + public URI uri() { + return getServiceSettings().getCommonSettings().uri(); + } + + public static class Configuration { + public static Map get() { + return configuration.getOrCompute(); + } + + private static final LazyInitializable, RuntimeException> configuration = + new LazyInitializable<>(() -> { + var configurationMap = new HashMap(); + + configurationMap.put( + RETURN_DOCUMENTS, + new SettingsConfiguration.Builder().setDisplay(SettingsConfigurationDisplayType.TOGGLE) + .setLabel("Return Documents") + .setOrder(1) + .setRequired(false) + .setSensitive(false) + .setTooltip("Specify whether to return doc text within the results.") + .setType(SettingsConfigurationFieldType.BOOLEAN) + .setValue(false) + .build() + ); + configurationMap.put( + TOP_N_DOCS_ONLY, + new SettingsConfiguration.Builder().setDisplay(SettingsConfigurationDisplayType.NUMERIC) + .setLabel("Top N") + .setOrder(2) + .setRequired(false) + .setSensitive(false) + .setTooltip("The number of most relevant documents to return, defaults to the number of the documents.") + .setType(SettingsConfigurationFieldType.INTEGER) + .build() + ); + + return Collections.unmodifiableMap(configurationMap); + }); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettings.java new file mode 100644 index 0000000000000..a9e492c2738c2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettings.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai.rerank; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +public class JinaAIRerankServiceSettings extends FilteredXContentObject implements ServiceSettings, JinaAIRateLimitServiceSettings { + public static final String NAME = "jinaai_rerank_service_settings"; + + private static final Logger logger = LogManager.getLogger(JinaAIRerankServiceSettings.class); + + public static JinaAIRerankServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + var commonServiceSettings = JinaAIServiceSettings.fromMap(map, context); + + return new JinaAIRerankServiceSettings(commonServiceSettings); + } + + private final JinaAIServiceSettings commonSettings; + + public JinaAIRerankServiceSettings(JinaAIServiceSettings commonSettings) { + this.commonSettings = commonSettings; + } + + public JinaAIRerankServiceSettings(StreamInput in) throws IOException { + this.commonSettings = new JinaAIServiceSettings(in); + } + + public JinaAIServiceSettings getCommonSettings() { + return commonSettings; + } + + @Override + public String modelId() { + return commonSettings.modelId(); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return commonSettings.rateLimitSettings(); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder = commonSettings.toXContentFragment(builder, params); + + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + commonSettings.toXContentFragmentOfExposedFields(builder, params); + return builder; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.JINA_AI_INTEGRATION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + commonSettings.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + JinaAIRerankServiceSettings that = (JinaAIRerankServiceSettings) o; + return Objects.equals(commonSettings, that.commonSettings); + } + + @Override + public int hashCode() { + return Objects.hash(commonSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankTaskSettings.java new file mode 100644 index 0000000000000..8dd93d89b6fb5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankTaskSettings.java @@ -0,0 +1,166 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; + +/** + * Defines the task settings for the JinaAI rerank service. + * + */ +public class JinaAIRerankTaskSettings implements TaskSettings { + + public static final String NAME = "jinaai_rerank_task_settings"; + public static final String RETURN_DOCUMENTS = "return_documents"; + public static final String TOP_N_DOCS_ONLY = "top_n"; + + public static final JinaAIRerankTaskSettings EMPTY_SETTINGS = new JinaAIRerankTaskSettings(null, null); + + public static JinaAIRerankTaskSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + if (map == null || map.isEmpty()) { + return EMPTY_SETTINGS; + } + + Boolean returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS, validationException); + Integer topNDocumentsOnly = extractOptionalPositiveInteger( + map, + TOP_N_DOCS_ONLY, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return of(topNDocumentsOnly, returnDocuments); + } + + /** + * Creates a new {@link JinaAIRerankTaskSettings} by preferring non-null fields from the request settings over the original settings. + * + * @param originalSettings the settings stored as part of the inference entity configuration + * @param requestTaskSettings the settings passed in within the task_settings field of the request + * @return a constructed {@link JinaAIRerankTaskSettings} + */ + public static JinaAIRerankTaskSettings of(JinaAIRerankTaskSettings originalSettings, JinaAIRerankTaskSettings requestTaskSettings) { + return new JinaAIRerankTaskSettings( + requestTaskSettings.getTopNDocumentsOnly() != null + ? requestTaskSettings.getTopNDocumentsOnly() + : originalSettings.getTopNDocumentsOnly(), + requestTaskSettings.getReturnDocuments() != null + ? requestTaskSettings.getReturnDocuments() + : originalSettings.getReturnDocuments() + ); + } + + public static JinaAIRerankTaskSettings of(Integer topNDocumentsOnly, Boolean returnDocuments) { + return new JinaAIRerankTaskSettings(topNDocumentsOnly, returnDocuments); + } + + private final Integer topNDocumentsOnly; + private final Boolean returnDocuments; + + public JinaAIRerankTaskSettings(StreamInput in) throws IOException { + this(in.readOptionalInt(), in.readOptionalBoolean()); + } + + public JinaAIRerankTaskSettings(@Nullable Integer topNDocumentsOnly, @Nullable Boolean doReturnDocuments) { + this.topNDocumentsOnly = topNDocumentsOnly; + this.returnDocuments = doReturnDocuments; + } + + @Override + public boolean isEmpty() { + return topNDocumentsOnly == null && returnDocuments == null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (topNDocumentsOnly != null) { + builder.field(TOP_N_DOCS_ONLY, topNDocumentsOnly); + } + if (returnDocuments != null) { + builder.field(RETURN_DOCUMENTS, returnDocuments); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.JINA_AI_INTEGRATION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalInt(topNDocumentsOnly); + out.writeOptionalBoolean(returnDocuments); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + JinaAIRerankTaskSettings that = (JinaAIRerankTaskSettings) o; + return Objects.equals(returnDocuments, that.returnDocuments) && Objects.equals(topNDocumentsOnly, that.topNDocumentsOnly); + } + + @Override + public int hashCode() { + return Objects.hash(returnDocuments, topNDocumentsOnly); + } + + public static String invalidInputTypeMessage(InputType inputType) { + return Strings.format("received invalid input type value [%s]", inputType.toString()); + } + + public Boolean getDoesReturnDocuments() { + return returnDocuments; + } + + public Integer getTopNDocumentsOnly() { + return topNDocumentsOnly; + } + + public Boolean getReturnDocuments() { + return returnDocuments; + } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + JinaAIRerankTaskSettings updatedSettings = JinaAIRerankTaskSettings.fromMap(new HashMap<>(newSettings)); + return JinaAIRerankTaskSettings.of(this, updatedSettings); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/jinaai/JinaAIResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/jinaai/JinaAIResponseHandlerTests.java new file mode 100644 index 0000000000000..4c18915e0187b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/jinaai/JinaAIResponseHandlerTests.java @@ -0,0 +1,138 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.jinaai; + +import org.apache.http.Header; +import org.apache.http.HeaderElement; +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.hamcrest.MatcherAssert; + +import java.nio.charset.StandardCharsets; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.core.Is.is; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class JinaAIResponseHandlerTests extends ESTestCase { + public void testCheckForFailureStatusCode_DoesNotThrowForStatusCodesBetween200And299() { + callCheckForFailureStatusCode(randomIntBetween(200, 299), "id"); + } + + public void testCheckForFailureStatusCode_ThrowsFor503() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(503, "id")); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received a server error status code for request from inference entity id [id] status [503]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor500_WithShouldRetryTrue() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(500, "id")); + assertTrue(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received a server error status code for request from inference entity id [id] status [500]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor429_WithShouldRetryTrue() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(429, "id")); + assertTrue(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received a rate limit status code for request from inference entity id [id] status [429]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.TOO_MANY_REQUESTS)); + } + + public void testCheckForFailureStatusCode_ThrowsFor400() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(400, "id")); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received an input validation error response for request from inference entity id [id] status [400]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor400_InputsTooLarge() { + var exception = expectThrows( + RetryException.class, + () -> callCheckForFailureStatusCode(400, "\"input\" length 2049 is larger than the largest allowed size 2048", "id") + ); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString("Received an input validation error response for request from inference entity id [id] status [400]") + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor401() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(401, "inferenceEntityId")); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat( + exception.getCause().getMessage(), + containsString( + "Received an authentication error status code for request from inference entity id [inferenceEntityId] status [401]" + ) + ); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.UNAUTHORIZED)); + } + + public void testCheckForFailureStatusCode_ThrowsFor402() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(402, "inferenceEntityId")); + assertFalse(exception.shouldRetry()); + MatcherAssert.assertThat(exception.getCause().getMessage(), containsString("Payment required")); + MatcherAssert.assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.PAYMENT_REQUIRED)); + } + + private static void callCheckForFailureStatusCode(int statusCode, String modelId) { + callCheckForFailureStatusCode(statusCode, null, modelId); + } + + private static void callCheckForFailureStatusCode(int statusCode, @Nullable String errorMessage, String modelId) { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + + var httpResponse = mock(HttpResponse.class); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + var header = mock(Header.class); + when(header.getElements()).thenReturn(new HeaderElement[] {}); + when(httpResponse.getFirstHeader(anyString())).thenReturn(header); + + String escapedErrorMessage = errorMessage != null ? errorMessage.replace("\\", "\\\\").replace("\"", "\\\"") : errorMessage; + + String responseJson = Strings.format(""" + { + "detail": "%s" + } + """, escapedErrorMessage); + + var mockRequest = mock(Request.class); + when(mockRequest.getInferenceEntityId()).thenReturn(modelId); + var httpResult = new HttpResult(httpResponse, errorMessage == null ? new byte[] {} : responseJson.getBytes(StandardCharsets.UTF_8)); + var handler = new JinaAIResponseHandler("", (request, result) -> null); + + handler.checkForFailureStatusCode(mockRequest, httpResult); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..7f3f6e5cdeb82 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntityTests.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.jinaai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.CoreMatchers.is; + +public class JinaAIEmbeddingsRequestEntityTests extends ESTestCase { + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = new JinaAIEmbeddingsRequestEntity(List.of("abc"), new JinaAIEmbeddingsTaskSettings(InputType.INGEST), "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"],"model":"model","task":"retrieval.passage"}""")); + } + + public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { + var entity = new JinaAIEmbeddingsRequestEntity(List.of("abc"), JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"input":["abc"],"model":"model"}""")); + } + + public void testConvertToString_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() { + var thrownException = expectThrows( + AssertionError.class, + () -> JinaAIEmbeddingsRequestEntity.convertToString(InputType.UNSPECIFIED) + ); + MatcherAssert.assertThat(thrownException.getMessage(), is("received invalid input type value [unspecified]")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestTests.java new file mode 100644 index 0000000000000..05194ceb0de9e --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestTests.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.jinaai; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class JinaAIEmbeddingsRequestTests extends ESTestCase { + public void testCreateRequest_UrlDefined() throws IOException { + var request = createRequest( + List.of("abc"), + JinaAIEmbeddingsModelTests.createModel("url", "secret", JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, "model") + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(JinaAIUtils.REQUEST_SOURCE_HEADER).getValue(), + is(JinaAIUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "model"))); + } + + public void testCreateRequest_AllOptionsDefined() throws IOException { + var request = createRequest( + List.of("abc"), + JinaAIEmbeddingsModelTests.createModel("url", "secret", new JinaAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model") + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(JinaAIUtils.REQUEST_SOURCE_HEADER).getValue(), + is(JinaAIUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "model", "task", "retrieval.passage"))); + } + + public void testCreateRequest_InputTypeSearch() throws IOException { + var request = createRequest( + List.of("abc"), + JinaAIEmbeddingsModelTests.createModel("url", "secret", new JinaAIEmbeddingsTaskSettings(InputType.SEARCH), null, null, "model") + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(JinaAIUtils.REQUEST_SOURCE_HEADER).getValue(), + is(JinaAIUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "model", "task", "retrieval.query"))); + } + + public static JinaAIEmbeddingsRequest createRequest(List input, JinaAIEmbeddingsModel model) { + return new JinaAIEmbeddingsRequest(input, model); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRequestTests.java new file mode 100644 index 0000000000000..031b44225628c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRequestTests.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.jinaai; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.jinaai.JinaAIAccount; + +import java.net.URI; + +import static org.hamcrest.Matchers.is; + +public class JinaAIRequestTests extends ESTestCase { + + public void testDecorateWithAuthHeader() { + var request = new HttpPost("http://www.abc.com"); + + JinaAIRequest.decorateWithAuthHeader( + request, + new JinaAIAccount(URI.create("http://www.abc.com"), new SecureString(new char[] { 'a', 'b', 'c' })) + ); + + assertThat(request.getFirstHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(request.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer abc")); + assertThat(request.getFirstHeader(JinaAIUtils.REQUEST_SOURCE_HEADER).getValue(), is(JinaAIUtils.ELASTIC_REQUEST_SOURCE)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntityTests.java new file mode 100644 index 0000000000000..7fd738fa2a8e4 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntityTests.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.jinaai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; +import static org.hamcrest.MatcherAssert.assertThat; + +public class JinaAIRerankRequestEntityTests extends ESTestCase { + public void testXContent_SingleRequest_WritesModelAndTopNIfDefined() throws IOException { + var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), new JinaAIRerankTaskSettings(8, null), "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc" + ], + "top_n": 8 + } + """)); + } + + public void testXContent_SingleRequest_WritesModelAndTopNIfDefined_ReturnDocumentsTrue() throws IOException { + var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), new JinaAIRerankTaskSettings(8, true), "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc" + ], + "top_n": 8, + "return_documents": true + } + """)); + } + + public void testXContent_SingleRequest_WritesModelAndTopNIfDefined_ReturnDocumentsFalse() throws IOException { + var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), new JinaAIRerankTaskSettings(8, false), "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc" + ], + "top_n": 8, + "return_documents": false + } + """)); + } + + public void testXContent_SingleRequest_DoesNotWriteTopNIfNull() throws IOException { + var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), null, "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc" + ] + } + """)); + } + + public void testXContent_MultipleRequests_WritesModelAndTopNIfDefined() throws IOException { + var entity = new JinaAIRerankRequestEntity("query", List.of("abc", "def"), new JinaAIRerankTaskSettings(8, null), "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc", + "def" + ], + "top_n": 8 + } + """)); + } + + public void testXContent_MultipleRequests_DoesNotWriteTopNIfNull() throws IOException { + var entity = new JinaAIRerankRequestEntity("query", List.of("abc", "def"), null, "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc", + "def" + ] + } + """)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestTests.java new file mode 100644 index 0000000000000..819362d397ba5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestTests.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.jinaai; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModelTests; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class JinaAIRerankRequestTests extends ESTestCase { + + private static final String API_KEY = "foo"; + + public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOException { + var input = "input"; + var query = "query"; + var modelId = "model"; + + var request = createRequest(query, input, modelId, null); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + API_KEY)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + + assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap.get("documents"), is(List.of(input))); + assertThat(requestMap.get("query"), is(query)); + assertThat(requestMap.get("model"), is(modelId)); + } + + public void testCreateRequest_WithTopNSet() throws IOException { + var input = "input"; + var query = "query"; + var topN = 1; + var modelId = "model"; + + var request = createRequest(query, input, modelId, topN); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + API_KEY)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + + assertThat(requestMap, aMapWithSize(4)); + assertThat(requestMap.get("documents"), is(List.of(input))); + assertThat(requestMap.get("query"), is(query)); + assertThat(requestMap.get("top_n"), is(topN)); + assertThat(requestMap.get("model"), is(modelId)); + } + + public void testCreateRequest_WithModelSet() throws IOException { + var input = "input"; + var query = "query"; + var modelId = "model"; + + var request = createRequest(query, input, modelId, null); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + API_KEY)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + + assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap.get("documents"), is(List.of(input))); + assertThat(requestMap.get("query"), is(query)); + assertThat(requestMap.get("model"), is(modelId)); + } + + public void testTruncate_DoesNotTruncate() { + var request = createRequest("query", "input", "null", null); + var truncatedRequest = request.truncate(); + + assertThat(truncatedRequest, sameInstance(request)); + } + + private static JinaAIRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topN) { + var rerankModel = JinaAIRerankModelTests.createModel(API_KEY, modelId, topN); + return new JinaAIRerankRequest(query, List.of(input), rerankModel); + + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIUtilsTests.java new file mode 100644 index 0000000000000..e3b4cfbed20ef --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIUtilsTests.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.jinaai; + +import org.elasticsearch.test.ESTestCase; + +import static org.hamcrest.Matchers.is; + +public class JinaAIUtilsTests extends ESTestCase { + + public void testCreateRequestSourceHeader() { + var requestSourceHeader = JinaAIUtils.createRequestSourceHeader(); + + assertThat(requestSourceHeader.getName(), is("Request-Source")); + assertThat(requestSourceHeader.getValue(), is("unspecified:elasticsearch")); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntityTests.java new file mode 100644 index 0000000000000..7dbb9d5441a4a --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntityTests.java @@ -0,0 +1,397 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.jinaai; + +import org.apache.http.HttpResponse; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class JinaAIEmbeddingsResponseEntityTests extends ESTestCase { + public void testFromResponse_CreatesResultsForASingleItem() throws IOException { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.014539449, + -0.015288644 + ] + } + ], + "model": "jina-embeddings-v3", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + InferenceTextEmbeddingFloatResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults.embeddings(), + is(List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.014539449F, -0.015288644F }))) + ); + } + + public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.014539449, + -0.015288644 + ] + }, + { + "object": "embedding", + "index": 1, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "jina-embeddings-v3", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + InferenceTextEmbeddingFloatResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults.embeddings(), + is( + List.of( + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.014539449F, -0.015288644F }), + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0123F, -0.0123F }) + ) + ) + ); + } + + public void testFromResponse_FailsWhenDataFieldIsNotPresent() { + String responseJson = """ + { + "object": "list", + "not_data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.014539449, + -0.015288644 + ] + } + ], + "model": "jina-embeddings-v3", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + var thrownException = expectThrows( + IllegalStateException.class, + () -> JinaAIEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat(thrownException.getMessage(), is("Failed to find required field [data] in JinaAI embeddings response")); + } + + public void testFromResponse_FailsWhenDataFieldNotAnArray() { + String responseJson = """ + { + "object": "list", + "data": { + "test": { + "object": "embedding", + "index": 0, + "embedding": [ + 0.014539449, + -0.015288644 + ] + } + }, + "model": "jina-embeddings-v3", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + var thrownException = expectThrows( + ParsingException.class, + () -> JinaAIEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat( + thrownException.getMessage(), + is("Failed to parse object: expecting token of type [START_ARRAY] but found [START_OBJECT]") + ); + } + + public void testFromResponse_FailsWhenEmbeddingsDoesNotExist() { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embeddingzzz": [ + 0.014539449, + -0.015288644 + ] + } + ], + "model": "jina-embeddings-v3", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + var thrownException = expectThrows( + IllegalStateException.class, + () -> JinaAIEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat(thrownException.getMessage(), is("Failed to find required field [embedding] in JinaAI embeddings response")); + } + + public void testFromResponse_FailsWhenEmbeddingValueIsAString() { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + "abc" + ] + } + ], + "model": "jina-embeddings-v3", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + var thrownException = expectThrows( + ParsingException.class, + () -> JinaAIEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat( + thrownException.getMessage(), + is("Failed to parse object: expecting token of type [VALUE_NUMBER] but found [VALUE_STRING]") + ); + } + + public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOException { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 1 + ] + } + ], + "model": "jina-embeddings-v3", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + InferenceTextEmbeddingFloatResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults.embeddings(), + is(List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 1.0F }))) + ); + } + + public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOException { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 40294967295 + ] + } + ], + "model": "jina-embeddings-v3", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + InferenceTextEmbeddingFloatResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults.embeddings(), + is(List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 4.0294965E10F }))) + ); + } + + public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + {} + ] + } + ], + "model": "jina-embeddings-v3", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + + var thrownException = expectThrows( + ParsingException.class, + () -> JinaAIEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat( + thrownException.getMessage(), + is("Failed to parse object: expecting token of type [VALUE_NUMBER] but found [START_OBJECT]") + ); + } + + public void testFieldsInDifferentOrderServer() throws IOException { + // The fields of the objects in the data array are reordered + String response = """ + { + "object": "list", + "id": "6667830b-716b-4796-9a61-33b67b5cc81d", + "model": "jina-embeddings-v3", + "data": [ + { + "embedding": [ + -0.9, + 0.5, + 0.3 + ], + "index": 0, + "object": "embedding" + }, + { + "index": 0, + "embedding": [ + 0.1, + 0.5 + ], + "object": "embedding" + }, + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.5, + 0.5 + ] + } + ], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + } + }"""; + + InferenceTextEmbeddingFloatResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), response.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults.embeddings(), + is( + List.of( + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { -0.9F, 0.5F, 0.3F }), + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.1F, 0.5F }), + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.5F, 0.5F }) + ) + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIErrorResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIErrorResponseEntityTests.java new file mode 100644 index 0000000000000..ce3bd10566cd8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIErrorResponseEntityTests.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.jinaai; + +import org.apache.http.HttpResponse; +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.hamcrest.MatcherAssert; + +import java.nio.charset.StandardCharsets; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class JinaAIErrorResponseEntityTests extends ESTestCase { + public void testFromResponse() { + String message = "\"input\" length 2049 is larger than the largest allowed size 2048"; + String escapedMessage = message.replace("\\", "\\\\").replace("\"", "\\\""); + String responseJson = Strings.format(""" + { + "detail": "%s" + } + """, escapedMessage); + + ErrorResponse errorResponse = JinaAIErrorResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + assertNotNull(errorResponse); + MatcherAssert.assertThat(errorResponse.getErrorMessage(), is(message)); + } + + public void testFromResponse_noMessage() { + String responseJson = """ + { + "error": "abc" + } + """; + + ErrorResponse errorResponse = JinaAIErrorResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + MatcherAssert.assertThat(errorResponse, is(ErrorResponse.UNDEFINED_ERROR)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIRerankResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIRerankResponseEntityTests.java new file mode 100644 index 0000000000000..33fe9819bd88a --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIRerankResponseEntityTests.java @@ -0,0 +1,180 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.jinaai; + +import org.apache.http.HttpResponse; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class JinaAIRerankResponseEntityTests extends ESTestCase { + + public void testResponseLiteral() throws IOException { + InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseLiteral.getBytes(StandardCharsets.UTF_8)) + ); + + MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class)); + List expected = responseLiteralDocs(); + for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) { + assertEquals(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), expected.get(i).index()); + } + } + + public void testGeneratedResponse() throws IOException { + int numDocs = randomIntBetween(1, 10); + + List expected = new ArrayList<>(numDocs); + StringBuilder responseBuilder = new StringBuilder(); + + responseBuilder.append("{"); + responseBuilder.append("\"model\": \"model\","); + responseBuilder.append("\"index\":\"").append(randomAlphaOfLength(36)).append("\","); + responseBuilder.append("\"results\": ["); + List indices = linear(numDocs); + List scores = linearFloats(numDocs); + for (int i = 0; i < numDocs; i++) { + int index = indices.remove(randomInt(indices.size() - 1)); + + responseBuilder.append("{"); + responseBuilder.append("\"index\":").append(index).append(","); + responseBuilder.append("\"relevance_score\":").append(scores.get(i).toString()).append("}"); + expected.add(new RankedDocsResults.RankedDoc(index, scores.get(i), null)); + if (i < numDocs - 1) { + responseBuilder.append(","); + } + } + responseBuilder.append("],"); + responseBuilder.append("\"usage\": {"); + responseBuilder.append("\"total_tokens\": 15}"); + responseBuilder.append("}"); + + InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseBuilder.toString().getBytes(StandardCharsets.UTF_8)) + ); + MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class)); + for (int i = 0; i < ((RankedDocsResults) parsedResults).getRankedDocs().size(); i++) { + assertEquals(((RankedDocsResults) parsedResults).getRankedDocs().get(i).index(), expected.get(i).index()); + } + } + + private ArrayList responseLiteralDocs() { + var list = new ArrayList(); + + list.add(new RankedDocsResults.RankedDoc(2, 0.98005307F, null)); + list.add(new RankedDocsResults.RankedDoc(3, 0.27904198F, null)); + list.add(new RankedDocsResults.RankedDoc(0, 0.10194652F, null)); + return list; + + }; + + private final String responseLiteral = """ + { + "model": "model", + "results": [ + { + "index": 2, + "relevance_score": 0.98005307 + }, + { + "index": 3, + "relevance_score": 0.27904198 + }, + { + "index": 0, + "relevance_score": 0.10194652 + } + ], + "usage": { + "total_tokens": 15 + } + } + """; + + public void testResponseLiteralWithDocuments() throws IOException { + InferenceServiceResults parsedResults = JinaAIRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseLiteralWithDocuments.getBytes(StandardCharsets.UTF_8)) + ); + + MatcherAssert.assertThat(parsedResults, instanceOf(RankedDocsResults.class)); + MatcherAssert.assertThat(((RankedDocsResults) parsedResults).getRankedDocs(), is(responseLiteralDocsWithText)); + } + + private final String responseLiteralWithDocuments = """ + { + "model": "model", + "results": [ + { + "document": { + "text": "Washington, D.C.." + }, + "index": 2, + "relevance_score": 0.98005307 + }, + { + "document": { + "text": "Capital punishment has existed in the United States since beforethe United States was a country. " + }, + "index": 3, + "relevance_score": 0.27904198 + }, + { + "document": { + "text": "Carson City is the capital city of the American state of Nevada." + }, + "index": 0, + "relevance_score": 0.10194652 + } + ], + "usage": { + "total_tokens": 15 + } + } + """; + + private final List responseLiteralDocsWithText = List.of( + new RankedDocsResults.RankedDoc(2, 0.98005307F, "Washington, D.C.."), + new RankedDocsResults.RankedDoc( + 3, + 0.27904198F, + "Capital punishment has existed in the United States since beforethe United States was a country. " + ), + new RankedDocsResults.RankedDoc(0, 0.10194652F, "Carson City is the capital city of the American state of Nevada.") + ); + + private ArrayList linear(int n) { + ArrayList list = new ArrayList<>(); + for (int i = 0; i <= n; i++) { + list.add(i); + } + return list; + } + + // creates a list of doubles of monotonically decreasing magnitude + private ArrayList linearFloats(int n) { + ArrayList list = new ArrayList<>(); + float startValue = 1.0f; + float decrement = startValue / n + 1; + for (int i = 0; i <= n; i++) { + list.add(startValue - (i * decrement)); + } + return list; + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceSettingsTests.java new file mode 100644 index 0000000000000..4729e9e059d93 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceSettingsTests.java @@ -0,0 +1,174 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class JinaAIServiceSettingsTests extends AbstractWireSerializingTestCase { + + public static JinaAIServiceSettings createRandomWithNonNullUrl() { + return createRandom(randomAlphaOfLength(15)); + } + + /** + * The created settings can have a url set to null. + */ + public static JinaAIServiceSettings createRandom() { + var url = randomBoolean() ? randomAlphaOfLength(15) : null; + return createRandom(url); + } + + private static JinaAIServiceSettings createRandom(String url) { + var model = randomAlphaOfLength(15); + + return new JinaAIServiceSettings(ServiceUtils.createOptionalUri(url), model, RateLimitSettingsTests.createRandom()); + } + + public void testFromMap() { + var url = "https://www.abc.com"; + var model = "model"; + var serviceSettings = JinaAIServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.URL, url, JinaAIServiceSettings.MODEL_ID, model)), + ConfigurationParseContext.REQUEST + ); + + MatcherAssert.assertThat(serviceSettings, is(new JinaAIServiceSettings(ServiceUtils.createUri(url), model, null))); + } + + public void testFromMap_WithRateLimit() { + var url = "https://www.abc.com"; + var model = "model"; + var serviceSettings = JinaAIServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + url, + JinaAIServiceSettings.MODEL_ID, + model, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3)) + ) + ), + ConfigurationParseContext.REQUEST + ); + + MatcherAssert.assertThat( + serviceSettings, + is(new JinaAIServiceSettings(ServiceUtils.createUri(url), model, new RateLimitSettings(3))) + ); + } + + public void testFromMap_WhenUsingModelId() { + var url = "https://www.abc.com"; + var model = "model"; + var serviceSettings = JinaAIServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.URL, url, JinaAIServiceSettings.MODEL_ID, model)), + ConfigurationParseContext.PERSISTENT + ); + + MatcherAssert.assertThat(serviceSettings, is(new JinaAIServiceSettings(ServiceUtils.createUri(url), model, null))); + } + + public void testFromMap_MissingUrl_DoesNotThrowException() { + var serviceSettings = JinaAIServiceSettings.fromMap( + new HashMap<>(Map.of(JinaAIServiceSettings.MODEL_ID, "model")), + ConfigurationParseContext.PERSISTENT + ); + assertNull(serviceSettings.uri()); + } + + public void testFromMap_EmptyUrl_ThrowsError() { + var thrownException = expectThrows( + ValidationException.class, + () -> JinaAIServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, "")), ConfigurationParseContext.PERSISTENT) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value empty string. [%s] must be a non-empty string;", + ServiceFields.URL + ) + ) + ); + } + + public void testFromMap_InvalidUrl_ThrowsError() { + var url = "https://www.abc^.com"; + var thrownException = expectThrows( + ValidationException.class, + () -> JinaAIServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url)), ConfigurationParseContext.PERSISTENT) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s]", url, ServiceFields.URL) + ) + ); + } + + public void testXContent_WritesModelId() throws IOException { + var entity = new JinaAIServiceSettings((String) null, "model", new RateLimitSettings(1)); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"model_id":"model","rate_limit":{"requests_per_minute":1}}""")); + } + + @Override + protected Writeable.Reader instanceReader() { + return JinaAIServiceSettings::new; + } + + @Override + protected JinaAIServiceSettings createTestInstance() { + return createRandomWithNonNullUrl(); + } + + @Override + protected JinaAIServiceSettings mutateInstance(JinaAIServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, JinaAIServiceSettingsTests::createRandom); + } + + public static Map getServiceSettingsMap(@Nullable String url, String model) { + var map = new HashMap(); + + if (url != null) { + map.put(ServiceFields.URL, url); + } + + map.put(JinaAIServiceSettings.MODEL_ID, model); + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java new file mode 100644 index 0000000000000..5a1bf8ec383c1 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -0,0 +1,2003 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file was contributed to by a generative AI + */ + +package org.elasticsearch.xpack.inference.services.jinaai; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettingsTests; +import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModelTests; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; +import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class JinaAIServiceTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testParseRequestConfig_CreatesAJinaAIEmbeddingsModel() throws IOException { + try (var service = createJinaAIService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + MatcherAssert.assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST))); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), + getSecretSettingsMap("secret") + ), + modelListener + ); + + } + } + + public void testParseRequestConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + try (var service = createJinaAIService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + MatcherAssert.assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST))); + MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + modelListener + ); + + } + } + + public void testParseRequestConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + try (var service = createJinaAIService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + MatcherAssert.assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST))); + MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), + getSecretSettingsMap("secret") + ), + modelListener + ); + + } + } + + public void testParseRequestConfig_OptionalTaskSettings() throws IOException { + try (var service = createJinaAIService()) { + + ActionListener modelListener = ActionListener.wrap(model -> { + MatcherAssert.assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), equalTo(JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + getSecretSettingsMap("secret") + ), + modelListener + ); + + } + } + + public void testParseRequestConfig_ThrowsUnsupportedTaskType() throws IOException { + try (var service = createJinaAIService()) { + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "The [jinaai] service does not support task type [sparse_embedding]" + ); + + service.parseRequestConfig( + "id", + TaskType.SPARSE_EMBEDDING, + getRequestConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), + getSecretSettingsMap("secret") + ), + failureListener + ); + } + } + + private static ActionListener getModelListenerForException(Class exceptionClass, String expectedMessage) { + return ActionListener.wrap((model) -> fail("Model parsing should have failed"), e -> { + MatcherAssert.assertThat(e, instanceOf(exceptionClass)); + MatcherAssert.assertThat(e.getMessage(), is(expectedMessage)); + }); + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createJinaAIService()) { + var config = getRequestConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), + getSecretSettingsMap("secret") + ); + config.put("extra_key", "value"); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [jinaai] service" + ); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { + try (var service = createJinaAIService()) { + var serviceSettings = JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"); + serviceSettings.put("extra_key", "value"); + + var config = getRequestConfigMap( + serviceSettings, + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + getSecretSettingsMap("secret") + ); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [jinaai] service" + ); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { + try (var service = createJinaAIService()) { + var taskSettingsMap = JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST); + taskSettingsMap.put("extra_key", "value"); + + var config = getRequestConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + taskSettingsMap, + getSecretSettingsMap("secret") + ); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [jinaai] service" + ); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); + + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { + try (var service = createJinaAIService()) { + var secretSettingsMap = getSecretSettingsMap("secret"); + secretSettingsMap.put("extra_key", "value"); + + var config = getRequestConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), + secretSettingsMap + ); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [jinaai] service" + ); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); + } + } + + public void testParseRequestConfig_CreatesAJinaAIEmbeddingsModelWithoutUrl() throws IOException { + try (var service = createJinaAIService()) { + var modelListener = ActionListener.wrap((model) -> { + MatcherAssert.assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + assertNull(embeddingsModel.getServiceSettings().getCommonSettings().uri()); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, (e) -> fail("Model parsing should have succeeded " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model"), + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), + getSecretSettingsMap("secret") + ), + modelListener + ); + + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModel() throws IOException { + try (var service = createJinaAIService()) { + var persistedConfig = getPersistedConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + try (var service = createJinaAIService()) { + var persistedConfig = getPersistedConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); + MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + try (var service = createJinaAIService()) { + var persistedConfig = getPersistedConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); + MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { + try (var service = createJinaAIService()) { + var persistedConfig = getPersistedConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "oldmodel"), + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), + getSecretSettingsMap("secret") + ); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parsePersistedConfigWithSecrets( + "id", + TaskType.SPARSE_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is("Failed to parse stored model [id] for [jinaai] service, please delete and add the service again") + ); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModelWithoutUrl() throws IOException { + try (var service = createJinaAIService()) { + var persistedConfig = getPersistedConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model"), + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + assertNull(embeddingsModel.getServiceSettings().getCommonSettings().uri()); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST))); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createJinaAIService()) { + var persistedConfig = getPersistedConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH), + getSecretSettingsMap("secret") + ); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.SEARCH))); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + try (var service = createJinaAIService()) { + var secretSettingsMap = getSecretSettingsMap("secret"); + secretSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), + secretSettingsMap + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { + try (var service = createJinaAIService()) { + var persistedConfig = getPersistedConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + getSecretSettingsMap("secret") + ); + persistedConfig.secrets().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + try (var service = createJinaAIService()) { + var serviceSettingsMap = JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"); + serviceSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + serviceSettingsMap, + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + try (var service = createJinaAIService()) { + var taskSettingsMap = JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH); + taskSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + taskSettingsMap, + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + MatcherAssert.assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.SEARCH))); + MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModel() throws IOException { + try (var service = createJinaAIService()) { + var persistedConfig = getPersistedConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { + try (var service = createJinaAIService()) { + var persistedConfig = getPersistedConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + createRandomChunkingSettingsMap() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); + MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { + try (var service = createJinaAIService()) { + var persistedConfig = getPersistedConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); + MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() throws IOException { + try (var service = createJinaAIService()) { + var persistedConfig = getPersistedConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model_old"), + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty() + ); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is("Failed to parse stored model [id] for [jinaai] service, please delete and add the service again") + ); + } + } + + public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModelWithoutUrl() throws IOException { + try (var service = createJinaAIService()) { + var persistedConfig = getPersistedConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model"), + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + assertNull(embeddingsModel.getServiceSettings().getCommonSettings().uri()); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createJinaAIService()) { + var persistedConfig = getPersistedConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty() + ); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS)); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + try (var service = createJinaAIService()) { + var serviceSettingsMap = JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"); + serviceSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + serviceSettingsMap, + JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH) + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + MatcherAssert.assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.SEARCH))); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + try (var service = createJinaAIService()) { + var taskSettingsMap = JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST); + taskSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"), + taskSettingsMap + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + MatcherAssert.assertThat(model, instanceOf(JinaAIEmbeddingsModel.class)); + + var embeddingsModel = (JinaAIEmbeddingsModel) model; + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url")); + MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model")); + MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST))); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testInfer_ThrowsErrorWhenModelIsNotJinaAIModel() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var mockModel = getInvalidModel("model_id", "service_name"); + + try (var service = new JinaAIService(factory, createWithEmptySettings(threadPool))) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + mockModel, + null, + List.of(""), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + MatcherAssert.assertThat( + thrownException.getMessage(), + is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") + ); + + verify(factory, times(1)).createSender(); + verify(sender, times(1)).start(); + } + + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + } + + public void testCheckModelConfig_UpdatesDimensions() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "model": "jina-clip-v2", + "object": "list", + "usage": { + "total_tokens": 5, + "prompt_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 10, + 1, + "jina-clip-v2" + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat( + result, + // the dimension is set to 2 because there are 2 embeddings returned from the mock server + is( + JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 10, + 2, + "jina-clip-v2" + ) + ) + ); + } + } + + public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "model": "jina-clip-v2", + "object": "list", + "usage": { + "total_tokens": 5, + "prompt_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 10, + 1, + "jina-clip-v2", + null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat( + result, + // the dimension is set to 2 because there are 2 embeddings returned from the mock server + is( + JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 10, + 2, + "jina-clip-v2", + SimilarityMeasure.DOT_PRODUCT + ) + ) + ); + } + } + + public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosine() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "model": "jina-clip-v2", + "object": "list", + "usage": { + "total_tokens": 5, + "prompt_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 10, + 1, + "jina-clip-v2", + SimilarityMeasure.COSINE + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat( + result, + // the dimension is set to 2 because there are 2 embeddings returned from the mock server + is( + JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 10, + 2, + "jina-clip-v2", + SimilarityMeasure.COSINE + ) + ) + ); + } + } + + public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException { + testUpdateModelWithEmbeddingDetails_Successful(null); + } + + public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException { + testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values())); + } + + private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + var embeddingSize = randomNonNegativeInt(); + var model = JinaAIEmbeddingsModelTests.createModel( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + randomNonNegativeInt(), + randomNonNegativeInt(), + randomAlphaOfLength(10), + similarityMeasure + ); + + Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize); + + SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? JinaAIService.defaultSimilarity() : similarityMeasure; + assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity()); + assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue()); + } + } + + public void testInfer_Embedding_UnauthorisedResponse() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "detail": "Unauthorized" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); + + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 1024, + 1024, + "model", + null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + MatcherAssert.assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); + MatcherAssert.assertThat(error.getMessage(), containsString("Error message: [Unauthorized]")); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + } + } + + public void testInfer_Rerank_UnauthorisedResponse() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "detail": "Unauthorized" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); + + var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "model", 1024, false); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + "query", + List.of("candidate1", "candidate2"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + MatcherAssert.assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); + MatcherAssert.assertThat(error.getMessage(), containsString("Error message: [Unauthorized]")); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + } + } + + public void testInfer_Embedding_Get_Response_Ingest() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "model": "jina-clip-v2", + "object": "list", + "usage": { + "total_tokens": 5, + "prompt_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 1024, + 1024, + "jina-clip-v2", + null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "jina-clip-v2", "task", "retrieval.passage"))); + } + } + + public void testInfer_Embedding_Get_Response_Search() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "model": "jina-clip-v2", + "object": "list", + "usage": { + "total_tokens": 5, + "prompt_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 1024, + 1024, + "jina-clip-v2", + null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.SEARCH, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "jina-clip-v2", "task", "retrieval.query"))); + } + } + + public void testInfer_Embedding_Get_Response_clustering() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + {"model":"jina-clip-v2","object":"list","usage":{"total_tokens":5,"prompt_tokens":5}, + "data":[{"object":"embedding","index":0,"embedding":[0.123, -0.123]}]} + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 1024, + 1024, + "jina-clip-v2", + null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.CLUSTERING, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "jina-clip-v2", "task", "separation"))); + } + } + + public void testInfer_Embedding_Get_Response_NullInputType() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "model": "jina-clip-v2", + "object": "list", + "usage": { + "total_tokens": 5, + "prompt_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + 1024, + 1024, + "jina-clip-v2", + null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer(model, null, List.of("abc"), false, new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertEquals(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })), result.asMap()); + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "jina-clip-v2"))); + } + } + + public void testInfer_Rerank_Get_Response_NoReturnDocuments_NoTopN() throws IOException { + String responseJson = """ + { + "model": "model", + "results": [ + { + "index": 2, + "relevance_score": 0.98005307 + }, + { + "index": 1, + "relevance_score": 0.27904198 + }, + { + "index": 0, + "relevance_score": 0.10194652 + } + ], + "usage": { + "total_tokens": 15 + } + } + """; + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, false); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + "query", + List.of("candidate1", "candidate2", "candidate3"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + var resultAsMap = result.asMap(); + assertThat( + resultAsMap, + is( + Map.of( + "rerank", + List.of( + Map.of("ranked_doc", Map.of("index", 2, "relevance_score", 0.98005307F)), + Map.of("ranked_doc", Map.of("index", 1, "relevance_score", 0.27904198F)), + Map.of("ranked_doc", Map.of("index", 0, "relevance_score", 0.10194652F)) + ) + ) + ) + ); + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "query", + "query", + "documents", + List.of("candidate1", "candidate2", "candidate3"), + "model", + "model", + "return_documents", + false + ) + ) + ); + + } + } + + public void testInfer_Rerank_Get_Response_NoReturnDocuments_TopN() throws IOException { + String responseJson = """ + { + "model": "model", + "results": [ + { + "index": 2, + "relevance_score": 0.98005307 + }, + { + "index": 1, + "relevance_score": 0.27904198 + }, + { + "index": 0, + "relevance_score": 0.10194652 + } + ], + "usage": { + "total_tokens": 15 + } + } + """; + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, false); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + "query", + List.of("candidate1", "candidate2", "candidate3", "candidate4"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + var resultAsMap = result.asMap(); + assertThat( + resultAsMap, + is( + Map.of( + "rerank", + List.of( + Map.of("ranked_doc", Map.of("index", 2, "relevance_score", 0.98005307F)), + Map.of("ranked_doc", Map.of("index", 1, "relevance_score", 0.27904198F)), + Map.of("ranked_doc", Map.of("index", 0, "relevance_score", 0.10194652F)) + ) + ) + ) + ); + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "query", + "query", + "documents", + List.of("candidate1", "candidate2", "candidate3", "candidate4"), + "model", + "model", + "return_documents", + false, + "top_n", + 3 + ) + ) + ); + + } + + } + + public void testInfer_Rerank_Get_Response_ReturnDocumentsNull_NoTopN() throws IOException { + String responseJson = """ + { + "model": "model", + "results": [ + { + "index": 2, + "relevance_score": 0.98005307, + "document": { + "text": "candidate3" + } + }, + { + "index": 1, + "relevance_score": 0.27904198, + "document": { + "text": "candidate2" + } + }, + { + "index": 0, + "relevance_score": 0.10194652, + "document": { + "text": "candidate1" + } + } + ], + "usage": { + "total_tokens": 15 + } + } + """; + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", null, null); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + "query", + List.of("candidate1", "candidate2", "candidate3"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + var resultAsMap = result.asMap(); + assertThat( + resultAsMap, + is( + Map.of( + "rerank", + List.of( + Map.of("ranked_doc", Map.of("text", "candidate3", "index", 2, "relevance_score", 0.98005307F)), + Map.of("ranked_doc", Map.of("text", "candidate2", "index", 1, "relevance_score", 0.27904198F)), + Map.of("ranked_doc", Map.of("text", "candidate1", "index", 0, "relevance_score", 0.10194652F)) + ) + ) + ) + ); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat( + requestMap, + is(Map.of("query", "query", "documents", List.of("candidate1", "candidate2", "candidate3"), "model", "model")) + ); + + } + + } + + public void testInfer_Rerank_Get_Response_ReturnDocuments_TopN() throws IOException { + String responseJson = """ + { + "model": "model", + "results": [ + { + "index": 2, + "relevance_score": 0.98005307, + "document": { + "text": "candidate3" + } + }, + { + "index": 1, + "relevance_score": 0.27904198, + "document": { + "text": "candidate2" + } + }, + { + "index": 0, + "relevance_score": 0.10194652, + "document": { + "text": "candidate1" + } + } + ], + "usage": { + "total_tokens": 15 + } + } + """; + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var model = JinaAIRerankModelTests.createModel(getUrl(webServer), "secret", "model", 3, true); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + "query", + List.of("candidate1", "candidate2", "candidate3", "candidate4"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + var resultAsMap = result.asMap(); + assertThat( + resultAsMap, + is( + Map.of( + "rerank", + List.of( + Map.of("ranked_doc", Map.of("text", "candidate3", "index", 2, "relevance_score", 0.98005307F)), + Map.of("ranked_doc", Map.of("text", "candidate2", "index", 1, "relevance_score", 0.27904198F)), + Map.of("ranked_doc", Map.of("text", "candidate1", "index", 0, "relevance_score", 0.10194652F)) + ) + ) + ) + ); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "query", + "query", + "documents", + List.of("candidate1", "candidate2", "candidate3", "candidate4"), + "model", + "model", + "return_documents", + true, + "top_n", + 3 + ) + ) + ); + + } + + } + + public void testInfer_Embedding_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "model": "jina-clip-v2", + "object": "list", + "usage": { + "total_tokens": 5, + "prompt_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + new JinaAIEmbeddingsTaskSettings((InputType) null), + 1024, + 1024, + "jina-clip-v2", + null + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.UNSPECIFIED, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat(result.asMap(), Matchers.is(buildExpectationFloat(List.of(new float[] { 0.123F, -0.123F })))); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "jina-clip-v2"))); + } + } + + public void test_Embedding_ChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOException { + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + new JinaAIEmbeddingsTaskSettings((InputType) null), + createRandomChunkingSettings(), + 1024, + 1024, + "jina-clip-v2" + ); + + test_Embedding_ChunkedInfer_BatchesCalls(model); + } + + public void test_Embedding_ChunkedInfer_ChunkingSettingsNotSet() throws IOException { + var model = JinaAIEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + new JinaAIEmbeddingsTaskSettings((InputType) null), + null, + 1024, + 1024, + "jina-clip-v2" + ); + + test_Embedding_ChunkedInfer_BatchesCalls(model); + } + + private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel model) throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { + + // Batching will call the service with 2 input + String responseJson = """ + { + "model": "jina-clip-v2", + "object": "list", + "usage": { + "total_tokens": 5, + "prompt_tokens": 5 + }, + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + }, + { + "object": "embedding", + "index": 1, + "embedding": [ + 0.223, + -0.223 + ] + } + ] + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture> listener = new PlainActionFuture<>(); + // 2 input + service.chunkedInfer( + model, + null, + List.of("foo", "bar"), + new HashMap<>(), + InputType.UNSPECIFIED, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, hasSize(2)); + { + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals("foo", floatResult.chunks().get(0).matchedText()); + assertArrayEquals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding(), 0.0f); + } + { + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals("bar", floatResult.chunks().get(0).matchedText()); + assertArrayEquals(new float[] { 0.223f, -0.223f }, floatResult.chunks().get(0).embedding(), 0.0f); + } + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + MatcherAssert.assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaType()) + ); + MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("foo", "bar"), "model", "jina-clip-v2"))); + } + } + + public void testDefaultSimilarity() { + assertEquals(SimilarityMeasure.DOT_PRODUCT, JinaAIService.defaultSimilarity()); + } + + @SuppressWarnings("checkstyle:LineLength") + public void testGetConfiguration() throws Exception { + try (var service = createJinaAIService()) { + String content = XContentHelper.stripWhitespace( + """ + { + "provider": "jinaai", + "task_types": [ + { + "task_type": "text_embedding", + "configuration": { + "task": { + "default_value": null, + "depends_on": [], + "display": "dropdown", + "label": "Task", + "options": [ + { + "label": "retrieval.query", + "value": "retrieval.query" + }, + { + "label": "retrieval.passage", + "value": "retrieval.passage" + }, + { + "label": "classification", + "value": "classification" + }, + { + "label": "separation", + "value": "separation" + } + ], + "order": 1, + "required": false, + "sensitive": false, + "tooltip": "Specifies the task type passed to the model.", + "type": "str", + "ui_restrictions": [], + "validations": [], + "value": "" + } + } + }, + { + "task_type": "rerank", + "configuration": { + "top_n": { + "default_value": null, + "depends_on": [], + "display": "numeric", + "label": "Top N", + "order": 2, + "required": false, + "sensitive": false, + "tooltip": "The number of most relevant documents to return, defaults to the number of the documents.", + "type": "int", + "ui_restrictions": [], + "validations": [], + "value": null + }, + "return_documents": { + "default_value": null, + "depends_on": [], + "display": "toggle", + "label": "Return Documents", + "order": 1, + "required": false, + "sensitive": false, + "tooltip": "Specify whether to return doc text within the results.", + "type": "bool", + "ui_restrictions": [], + "validations": [], + "value": false + } + } + } + ], + "configuration": { + "api_key": { + "default_value": null, + "depends_on": [], + "display": "textbox", + "label": "API Key", + "order": 1, + "required": true, + "sensitive": true, + "tooltip": "API Key for the provider you're connecting to.", + "type": "str", + "ui_restrictions": [], + "validations": [], + "value": null + }, + "rate_limit.requests_per_minute": { + "default_value": null, + "depends_on": [], + "display": "numeric", + "label": "Rate Limit", + "order": 6, + "required": false, + "sensitive": false, + "tooltip": "Minimize the number of rate limit errors.", + "type": "int", + "ui_restrictions": [], + "validations": [], + "value": null + } + } + } + """ + ); + InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( + new BytesArray(content), + XContentType.JSON + ); + boolean humanReadable = true; + BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); + InferenceServiceConfiguration serviceConfiguration = service.getConfiguration(); + assertToXContentEquivalent( + originalBytes, + toXContent(serviceConfiguration, XContentType.JSON, humanReadable), + XContentType.JSON + ); + } + } + + public void testDoesNotSupportsStreaming() throws IOException { + try (var service = new JinaAIService(mock(), createWithEmptySettings(mock()))) { + assertFalse(service.canStream(TaskType.COMPLETION)); + assertFalse(service.canStream(TaskType.ANY)); + } + } + + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map chunkingSettings, + Map secretSettings + ) { + var requestConfigMap = getRequestConfigMap(serviceSettings, taskSettings, secretSettings); + requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return requestConfigMap; + } + + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>( + Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings) + ); + } + + private Map getRequestConfigMap(Map serviceSettings, Map secretSettings) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings)); + } + + private JinaAIService createJinaAIService() { + return new JinaAIService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java new file mode 100644 index 0000000000000..58455bb1f54ea --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java @@ -0,0 +1,168 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai.embeddings; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.hamcrest.MatcherAssert; + +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap; +import static org.hamcrest.Matchers.is; + +public class JinaAIEmbeddingsModelTests extends ESTestCase { + + public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreEmpty_AndInputTypeIsInvalid() { + var model = createModel("url", "api_key", null, null, "model"); + + var overriddenModel = JinaAIEmbeddingsModel.of(model, Map.of(), InputType.UNSPECIFIED); + MatcherAssert.assertThat(overriddenModel, is(model)); + } + + public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreNull_AndInputTypeIsInvalid() { + var model = createModel("url", "api_key", null, null, "model"); + + var overriddenModel = JinaAIEmbeddingsModel.of(model, null, InputType.UNSPECIFIED); + MatcherAssert.assertThat(overriddenModel, is(model)); + } + + public void testOverrideWith_SetsInputTypeToIngest_WhenTheFieldIsNullInModelTaskSettings_AndNullInRequestTaskSettings() { + var model = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings((InputType) null), null, null, "model"); + + var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.INGEST); + var expectedModel = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model"); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + + public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingStoredTaskSettings() { + var model = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model"); + + var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.SEARCH); + var expectedModel = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.SEARCH), null, null, "model"); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + + public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingRequestTaskSettings() { + var model = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings((InputType) null), null, null, "model"); + + var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(InputType.INGEST), InputType.SEARCH); + var expectedModel = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.SEARCH), null, null, "model"); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + + public void testOverrideWith_OverridesInputType_WithRequestTaskSettingsSearch_WhenRequestInputTypeIsInvalid() { + var model = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model"); + + var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(InputType.SEARCH), InputType.UNSPECIFIED); + var expectedModel = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.SEARCH), null, null, "model"); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + + public void testOverrideWith_DoesNotSetInputType_FromRequest_IfInputTypeIsInvalid() { + var model = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings((InputType) null), null, null, "model"); + + var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.UNSPECIFIED); + var expectedModel = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings((InputType) null), null, null, "model"); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + + public void testOverrideWith_DoesNotSetInputType_WhenRequestTaskSettingsIsNull_AndRequestInputTypeIsInvalid() { + var model = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model"); + + var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.UNSPECIFIED); + var expectedModel = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model"); + MatcherAssert.assertThat(overriddenModel, is(expectedModel)); + } + + public static JinaAIEmbeddingsModel createModel(String url, String apiKey, @Nullable Integer tokenLimit, @Nullable String model) { + return createModel(url, apiKey, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, null, model); + } + + public static JinaAIEmbeddingsModel createModel( + String url, + String apiKey, + @Nullable Integer tokenLimit, + @Nullable Integer dimensions, + String model + ) { + return createModel(url, apiKey, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, dimensions, model); + } + + public static JinaAIEmbeddingsModel createModel( + String url, + String apiKey, + JinaAIEmbeddingsTaskSettings taskSettings, + ChunkingSettings chunkingSettings, + @Nullable Integer tokenLimit, + @Nullable Integer dimensions, + String model + ) { + return new JinaAIEmbeddingsModel( + "id", + "service", + new JinaAIEmbeddingsServiceSettings( + new JinaAIServiceSettings(url, model, null), + SimilarityMeasure.DOT_PRODUCT, + dimensions, + tokenLimit + ), + taskSettings, + chunkingSettings, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static JinaAIEmbeddingsModel createModel( + String url, + String apiKey, + JinaAIEmbeddingsTaskSettings taskSettings, + @Nullable Integer tokenLimit, + @Nullable Integer dimensions, + String model + ) { + return new JinaAIEmbeddingsModel( + "id", + "service", + new JinaAIEmbeddingsServiceSettings( + new JinaAIServiceSettings(url, model, null), + SimilarityMeasure.DOT_PRODUCT, + dimensions, + tokenLimit + ), + taskSettings, + null, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static JinaAIEmbeddingsModel createModel( + String url, + String apiKey, + JinaAIEmbeddingsTaskSettings taskSettings, + @Nullable Integer tokenLimit, + @Nullable Integer dimensions, + String model, + @Nullable SimilarityMeasure similarityMeasure + ) { + return new JinaAIEmbeddingsModel( + "id", + "service", + new JinaAIEmbeddingsServiceSettings(new JinaAIServiceSettings(url, model, null), similarityMeasure, dimensions, tokenLimit), + taskSettings, + null, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettingsTests.java new file mode 100644 index 0000000000000..6847d249a57a0 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettingsTests.java @@ -0,0 +1,187 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class JinaAIEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase { + public static JinaAIEmbeddingsServiceSettings createRandom() { + SimilarityMeasure similarityMeasure = null; + Integer dims = null; + similarityMeasure = SimilarityMeasure.DOT_PRODUCT; + dims = 1024; + Integer maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256); + + var commonSettings = JinaAIServiceSettingsTests.createRandom(); + + return new JinaAIEmbeddingsServiceSettings(commonSettings, similarityMeasure, dims, maxInputTokens); + } + + public void testFromMap() { + var url = "https://www.abc.com"; + var similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + var dims = 1536; + var maxInputTokens = 512; + var model = "model"; + var serviceSettings = JinaAIEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + url, + ServiceFields.SIMILARITY, + similarity, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + JinaAIServiceSettings.MODEL_ID, + model + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + MatcherAssert.assertThat( + serviceSettings, + is( + new JinaAIEmbeddingsServiceSettings( + new JinaAIServiceSettings(ServiceUtils.createUri(url), model, null), + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens + ) + ) + ); + } + + public void testFromMap_WithModelId() { + var url = "https://www.abc.com"; + var similarity = SimilarityMeasure.DOT_PRODUCT.toString(); + var dims = 1536; + var maxInputTokens = 512; + var model = "model"; + var serviceSettings = JinaAIEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + url, + ServiceFields.SIMILARITY, + similarity, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + JinaAIServiceSettings.MODEL_ID, + model + ) + ), + ConfigurationParseContext.REQUEST + ); + + MatcherAssert.assertThat( + serviceSettings, + is( + new JinaAIEmbeddingsServiceSettings( + new JinaAIServiceSettings(ServiceUtils.createUri(url), model, null), + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens + ) + ) + ); + } + + public void testFromMap_InvalidSimilarity_ThrowsError() { + var similarity = "by_size"; + var thrownException = expectThrows( + ValidationException.class, + () -> JinaAIEmbeddingsServiceSettings.fromMap( + new HashMap<>(Map.of(JinaAIServiceSettings.MODEL_ID, "model", ServiceFields.SIMILARITY, similarity)), + ConfigurationParseContext.PERSISTENT + ) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is( + "Validation Failed: 1: [service_settings] Invalid value [by_size] received. [similarity] " + + "must be one of [cosine, dot_product, l2_norm];" + ) + ); + } + + public void testToXContent_WritesAllValues() throws IOException { + var serviceSettings = new JinaAIEmbeddingsServiceSettings( + new JinaAIServiceSettings("url", "model", new RateLimitSettings(3)), + SimilarityMeasure.COSINE, + 5, + 10 + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + assertThat(xContentResult, is(""" + {"url":"url","model_id":"model",""" + """ + "rate_limit":{"requests_per_minute":3},"similarity":"cosine","dimensions":5,"max_input_tokens":10}""")); + } + + @Override + protected Writeable.Reader instanceReader() { + return JinaAIEmbeddingsServiceSettings::new; + } + + @Override + protected JinaAIEmbeddingsServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected JinaAIEmbeddingsServiceSettings mutateInstance(JinaAIEmbeddingsServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, JinaAIEmbeddingsServiceSettingsTests::createRandom); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + entries.addAll(InferenceNamedWriteablesProvider.getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + + public static Map getServiceSettingsMap(@Nullable String url, String model) { + var map = new HashMap<>(JinaAIServiceSettingsTests.getServiceSettingsMap(url, model)); + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettingsTests.java new file mode 100644 index 0000000000000..8535381b43adc --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsTaskSettingsTests.java @@ -0,0 +1,193 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithoutUnspecified; +import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings.VALID_REQUEST_VALUES; +import static org.hamcrest.Matchers.is; + +public class JinaAIEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase { + + public static JinaAIEmbeddingsTaskSettings createRandom() { + var inputType = randomBoolean() ? randomWithoutUnspecified() : null; + + return new JinaAIEmbeddingsTaskSettings(inputType); + } + + public void testIsEmpty() { + var randomSettings = createRandom(); + var stringRep = Strings.toString(randomSettings); + assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); + } + + public void testUpdatedTaskSettings_NotUpdated_UseInitialSettings() { + var initialSettings = createRandom(); + var newSettings = new JinaAIEmbeddingsTaskSettings((InputType) null); + Map newSettingsMap = new HashMap<>(); + JinaAIEmbeddingsTaskSettings updatedSettings = (JinaAIEmbeddingsTaskSettings) initialSettings.updatedTaskSettings( + Collections.unmodifiableMap(newSettingsMap) + ); + assertEquals(initialSettings.getInputType(), updatedSettings.getInputType()); + } + + public void testUpdatedTaskSettings_Updated_UseNewSettings() { + var initialSettings = createRandom(); + var newSettings = new JinaAIEmbeddingsTaskSettings(randomWithoutUnspecified()); + Map newSettingsMap = new HashMap<>(); + newSettingsMap.put(JinaAIEmbeddingsTaskSettings.INPUT_TYPE, newSettings.getInputType().toString()); + JinaAIEmbeddingsTaskSettings updatedSettings = (JinaAIEmbeddingsTaskSettings) initialSettings.updatedTaskSettings( + Collections.unmodifiableMap(newSettingsMap) + ); + assertEquals(newSettings.getInputType(), updatedSettings.getInputType()); + } + + public void testFromMap_CreatesEmptySettings_WhenAllFieldsAreNull() { + MatcherAssert.assertThat( + JinaAIEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of())), + is(new JinaAIEmbeddingsTaskSettings((InputType) null)) + ); + } + + public void testFromMap_CreatesEmptySettings_WhenMapIsNull() { + MatcherAssert.assertThat(JinaAIEmbeddingsTaskSettings.fromMap(null), is(new JinaAIEmbeddingsTaskSettings((InputType) null))); + } + + public void testFromMap_CreatesSettings_WhenAllFieldsOfSettingsArePresent() { + MatcherAssert.assertThat( + JinaAIEmbeddingsTaskSettings.fromMap( + new HashMap<>(Map.of(JinaAIEmbeddingsTaskSettings.INPUT_TYPE, InputType.INGEST.toString())) + ), + is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST)) + ); + } + + public void testFromMap_ReturnsFailure_WhenInputTypeIsInvalid() { + var exception = expectThrows( + ValidationException.class, + () -> JinaAIEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(JinaAIEmbeddingsTaskSettings.INPUT_TYPE, "abc"))) + ); + + MatcherAssert.assertThat( + exception.getMessage(), + is( + Strings.format( + "Validation Failed: 1: [task_settings] Invalid value [abc] received. [input_type] must be one of [%s];", + getValidValuesSortedAndCombined(VALID_REQUEST_VALUES) + ) + ) + ); + } + + public void testFromMap_ReturnsFailure_WhenInputTypeIsUnspecified() { + var exception = expectThrows( + ValidationException.class, + () -> JinaAIEmbeddingsTaskSettings.fromMap( + new HashMap<>(Map.of(JinaAIEmbeddingsTaskSettings.INPUT_TYPE, InputType.UNSPECIFIED.toString())) + ) + ); + + MatcherAssert.assertThat( + exception.getMessage(), + is( + Strings.format( + "Validation Failed: 1: [task_settings] Invalid value [unspecified] received. [input_type] must be one of [%s];", + getValidValuesSortedAndCombined(VALID_REQUEST_VALUES) + ) + ) + ); + } + + private static > String getValidValuesSortedAndCombined(EnumSet validValues) { + var validValuesAsStrings = validValues.stream().map(value -> value.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new); + Arrays.sort(validValuesAsStrings); + + return String.join(", ", validValuesAsStrings); + } + + public void testXContent_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() { + var thrownException = expectThrows(AssertionError.class, () -> new JinaAIEmbeddingsTaskSettings(InputType.UNSPECIFIED)); + MatcherAssert.assertThat(thrownException.getMessage(), is("received invalid input type value [unspecified]")); + } + + public void testOf_KeepsOriginalValuesWhenRequestSettingsAreNull_AndRequestInputTypeIsInvalid() { + var taskSettings = new JinaAIEmbeddingsTaskSettings(InputType.INGEST); + var overriddenTaskSettings = JinaAIEmbeddingsTaskSettings.of( + taskSettings, + JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, + InputType.UNSPECIFIED + ); + MatcherAssert.assertThat(overriddenTaskSettings, is(taskSettings)); + } + + public void testOf_UsesRequestTaskSettings() { + var taskSettings = new JinaAIEmbeddingsTaskSettings((InputType) null); + var overriddenTaskSettings = JinaAIEmbeddingsTaskSettings.of( + taskSettings, + new JinaAIEmbeddingsTaskSettings(InputType.INGEST), + InputType.UNSPECIFIED + ); + + MatcherAssert.assertThat(overriddenTaskSettings, is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST))); + } + + public void testOf_UsesRequestTaskSettings_AndRequestInputType() { + var taskSettings = new JinaAIEmbeddingsTaskSettings(InputType.SEARCH); + var overriddenTaskSettings = JinaAIEmbeddingsTaskSettings.of( + taskSettings, + new JinaAIEmbeddingsTaskSettings((InputType) null), + InputType.INGEST + ); + + MatcherAssert.assertThat(overriddenTaskSettings, is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST))); + } + + @Override + protected Writeable.Reader instanceReader() { + return JinaAIEmbeddingsTaskSettings::new; + } + + @Override + protected JinaAIEmbeddingsTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected JinaAIEmbeddingsTaskSettings mutateInstance(JinaAIEmbeddingsTaskSettings instance) throws IOException { + return randomValueOtherThan(instance, JinaAIEmbeddingsTaskSettingsTests::createRandom); + } + + public static Map getTaskSettingsMapEmpty() { + return new HashMap<>(); + } + + public static Map getTaskSettingsMap(@Nullable InputType inputType) { + var map = new HashMap(); + + if (inputType != null) { + map.put(JinaAIEmbeddingsTaskSettings.INPUT_TYPE, inputType.toString()); + } + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankModelTests.java new file mode 100644 index 0000000000000..d6b3df5fd3717 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankModelTests.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai.rerank; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +public class JinaAIRerankModelTests { + + public static JinaAIRerankModel createModel(String apiKey, String modelId, @Nullable Integer topN) { + return new JinaAIRerankModel( + "id", + "service", + new JinaAIRerankServiceSettings(new JinaAIServiceSettings(ESTestCase.randomAlphaOfLength(10), modelId, null)), + new JinaAIRerankTaskSettings(topN, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static JinaAIRerankModel createModel(String modelId, @Nullable Integer topN) { + return new JinaAIRerankModel( + "id", + "service", + new JinaAIRerankServiceSettings(new JinaAIServiceSettings(ESTestCase.randomAlphaOfLength(10), modelId, null)), + new JinaAIRerankTaskSettings(topN, null), + new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8)) + ); + } + + public static JinaAIRerankModel createModel(String modelId, @Nullable Integer topN, Boolean returnDocuments) { + return new JinaAIRerankModel( + "id", + "service", + new JinaAIRerankServiceSettings(new JinaAIServiceSettings(ESTestCase.randomAlphaOfLength(10), modelId, null)), + new JinaAIRerankTaskSettings(topN, returnDocuments), + new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8)) + ); + } + + public static JinaAIRerankModel createModel(String url, String modelId, @Nullable Integer topN, Boolean returnDocuments) { + return new JinaAIRerankModel( + "id", + "service", + new JinaAIRerankServiceSettings(new JinaAIServiceSettings(url, modelId, null)), + new JinaAIRerankTaskSettings(topN, returnDocuments), + new DefaultSecretSettings(ESTestCase.randomSecureStringOfLength(8)) + ); + } + + public static JinaAIRerankModel createModel( + String url, + String apiKey, + String modelId, + @Nullable Integer topN, + Boolean returnDocuments + ) { + return new JinaAIRerankModel( + "id", + "service", + new JinaAIRerankServiceSettings(new JinaAIServiceSettings(url, modelId, null)), + new JinaAIRerankTaskSettings(topN, returnDocuments), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettingsTests.java new file mode 100644 index 0000000000000..47f67bd8cefb8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankServiceSettingsTests.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings; +import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; + +public class JinaAIRerankServiceSettingsTests extends AbstractBWCWireSerializationTestCase { + public static JinaAIRerankServiceSettings createRandom() { + return new JinaAIRerankServiceSettings( + new JinaAIServiceSettings( + randomFrom(new String[] { null, Strings.format("http://%s.com", randomAlphaOfLength(8)) }), + randomAlphaOfLength(10), + RateLimitSettingsTests.createRandom() + ) + ); + } + + public void testToXContent_WritesAllValues() throws IOException { + var url = "http://www.abc.com"; + var model = "model"; + + var serviceSettings = new JinaAIRerankServiceSettings(new JinaAIServiceSettings(url, model, null)); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "url":"http://www.abc.com", + "model_id":"model", + "rate_limit": { + "requests_per_minute": 2000 + } + } + """)); + } + + @Override + protected Writeable.Reader instanceReader() { + return JinaAIRerankServiceSettings::new; + } + + @Override + protected JinaAIRerankServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected JinaAIRerankServiceSettings mutateInstance(JinaAIRerankServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, JinaAIRerankServiceSettingsTests::createRandom); + } + + @Override + protected JinaAIRerankServiceSettings mutateInstanceForVersion(JinaAIRerankServiceSettings instance, TransportVersion version) { + return instance; + } + + public static Map getServiceSettingsMap(@Nullable String url, @Nullable String model) { + return new HashMap<>(JinaAIServiceSettingsTests.getServiceSettingsMap(url, model)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankTaskSettingsTests.java new file mode 100644 index 0000000000000..fa70248d01513 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankTaskSettingsTests.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.jinaai.rerank; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; + +public class JinaAIRerankTaskSettingsTests extends AbstractWireSerializingTestCase { + + public static JinaAIRerankTaskSettings createRandom() { + var returnDocuments = randomBoolean() ? randomBoolean() : null; + var topNDocsOnly = randomBoolean() ? randomIntBetween(1, 10) : null; + + return new JinaAIRerankTaskSettings(topNDocsOnly, returnDocuments); + } + + public void testFromMap_WithValidValues_ReturnsSettings() { + Map taskMap = Map.of(JinaAIRerankTaskSettings.RETURN_DOCUMENTS, true, JinaAIRerankTaskSettings.TOP_N_DOCS_ONLY, 5); + var settings = JinaAIRerankTaskSettings.fromMap(new HashMap<>(taskMap)); + assertTrue(settings.getReturnDocuments()); + assertEquals(5, settings.getTopNDocumentsOnly().intValue()); + } + + public void testFromMap_WithNullValues_ReturnsSettingsWithNulls() { + var settings = JinaAIRerankTaskSettings.fromMap(Map.of()); + assertNull(settings.getReturnDocuments()); + assertNull(settings.getTopNDocumentsOnly()); + } + + public void testFromMap_WithInvalidReturnDocuments_ThrowsValidationException() { + Map taskMap = Map.of( + JinaAIRerankTaskSettings.RETURN_DOCUMENTS, + "invalid", + JinaAIRerankTaskSettings.TOP_N_DOCS_ONLY, + 5 + ); + var thrownException = expectThrows(ValidationException.class, () -> JinaAIRerankTaskSettings.fromMap(new HashMap<>(taskMap))); + assertThat(thrownException.getMessage(), containsString("field [return_documents] is not of the expected type")); + } + + public void testFromMap_WithInvalidTopNDocsOnly_ThrowsValidationException() { + Map taskMap = Map.of( + JinaAIRerankTaskSettings.RETURN_DOCUMENTS, + true, + JinaAIRerankTaskSettings.TOP_N_DOCS_ONLY, + "invalid" + ); + var thrownException = expectThrows(ValidationException.class, () -> JinaAIRerankTaskSettings.fromMap(new HashMap<>(taskMap))); + assertThat(thrownException.getMessage(), containsString("field [top_n] is not of the expected type")); + } + + public void testUpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() { + var initialSettings = new JinaAIRerankTaskSettings(5, true); + JinaAIRerankTaskSettings updatedSettings = (JinaAIRerankTaskSettings) initialSettings.updatedTaskSettings(Map.of()); + assertEquals(initialSettings, updatedSettings); + } + + public void testUpdatedTaskSettings_WithNewReturnDocuments_ReturnsUpdatedSettings() { + var initialSettings = new JinaAIRerankTaskSettings(5, true); + Map newSettings = Map.of(JinaAIRerankTaskSettings.RETURN_DOCUMENTS, false); + JinaAIRerankTaskSettings updatedSettings = (JinaAIRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertFalse(updatedSettings.getReturnDocuments()); + assertEquals(initialSettings.getTopNDocumentsOnly(), updatedSettings.getTopNDocumentsOnly()); + } + + public void testUpdatedTaskSettings_WithNewTopNDocsOnly_ReturnsUpdatedSettings() { + var initialSettings = new JinaAIRerankTaskSettings(5, true); + Map newSettings = Map.of(JinaAIRerankTaskSettings.TOP_N_DOCS_ONLY, 7); + JinaAIRerankTaskSettings updatedSettings = (JinaAIRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertEquals(7, updatedSettings.getTopNDocumentsOnly().intValue()); + assertEquals(initialSettings.getReturnDocuments(), updatedSettings.getReturnDocuments()); + } + + public void testUpdatedTaskSettings_WithMultipleNewValues_ReturnsUpdatedSettings() { + var initialSettings = new JinaAIRerankTaskSettings(5, true); + Map newSettings = Map.of( + JinaAIRerankTaskSettings.RETURN_DOCUMENTS, + false, + JinaAIRerankTaskSettings.TOP_N_DOCS_ONLY, + 7 + ); + JinaAIRerankTaskSettings updatedSettings = (JinaAIRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings); + assertFalse(updatedSettings.getReturnDocuments()); + assertEquals(7, updatedSettings.getTopNDocumentsOnly().intValue()); + } + + @Override + protected Writeable.Reader instanceReader() { + return JinaAIRerankTaskSettings::new; + } + + @Override + protected JinaAIRerankTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected JinaAIRerankTaskSettings mutateInstance(JinaAIRerankTaskSettings instance) throws IOException { + return randomValueOtherThan(instance, JinaAIRerankTaskSettingsTests::createRandom); + } + + public static Map getTaskSettingsMapEmpty() { + return new HashMap<>(); + } + + public static Map getTaskSettingsMap(@Nullable Integer topNDocumentsOnly, Boolean returnDocuments) { + var map = new HashMap(); + + if (topNDocumentsOnly != null) { + map.put(JinaAIRerankTaskSettings.TOP_N_DOCS_ONLY, topNDocumentsOnly.toString()); + } + + if (returnDocuments != null) { + map.put(JinaAIRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments.toString()); + } + + return map; + } +}