diff --git a/docs/changelog/138632.yaml b/docs/changelog/138632.yaml new file mode 100644 index 0000000000000..c59c59b13564a --- /dev/null +++ b/docs/changelog/138632.yaml @@ -0,0 +1,5 @@ +pr: 138632 +summary: Correctly handle empty inputs in `chunkedInfer()` +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index e8ff3d98f3776..3c1678b547afc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -149,9 +149,18 @@ public void chunkedInfer( if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - - // a non-null query is not supported and is dropped by all providers - doChunkedInfer(model, input, taskSettings, inputType, timeout, chunkedInferListener); + if (supportsChunkedInfer()) { + if (input.isEmpty()) { + chunkedInferListener.onResponse(List.of()); + } else { + // a non-null query is not supported and is dropped by all providers + doChunkedInfer(model, input, taskSettings, inputType, timeout, chunkedInferListener); + } + } else { + chunkedInferListener.onFailure( + new UnsupportedOperationException(Strings.format("%s service does not support chunked inference", name())) + ); + } }).addListener(listener); } @@ -183,6 +192,10 @@ protected abstract void doChunkedInfer( ActionListener> listener ); + protected boolean supportsChunkedInfer() { + return true; + } + public void start(Model model, ActionListener listener) { SubscribableListener.newForked(this::init) .andThen((doStartListener) -> doStart(model, doStartListener)) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java index 69eddd7ecade2..38e5a33382ac5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java @@ -147,9 +147,15 @@ protected void doChunkedInfer( TimeValue timeout, ActionListener> listener ) { + // Should never be called throw new UnsupportedOperationException("AI21 service does not support chunked inference"); } + @Override + protected boolean supportsChunkedInfer() { + return false; + } + @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index 29a1582ce6236..fa48b41df09ae 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -218,9 +218,15 @@ protected void doChunkedInfer( TimeValue timeout, ActionListener> listener ) { + // Should never be called throw new UnsupportedOperationException("Anthropic service does not support chunked inference"); } + @Override + protected boolean supportsChunkedInfer() { + return false; + } + @Override public TransportVersion getMinimalSupportedVersion() { return TransportVersions.V_8_15_0; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java index 9088de1dacd2d..aa9753cbe9582 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/contextualai/ContextualAiService.java @@ -185,9 +185,15 @@ protected void doChunkedInfer( TimeValue timeout, ActionListener> listener ) { + // Should never be called listener.onFailure(new ElasticsearchStatusException("Chunked inference is not supported for rerank task", RestStatus.BAD_REQUEST)); } + @Override + protected boolean supportsChunkedInfer() { + return false; + } + @Override protected void doUnifiedCompletionInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java index ae5928e52f0d4..502e8faf1f3f7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java @@ -122,9 +122,15 @@ protected void doChunkedInfer( TimeValue timeout, ActionListener> listener ) { + // Should never be called listener.onFailure(new UnsupportedOperationException(Strings.format("The %s service only supports unified completion", NAME))); } + @Override + protected boolean supportsChunkedInfer() { + return false; + } + @Override public String name() { return NAME; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java index 5f279e4f60b6b..bec318496217d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java @@ -271,6 +271,9 @@ public void chunkedInfer( listener.onFailure(createInvalidModelException(model)); return; } + if (input.isEmpty()) { + listener.onResponse(List.of()); + } try { var sageMakerModel = ((SageMakerModel) model).override(taskSettings); var batchedRequests = new EmbeddingRequestChunker<>( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index 04241686d453a..32ff5a3bccccc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -71,6 +71,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.mock; @@ -491,6 +492,27 @@ public void testChunkedInfer_SparseEmbeddingChunkingSettingsNotSet() throws IOEx testChunkedInfer(TaskType.SPARSE_EMBEDDING, null); } + public void testChunkedInfer_noInputs() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + PlainActionFuture> listener = new PlainActionFuture<>(); + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = createModelForTaskType(randomFrom(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING), null); + + service.chunkedInfer( + model, + null, + List.of(), + new HashMap<>(), + InputTypeTests.randomWithIngestAndSearch(), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + } + assertThat(listener.actionGet(TIMEOUT), empty()); + } + private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettings) throws IOException { var input = List.of(new ChunkInferenceInput("foo"), new ChunkInferenceInput("bar")); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 0aede89ac3f35..35913952ce3a0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -85,6 +85,7 @@ import static org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettingsTests.createEmbeddingsRequestSettingsMap; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.ArgumentMatchers.any; @@ -1323,6 +1324,50 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { testChunkedInfer(model); } + public void testChunkedInfer_noInputs() throws IOException { + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + null, + "access", + "secret" + ); + + var sender = createMockSender(); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), + mockClusterServiceEmpty() + ); + + try ( + var service = new AmazonBedrockService( + factory, + amazonBedrockFactory, + createWithEmptySettings(threadPool), + mockClusterServiceEmpty() + ) + ) { + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(), + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, empty()); + } + } + private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOException { var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 5fbd2993599b8..449c6312f15ce 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -89,6 +89,7 @@ import static org.elasticsearch.xpack.inference.services.azureaistudio.request.AzureAiStudioRequestFields.API_KEY_HEADER; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -1294,6 +1295,27 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { testChunkedInfer(model); } + public void testChunkedInfer_noInputs() throws IOException { + var model = AzureAiStudioEmbeddingsModelTests.createModel( + "id", + getUrl(webServer), + AzureAiStudioProvider.OPENAI, + AzureAiStudioEndpointType.TOKEN, + "apikey" + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + PlainActionFuture> listener = new PlainActionFuture<>(); + List input = List.of(); + service.chunkedInfer(model, null, input, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, empty()); + assertThat(webServer.requests(), empty()); + } + } + private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 19d02ef8ec205..adbe0fde84aa2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -81,6 +81,7 @@ import static org.elasticsearch.xpack.inference.services.azureopenai.request.AzureOpenAiUtils.API_KEY_HEADER; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -953,6 +954,32 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException, URISyn testChunkedInfer(model); } + public void testChunkedInfer_noInputs() throws IOException, URISyntaxException { + var model = AzureOpenAiEmbeddingsModelTests.createModel("resource", "deployment", "apiversion", "user", null, "apikey", null, "id"); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + + model.setUri(new URI(getUrl(webServer))); + PlainActionFuture> listener = new PlainActionFuture<>(); + List input = List.of(); + service.chunkedInfer( + model, + null, + input, + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, empty()); + assertThat(webServer.requests(), empty()); + } + } + private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOException, URISyntaxException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index f586af497b674..8eadc73ef747d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -92,6 +92,7 @@ import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -1377,6 +1378,28 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { testChunkedInfer(model); } + public void testChunkedInfer_noInputs() throws IOException { + var model = CohereEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1024, "model", null); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(), + new HashMap<>(), + InputType.UNSPECIFIED, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, empty()); + assertThat(webServer.requests(), empty()); + } + } + private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index 40fbe4b47bcf5..7985042779458 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -63,6 +63,7 @@ import static org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser.RERANK_PARSER_SCORE; import static org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser.SPARSE_EMBEDDING_TOKEN_PATH; import static org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser.SPARSE_EMBEDDING_WEIGHT_PATH; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; @@ -822,4 +823,29 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { assertThat(requestMap.get("input"), is(List.of("a"))); } } + + public void testChunkedInfer_noInputs() throws IOException { + var model = createInternalEmbeddingModel( + new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT), + getUrl(webServer) + ); + + try (var service = createService(threadPool, clientManager)) { + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(), + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, empty()); + assertThat(webServer.requests(), empty()); + } + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java index 498fbca10b5a8..f216d0977665e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java @@ -16,6 +16,8 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -349,12 +351,21 @@ public void testUnifiedCompletionMalformedError() throws Exception { }}"""); } - public void testDoChunkedInferAlwaysFails() throws IOException { + public void testChunkedInferFails() throws IOException { try (var service = createService()) { - service.doChunkedInfer(mock(), mock(), Map.of(), InputType.UNSPECIFIED, TIMEOUT, assertNoSuccessListener(e -> { - assertThat(e, isA(UnsupportedOperationException.class)); - assertThat(e.getMessage(), equalTo("The deepseek service only supports unified completion")); - })); + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer(mock(), null, List.of(new ChunkInferenceInput("a")), Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener); + var exception = expectThrows(UnsupportedOperationException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), is("deepseek service does not support chunked inference")); + } + } + + public void testChunkedInferFails_noInputs() throws IOException { + try (var service = createService()) { + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer(mock(), null, List.of(), Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener); + var exception = expectThrows(UnsupportedOperationException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.getMessage(), is("deepseek service does not support chunked inference")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index cff7c0ef5d97f..3862cf699905b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -90,6 +90,7 @@ import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.isA; @@ -912,6 +913,29 @@ public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOExceptio } } + public void testChunkedInfer_noInputs() throws IOException { + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id"); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = createService(senderFactory, getUrl(webServer))) { + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, empty()); + assertThat(webServer.requests(), empty()); + } + } + public void testHideFromConfigurationApi_ThrowsUnsupported_WithNoAvailableModels() throws Exception { try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorizationModel.newDisabledService())) { expectThrows(UnsupportedOperationException.class, service::hideFromConfigurationApi); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 70e1b15be261c..eacfde6a7944d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -126,6 +126,7 @@ import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.NAME; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -1501,6 +1502,23 @@ public void testChunkingLargeDocument() throws InterruptedException { assertTrue("Listener not called with results", gotResults.get()); } + public void testChunkInfer_noInputs() throws IOException { + var model = new MultilingualE5SmallModel( + "foo", + TaskType.TEXT_EMBEDDING, + "e5", + new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform", null), + null + ); + try (var service = createService(mock(Client.class))) { + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer(model, null, List.of(), Map.of(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var results = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + assertThat(results, empty()); + } + } + public void testParsePersistedConfig_Rerank() { // with task settings { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index d6e938ac17fdf..8a0eef98233c7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -84,6 +84,7 @@ import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.endsWith; import static org.hamcrest.Matchers.hasSize; import static org.mockito.ArgumentMatchers.any; @@ -893,6 +894,29 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { testChunkedInfer(modelId, apiKey, model); } + public void testChunkedInfer_noInputs() throws IOException { + var model = GoogleAiStudioEmbeddingsModelTests.createModel("modelId", createRandomChunkingSettings(), "apiKey", getUrl(webServer)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(), + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, empty()); + assertThat(webServer.requests(), empty()); + } + } + private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbeddingsModel model) throws IOException { var input = List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index c1ed8c6ed996b..fb910d5daaefa 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -9,12 +9,14 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -28,6 +30,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -50,6 +53,7 @@ import java.net.URI; import java.net.URISyntaxException; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -61,8 +65,10 @@ import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.Utils.randomSimilarityMeasure; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; @@ -1069,6 +1075,26 @@ public void testGetConfiguration() throws Exception { } } + public void testChunkedInfer_noInputs() throws IOException { + try (var service = createGoogleVertexAiService()) { + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + GoogleVertexAiEmbeddingsModelTests.createModel(randomAlphaOfLength(10), randomBoolean(), randomSimilarityMeasure()), + null, + List.of(), + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, empty()); + assertThat(webServer.requests(), empty()); + } + } + private GoogleVertexAiService createGoogleVertexAiService() { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java index f9a67e7216e75..a1964e658f1a2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java @@ -52,6 +52,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.mockito.Mockito.mock; @@ -133,6 +134,27 @@ public void testChunkedInfer_CallsInfer_Elser_ConvertsFloatResponse() throws IOE } } + public void testChunkedInfer_noInputs() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new HuggingFaceElserService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = HuggingFaceElserModelTests.createModel(getUrl(webServer), "secret"); + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(), + new HashMap<>(), + InputType.INTERNAL_SEARCH, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + assertThat(listener.actionGet(TIMEOUT), empty()); + assertThat(webServer.requests(), empty()); + } + } + public void testGetConfiguration() throws Exception { try ( var service = new HuggingFaceElserService( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 2c078f154d934..ffd186f6c1e65 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -89,6 +89,7 @@ import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettingsTests.getServiceSettingsMap; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -1289,6 +1290,28 @@ public void testChunkedInfer() throws IOException { } } + public void testChunkedInfer_noInputs() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret"); + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(), + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, empty()); + assertThat(webServer.requests(), empty()); + } + } + public void testGetConfiguration() throws Exception { try (var service = createHuggingFaceService()) { String content = XContentHelper.stripWhitespace(""" diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index af935cf898927..659e5c70c7677 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -87,6 +87,7 @@ import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasSize; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; @@ -732,6 +733,27 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { testChunkedInfer_Batches(createRandomChunkingSettings()); } + public void testChunkedInfer_noInputs() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + var model = IbmWatsonxEmbeddingsModelTests.createModel(modelId, projectId, URI.create(url), apiVersion, apiKey, getUrl(webServer)); + try (var service = new IbmWatsonxServiceWithoutAuth(senderFactory, createWithEmptySettings(threadPool))) { + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(), + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, empty()); + assertThat(webServer.requests(), empty()); + } + } + private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws IOException { var input = List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index a67701b04a2e8..b4c8d34bff52b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -79,6 +79,7 @@ import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -1663,6 +1664,28 @@ public void test_Embedding_ChunkedInfer_LateChunkingDisabled() throws IOExceptio test_Embedding_ChunkedInfer_BatchesCalls(model); } + public void test_Embedding_ChunkedInfer_noInputs() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + var model = JinaAIEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1024, "jina-clip-v2", JinaAIEmbeddingType.FLOAT); + + try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(), + new HashMap<>(), + InputType.UNSPECIFIED, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, empty()); + assertThat(webServer.requests(), empty()); + } + } + private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java index 2011e2378a2dd..42f0848b6f12d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/LlamaServiceTests.java @@ -93,6 +93,7 @@ import static org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettingsTests.buildServiceSettingsMap; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -656,6 +657,31 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { testChunkedInfer(model); } + public void testChunkedInfer_noInputs() throws IOException { + var model = LlamaEmbeddingsModelTests.createEmbeddingsModelWithChunkingSettings("id", "url", "api_key"); + model.setURI(getUrl(webServer)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new LlamaService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(), + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + + assertThat(results, empty()); + assertThat(webServer.requests(), empty()); + } + } + public void testChunkedInfer(LlamaEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index d3cc5ee152bd0..1627270be674a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -92,6 +92,7 @@ import static org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettingsTests.createRequestSettingsMap; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -1082,6 +1083,31 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { testChunkedInfer(model); } + public void testChunkedInfer_noInputs() throws IOException { + var model = MistralEmbeddingModelTests.createModel("id", "mistral-embed", "apikey"); + model.setURI(getUrl(webServer)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(), + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + + assertThat(results, empty()); + assertThat(webServer.requests(), empty()); + } + } + public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 5beda4426983e..0e9ea32a4645e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -996,6 +996,28 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { testChunkedInfer(model); } + public void testChunkedInfer_noInputs() throws IOException { + var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user", (ChunkingSettings) null); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(), + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, empty()); + assertThat(webServer.requests(), empty()); + } + } + private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java index 131fa98f97563..4695904dd1ca0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java @@ -94,6 +94,7 @@ import static org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionServiceSettingsTests.getServiceSettingsMap; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -662,6 +663,30 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { testChunkedInfer(model); } + public void testChunkedInfer_noInputs() throws IOException { + var model = OpenShiftAiEmbeddingsModelTests.createModel(getUrl(webServer), API_KEY_VALUE, MODEL_VALUE); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(), + new HashMap<>(), + InputType.INTERNAL_INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(ESTestCase.TEST_REQUEST_TIMEOUT); + + assertThat(results, empty()); + assertThat(webServer.requests(), empty()); + } + } + public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java index d7cb372dd6b6d..d9830dcc0ab5e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java @@ -72,6 +72,7 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.only; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -472,6 +473,30 @@ public void testChunkedInfer() throws Exception { verifyNoMoreInteractions(client, schemas, schema); } + public void testChunkedInfer_noInputs() throws Exception { + var model = mockModelForChunking(); + + SageMakerSchema schema = mock(); + when(schemas.schemaFor(model)).thenReturn(schema); + mockInvoke(); + + sageMakerService.chunkedInfer( + model, + QUERY, + List.of(), + null, + INPUT_TYPE, + THIRTY_SECONDS, + assertOnce(assertNoFailureListener(chunkedInferences -> { + verify(schemas, never()).schemaFor(any()); + verify(schema, never()).request(any(), any()); + verify(schema, never()).response(any(), any(), any()); + })) + ); + verify(client, never()).invoke(any(), any(), any(), any()); + verifyNoMoreInteractions(client, schemas, schema); + } + private SageMakerModel mockModelForChunking() { var model = mockModel(); when(model.batchSize()).thenReturn(Optional.of(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index f508ca218bee1..d0f4a12d6ba9d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -77,6 +77,7 @@ import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -1600,6 +1601,28 @@ public void test_Embedding_ChunkedInfer_ChunkingSettingsNotSet() throws IOExcept test_Embedding_ChunkedInfer_BatchesCalls(model); } + public void test_Embedding_ChunkedInfer_noInputs() throws IOException { + var model = VoyageAIEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1024, "voyage-3-large"); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) { + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + null, + List.of(), + new HashMap<>(), + InputType.UNSPECIFIED, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, empty()); + MatcherAssert.assertThat(webServer.requests(), empty()); + } + } + private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);