Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/128694.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 128694
summary: "Adding Google VertexAI completion integration"
area: Inference
type: enhancement
issues: [ ]
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {

public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(14));
assertThat(services.size(), equalTo(15));

var providers = providers(services);

Expand All @@ -155,6 +155,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"streaming_completion_test_service",
"hugging_face",
"amazon_sagemaker",
"googlevertexai",
"mistral"
).toArray()
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public abstract class GoogleVertexAiModel extends RateLimitGroupingModel {

private final GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings;

protected URI uri;
protected URI nonStreamingUri;

public GoogleVertexAiModel(
ModelConfigurations configurations,
Expand All @@ -39,14 +39,14 @@ public GoogleVertexAiModel(
public GoogleVertexAiModel(GoogleVertexAiModel model, ServiceSettings serviceSettings) {
super(model, serviceSettings);

uri = model.uri();
nonStreamingUri = model.nonStreamingUri();
rateLimitServiceSettings = model.rateLimitServiceSettings();
}

public GoogleVertexAiModel(GoogleVertexAiModel model, TaskSettings taskSettings) {
super(model, taskSettings);

uri = model.uri();
nonStreamingUri = model.nonStreamingUri();
rateLimitServiceSettings = model.rateLimitServiceSettings();
}

Expand All @@ -56,17 +56,8 @@ public GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings() {
return rateLimitServiceSettings;
}

public URI uri() {
return uri;
}

@Override
public int rateLimitGroupingHash() {
// In VertexAI rate limiting is scoped to the project, region and model. URI already has this information so we are using that.
// API Key does not affect the quota
// https://ai.google.dev/gemini-api/docs/rate-limits
// https://cloud.google.com/vertex-ai/docs/quotas
return Objects.hash(uri);
public URI nonStreamingUri() {
return nonStreamingUri;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@

package org.elasticsearch.xpack.inference.services.googlevertexai;

import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiErrorResponseEntity;

import java.util.concurrent.Flow;
import java.util.function.Function;

import static org.elasticsearch.core.Strings.format;
Expand Down Expand Up @@ -66,4 +71,14 @@ protected void checkForFailureStatusCode(Request request, HttpResult result) thr
private static String resourceNotFoundError(Request request) {
return format("Resource not found at [%s]", request.getURI());
}

@Override
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
var googleVertexAiProcessor = new GoogleVertexAiStreamingProcessor();

flow.subscribe(serverSentEventProcessor);
serverSentEventProcessor.subscribe(googleVertexAiProcessor);
return new StreamingChatCompletionResults(googleVertexAiProcessor);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,9 @@ public static Map<String, SettingsConfiguration> get() {
var configurationMap = new HashMap<String, SettingsConfiguration>();
configurationMap.put(
SERVICE_ACCOUNT_JSON,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.CHAT_COMPLETION))
.setDescription("API Key for the provider you're connecting to.")
new SettingsConfiguration.Builder(
EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.CHAT_COMPLETION, TaskType.COMPLETION)
).setDescription("API Key for the provider you're connecting to.")
.setLabel("Credentials JSON")
.setRequired(true)
.setSensitive(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ public class GoogleVertexAiService extends SenderService {
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
TaskType.TEXT_EMBEDDING,
TaskType.RERANK,
TaskType.CHAT_COMPLETION
TaskType.CHAT_COMPLETION,
TaskType.COMPLETION
);

public static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
Expand All @@ -87,13 +88,13 @@ public class GoogleVertexAiService extends SenderService {
InputType.INTERNAL_SEARCH
);

private final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
public static final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
"Google VertexAI chat completion"
);

@Override
public Set<TaskType> supportedStreamingTasks() {
return EnumSet.of(TaskType.CHAT_COMPLETION);
return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION);
}

public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
Expand Down Expand Up @@ -358,7 +359,7 @@ private static GoogleVertexAiModel createModel(
context
);

case CHAT_COMPLETION -> new GoogleVertexAiChatCompletionModel(
case CHAT_COMPLETION, COMPLETION -> new GoogleVertexAiChatCompletionModel(
inferenceEntityId,
taskType,
NAME,
Expand Down Expand Up @@ -396,10 +397,11 @@ public static InferenceServiceConfiguration get() {

configurationMap.put(
LOCATION,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION)).setDescription(
"Please provide the GCP region where the Vertex AI API(s) is enabled. "
+ "For more information, refer to the {geminiVertexAIDocs}."
)
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.COMPLETION))
.setDescription(
"Please provide the GCP region where the Vertex AI API(s) is enabled. "
+ "For more information, refer to the {geminiVertexAIDocs}."
)
.setLabel("GCP Region")
.setRequired(true)
.setSensitive(false)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.googlevertexai;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;

import java.io.IOException;
import java.util.Deque;
import java.util.Objects;
import java.util.stream.Stream;

public class GoogleVertexAiStreamingProcessor extends DelegatingProcessor<Deque<ServerSentEvent>, InferenceServiceResults.Result> {

@Override
protected void next(Deque<ServerSentEvent> item) throws Exception {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
var results = parseEvent(item, GoogleVertexAiStreamingProcessor::parse, parserConfig);

if (results.isEmpty()) {
upstream().request(1);
} else {
downstream().onNext(new StreamingChatCompletionResults.Results(results));
}
}

public static Stream<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event) {
String data = event.data();
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, data)) {
var chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(jsonParser);

return chunk.choices()
.stream()
.map(choice -> choice.delta())
.filter(Objects::nonNull)
.map(delta -> delta.content())
.filter(content -> Strings.isNullOrEmpty(content) == false)
.map(StreamingChatCompletionResults.Result::new);

} catch (IOException e) {
throw new ElasticsearchStatusException(
"Failed to parse event from inference provider: {}",
RestStatus.INTERNAL_SERVER_ERROR,
e,
event
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity;

import java.nio.charset.StandardCharsets;
import java.util.Locale;
Expand All @@ -43,10 +43,8 @@ public class GoogleVertexAiUnifiedChatCompletionResponseHandler extends GoogleVe
private static final String ERROR_MESSAGE_FIELD = "message";
private static final String ERROR_STATUS_FIELD = "status";

private static final ResponseParser noopParseFunction = (a, b) -> null;

public GoogleVertexAiUnifiedChatCompletionResponseHandler(String requestType) {
super(requestType, noopParseFunction, GoogleVertexAiErrorResponse::fromResponse, true);
super(requestType, GoogleVertexAiCompletionResponseEntity::fromResponse, GoogleVertexAiErrorResponse::fromResponse, true);
}

@Override
Expand All @@ -64,6 +62,7 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpR
@Override
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
assert request.isStreaming() : "Only streaming requests support this format";

var responseStatusCode = result.response().getStatusLine().getStatusCode();
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
var restStatus = toRestStatus(responseStatusCode);
Expand Down Expand Up @@ -111,7 +110,7 @@ private static Exception buildMidStreamError(Request request, String message, Ex
}
}

private static class GoogleVertexAiErrorResponse extends ErrorResponse {
public static class GoogleVertexAiErrorResponse extends ErrorResponse {
private static final Logger logger = LogManager.getLogger(GoogleVertexAiErrorResponse.class);
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
"google_vertex_ai_error_wrapper",
Expand All @@ -138,7 +137,7 @@ private static class GoogleVertexAiErrorResponse extends ErrorResponse {
);
}

static ErrorResponse fromResponse(HttpResult response) {
public static ErrorResponse fromResponse(HttpResult response) {
try (
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, response.body())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRerankRequestManager;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiResponseHandler;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiUnifiedChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity;

import java.util.Map;
import java.util.Objects;
Expand All @@ -36,9 +38,13 @@ public class GoogleVertexAiActionCreator implements GoogleVertexAiActionVisitor

private final ServiceComponents serviceComponents;

static final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
"Google VertexAI chat completion"
static final ResponseHandler CHAT_COMPLETION_HANDLER = new GoogleVertexAiResponseHandler(
"Google VertexAI completion",
GoogleVertexAiCompletionResponseEntity::fromResponse,
GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse,
true
);

static final String USER_ROLE = "user";

public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceComponents) {
Expand Down Expand Up @@ -67,12 +73,12 @@ public ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Obje

@Override
public ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map<String, Object> taskSettings) {

var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);

var manager = new GenericRequestManager<>(
serviceComponents.threadPool(),
model,
COMPLETION_HANDLER,
CHAT_COMPLETION_HANDLER,
inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
ChatCompletionInput.class
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ public interface GoogleVertexAiActionVisitor {
ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Object> taskSettings);

ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map<String, Object> taskSettings);

}
Loading