From c65132afb9ea4eff9beee07b9d497482ae7bbcd8 Mon Sep 17 00:00:00 2001 From: Leonardo Hoet <55866308+leo-hoet@users.noreply.github.com> Date: Mon, 9 Jun 2025 11:31:56 -0300 Subject: [PATCH 1/3] Implemented completion task for Google VertexAI (#128694) * Google Vertex AI completion model, response entity and tests * Fixed GoogleVertexAiServiceTest for Service configuration * Changelog * Removed downcasting and using `moveToFirstToken` * Create GoogleVertexAiChatCompletionResponseHandler for streaming and non streaming responses * Added unit tests * PR feedback * Removed googlevertexaicompletion model. Using just GoogleVertexAiChatCompletionModel for completion and chat completion * Renamed uri -> nonStreamingUri. Added streamingUri and getters in GoogleVertexAiChatCompletionModel * Moved rateLimitGroupHashing to subclasses of GoogleVertexAiModel * Fixed rate limit has of GoogleVertexAiRerankModel and refactored uri for GoogleVertexAiUnifiedChatCompletionRequest --------- Co-authored-by: lhoet-google Co-authored-by: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com> --- docs/changelog/128694.yaml | 5 + .../googlevertexai/GoogleVertexAiModel.java | 19 +-- .../GoogleVertexAiResponseHandler.java | 15 +++ .../GoogleVertexAiSecretSettings.java | 5 +- .../googlevertexai/GoogleVertexAiService.java | 18 +-- .../GoogleVertexAiStreamingProcessor.java | 64 ++++++++++ ...iUnifiedChatCompletionResponseHandler.java | 11 +- .../action/GoogleVertexAiActionCreator.java | 14 ++- .../action/GoogleVertexAiActionVisitor.java | 1 + .../GoogleVertexAiChatCompletionModel.java | 50 +++++++- .../GoogleVertexAiEmbeddingsModel.java | 18 ++- .../GoogleVertexAiEmbeddingsRequest.java | 4 +- .../request/GoogleVertexAiRerankRequest.java | 4 +- ...eVertexAiUnifiedChatCompletionRequest.java | 6 +- .../request/GoogleVertexAiUtils.java | 2 + .../rerank/GoogleVertexAiRerankModel.java | 18 ++- ...oogleVertexAiCompletionResponseEntity.java | 103 +++++++++++++++ .../GoogleVertexAiServiceTests.java | 12 +- ...GoogleVertexAiStreamingProcessorTests.java | 119 ++++++++++++++++++ ...texAiUnifiedChatCompletionActionTests.java | 2 +- ...oogleVertexAiChatCompletionModelTests.java | 4 +- .../GoogleVertexAiEmbeddingsModelTests.java | 2 +- ...VertexAiCompletionResponseEntityTests.java | 80 ++++++++++++ 23 files changed, 520 insertions(+), 56 deletions(-) create mode 100644 docs/changelog/128694.yaml create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiStreamingProcessor.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiCompletionResponseEntity.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiStreamingProcessorTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiCompletionResponseEntityTests.java diff --git a/docs/changelog/128694.yaml b/docs/changelog/128694.yaml new file mode 100644 index 0000000000000..031bec11899e5 --- /dev/null +++ b/docs/changelog/128694.yaml @@ -0,0 +1,5 @@ +pr: 128694 +summary: "Adding Google VertexAI completion integration" +area: Inference +type: enhancement +issues: [ ] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiModel.java index 60cd2faa7155b..0ba69b2a34414 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiModel.java @@ -24,7 +24,7 @@ public abstract class GoogleVertexAiModel extends RateLimitGroupingModel { private final GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings; - protected URI uri; + protected URI nonStreamingUri; public GoogleVertexAiModel( ModelConfigurations configurations, @@ -39,14 +39,14 @@ public GoogleVertexAiModel( public GoogleVertexAiModel(GoogleVertexAiModel model, ServiceSettings serviceSettings) { super(model, serviceSettings); - uri = model.uri(); + nonStreamingUri = model.nonStreamingUri(); rateLimitServiceSettings = model.rateLimitServiceSettings(); } public GoogleVertexAiModel(GoogleVertexAiModel model, TaskSettings taskSettings) { super(model, taskSettings); - uri = model.uri(); + nonStreamingUri = model.nonStreamingUri(); rateLimitServiceSettings = model.rateLimitServiceSettings(); } @@ -56,17 +56,8 @@ public GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings() { return rateLimitServiceSettings; } - public URI uri() { - return uri; - } - - @Override - public int rateLimitGroupingHash() { - // In VertexAI rate limiting is scoped to the project, region and model. URI already has this information so we are using that. - // API Key does not affect the quota - // https://ai.google.dev/gemini-api/docs/rate-limits - // https://cloud.google.com/vertex-ai/docs/quotas - return Objects.hash(uri); + public URI nonStreamingUri() { + return nonStreamingUri; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiResponseHandler.java index 9adefd19ef6d5..bc001f85d5431 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiResponseHandler.java @@ -7,14 +7,19 @@ package org.elasticsearch.xpack.inference.services.googlevertexai; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; import org.elasticsearch.xpack.inference.external.http.retry.RetryException; import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiErrorResponseEntity; +import java.util.concurrent.Flow; import java.util.function.Function; import static org.elasticsearch.core.Strings.format; @@ -66,4 +71,14 @@ protected void checkForFailureStatusCode(Request request, HttpResult result) thr private static String resourceNotFoundError(Request request) { return format("Resource not found at [%s]", request.getURI()); } + + @Override + public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { + var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); + var googleVertexAiProcessor = new GoogleVertexAiStreamingProcessor(); + + flow.subscribe(serverSentEventProcessor); + serverSentEventProcessor.subscribe(googleVertexAiProcessor); + return new StreamingChatCompletionResults(googleVertexAiProcessor); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java index 1abf1db642932..b97fc1e483a92 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java @@ -124,8 +124,9 @@ public static Map get() { var configurationMap = new HashMap(); configurationMap.put( SERVICE_ACCOUNT_JSON, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.CHAT_COMPLETION)) - .setDescription("API Key for the provider you're connecting to.") + new SettingsConfiguration.Builder( + EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.CHAT_COMPLETION, TaskType.COMPLETION) + ).setDescription("API Key for the provider you're connecting to.") .setLabel("Credentials JSON") .setRequired(true) .setSensitive(true) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index dc91e01322e6e..3b59e999125e5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -75,7 +75,8 @@ public class GoogleVertexAiService extends SenderService { private static final EnumSet supportedTaskTypes = EnumSet.of( TaskType.TEXT_EMBEDDING, TaskType.RERANK, - TaskType.CHAT_COMPLETION + TaskType.CHAT_COMPLETION, + TaskType.COMPLETION ); public static final EnumSet VALID_INPUT_TYPE_VALUES = EnumSet.of( @@ -87,13 +88,13 @@ public class GoogleVertexAiService extends SenderService { InputType.INTERNAL_SEARCH ); - private final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler( + public static final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler( "Google VertexAI chat completion" ); @Override public Set supportedStreamingTasks() { - return EnumSet.of(TaskType.CHAT_COMPLETION); + return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION); } public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { @@ -358,7 +359,7 @@ private static GoogleVertexAiModel createModel( context ); - case CHAT_COMPLETION -> new GoogleVertexAiChatCompletionModel( + case CHAT_COMPLETION, COMPLETION -> new GoogleVertexAiChatCompletionModel( inferenceEntityId, taskType, NAME, @@ -396,10 +397,11 @@ public static InferenceServiceConfiguration get() { configurationMap.put( LOCATION, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION)).setDescription( - "Please provide the GCP region where the Vertex AI API(s) is enabled. " - + "For more information, refer to the {geminiVertexAIDocs}." - ) + new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.COMPLETION)) + .setDescription( + "Please provide the GCP region where the Vertex AI API(s) is enabled. " + + "For more information, refer to the {geminiVertexAIDocs}." + ) .setLabel("GCP Region") .setRequired(true) .setSensitive(false) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiStreamingProcessor.java new file mode 100644 index 0000000000000..05fc9216c8916 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiStreamingProcessor.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; +import org.elasticsearch.xpack.inference.common.DelegatingProcessor; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; + +import java.io.IOException; +import java.util.Deque; +import java.util.Objects; +import java.util.stream.Stream; + +public class GoogleVertexAiStreamingProcessor extends DelegatingProcessor, InferenceServiceResults.Result> { + + @Override + protected void next(Deque item) throws Exception { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + var results = parseEvent(item, GoogleVertexAiStreamingProcessor::parse, parserConfig); + + if (results.isEmpty()) { + upstream().request(1); + } else { + downstream().onNext(new StreamingChatCompletionResults.Results(results)); + } + } + + public static Stream parse(XContentParserConfiguration parserConfig, ServerSentEvent event) { + String data = event.data(); + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, data)) { + var chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(jsonParser); + + return chunk.choices() + .stream() + .map(choice -> choice.delta()) + .filter(Objects::nonNull) + .map(delta -> delta.content()) + .filter(content -> Strings.isNullOrEmpty(content) == false) + .map(StreamingChatCompletionResults.Result::new); + + } catch (IOException e) { + throw new ElasticsearchStatusException( + "Failed to parse event from inference provider: {}", + RestStatus.INTERNAL_SERVER_ERROR, + e, + event + ); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java index 8c355c9f67f18..9e6fdb6eb8bb5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java @@ -23,10 +23,10 @@ import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; +import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity; import java.nio.charset.StandardCharsets; import java.util.Locale; @@ -43,10 +43,8 @@ public class GoogleVertexAiUnifiedChatCompletionResponseHandler extends GoogleVe private static final String ERROR_MESSAGE_FIELD = "message"; private static final String ERROR_STATUS_FIELD = "status"; - private static final ResponseParser noopParseFunction = (a, b) -> null; - public GoogleVertexAiUnifiedChatCompletionResponseHandler(String requestType) { - super(requestType, noopParseFunction, GoogleVertexAiErrorResponse::fromResponse, true); + super(requestType, GoogleVertexAiCompletionResponseEntity::fromResponse, GoogleVertexAiErrorResponse::fromResponse, true); } @Override @@ -64,6 +62,7 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher, Void> ERROR_PARSER = new ConstructingObjectParser<>( "google_vertex_ai_error_wrapper", @@ -138,7 +137,7 @@ private static class GoogleVertexAiErrorResponse extends ErrorResponse { ); } - static ErrorResponse fromResponse(HttpResult response) { + public static ErrorResponse fromResponse(HttpResult response) { try ( XContentParser parser = XContentFactory.xContent(XContentType.JSON) .createParser(XContentParserConfiguration.EMPTY, response.body()) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java index 2aa42a8ae69c2..80d82df1cac26 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java @@ -18,11 +18,13 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiEmbeddingsRequestManager; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRerankRequestManager; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiResponseHandler; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiUnifiedChatCompletionResponseHandler; import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel; import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel; +import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity; import java.util.Map; import java.util.Objects; @@ -36,9 +38,13 @@ public class GoogleVertexAiActionCreator implements GoogleVertexAiActionVisitor private final ServiceComponents serviceComponents; - static final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler( - "Google VertexAI chat completion" + static final ResponseHandler CHAT_COMPLETION_HANDLER = new GoogleVertexAiResponseHandler( + "Google VertexAI completion", + GoogleVertexAiCompletionResponseEntity::fromResponse, + GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse, + true ); + static final String USER_ROLE = "user"; public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceComponents) { @@ -67,12 +73,12 @@ public ExecutableAction create(GoogleVertexAiRerankModel model, Map taskSettings) { - var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX); + var manager = new GenericRequestManager<>( serviceComponents.threadPool(), model, - COMPLETION_HANDLER, + CHAT_COMPLETION_HANDLER, inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), ChatCompletionInput.class ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionVisitor.java index eaa71f2646efe..fd3691d2981b1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionVisitor.java @@ -21,4 +21,5 @@ public interface GoogleVertexAiActionVisitor { ExecutableAction create(GoogleVertexAiRerankModel model, Map taskSettings); ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map taskSettings); + } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java index 301d8f1075502..fdb4ed34d92db 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java @@ -30,6 +30,9 @@ import static org.elasticsearch.core.Strings.format; public class GoogleVertexAiChatCompletionModel extends GoogleVertexAiModel { + + private final URI streamingURI; + public GoogleVertexAiChatCompletionModel( String inferenceEntityId, TaskType taskType, @@ -63,7 +66,8 @@ public GoogleVertexAiChatCompletionModel( serviceSettings ); try { - this.uri = buildUri(serviceSettings.location(), serviceSettings.projectId(), serviceSettings.modelId()); + this.streamingURI = buildUriStreaming(serviceSettings.location(), serviceSettings.projectId(), serviceSettings.modelId()); + this.nonStreamingUri = buildUriNonStreaming(serviceSettings.location(), serviceSettings.projectId(), serviceSettings.modelId()); } catch (URISyntaxException e) { throw new RuntimeException(e); } @@ -114,7 +118,28 @@ public GoogleVertexAiSecretSettings getSecretSettings() { return (GoogleVertexAiSecretSettings) super.getSecretSettings(); } - public static URI buildUri(String location, String projectId, String model) throws URISyntaxException { + public URI streamingURI() { + return this.streamingURI; + } + + public static URI buildUriNonStreaming(String location, String projectId, String model) throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(format("%s%s", location, GoogleVertexAiUtils.GOOGLE_VERTEX_AI_HOST_SUFFIX)) + .setPathSegments( + GoogleVertexAiUtils.V1, + GoogleVertexAiUtils.PROJECTS, + projectId, + GoogleVertexAiUtils.LOCATIONS, + GoogleVertexAiUtils.GLOBAL, + GoogleVertexAiUtils.PUBLISHERS, + GoogleVertexAiUtils.PUBLISHER_GOOGLE, + GoogleVertexAiUtils.MODELS, + format("%s:%s", model, GoogleVertexAiUtils.GENERATE_CONTENT) + ) + .build(); + } + + public static URI buildUriStreaming(String location, String projectId, String model) throws URISyntaxException { return new URIBuilder().setScheme("https") .setHost(format("%s%s", location, GoogleVertexAiUtils.GOOGLE_VERTEX_AI_HOST_SUFFIX)) .setPathSegments( @@ -131,4 +156,25 @@ public static URI buildUri(String location, String projectId, String model) thro .setCustomQuery(GoogleVertexAiUtils.QUERY_PARAM_ALT_SSE) .build(); } + + @Override + public int rateLimitGroupingHash() { + // In VertexAI rate limiting is scoped to the project, region, model and endpoint. + // API Key does not affect the quota + // https://ai.google.dev/gemini-api/docs/rate-limits + // https://cloud.google.com/vertex-ai/docs/quotas + var projectId = getServiceSettings().projectId(); + var location = getServiceSettings().location(); + var modelId = getServiceSettings().modelId(); + + // Since we don't beforehand know which API is going to be used, we take a conservative approach and + // count both endpoint for the rate limit + return Objects.hash( + projectId, + location, + modelId, + GoogleVertexAiUtils.GENERATE_CONTENT, + GoogleVertexAiUtils.STREAM_GENERATE_CONTENT + ); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java index 2bf9349db83fb..66031f7e5475d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java @@ -23,6 +23,7 @@ import java.net.URI; import java.net.URISyntaxException; import java.util.Map; +import java.util.Objects; import static org.elasticsearch.core.Strings.format; @@ -81,7 +82,7 @@ public GoogleVertexAiEmbeddingsModel(GoogleVertexAiEmbeddingsModel model, Google serviceSettings ); try { - this.uri = buildUri(serviceSettings.location(), serviceSettings.projectId(), serviceSettings.modelId()); + this.nonStreamingUri = buildUri(serviceSettings.location(), serviceSettings.projectId(), serviceSettings.modelId()); } catch (URISyntaxException e) { throw new RuntimeException(e); } @@ -103,7 +104,7 @@ protected GoogleVertexAiEmbeddingsModel( serviceSettings ); try { - this.uri = new URI(uri); + this.nonStreamingUri = new URI(uri); } catch (URISyntaxException e) { throw new RuntimeException(e); } @@ -150,4 +151,17 @@ public static URI buildUri(String location, String projectId, String modelId) th ) .build(); } + + @Override + public int rateLimitGroupingHash() { + // In VertexAI rate limiting is scoped to the project, region, model and endpoint. + // API Key does not affect the quota + // https://ai.google.dev/gemini-api/docs/rate-limits + // https://cloud.google.com/vertex-ai/docs/quotas + var projectId = getServiceSettings().projectId(); + var location = getServiceSettings().location(); + var modelId = getServiceSettings().modelId(); + + return Objects.hash(projectId, location, modelId, GoogleVertexAiUtils.PREDICT); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequest.java index 53898a5f355d0..bf506a08d8268 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequest.java @@ -46,7 +46,7 @@ public GoogleVertexAiEmbeddingsRequest( @Override public HttpRequest createHttpRequest() { - HttpPost httpPost = new HttpPost(model.uri()); + HttpPost httpPost = new HttpPost(model.nonStreamingUri()); ByteArrayEntity byteEntity = new ByteArrayEntity( Strings.toString(new GoogleVertexAiEmbeddingsRequestEntity(truncationResult.input(), inputType, model.getTaskSettings())) @@ -84,7 +84,7 @@ public String getInferenceEntityId() { @Override public URI getURI() { - return model.uri(); + return model.nonStreamingUri(); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiRerankRequest.java index 7939f3b70c21f..1fcdd5189b459 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiRerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiRerankRequest.java @@ -50,7 +50,7 @@ public GoogleVertexAiRerankRequest( @Override public HttpRequest createHttpRequest() { - HttpPost httpPost = new HttpPost(model.uri()); + HttpPost httpPost = new HttpPost(model.nonStreamingUri()); ByteArrayEntity byteEntity = new ByteArrayEntity( Strings.toString( @@ -87,7 +87,7 @@ public String getInferenceEntityId() { @Override public URI getURI() { - return model.uri(); + return model.nonStreamingUri(); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java index 7b20e71099e66..7acc859d26748 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequest.java @@ -25,15 +25,17 @@ public class GoogleVertexAiUnifiedChatCompletionRequest implements GoogleVertexA private final GoogleVertexAiChatCompletionModel model; private final UnifiedChatInput unifiedChatInput; + private final URI uri; public GoogleVertexAiUnifiedChatCompletionRequest(UnifiedChatInput unifiedChatInput, GoogleVertexAiChatCompletionModel model) { this.model = Objects.requireNonNull(model); this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); + this.uri = unifiedChatInput.stream() ? model.streamingURI() : model.nonStreamingUri(); } @Override public HttpRequest createHttpRequest() { - HttpPost httpPost = new HttpPost(model.uri()); + HttpPost httpPost = new HttpPost(uri); var requestEntity = new GoogleVertexAiUnifiedChatCompletionRequestEntity(unifiedChatInput); @@ -52,7 +54,7 @@ public void decorateWithAuth(HttpPost httpPost) { @Override public URI getURI() { - return model.uri(); + return this.uri; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUtils.java index 7eda9c8b01cae..633b787ff9cf3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUtils.java @@ -37,6 +37,8 @@ public final class GoogleVertexAiUtils { public static final String STREAM_GENERATE_CONTENT = "streamGenerateContent"; + public static final String GENERATE_CONTENT = "generateContent"; + public static final String QUERY_PARAM_ALT_SSE = "alt=sse"; private GoogleVertexAiUtils() {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/rerank/GoogleVertexAiRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/rerank/GoogleVertexAiRerankModel.java index 650402cd7f713..a77756a7c00b1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/rerank/GoogleVertexAiRerankModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/rerank/GoogleVertexAiRerankModel.java @@ -22,10 +22,12 @@ import java.net.URI; import java.net.URISyntaxException; import java.util.Map; +import java.util.Objects; import static org.elasticsearch.core.Strings.format; public class GoogleVertexAiRerankModel extends GoogleVertexAiModel { + private static final String RERANK_RATE_LIMIT_ENDPOINT_ID = "rerank"; public GoogleVertexAiRerankModel( String inferenceEntityId, @@ -65,7 +67,7 @@ public GoogleVertexAiRerankModel(GoogleVertexAiRerankModel model, GoogleVertexAi serviceSettings ); try { - this.uri = buildUri(serviceSettings.projectId()); + this.nonStreamingUri = buildUri(serviceSettings.projectId()); } catch (URISyntaxException e) { throw new RuntimeException(e); } @@ -87,7 +89,7 @@ protected GoogleVertexAiRerankModel( serviceSettings ); try { - this.uri = new URI(uri); + this.nonStreamingUri = new URI(uri); } catch (URISyntaxException e) { throw new RuntimeException(e); } @@ -132,4 +134,16 @@ public static URI buildUri(String projectId) throws URISyntaxException { ) .build(); } + + @Override + public int rateLimitGroupingHash() { + // In VertexAI rate limiting is scoped to the project, region, model and endpoint. + // API Key does not affect the quota + // https://ai.google.dev/gemini-api/docs/rate-limits + // https://cloud.google.com/vertex-ai/docs/quotas + var projectId = getServiceSettings().projectId(); + var modelId = getServiceSettings().modelId(); + + return Objects.hash(projectId, modelId, RERANK_RATE_LIMIT_ENDPOINT_ID); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiCompletionResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiCompletionResponseEntity.java new file mode 100644 index 0000000000000..233981699f1da --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiCompletionResponseEntity.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.response; + +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiUnifiedStreamingProcessor; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; + +public class GoogleVertexAiCompletionResponseEntity { + /** + * Parses the response from Google Vertex AI's generateContent endpoint + * For a request like: + *
+     *     
+     *         {
+     *             "inputs": "Please summarize this text: some text"
+     *         }
+     *     
+     * 
+ * + * The response is a GenerateContentResponse objects that looks like: + * + *
+     *     
+     *
+     * {
+     *   "candidates": [
+     *     {
+     *       "content": {
+     *         "role": "model",
+     *         "parts": [
+     *           {
+     *             "text": "I am sorry, I cannot summarize the text because I do not have access to the text you are referring to."
+     *           }
+     *         ]
+     *       },
+     *       "finishReason": "STOP",
+     *       "avgLogprobs": -0.19326641248620074
+     *     }
+     *   ],
+     *   "usageMetadata": {
+     *     "promptTokenCount": 71,
+     *     "candidatesTokenCount": 23,
+     *     "totalTokenCount": 94,
+     *     "trafficType": "ON_DEMAND",
+     *     "promptTokensDetails": [
+     *       {
+     *         "modality": "TEXT",
+     *         "tokenCount": 71
+     *       }
+     *     ],
+     *     "candidatesTokensDetails": [
+     *       {
+     *         "modality": "TEXT",
+     *         "tokenCount": 23
+     *       }
+     *     ]
+     *   },
+     *   "modelVersion": "gemini-2.0-flash-001",
+     *   "createTime": "2025-05-28T15:08:20.049493Z",
+     *   "responseId": "5CY3aNWCA6mm4_UPr-zduAE"
+     * }
+     *    
+     * 
+ * + * @param request The original request made to the service. + **/ + public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException { + var responseJson = new String(response.body(), StandardCharsets.UTF_8); + + // Response from generateContent has the same shape as streamGenerateContent. We reuse the already implemented + // class to avoid code duplication + + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk; + try ( + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, responseJson) + ) { + moveToFirstToken(parser); + chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(parser); + } + var results = chunk.choices().stream().map(choice -> choice.delta().content()).map(ChatCompletionResults.Result::new).toList(); + + return new ChatCompletionResults(results); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index 3d7eb2d76ce47..99a09b983787d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -974,7 +974,7 @@ public void testGetConfiguration() throws Exception { { "service": "googlevertexai", "name": "Google Vertex AI", - "task_types": ["text_embedding", "rerank", "chat_completion"], + "task_types": ["text_embedding", "rerank", "completion", "chat_completion"], "configurations": { "service_account_json": { "description": "API Key for the provider you're connecting to.", @@ -983,7 +983,7 @@ public void testGetConfiguration() throws Exception { "sensitive": true, "updatable": true, "type": "str", - "supported_task_types": ["text_embedding", "rerank", "chat_completion"] + "supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"] }, "project_id": { "description": "The GCP Project ID which has Vertex AI API(s) enabled. For more information on the URL, refer to the {geminiVertexAIDocs}.", @@ -992,7 +992,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "rerank", "chat_completion"] + "supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"] }, "location": { "description": "Please provide the GCP region where the Vertex AI API(s) is enabled. For more information, refer to the {geminiVertexAIDocs}.", @@ -1001,7 +1001,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "chat_completion"] + "supported_task_types": ["text_embedding", "completion", "chat_completion"] }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -1010,7 +1010,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["text_embedding", "rerank", "chat_completion"] + "supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"] }, "model_id": { "description": "ID of the LLM you're using.", @@ -1019,7 +1019,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "rerank", "chat_completion"] + "supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"] } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiStreamingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiStreamingProcessorTests.java new file mode 100644 index 0000000000000..2a915cd16dc7e --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiStreamingProcessorTests.java @@ -0,0 +1,119 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParseException; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; + +import java.io.IOException; +import java.util.ArrayDeque; + +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.elasticsearch.xpack.inference.common.DelegatingProcessorTests.onError; +import static org.elasticsearch.xpack.inference.common.DelegatingProcessorTests.onNext; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + +public class GoogleVertexAiStreamingProcessorTests extends ESTestCase { + + public void testParseVertexAiResponse() throws IOException { + var item = new ArrayDeque(); + item.offer(new ServerSentEvent(vertexAiJsonResponse("test", true))); + + var response = onNext(new GoogleVertexAiStreamingProcessor(), item); + var json = toJsonString(response); + + assertThat(json, equalTo(""" + {"completion":[{"delta":"test"}]}""")); + } + + public void testParseVertexAiResponseMultiple() throws IOException { + var item = new ArrayDeque(); + item.offer(new ServerSentEvent(vertexAiJsonResponse("hello", false))); + + item.offer(new ServerSentEvent(vertexAiJsonResponse("world", true))); + + var response = onNext(new GoogleVertexAiStreamingProcessor(), item); + var json = toJsonString(response); + + assertThat(json, equalTo(""" + {"completion":[{"delta":"hello"},{"delta":"world"}]}""")); + } + + public void testParseErrorCallsOnError() { + var item = new ArrayDeque(); + item.offer(new ServerSentEvent("not json")); + + var exception = onError(new GoogleVertexAiStreamingProcessor(), item); + assertThat(exception, instanceOf(XContentParseException.class)); + } + + private String vertexAiJsonResponse(String content, boolean includeFinishReason) { + String finishReason = includeFinishReason ? "\"finishReason\": \"STOP\"," : ""; + + return Strings.format(""" + { + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + { + "text": "%s" + } + ] + }, + %s + "avgLogprobs": -0.19326641248620074 + } + ], + "usageMetadata": { + "promptTokenCount": 71, + "candidatesTokenCount": 23, + "totalTokenCount": 94, + "trafficType": "ON_DEMAND", + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 71 + } + ], + "candidatesTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 23 + } + ] + }, + "modelVersion": "gemini-2.0-flash-001", + "createTime": "2025-05-28T15:08:20.049493Z", + "responseId": "5CY3aNWCA6mm4_UPr-zduAE" + } + """, content, finishReason); + } + + private String toJsonString(ChunkedToXContent chunkedToXContent) throws IOException { + try (var builder = XContentFactory.jsonBuilder()) { + chunkedToXContent.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException e) { + logger.error(e.getMessage(), e); + fail(e.getMessage()); + } + }); + return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java index 58072b747a0aa..b7547857b8d2e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiUnifiedChatCompletionActionTests.java @@ -38,7 +38,7 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; -import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.COMPLETION_HANDLER; +import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService.COMPLETION_HANDLER; import static org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionCreator.USER_ROLE; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java index fb5dccf89aa57..6a0ec6edfaa79 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModelTests.java @@ -91,7 +91,7 @@ public void testBuildUri() throws URISyntaxException { "https://us-east1-aiplatform.googleapis.com/v1/projects/my-gcp-project" + "/locations/global/publishers/google/models/gemini-1.5-flash-001:streamGenerateContent?alt=sse" ); - URI actualUri = GoogleVertexAiChatCompletionModel.buildUri(location, projectId, model); + URI actualUri = GoogleVertexAiChatCompletionModel.buildUriStreaming(location, projectId, model); assertThat(actualUri, is(expectedUri)); } @@ -113,6 +113,6 @@ public static GoogleVertexAiChatCompletionModel createCompletionModel( } public static URI buildDefaultUri() throws URISyntaxException { - return GoogleVertexAiChatCompletionModel.buildUri(DEFAULT_LOCATION, DEFAULT_PROJECT_ID, DEFAULT_MODEL_ID); + return GoogleVertexAiChatCompletionModel.buildUriStreaming(DEFAULT_LOCATION, DEFAULT_PROJECT_ID, DEFAULT_MODEL_ID); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModelTests.java index 07169024e1e01..7eaa8e03fee66 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModelTests.java @@ -83,7 +83,7 @@ public void testOverrideWith_DoesNotOverrideModelUri() { var model = createModel("model", Boolean.FALSE, InputType.SEARCH); var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, Map.of()); - MatcherAssert.assertThat(overriddenModel.uri(), is(model.uri())); + MatcherAssert.assertThat(overriddenModel.nonStreamingUri(), is(model.nonStreamingUri())); } public static GoogleVertexAiEmbeddingsModel createModel( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiCompletionResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiCompletionResponseEntityTests.java new file mode 100644 index 0000000000000..e634f75829743 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiCompletionResponseEntityTests.java @@ -0,0 +1,80 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.response; + +import org.apache.http.HttpResponse; +import org.elasticsearch.core.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class GoogleVertexAiCompletionResponseEntityTests extends ESTestCase { + + public void testFromResponse_Javadoc() throws IOException { + var responseText = "I am sorry, I cannot summarize the text because I do not have access to the text you are referring to."; + + String responseJson = Strings.format(""" + { + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + { + "text": "%s" + } + ] + }, + "finishReason": "STOP", + "avgLogprobs": -0.19326641248620074 + } + ], + "usageMetadata": { + "promptTokenCount": 71, + "candidatesTokenCount": 23, + "totalTokenCount": 94, + "trafficType": "ON_DEMAND", + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 71 + } + ], + "candidatesTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 23 + } + ] + }, + "modelVersion": "gemini-2.0-flash-001", + "createTime": "2025-05-28T15:08:20.049493Z", + "responseId": "5CY3aNWCA6mm4_UPr-zduAE" + } + """, responseText); + + var parsedResults = GoogleVertexAiCompletionResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assert parsedResults instanceof ChatCompletionResults; + var results = (ChatCompletionResults) parsedResults; + + assertThat(results.isStreaming(), is(false)); + assertThat(results.results().size(), is(1)); + assertThat(results.results().get(0).content(), is(responseText)); + } +} From 4afffbba6130ed7f963ea26b1219aa30569ea72e Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Mon, 9 Jun 2025 13:42:03 -0400 Subject: [PATCH 2/3] Fixing long line --- .../response/GoogleVertexAiCompletionResponseEntity.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiCompletionResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiCompletionResponseEntity.java index 233981699f1da..4db219eaf17ea 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiCompletionResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/response/GoogleVertexAiCompletionResponseEntity.java @@ -35,7 +35,9 @@ public class GoogleVertexAiCompletionResponseEntity { * * * - * The response is a GenerateContentResponse objects that looks like: + * The response is a + * GenerateContentResponse + * objects that looks like: * *
      *     

From e259025fa88c6bab045c649097213eb9bf257959 Mon Sep 17 00:00:00 2001
From: Jonathan Buttner 
Date: Tue, 10 Jun 2025 08:58:14 -0400
Subject: [PATCH 3/3] Fixing test

---
 .../xpack/inference/InferenceGetServicesIT.java              | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java
index 4de3c9f31d38e..0ce9b1bf5dc63 100644
--- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java
+++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java
@@ -134,7 +134,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
 
     public void testGetServicesWithCompletionTaskType() throws IOException {
         List services = getServices(TaskType.COMPLETION);
-        assertThat(services.size(), equalTo(13));
+        assertThat(services.size(), equalTo(14));
 
         var providers = providers(services);
 
@@ -154,7 +154,8 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
                     "openai",
                     "streaming_completion_test_service",
                     "hugging_face",
-                    "amazon_sagemaker"
+                    "amazon_sagemaker",
+                    "googlevertexai"
                 ).toArray()
             )
         );