Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/118652.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 118652
summary: Add Jina AI API to do inference for Embedding and Rerank models
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
Expand Down Expand Up @@ -132,6 +137,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addAmazonBedrockNamedWriteables(namedWriteables);
addEisNamedWriteables(namedWriteables);
addAlibabaCloudSearchNamedWriteables(namedWriteables);
addJinaAINamedWriteables(namedWriteables);

addUnifiedNamedWriteables(namedWriteables);

Expand Down Expand Up @@ -569,6 +575,28 @@ private static void addAlibabaCloudSearchNamedWriteables(List<NamedWriteableRegi

}

private static void addJinaAINamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, JinaAIServiceSettings.NAME, JinaAIServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
JinaAIEmbeddingsServiceSettings.NAME,
JinaAIEmbeddingsServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, JinaAIEmbeddingsTaskSettings.NAME, JinaAIEmbeddingsTaskSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, JinaAIRerankServiceSettings.NAME, JinaAIRerankServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, JinaAIRerankTaskSettings.NAME, JinaAIRerankTaskSettings::new)
);
}

private static void addEisNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
Expand Down Expand Up @@ -289,6 +290,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()),
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()),
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.jinaai;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.JinaAIEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.JinaAIRerankRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel;

import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;

/**
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the jinaai model type.
*/
public class JinaAIActionCreator implements JinaAIActionVisitor {
private final Sender sender;
private final ServiceComponents serviceComponents;

public JinaAIActionCreator(Sender sender, ServiceComponents serviceComponents) {
this.sender = Objects.requireNonNull(sender);
this.serviceComponents = Objects.requireNonNull(serviceComponents);
}

@Override
public ExecutableAction create(JinaAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
var overriddenModel = JinaAIEmbeddingsModel.of(model, taskSettings, inputType);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
overriddenModel.getServiceSettings().getCommonSettings().uri(),
"JinaAI embeddings"
);
var requestCreator = JinaAIEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}

@Override
public ExecutableAction create(JinaAIRerankModel model, Map<String, Object> taskSettings) {
var overriddenModel = JinaAIRerankModel.of(model, taskSettings);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
overriddenModel.getServiceSettings().getCommonSettings().uri(),
"JinaAI rerank"
);
var requestCreator = JinaAIRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.jinaai;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel;

import java.util.Map;

public interface JinaAIActionVisitor {
ExecutableAction create(JinaAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);

ExecutableAction create(JinaAIRerankModel model, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,9 @@ private static byte[] limitBody(ByteSizeValue maxResponseSize, HttpResponse resp
public boolean isBodyEmpty() {
return body().length == 0;
}

public boolean isSuccessfulResponse() {
var code = response.getStatusLine().getStatusCode();
return code >= 200 && code < 300;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.jinaai.JinaAIResponseHandler;
import org.elasticsearch.xpack.inference.external.request.jinaai.JinaAIEmbeddingsRequest;
import org.elasticsearch.xpack.inference.external.response.jinaai.JinaAIEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel;

import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

public class JinaAIEmbeddingsRequestManager extends JinaAIRequestManager {
private static final Logger logger = LogManager.getLogger(JinaAIEmbeddingsRequestManager.class);
private static final ResponseHandler HANDLER = createEmbeddingsHandler();

private static ResponseHandler createEmbeddingsHandler() {
return new JinaAIResponseHandler("jinaai text embedding", JinaAIEmbeddingsResponseEntity::fromResponse);
}

public static JinaAIEmbeddingsRequestManager of(JinaAIEmbeddingsModel model, ThreadPool threadPool) {
return new JinaAIEmbeddingsRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
}

private final JinaAIEmbeddingsModel model;

private JinaAIEmbeddingsRequestManager(JinaAIEmbeddingsModel model, ThreadPool threadPool) {
super(threadPool, model);
this.model = Objects.requireNonNull(model);
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
JinaAIEmbeddingsRequest request = new JinaAIEmbeddingsRequest(docsInput, model);

execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIModel;

import java.util.Objects;

abstract class JinaAIRequestManager extends BaseRequestManager {

protected JinaAIRequestManager(ThreadPool threadPool, JinaAIModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
}

record RateLimitGrouping(int apiKeyHash) {
public static RateLimitGrouping of(JinaAIModel model) {
Objects.requireNonNull(model);

return new RateLimitGrouping(model.apiKey().hashCode());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.jinaai.JinaAIResponseHandler;
import org.elasticsearch.xpack.inference.external.request.jinaai.JinaAIRerankRequest;
import org.elasticsearch.xpack.inference.external.response.jinaai.JinaAIRerankResponseEntity;
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel;

import java.util.Objects;
import java.util.function.Supplier;

public class JinaAIRerankRequestManager extends JinaAIRequestManager {
private static final Logger logger = LogManager.getLogger(JinaAIRerankRequestManager.class);
private static final ResponseHandler HANDLER = createJinaAIResponseHandler();

private static ResponseHandler createJinaAIResponseHandler() {
return new JinaAIResponseHandler("jinaai rerank", (request, response) -> JinaAIRerankResponseEntity.fromResponse(response));
}

public static JinaAIRerankRequestManager of(JinaAIRerankModel model, ThreadPool threadPool) {
return new JinaAIRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
}

private final JinaAIRerankModel model;

private JinaAIRerankRequestManager(JinaAIRerankModel model, ThreadPool threadPool) {
super(threadPool, model);
this.model = model;
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
JinaAIRerankRequest request = new JinaAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);

execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.jinaai;

import org.elasticsearch.common.CheckedSupplier;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIModel;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri;

public record JinaAIAccount(URI uri, SecureString apiKey) {

public static JinaAIAccount of(JinaAIModel model, CheckedSupplier<URI, URISyntaxException> uriBuilder) {
var uri = buildUri(model.uri(), "JinaAI", uriBuilder);

return new JinaAIAccount(uri, model.apiKey());
}

public JinaAIAccount {
Objects.requireNonNull(uri);
Objects.requireNonNull(apiKey);
}
}
Loading