From 9616fb30dbdb4d735e7add7d3c137cea09cb75f9 Mon Sep 17 00:00:00 2001 From: Donal Evans Date: Wed, 17 Sep 2025 13:40:05 -0700 Subject: [PATCH] [ML] Do not convert input Strings to ChunkInferenceInput unless necessary (#134945) The SenderService.infer() method was converting the input variable from a List into a List, but then when that list was passed into SenderService.createInput() it was immediately converted back into a List. To avoid unnecessary work, allow the EmbeddingsInput constructor to convert the list if necessary. (cherry picked from commit 8d2c4c312f073d857be2ef4417e01cbf17f1b00d) --- .../external/http/sender/EmbeddingsInput.java | 14 +++++++++----- .../xpack/inference/services/SenderService.java | 12 +++++------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java index 1e188d0f7bf5b..55cdb7207e25d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java @@ -30,10 +30,6 @@ public static EmbeddingsInput of(InferenceInputs inferenceInputs) { private final Supplier> listSupplier; private final InputType inputType; - public EmbeddingsInput(List input, @Nullable InputType inputType) { - this(input, inputType, false); - } - public EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType) { super(false); this.listSupplier = Objects.requireNonNull(inputSupplier); @@ -41,7 +37,15 @@ public EmbeddingsInput(Supplier> inputSupplier, @Nulla } public EmbeddingsInput(List input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType) { - this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).collect(Collectors.toList()), inputType, false); + this(input, chunkingSettings, inputType, false); + } + + public EmbeddingsInput(List input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType, boolean stream) { + this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).toList(), inputType, stream); + } + + public EmbeddingsInput(List input, @Nullable InputType inputType) { + this(input, inputType, false); } public EmbeddingsInput(List input, @Nullable InputType inputType, boolean stream) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index ff8ae6fd5aac3..657834e6831ff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -71,31 +71,29 @@ public void infer( ActionListener listener ) { init(); - var chunkInferenceInput = input.stream().map(i -> new ChunkInferenceInput(i, null)).toList(); - var inferenceInput = createInput(this, model, chunkInferenceInput, inputType, query, returnDocuments, topN, stream); + var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream); doInfer(model, inferenceInput, taskSettings, timeout, listener); } private static InferenceInputs createInput( SenderService service, Model model, - List input, + List input, InputType inputType, @Nullable String query, @Nullable Boolean returnDocuments, @Nullable Integer topN, boolean stream ) { - List textInput = ChunkInferenceInput.inputs(input); return switch (model.getTaskType()) { - case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(textInput, stream); + case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream); case RERANK -> { ValidationException validationException = new ValidationException(); service.validateRerankParameters(returnDocuments, topN, validationException); if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - yield new QueryAndDocsInputs(query, textInput, returnDocuments, topN, stream); + yield new QueryAndDocsInputs(query, input, returnDocuments, topN, stream); } case TEXT_EMBEDDING, SPARSE_EMBEDDING -> { ValidationException validationException = new ValidationException(); @@ -103,7 +101,7 @@ private static InferenceInputs createInput( if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - yield new EmbeddingsInput(input, inputType, stream); + yield new EmbeddingsInput(input, null, inputType, stream); } default -> throw new ElasticsearchStatusException( Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()),