From cf7c557ae3f8aef34384c4f84f9bce3079fb242d Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 15 Oct 2025 14:43:23 -0400 Subject: [PATCH 01/18] Refactor createInferenceEndpoint into a util class --- ...actSemanticCrossClusterSearchTestCase.java | 36 +----------- .../inference/integration/InferenceUtils.java | 58 +++++++++++++++++++ .../SemanticTextIndexOptionsIT.java | 30 +--------- .../SemanticTextInferenceFieldsIT.java | 32 +--------- 4 files changed, 61 insertions(+), 95 deletions(-) create mode 100644 x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceUtils.java diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java index e1f72ddec5aab..2106e524f0ea2 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java @@ -39,18 +39,12 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.AbstractMultiClustersTestCase; import org.elasticsearch.transport.RemoteConnectionInfo; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; -import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; -import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; import org.elasticsearch.xpack.ml.action.TransportCoordinatedInferenceAction; import java.io.IOException; @@ -66,6 +60,7 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; +import static org.elasticsearch.xpack.inference.integration.InferenceUtils.createInferenceEndpoint; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; @@ -165,35 +160,6 @@ protected BytesReference openPointInTime(String[] indices, TimeValue keepAlive) return response.getPointInTimeId(); } - protected static void createInferenceEndpoint(Client client, TaskType taskType, String inferenceId, Map serviceSettings) - throws IOException { - final String service = switch (taskType) { - case TEXT_EMBEDDING -> TestDenseInferenceServiceExtension.TestInferenceService.NAME; - case SPARSE_EMBEDDING -> TestSparseInferenceServiceExtension.TestInferenceService.NAME; - default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]"); - }; - - final BytesReference content; - try (XContentBuilder builder = XContentFactory.jsonBuilder()) { - builder.startObject(); - builder.field("service", service); - builder.field("service_settings", serviceSettings); - builder.endObject(); - - content = BytesReference.bytes(builder); - } - - PutInferenceModelAction.Request request = new PutInferenceModelAction.Request( - taskType, - inferenceId, - content, - XContentType.JSON, - TEST_REQUEST_TIMEOUT - ); - var responseFuture = client.execute(PutInferenceModelAction.INSTANCE, request); - assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId)); - } - protected void assertSearchResponse(QueryBuilder queryBuilder, List indices, List expectedSearchResults) throws Exception { assertSearchResponse(queryBuilder, indices, expectedSearchResults, null, null); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceUtils.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceUtils.java new file mode 100644 index 0000000000000..a79ad218254ea --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceUtils.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; + +import java.io.IOException; +import java.util.Map; + +import static org.elasticsearch.test.ESTestCase.TEST_REQUEST_TIMEOUT; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; + +public class InferenceUtils { + private InferenceUtils() {} + + public static void createInferenceEndpoint(Client client, TaskType taskType, String inferenceId, Map serviceSettings) + throws IOException { + final String service = switch (taskType) { + case TEXT_EMBEDDING -> TestDenseInferenceServiceExtension.TestInferenceService.NAME; + case SPARSE_EMBEDDING -> TestSparseInferenceServiceExtension.TestInferenceService.NAME; + default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]"); + }; + + final BytesReference content; + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + builder.startObject(); + builder.field("service", service); + builder.field("service_settings", serviceSettings); + builder.endObject(); + + content = BytesReference.bytes(builder); + } + + PutInferenceModelAction.Request request = new PutInferenceModelAction.Request( + taskType, + inferenceId, + content, + XContentType.JSON, + TEST_REQUEST_TIMEOUT + ); + var responseFuture = client.execute(PutInferenceModelAction.INSTANCE, request); + assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId)); + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexOptionsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexOptionsIT.java index 08d3f0a6a9b9f..54055a9bc20da 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexOptionsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexOptionsIT.java @@ -37,13 +37,10 @@ import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction; -import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; import org.elasticsearch.xpack.inference.InferenceIndex; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; -import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; -import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; import org.junit.After; import org.junit.Before; @@ -153,32 +150,7 @@ public void testSetDefaultBBQIndexOptionsWithBasicLicense() throws Exception { } private void createInferenceEndpoint(TaskType taskType, String inferenceId, Map serviceSettings) throws IOException { - final String service = switch (taskType) { - case TEXT_EMBEDDING -> TestDenseInferenceServiceExtension.TestInferenceService.NAME; - case SPARSE_EMBEDDING -> TestSparseInferenceServiceExtension.TestInferenceService.NAME; - default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]"); - }; - - final BytesReference content; - try (XContentBuilder builder = XContentFactory.jsonBuilder()) { - builder.startObject(); - builder.field("service", service); - builder.field("service_settings", serviceSettings); - builder.endObject(); - - content = BytesReference.bytes(builder); - } - - PutInferenceModelAction.Request request = new PutInferenceModelAction.Request( - taskType, - inferenceId, - content, - XContentType.JSON, - TEST_REQUEST_TIMEOUT - ); - var responseFuture = client().execute(PutInferenceModelAction.INSTANCE, request); - assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId)); - + InferenceUtils.createInferenceEndpoint(client(), taskType, inferenceId, serviceSettings); inferenceIds.put(inferenceId, taskType); } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java index bb26153965499..b7f940163c01b 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java @@ -12,7 +12,6 @@ import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.document.DocumentField; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Settings; @@ -33,15 +32,11 @@ import org.elasticsearch.test.index.IndexVersionUtils; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction; -import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; -import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; -import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; import org.junit.After; @@ -168,32 +163,7 @@ private void excludeInferenceFieldsFromSourceTestCase(IndexVersion minIndexVersi } private void createInferenceEndpoint(TaskType taskType, String inferenceId, Map serviceSettings) throws IOException { - final String service = switch (taskType) { - case TEXT_EMBEDDING -> TestDenseInferenceServiceExtension.TestInferenceService.NAME; - case SPARSE_EMBEDDING -> TestSparseInferenceServiceExtension.TestInferenceService.NAME; - default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]"); - }; - - final BytesReference content; - try (XContentBuilder builder = XContentFactory.jsonBuilder()) { - builder.startObject(); - builder.field("service", service); - builder.field("service_settings", serviceSettings); - builder.endObject(); - - content = BytesReference.bytes(builder); - } - - PutInferenceModelAction.Request request = new PutInferenceModelAction.Request( - taskType, - inferenceId, - content, - XContentType.JSON, - TEST_REQUEST_TIMEOUT - ); - var responseFuture = client().execute(PutInferenceModelAction.INSTANCE, request); - assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId)); - + InferenceUtils.createInferenceEndpoint(client(), taskType, inferenceId, serviceSettings); inferenceIds.put(inferenceId, taskType); } From 40e87f581858bf4a0946b614911725b6ef29ddc5 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 15 Oct 2025 15:49:30 -0400 Subject: [PATCH 02/18] Added integration test --- .../integration/SemanticQueryInferenceIT.java | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticQueryInferenceIT.java diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticQueryInferenceIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticQueryInferenceIT.java new file mode 100644 index 0000000000000..e5d842b68334d --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticQueryInferenceIT.java @@ -0,0 +1,89 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.license.LicenseSettings; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.reindex.ReindexPlugin; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; +import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; +import static org.elasticsearch.xpack.inference.integration.InferenceUtils.createInferenceEndpoint; +import static org.hamcrest.CoreMatchers.equalTo; + +public class SemanticQueryInferenceIT extends ESIntegTestCase { + private static final Map SPARSE_EMBEDDING_SERVICE_SETTINGS = Map.of("model", "my_model", "api_key", "my_api_key"); + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); + } + + @Override + protected Collection> nodePlugins() { + return List.of( + LocalStateInferencePlugin.class, + TestInferenceServicePlugin.class, + ReindexPlugin.class, + SemanticTextInferenceFieldsIT.FakeMlPlugin.class + ); + } + + public void testManyInferenceRequests() throws Exception { + final String semanticTextFieldName = randomAlphaOfLength(10); + int indexCount = randomIntBetween(16, 20); + List indices = new ArrayList<>(indexCount); + for (int i = 0; i < indexCount; i++) { + String indexName = randomIdentifier(); + String inferenceId = randomIdentifier(); + XContentBuilder mapping = generateMapping(semanticTextFieldName, inferenceId); + + createInferenceEndpoint(client(), TaskType.SPARSE_EMBEDDING, inferenceId, SPARSE_EMBEDDING_SERVICE_SETTINGS); + assertAcked(prepareCreate(indexName).setMapping(mapping)); + + indices.add(indexName); + } + + SemanticQueryBuilder query = new SemanticQueryBuilder(semanticTextFieldName, randomAlphaOfLength(10)); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); + SearchRequest searchRequest = new SearchRequest(indices.toArray(new String[0]), searchSourceBuilder); + assertResponse(client().search(searchRequest), response -> { + assertThat(response.getSuccessfulShards(), equalTo(response.getTotalShards())); + assertThat(response.getHits().getTotalHits().value(), equalTo(0L)); + }); + } + + private static XContentBuilder generateMapping(String semanticTextFieldName, String inferenceId) throws IOException { + return XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(semanticTextFieldName) + .field("type", SemanticTextFieldMapper.CONTENT_TYPE) + .field("inference_id", inferenceId) + .endObject() + .endObject() + .endObject(); + } +} From e56b16346bf76992d49a9fb75c2456cf2e90f02a Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Wed, 15 Oct 2025 16:36:02 -0400 Subject: [PATCH 03/18] Added tests for intercepted queries --- .../integration/SemanticQueryInferenceIT.java | 113 +++++++++++++++++- 1 file changed, 110 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticQueryInferenceIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticQueryInferenceIT.java index e5d842b68334d..65ec23cad1687 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticQueryInferenceIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticQueryInferenceIT.java @@ -8,15 +8,24 @@ package org.elasticsearch.xpack.inference.integration; import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.query.MatchQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.inference.TaskType; import org.elasticsearch.license.LicenseSettings; +import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.reindex.ReindexPlugin; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; +import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; @@ -33,6 +42,7 @@ import static org.elasticsearch.xpack.inference.integration.InferenceUtils.createInferenceEndpoint; import static org.hamcrest.CoreMatchers.equalTo; +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST) public class SemanticQueryInferenceIT extends ESIntegTestCase { private static final Map SPARSE_EMBEDDING_SERVICE_SETTINGS = Map.of("model", "my_model", "api_key", "my_api_key"); @@ -47,13 +57,13 @@ protected Collection> nodePlugins() { LocalStateInferencePlugin.class, TestInferenceServicePlugin.class, ReindexPlugin.class, - SemanticTextInferenceFieldsIT.FakeMlPlugin.class + FakeMlPlugin.class ); } - public void testManyInferenceRequests() throws Exception { + public void testManyInferenceRequests_SemanticQuery() throws Exception { final String semanticTextFieldName = randomAlphaOfLength(10); - int indexCount = randomIntBetween(16, 20); + int indexCount = 20; List indices = new ArrayList<>(indexCount); for (int i = 0; i < indexCount; i++) { String indexName = randomIdentifier(); @@ -75,6 +85,85 @@ public void testManyInferenceRequests() throws Exception { }); } + public void testManyInferenceRequests_KnnQuery() throws Exception { + final String semanticTextFieldName = randomAlphaOfLength(10); + int indexCount = 20; + List indices = new ArrayList<>(indexCount); + for (int i = 0; i < indexCount; i++) { + String indexName = randomIdentifier(); + String inferenceId = randomIdentifier(); + XContentBuilder mapping = generateMapping(semanticTextFieldName, inferenceId); + + createInferenceEndpoint(client(), TaskType.SPARSE_EMBEDDING, inferenceId, SPARSE_EMBEDDING_SERVICE_SETTINGS); + assertAcked(prepareCreate(indexName).setMapping(mapping)); + + indices.add(indexName); + } + + QueryBuilder query = new KnnVectorQueryBuilder( + semanticTextFieldName, + new TextEmbeddingQueryVectorBuilder(null, randomAlphanumericOfLength(10)), + 10, + 100, + 10f, + null + ); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); + SearchRequest searchRequest = new SearchRequest(indices.toArray(new String[0]), searchSourceBuilder); + assertResponse(client().search(searchRequest), response -> { + assertThat(response.getSuccessfulShards(), equalTo(response.getTotalShards())); + assertThat(response.getHits().getTotalHits().value(), equalTo(0L)); + }); + } + + public void testManyInferenceRequests_SparseVectorQuery() throws Exception { + final String semanticTextFieldName = randomAlphaOfLength(10); + int indexCount = 20; + List indices = new ArrayList<>(indexCount); + for (int i = 0; i < indexCount; i++) { + String indexName = randomIdentifier(); + String inferenceId = randomIdentifier(); + XContentBuilder mapping = generateMapping(semanticTextFieldName, inferenceId); + + createInferenceEndpoint(client(), TaskType.SPARSE_EMBEDDING, inferenceId, SPARSE_EMBEDDING_SERVICE_SETTINGS); + assertAcked(prepareCreate(indexName).setMapping(mapping)); + + indices.add(indexName); + } + + QueryBuilder query = new SparseVectorQueryBuilder(semanticTextFieldName, null, randomAlphanumericOfLength(10)); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); + SearchRequest searchRequest = new SearchRequest(indices.toArray(new String[0]), searchSourceBuilder); + assertResponse(client().search(searchRequest), response -> { + assertThat(response.getSuccessfulShards(), equalTo(response.getTotalShards())); + assertThat(response.getHits().getTotalHits().value(), equalTo(0L)); + }); + } + + public void testManyInferenceRequests_MatchQuery() throws Exception { + final String semanticTextFieldName = randomAlphaOfLength(10); + int indexCount = 20; + List indices = new ArrayList<>(indexCount); + for (int i = 0; i < indexCount; i++) { + String indexName = randomIdentifier(); + String inferenceId = randomIdentifier(); + XContentBuilder mapping = generateMapping(semanticTextFieldName, inferenceId); + + createInferenceEndpoint(client(), TaskType.SPARSE_EMBEDDING, inferenceId, SPARSE_EMBEDDING_SERVICE_SETTINGS); + assertAcked(prepareCreate(indexName).setMapping(mapping)); + + indices.add(indexName); + } + + QueryBuilder query = new MatchQueryBuilder(semanticTextFieldName, randomAlphaOfLength(10)); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); + SearchRequest searchRequest = new SearchRequest(indices.toArray(new String[0]), searchSourceBuilder); + assertResponse(client().search(searchRequest), response -> { + assertThat(response.getSuccessfulShards(), equalTo(response.getTotalShards())); + assertThat(response.getHits().getTotalHits().value(), equalTo(0L)); + }); + } + private static XContentBuilder generateMapping(String semanticTextFieldName, String inferenceId) throws IOException { return XContentFactory.jsonBuilder() .startObject() @@ -86,4 +175,22 @@ private static XContentBuilder generateMapping(String semanticTextFieldName, Str .endObject() .endObject(); } + + public static class FakeMlPlugin extends Plugin implements ActionPlugin, SearchPlugin { + @Override + public List getNamedWriteables() { + return new MlInferenceNamedXContentProvider().getNamedWriteables(); + } + + @Override + public List> getQueryVectorBuilders() { + return List.of( + new QueryVectorBuilderSpec<>( + TextEmbeddingQueryVectorBuilder.NAME, + TextEmbeddingQueryVectorBuilder::new, + TextEmbeddingQueryVectorBuilder.PARSER + ) + ); + } + } } From 7630cc4f115989e6da311808bea0c4ba14156a01 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Thu, 16 Oct 2025 12:36:44 -0400 Subject: [PATCH 04/18] Updated semantic query builder to use a supplier to get additional inference results --- .../queries/SemanticQueryBuilder.java | 219 ++++++++++++++++-- 1 file changed, 202 insertions(+), 17 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 97d0caef98d0e..97e83637e358e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -10,12 +10,16 @@ import org.apache.lucene.search.Query; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ResolvedIndices; +import org.elasticsearch.action.support.GroupedActionListener; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.common.collect.ImmutableOpenMap; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Tuple; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.query.AbstractQueryBuilder; @@ -40,6 +44,7 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashSet; @@ -48,7 +53,9 @@ import java.util.Objects; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Supplier; import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; @@ -93,6 +100,7 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder inferenceResultsMap; + private final Supplier> inferenceResultsMapSupplier; private final Boolean lenient; // ccsRequest is only used on the local cluster coordinator node to detect when: @@ -135,6 +143,7 @@ protected SemanticQueryBuilder( this.fieldName = fieldName; this.query = query; this.inferenceResultsMap = inferenceResultsMap != null ? Map.copyOf(inferenceResultsMap) : null; + this.inferenceResultsMapSupplier = null; this.lenient = lenient; this.ccsRequest = ccsRequest; } @@ -143,6 +152,7 @@ public SemanticQueryBuilder(StreamInput in) throws IOException { super(in); this.fieldName = in.readString(); this.query = in.readString(); + if (in.getTransportVersion().supports(INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS)) { this.inferenceResultsMap = in.readOptional( i1 -> i1.readImmutableMap(FullyQualifiedInferenceId::new, i2 -> i2.readNamedWriteable(InferenceResults.class)) @@ -156,20 +166,30 @@ public SemanticQueryBuilder(StreamInput in) throws IOException { this.inferenceResultsMap = inferenceResults != null ? buildSingleResultInferenceResultsMap(inferenceResults) : null; in.readBoolean(); // Discard noInferenceResults, it is no longer necessary } + if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_LENIENT)) { this.lenient = in.readOptionalBoolean(); } else { this.lenient = null; } + if (in.getTransportVersion().supports(SEMANTIC_SEARCH_CCS_SUPPORT)) { this.ccsRequest = in.readBoolean(); } else { this.ccsRequest = false; } + + this.inferenceResultsMapSupplier = null; } @Override protected void doWriteTo(StreamOutput out) throws IOException { + if (inferenceResultsMapSupplier != null) { + throw new IllegalStateException( + "inferenceResultsMapSupplier must be null, can't serialize suppliers, missing a rewriteAndFetch?" + ); + } + out.writeString(fieldName); out.writeString(query); if (out.getTransportVersion().supports(INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS)) { @@ -216,6 +236,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { private SemanticQueryBuilder( SemanticQueryBuilder other, Map inferenceResultsMap, + Supplier> inferenceResultsMapSupplier, boolean ccsRequest ) { this.fieldName = other.fieldName; @@ -224,6 +245,7 @@ private SemanticQueryBuilder( this.queryName = other.queryName; // No need to copy the map here since this is only called internally. We can safely assume that the caller will not modify the map. this.inferenceResultsMap = inferenceResultsMap; + this.inferenceResultsMapSupplier = inferenceResultsMapSupplier; this.lenient = other.lenient; this.ccsRequest = ccsRequest; } @@ -310,6 +332,41 @@ static Map getInferenceResults( return currentInferenceResultsMap; } + static Supplier> getInferenceResultsNew( + QueryRewriteContext queryRewriteContext, + Set fullyQualifiedInferenceIds, + @Nullable Map inferenceResultsMap, + @Nullable String query + ) { + List inferenceIds = new ArrayList<>(fullyQualifiedInferenceIds.size()); + if (query != null) { + for (FullyQualifiedInferenceId fullyQualifiedInferenceId : fullyQualifiedInferenceIds) { + if (inferenceResultsMap == null || inferenceResultsMap.containsKey(fullyQualifiedInferenceId) == false) { + if (fullyQualifiedInferenceId.clusterAlias().equals(queryRewriteContext.getLocalClusterAlias()) == false) { + // Catch if we are missing inference results that should have been generated on another cluster + throw new IllegalStateException( + "Cannot get inference results for inference endpoint [" + + fullyQualifiedInferenceId + + "] on cluster [" + + queryRewriteContext.getLocalClusterAlias() + + "]" + ); + } + + inferenceIds.add(fullyQualifiedInferenceId.inferenceId()); + } + } + } + + InferenceResultsMapSupplier inferenceResultsMapSupplier = null; + if (inferenceIds.isEmpty() == false) { + inferenceResultsMapSupplier = new InferenceResultsMapSupplier(); + registerInferenceAsyncActions(queryRewriteContext, inferenceResultsMapSupplier, query, inferenceIds); + } + + return inferenceResultsMapSupplier; + } + static void registerInferenceAsyncAction( QueryRewriteContext queryRewriteContext, ConcurrentHashMap inferenceResultsMap, @@ -346,6 +403,70 @@ static void registerInferenceAsyncAction( ); } + static void registerInferenceAsyncActions( + QueryRewriteContext queryRewriteContext, + InferenceResultsMapSupplier inferenceResultsMapSupplier, + String query, + List inferenceIds + ) { + List inferenceRequests = inferenceIds.stream() + .map( + i -> new InferenceAction.Request( + TaskType.ANY, + i, + null, + null, + null, + List.of(query), + Map.of(), + InputType.INTERNAL_SEARCH, + null, + false + ) + ) + .toList(); + + queryRewriteContext.registerAsyncAction((client, listener) -> { + GroupedActionListener> gal = createGroupedActionListener( + inferenceResultsMapSupplier, + inferenceRequests.size(), + listener + ); + for (InferenceAction.Request inferenceRequest : inferenceRequests) { + FullyQualifiedInferenceId fullyQualifiedInferenceId = new FullyQualifiedInferenceId( + queryRewriteContext.getLocalClusterAlias(), + inferenceRequest.getInferenceEntityId() + ); + executeAsyncWithOrigin( + client, + ML_ORIGIN, + InferenceAction.INSTANCE, + inferenceRequest, + gal.delegateFailureAndWrap((l, inferenceResponse) -> { + InferenceResults inferenceResults = validateAndConvertInferenceResults( + inferenceResponse.getResults(), + fullyQualifiedInferenceId.inferenceId() + ); + l.onResponse(Tuple.tuple(fullyQualifiedInferenceId, inferenceResults)); + }) + ); + } + }); + } + + private static GroupedActionListener> createGroupedActionListener( + InferenceResultsMapSupplier inferenceResultsMapSupplier, + int inferenceRequestCount, + ActionListener listener + ) { + return new GroupedActionListener<>(inferenceRequestCount, listener.delegateFailureAndWrap((l, responses) -> { + ImmutableOpenMap.Builder mapBuilder = ImmutableOpenMap.builder(); + responses.forEach(r -> mapBuilder.put(r.v1(), r.v2())); + inferenceResultsMapSupplier.set(mapBuilder.build()); + l.onResponse(null); + })); + } + static Map convertFromBwcInferenceResultsMap( Map inferenceResultsMap ) { @@ -459,26 +580,48 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu ); } - SemanticQueryBuilder rewritten = this; - if (queryRewriteContext.hasAsyncActions() == false) { - Set fullyQualifiedInferenceIds = getInferenceIdsForField( - resolvedIndices.getConcreteLocalIndicesMetadata().values(), - queryRewriteContext.getLocalClusterAlias(), - fieldName - ); - Map modifiedInferenceResultsMap = getInferenceResults( - queryRewriteContext, - fullyQualifiedInferenceIds, - inferenceResultsMap, - query - ); + if (inferenceResultsMapSupplier != null) { + // Additional inference results have already been requested, and we are waiting for them to continue the rewrite process + SemanticQueryBuilder rewritten = this; + + Map newInferenceResultsMap = inferenceResultsMapSupplier.get(); + if (newInferenceResultsMap != null) { + Map mergedInferenceResultsMap = mergeInferenceResultsMaps( + inferenceResultsMap, + newInferenceResultsMap + ); + rewritten = new SemanticQueryBuilder(this, mergedInferenceResultsMap, null, ccsRequest); + } + + return rewritten; + } - if (modifiedInferenceResultsMap == inferenceResultsMap) { + Set fullyQualifiedInferenceIds = getInferenceIdsForField( + resolvedIndices.getConcreteLocalIndicesMetadata().values(), + queryRewriteContext.getLocalClusterAlias(), + fieldName + ); + Supplier> newInferenceResultsMapSupplier = getInferenceResultsNew( + queryRewriteContext, + fullyQualifiedInferenceIds, + inferenceResultsMap, + query + ); + + SemanticQueryBuilder rewritten = this; + if (newInferenceResultsMapSupplier == null) { + // No additional inference results are required + if (inferenceResultsMap != null) { // The inference results map is fully populated, so we can perform error checking - inferenceResultsErrorCheck(modifiedInferenceResultsMap); + inferenceResultsErrorCheck(inferenceResultsMap); } else { - rewritten = new SemanticQueryBuilder(this, modifiedInferenceResultsMap, ccsRequest); + // No inference results have been collected yet, indicating we don't need any to rewrite this query. + // This can happen when querying an unsupported field type or an unavailable index. Set an empty inference results map so + // that rewriting can continue. + rewritten = new SemanticQueryBuilder(this, Map.of(), null, ccsRequest); } + } else { + rewritten = new SemanticQueryBuilder(this, inferenceResultsMap, newInferenceResultsMapSupplier, ccsRequest); } return rewritten; @@ -526,6 +669,21 @@ private static InferenceResults validateAndConvertInferenceResults( return inferenceResults; } + private static Map mergeInferenceResultsMaps( + @Nullable Map originalInferenceResultsMap, + Map newInferenceResultsMap + ) { + Map mergedInferenceResultsMap = newInferenceResultsMap; + if (originalInferenceResultsMap != null && originalInferenceResultsMap.isEmpty() == false) { + mergedInferenceResultsMap = Stream.concat( + originalInferenceResultsMap.entrySet().stream(), + newInferenceResultsMap.entrySet().stream() + ).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + return mergedInferenceResultsMap; + } + private void inferenceResultsErrorCheck(Map inferenceResultsMap) { for (var entry : inferenceResultsMap.entrySet()) { String inferenceId = entry.getKey().inferenceId(); @@ -577,11 +735,38 @@ protected boolean doEquals(SemanticQueryBuilder other) { return Objects.equals(fieldName, other.fieldName) && Objects.equals(query, other.query) && Objects.equals(inferenceResultsMap, other.inferenceResultsMap) + && Objects.equals(inferenceResultsMapSupplier, other.inferenceResultsMapSupplier) && Objects.equals(ccsRequest, other.ccsRequest); } @Override protected int doHashCode() { - return Objects.hash(fieldName, query, inferenceResultsMap, ccsRequest); + return Objects.hash(fieldName, query, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); + } + + private static class InferenceResultsMapSupplier implements Supplier> { + private Map inferenceResultsMap = null; + + private void set(Map inferenceResultsMap) { + this.inferenceResultsMap = inferenceResultsMap; + } + + @Override + public Map get() { + return inferenceResultsMap; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferenceResultsMapSupplier that = (InferenceResultsMapSupplier) o; + return Objects.equals(inferenceResultsMap, that.inferenceResultsMap); + } + + @Override + public int hashCode() { + return Objects.hashCode(inferenceResultsMap); + } } } From 2a03a5651a541f60734e84a9615de913da2abf47 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Thu, 16 Oct 2025 14:00:43 -0400 Subject: [PATCH 05/18] Updated intercepted queries to use a supplier to get additional inference results --- .../queries/InferenceResultsMapSupplier.java | 40 ++++++++++ ...rceptedInferenceKnnVectorQueryBuilder.java | 12 ++- ...InterceptedInferenceMatchQueryBuilder.java | 12 ++- .../InterceptedInferenceQueryBuilder.java | 74 ++++++++++++++----- ...ptedInferenceSparseVectorQueryBuilder.java | 12 ++- .../queries/SemanticQueryBuilder.java | 56 ++++---------- 6 files changed, 139 insertions(+), 67 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceResultsMapSupplier.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceResultsMapSupplier.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceResultsMapSupplier.java new file mode 100644 index 0000000000000..16558e9df579e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceResultsMapSupplier.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.queries; + +import org.elasticsearch.inference.InferenceResults; + +import java.util.Map; +import java.util.Objects; +import java.util.function.Supplier; + +class InferenceResultsMapSupplier implements Supplier> { + private Map inferenceResultsMap = null; + + void set(Map inferenceResultsMap) { + this.inferenceResultsMap = inferenceResultsMap; + } + + @Override + public Map get() { + return inferenceResultsMap; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferenceResultsMapSupplier that = (InferenceResultsMapSupplier) o; + return Objects.equals(inferenceResultsMap, that.inferenceResultsMap); + } + + @Override + public int hashCode() { + return Objects.hashCode(inferenceResultsMap); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java index 210ab2b67f9c9..aaaec0e64206b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java @@ -33,6 +33,7 @@ import java.io.IOException; import java.util.Collection; import java.util.Map; +import java.util.function.Supplier; import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; @@ -62,9 +63,10 @@ public InterceptedInferenceKnnVectorQueryBuilder(StreamInput in) throws IOExcept private InterceptedInferenceKnnVectorQueryBuilder( InterceptedInferenceQueryBuilder other, Map inferenceResultsMap, + Supplier> inferenceResultsMapSupplier, boolean ccsRequest ) { - super(other, inferenceResultsMap, ccsRequest); + super(other, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); } @Override @@ -127,8 +129,12 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) { } @Override - protected QueryBuilder copy(Map inferenceResultsMap, boolean ccsRequest) { - return new InterceptedInferenceKnnVectorQueryBuilder(this, inferenceResultsMap, ccsRequest); + protected QueryBuilder copy( + Map inferenceResultsMap, + Supplier> inferenceResultsMapSupplier, + boolean ccsRequest + ) { + return new InterceptedInferenceKnnVectorQueryBuilder(this, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java index a689c8a3a7f1e..8388ad44f53a0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.Map; +import java.util.function.Supplier; public class InterceptedInferenceMatchQueryBuilder extends InterceptedInferenceQueryBuilder { public static final String NAME = "intercepted_inference_match"; @@ -47,9 +48,10 @@ public InterceptedInferenceMatchQueryBuilder(StreamInput in) throws IOException private InterceptedInferenceMatchQueryBuilder( InterceptedInferenceQueryBuilder other, Map inferenceResultsMap, + Supplier> inferenceResultsMapSupplier, boolean ccsRequest ) { - super(other, inferenceResultsMap, ccsRequest); + super(other, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); } @Override @@ -73,8 +75,12 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) { } @Override - protected QueryBuilder copy(Map inferenceResultsMap, boolean ccsRequest) { - return new InterceptedInferenceMatchQueryBuilder(this, inferenceResultsMap, ccsRequest); + protected QueryBuilder copy( + Map inferenceResultsMap, + Supplier> inferenceResultsMapSupplier, + boolean ccsRequest + ) { + return new InterceptedInferenceMatchQueryBuilder(this, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java index 8774d35f17ade..fe3c98f64cf07 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java @@ -37,11 +37,14 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Supplier; import static org.elasticsearch.index.IndexSettings.DEFAULT_FIELD_SETTING; import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.SEMANTIC_SEARCH_CCS_SUPPORT; import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.convertFromBwcInferenceResultsMap; +import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.getInferenceResultsNew; +import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.mergeInferenceResultsMaps; /** *

@@ -67,6 +70,7 @@ public abstract class InterceptedInferenceQueryBuilder inferenceResultsMap; + protected final Supplier> inferenceResultsMapSupplier; protected final boolean ccsRequest; protected InterceptedInferenceQueryBuilder(T originalQuery) { @@ -77,6 +81,7 @@ protected InterceptedInferenceQueryBuilder(T originalQuery, Map other, Map inferenceResultsMap, + Supplier> inferenceResultsMapSupplier, boolean ccsRequest ) { this.originalQuery = other.originalQuery; this.inferenceResultsMap = inferenceResultsMap; + this.inferenceResultsMapSupplier = inferenceResultsMapSupplier; this.ccsRequest = ccsRequest; } @@ -144,13 +153,18 @@ protected InterceptedInferenceQueryBuilder( protected abstract QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext); /** - * Generate a copy of {@code this} using the provided inference results map. + * Generate a copy of {@code this}. * * @param inferenceResultsMap The inference results map + * @param inferenceResultsMapSupplier The inference results map supplier * @param ccsRequest Flag indicating if this is a CCS request * @return A copy of {@code this} with the provided inference results map */ - protected abstract QueryBuilder copy(Map inferenceResultsMap, boolean ccsRequest); + protected abstract QueryBuilder copy( + Map inferenceResultsMap, + Supplier> inferenceResultsMapSupplier, + boolean ccsRequest + ); /** * Rewrite to a {@link QueryBuilder} appropriate for a specific index's mappings. The implementation can use @@ -195,6 +209,12 @@ protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {} @Override protected void doWriteTo(StreamOutput out) throws IOException { + if (inferenceResultsMapSupplier != null) { + throw new IllegalStateException( + "inferenceResultsMapSupplier must be null, can't serialize suppliers, missing a rewriteAndFetch?" + ); + } + out.writeNamedWriteable(originalQuery); if (out.getTransportVersion().supports(INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS)) { out.writeOptional( @@ -238,12 +258,13 @@ protected Query doToQuery(SearchExecutionContext context) { protected boolean doEquals(InterceptedInferenceQueryBuilder other) { return Objects.equals(originalQuery, other.originalQuery) && Objects.equals(inferenceResultsMap, other.inferenceResultsMap) + && Objects.equals(inferenceResultsMapSupplier, other.inferenceResultsMapSupplier) && Objects.equals(ccsRequest, other.ccsRequest); } @Override protected int doHashCode() { - return Objects.hash(originalQuery, inferenceResultsMap, ccsRequest); + return Objects.hash(originalQuery, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); } @Override @@ -321,29 +342,48 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri ); } + if (inferenceResultsMapSupplier != null) { + // Additional inference results have already been requested, and we are waiting for them to continue the rewrite process + QueryBuilder rewritten = this; + + Map newInferenceResultsMap = inferenceResultsMapSupplier.get(); + if (newInferenceResultsMap != null) { + Map mergedInferenceResultsMap = mergeInferenceResultsMaps( + inferenceResultsMap, + newInferenceResultsMap + ); + rewritten = copy(mergedInferenceResultsMap, null, ccsRequest); + } + + return rewritten; + } + FullyQualifiedInferenceId inferenceIdOverride = getInferenceIdOverride(); if (inferenceIdOverride != null) { inferenceIds = Set.of(inferenceIdOverride); } - QueryBuilder rewritten = this; - if (queryRewriteContext.hasAsyncActions() == false) { - // If the query is null, there's nothing to generate inference results for. This can happen if pre-computed inference results - // are provided by the user. Ensure that we set an empty inference results map in this case so that it is always non-null after - // coordinator node rewrite. - Map modifiedInferenceResultsMap = SemanticQueryBuilder.getInferenceResults( - queryRewriteContext, - inferenceIds, - this.inferenceResultsMap, - getQuery() - ); + Supplier> newInferenceResultsMapSupplier = getInferenceResultsNew( + queryRewriteContext, + inferenceIds, + inferenceResultsMap, + getQuery() + ); - if (modifiedInferenceResultsMap == this.inferenceResultsMap) { + QueryBuilder rewritten = this; + if (newInferenceResultsMapSupplier == null) { + // No additional inference results are required + if (inferenceResultsMap != null) { // The inference results map is fully populated, so we can perform error checking - inferenceResultsErrorCheck(modifiedInferenceResultsMap); + inferenceResultsErrorCheck(inferenceResultsMap); } else { - rewritten = copy(modifiedInferenceResultsMap, ccsRequest); + // No inference results have been collected yet, indicating we don't need any to rewrite this query. + // This can happen when pre-computed inference results are provided by the user. + // Set an empty inference results map so that rewriting can continue. + rewritten = copy(Map.of(), null, ccsRequest); } + } else { + rewritten = copy(inferenceResultsMap, newInferenceResultsMapSupplier, ccsRequest); } return rewritten; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java index d2f94593f4340..809250a8410c3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java @@ -32,6 +32,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.function.Supplier; import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; @@ -61,9 +62,10 @@ public InterceptedInferenceSparseVectorQueryBuilder(StreamInput in) throws IOExc private InterceptedInferenceSparseVectorQueryBuilder( InterceptedInferenceQueryBuilder other, Map inferenceResultsMap, + Supplier> inferenceResultsMapSupplier, boolean ccsRequest ) { - super(other, inferenceResultsMap, ccsRequest); + super(other, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); } @Override @@ -114,8 +116,12 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) { } @Override - protected QueryBuilder copy(Map inferenceResultsMap, boolean ccsRequest) { - return new InterceptedInferenceSparseVectorQueryBuilder(this, inferenceResultsMap, ccsRequest); + protected QueryBuilder copy( + Map inferenceResultsMap, + Supplier> inferenceResultsMapSupplier, + boolean ccsRequest + ) { + return new InterceptedInferenceSparseVectorQueryBuilder(this, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 97e83637e358e..577bffff3629e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -367,6 +367,21 @@ static Supplier> getInferenceRe return inferenceResultsMapSupplier; } + static Map mergeInferenceResultsMaps( + @Nullable Map originalInferenceResultsMap, + Map newInferenceResultsMap + ) { + Map mergedInferenceResultsMap = newInferenceResultsMap; + if (originalInferenceResultsMap != null && originalInferenceResultsMap.isEmpty() == false) { + mergedInferenceResultsMap = Stream.concat( + originalInferenceResultsMap.entrySet().stream(), + newInferenceResultsMap.entrySet().stream() + ).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + return mergedInferenceResultsMap; + } + static void registerInferenceAsyncAction( QueryRewriteContext queryRewriteContext, ConcurrentHashMap inferenceResultsMap, @@ -669,21 +684,6 @@ private static InferenceResults validateAndConvertInferenceResults( return inferenceResults; } - private static Map mergeInferenceResultsMaps( - @Nullable Map originalInferenceResultsMap, - Map newInferenceResultsMap - ) { - Map mergedInferenceResultsMap = newInferenceResultsMap; - if (originalInferenceResultsMap != null && originalInferenceResultsMap.isEmpty() == false) { - mergedInferenceResultsMap = Stream.concat( - originalInferenceResultsMap.entrySet().stream(), - newInferenceResultsMap.entrySet().stream() - ).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - } - - return mergedInferenceResultsMap; - } - private void inferenceResultsErrorCheck(Map inferenceResultsMap) { for (var entry : inferenceResultsMap.entrySet()) { String inferenceId = entry.getKey().inferenceId(); @@ -743,30 +743,4 @@ protected boolean doEquals(SemanticQueryBuilder other) { protected int doHashCode() { return Objects.hash(fieldName, query, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); } - - private static class InferenceResultsMapSupplier implements Supplier> { - private Map inferenceResultsMap = null; - - private void set(Map inferenceResultsMap) { - this.inferenceResultsMap = inferenceResultsMap; - } - - @Override - public Map get() { - return inferenceResultsMap; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceResultsMapSupplier that = (InferenceResultsMapSupplier) o; - return Objects.equals(inferenceResultsMap, that.inferenceResultsMap); - } - - @Override - public int hashCode() { - return Objects.hashCode(inferenceResultsMap); - } - } } From 40c318f086a1930207d44b350d0bc78998eefb5a Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Thu, 16 Oct 2025 14:11:39 -0400 Subject: [PATCH 06/18] Remove obsolete methods --- .../InterceptedInferenceQueryBuilder.java | 4 +- .../queries/SemanticQueryBuilder.java | 92 +------------------ 2 files changed, 7 insertions(+), 89 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java index fe3c98f64cf07..d23de8ae615f1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java @@ -43,7 +43,7 @@ import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.SEMANTIC_SEARCH_CCS_SUPPORT; import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.convertFromBwcInferenceResultsMap; -import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.getInferenceResultsNew; +import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.getInferenceResults; import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.mergeInferenceResultsMaps; /** @@ -363,7 +363,7 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri inferenceIds = Set.of(inferenceIdOverride); } - Supplier> newInferenceResultsMapSupplier = getInferenceResultsNew( + Supplier> newInferenceResultsMapSupplier = getInferenceResults( queryRewriteContext, inferenceIds, inferenceResultsMap, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 577bffff3629e..05d08962b95f0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -52,7 +52,6 @@ import java.util.Map; import java.util.Objects; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -277,62 +276,17 @@ public static SemanticQueryBuilder fromXContent(XContentParser parser) throws IO * Get inference results for the provided query using the provided fully qualified inference IDs. *

*

- * This method will return an inference results map that will be asynchronously populated with inference results. If the provided - * inference results map already contains all required inference results, the same map instance will be returned. Otherwise, a new map - * instance will be returned. It is guaranteed that a non-null map instance will be returned. + * This method will return an inference results map supplier that will provide a complete map of additional inference results required. + * If the provided inference results map already contains all required inference results, a null supplier will be returned. *

* * @param queryRewriteContext The query rewrite context * @param fullyQualifiedInferenceIds The fully qualified inference IDs to use to generate inference results * @param inferenceResultsMap The initial inference results map * @param query The query to generate inference results for - * @return An inference results map + * @return An inference results map supplier */ - static Map getInferenceResults( - QueryRewriteContext queryRewriteContext, - Set fullyQualifiedInferenceIds, - @Nullable Map inferenceResultsMap, - @Nullable String query - ) { - boolean modifiedInferenceResultsMap = false; - Map currentInferenceResultsMap = inferenceResultsMap != null - ? inferenceResultsMap - : Map.of(); - - if (query != null) { - for (FullyQualifiedInferenceId fullyQualifiedInferenceId : fullyQualifiedInferenceIds) { - if (currentInferenceResultsMap.containsKey(fullyQualifiedInferenceId) == false) { - if (fullyQualifiedInferenceId.clusterAlias().equals(queryRewriteContext.getLocalClusterAlias()) == false) { - // Catch if we are missing inference results that should have been generated on another cluster - throw new IllegalStateException( - "Cannot get inference results for inference endpoint [" - + fullyQualifiedInferenceId - + "] on cluster [" - + queryRewriteContext.getLocalClusterAlias() - + "]" - ); - } - - if (modifiedInferenceResultsMap == false) { - // Copy the inference results map to ensure it is mutable and thread safe - currentInferenceResultsMap = new ConcurrentHashMap<>(currentInferenceResultsMap); - modifiedInferenceResultsMap = true; - } - - registerInferenceAsyncAction( - queryRewriteContext, - ((ConcurrentHashMap) currentInferenceResultsMap), - query, - fullyQualifiedInferenceId.inferenceId() - ); - } - } - } - - return currentInferenceResultsMap; - } - - static Supplier> getInferenceResultsNew( + static Supplier> getInferenceResults( QueryRewriteContext queryRewriteContext, Set fullyQualifiedInferenceIds, @Nullable Map inferenceResultsMap, @@ -382,42 +336,6 @@ static Map mergeInferenceResultsMap return mergedInferenceResultsMap; } - static void registerInferenceAsyncAction( - QueryRewriteContext queryRewriteContext, - ConcurrentHashMap inferenceResultsMap, - String query, - String inferenceId - ) { - InferenceAction.Request inferenceRequest = new InferenceAction.Request( - TaskType.ANY, - inferenceId, - null, - null, - null, - List.of(query), - Map.of(), - InputType.INTERNAL_SEARCH, - null, - false - ); - - queryRewriteContext.registerAsyncAction( - (client, listener) -> executeAsyncWithOrigin( - client, - ML_ORIGIN, - InferenceAction.INSTANCE, - inferenceRequest, - listener.delegateFailureAndWrap((l, inferenceResponse) -> { - inferenceResultsMap.put( - new FullyQualifiedInferenceId(queryRewriteContext.getLocalClusterAlias(), inferenceId), - validateAndConvertInferenceResults(inferenceResponse.getResults(), inferenceId) - ); - l.onResponse(null); - }) - ) - ); - } - static void registerInferenceAsyncActions( QueryRewriteContext queryRewriteContext, InferenceResultsMapSupplier inferenceResultsMapSupplier, @@ -616,7 +534,7 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu queryRewriteContext.getLocalClusterAlias(), fieldName ); - Supplier> newInferenceResultsMapSupplier = getInferenceResultsNew( + Supplier> newInferenceResultsMapSupplier = getInferenceResults( queryRewriteContext, fullyQualifiedInferenceIds, inferenceResultsMap, From c0242fca98ee0078575d8600019b3490f07834b0 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Thu, 16 Oct 2025 15:04:37 -0400 Subject: [PATCH 07/18] Added tests --- ...actSemanticCrossClusterSearchTestCase.java | 2 +- ...ceUtils.java => IntegrationTestUtils.java} | 47 ++++- .../integration/SemanticQueryInferenceIT.java | 179 ++++++++---------- .../SemanticTextIndexOptionsIT.java | 2 +- .../SemanticTextInferenceFieldsIT.java | 48 +---- 5 files changed, 129 insertions(+), 149 deletions(-) rename x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/{InferenceUtils.java => IntegrationTestUtils.java} (56%) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java index 2106e524f0ea2..f4d9d45d00c87 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java @@ -60,7 +60,7 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; -import static org.elasticsearch.xpack.inference.integration.InferenceUtils.createInferenceEndpoint; +import static org.elasticsearch.xpack.inference.integration.IntegrationTestUtils.createInferenceEndpoint; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceUtils.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/IntegrationTestUtils.java similarity index 56% rename from x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceUtils.java rename to x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/IntegrationTestUtils.java index a79ad218254ea..14f48526824e3 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceUtils.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/IntegrationTestUtils.java @@ -7,13 +7,16 @@ package org.elasticsearch.xpack.inference.integration; +import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; @@ -21,11 +24,13 @@ import java.util.Map; import static org.elasticsearch.test.ESTestCase.TEST_REQUEST_TIMEOUT; +import static org.elasticsearch.test.ESTestCase.safeGet; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; -public class InferenceUtils { - private InferenceUtils() {} +public class IntegrationTestUtils { + private IntegrationTestUtils() {} public static void createInferenceEndpoint(Client client, TaskType taskType, String inferenceId, Map serviceSettings) throws IOException { @@ -55,4 +60,42 @@ public static void createInferenceEndpoint(Client client, TaskType taskType, Str var responseFuture = client.execute(PutInferenceModelAction.INSTANCE, request); assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId)); } + + public static void deleteInferenceEndpoint(Client client, TaskType taskType, String inferenceId) { + assertAcked( + safeGet( + client.execute( + DeleteInferenceEndpointAction.INSTANCE, + new DeleteInferenceEndpointAction.Request(inferenceId, taskType, true, false) + ) + ) + ); + } + + public static void deleteIndex(Client client, String indexName) { + assertAcked( + safeGet( + client.admin() + .indices() + .prepareDelete(indexName) + .setIndicesOptions( + IndicesOptions.builder().concreteTargetOptions(new IndicesOptions.ConcreteTargetOptions(true)).build() + ) + .execute() + ) + ); + } + + public static XContentBuilder generateMapping(Map semanticTextFields) throws IOException { + XContentBuilder mapping = XContentFactory.jsonBuilder().startObject().startObject("properties"); + for (var entry : semanticTextFields.entrySet()) { + mapping.startObject(entry.getKey()); + mapping.field("type", SemanticTextFieldMapper.CONTENT_TYPE); + mapping.field("inference_id", entry.getValue()); + mapping.endObject(); + } + mapping.endObject().endObject(); + + return mapping; + } } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticQueryInferenceIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticQueryInferenceIT.java index 65ec23cad1687..52dbc8eeff21b 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticQueryInferenceIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticQueryInferenceIT.java @@ -7,11 +7,14 @@ package org.elasticsearch.xpack.inference.integration; +import org.elasticsearch.action.DocWriteResponse; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.MatchQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.inference.TaskType; import org.elasticsearch.license.LicenseSettings; import org.elasticsearch.plugins.ActionPlugin; @@ -22,29 +25,43 @@ import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; -import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; +import org.junit.After; import java.io.IOException; -import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.BiFunction; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; -import static org.elasticsearch.xpack.inference.integration.InferenceUtils.createInferenceEndpoint; import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; -@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST) +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE, numDataNodes = 1) public class SemanticQueryInferenceIT extends ESIntegTestCase { + private static final String INDEX_NAME = "test_index"; + private static final Map SPARSE_EMBEDDING_SERVICE_SETTINGS = Map.of("model", "my_model", "api_key", "my_api_key"); + private static final Map TEXT_EMBEDDING_SERVICE_SETTINGS = Map.of( + "model", + "my_model", + "dimensions", + 256, + "similarity", + "cosine", + "api_key", + "my_api_key" + ); + + private final Map inferenceIds = new HashMap<>(); @Override protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { @@ -53,127 +70,85 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { @Override protected Collection> nodePlugins() { - return List.of( - LocalStateInferencePlugin.class, - TestInferenceServicePlugin.class, - ReindexPlugin.class, - FakeMlPlugin.class - ); + return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class, ReindexPlugin.class, FakeMlPlugin.class); } - public void testManyInferenceRequests_SemanticQuery() throws Exception { - final String semanticTextFieldName = randomAlphaOfLength(10); - int indexCount = 20; - List indices = new ArrayList<>(indexCount); - for (int i = 0; i < indexCount; i++) { - String indexName = randomIdentifier(); - String inferenceId = randomIdentifier(); - XContentBuilder mapping = generateMapping(semanticTextFieldName, inferenceId); - - createInferenceEndpoint(client(), TaskType.SPARSE_EMBEDDING, inferenceId, SPARSE_EMBEDDING_SERVICE_SETTINGS); - assertAcked(prepareCreate(indexName).setMapping(mapping)); - - indices.add(indexName); + @After + public void cleanUp() { + IntegrationTestUtils.deleteIndex(client(), INDEX_NAME); + for (var entry : inferenceIds.entrySet()) { + IntegrationTestUtils.deleteInferenceEndpoint(client(), entry.getValue(), entry.getKey()); } - - SemanticQueryBuilder query = new SemanticQueryBuilder(semanticTextFieldName, randomAlphaOfLength(10)); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); - SearchRequest searchRequest = new SearchRequest(indices.toArray(new String[0]), searchSourceBuilder); - assertResponse(client().search(searchRequest), response -> { - assertThat(response.getSuccessfulShards(), equalTo(response.getTotalShards())); - assertThat(response.getHits().getTotalHits().value(), equalTo(0L)); - }); } - public void testManyInferenceRequests_KnnQuery() throws Exception { - final String semanticTextFieldName = randomAlphaOfLength(10); - int indexCount = 20; - List indices = new ArrayList<>(indexCount); - for (int i = 0; i < indexCount; i++) { - String indexName = randomIdentifier(); - String inferenceId = randomIdentifier(); - XContentBuilder mapping = generateMapping(semanticTextFieldName, inferenceId); + public void testManySemanticQueryClauses() throws Exception { + manyQueryClausesTestCase(randomIntBetween(18, 24), SemanticQueryBuilder::new, TaskType.SPARSE_EMBEDDING); + } - createInferenceEndpoint(client(), TaskType.SPARSE_EMBEDDING, inferenceId, SPARSE_EMBEDDING_SERVICE_SETTINGS); - assertAcked(prepareCreate(indexName).setMapping(mapping)); + public void testManyMatchQueryClauses() throws Exception { + manyQueryClausesTestCase(randomIntBetween(18, 24), MatchQueryBuilder::new, TaskType.SPARSE_EMBEDDING); + } - indices.add(indexName); - } + public void testManySparseVectorQueryClauses() throws Exception { + manyQueryClausesTestCase(randomIntBetween(18, 24), (f, q) -> new SparseVectorQueryBuilder(f, null, q), TaskType.SPARSE_EMBEDDING); + } - QueryBuilder query = new KnnVectorQueryBuilder( - semanticTextFieldName, - new TextEmbeddingQueryVectorBuilder(null, randomAlphanumericOfLength(10)), - 10, - 100, - 10f, - null + public void testManyKnnQueryClauses() throws Exception { + int clauseCount = randomIntBetween(18, 24); + manyQueryClausesTestCase( + clauseCount, + (f, q) -> new KnnVectorQueryBuilder(f, new TextEmbeddingQueryVectorBuilder(null, q), clauseCount, clauseCount * 10, null, null), + TaskType.TEXT_EMBEDDING ); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); - SearchRequest searchRequest = new SearchRequest(indices.toArray(new String[0]), searchSourceBuilder); - assertResponse(client().search(searchRequest), response -> { - assertThat(response.getSuccessfulShards(), equalTo(response.getTotalShards())); - assertThat(response.getHits().getTotalHits().value(), equalTo(0L)); - }); } - public void testManyInferenceRequests_SparseVectorQuery() throws Exception { - final String semanticTextFieldName = randomAlphaOfLength(10); - int indexCount = 20; - List indices = new ArrayList<>(indexCount); - for (int i = 0; i < indexCount; i++) { - String indexName = randomIdentifier(); + private void manyQueryClausesTestCase( + int clauseCount, + BiFunction clauseGenerator, + TaskType semanticTextFieldTaskType + ) throws Exception { + Map inferenceEndpointServiceSettings = getServiceSettings(semanticTextFieldTaskType); + Map semanticTextFields = new HashMap<>(clauseCount); + for (int i = 0; i < clauseCount; i++) { + String fieldName = randomAlphaOfLength(10); String inferenceId = randomIdentifier(); - XContentBuilder mapping = generateMapping(semanticTextFieldName, inferenceId); - createInferenceEndpoint(client(), TaskType.SPARSE_EMBEDDING, inferenceId, SPARSE_EMBEDDING_SERVICE_SETTINGS); - assertAcked(prepareCreate(indexName).setMapping(mapping)); - - indices.add(indexName); + createInferenceEndpoint(semanticTextFieldTaskType, inferenceId, inferenceEndpointServiceSettings); + semanticTextFields.put(fieldName, inferenceId); } - QueryBuilder query = new SparseVectorQueryBuilder(semanticTextFieldName, null, randomAlphanumericOfLength(10)); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); - SearchRequest searchRequest = new SearchRequest(indices.toArray(new String[0]), searchSourceBuilder); - assertResponse(client().search(searchRequest), response -> { - assertThat(response.getSuccessfulShards(), equalTo(response.getTotalShards())); - assertThat(response.getHits().getTotalHits().value(), equalTo(0L)); - }); - } - - public void testManyInferenceRequests_MatchQuery() throws Exception { - final String semanticTextFieldName = randomAlphaOfLength(10); - int indexCount = 20; - List indices = new ArrayList<>(indexCount); - for (int i = 0; i < indexCount; i++) { - String indexName = randomIdentifier(); - String inferenceId = randomIdentifier(); - XContentBuilder mapping = generateMapping(semanticTextFieldName, inferenceId); + XContentBuilder mapping = IntegrationTestUtils.generateMapping(semanticTextFields); + assertAcked(prepareCreate(INDEX_NAME).setMapping(mapping)); - createInferenceEndpoint(client(), TaskType.SPARSE_EMBEDDING, inferenceId, SPARSE_EMBEDDING_SERVICE_SETTINGS); - assertAcked(prepareCreate(indexName).setMapping(mapping)); + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); + for (String semanticTextField : semanticTextFields.keySet()) { + Map source = Map.of(semanticTextField, randomAlphaOfLength(10)); + DocWriteResponse docWriteResponse = client().prepareIndex(INDEX_NAME).setSource(source).get(TEST_REQUEST_TIMEOUT); + assertThat(docWriteResponse.getResult(), is(DocWriteResponse.Result.CREATED)); - indices.add(indexName); + boolQuery.should(clauseGenerator.apply(semanticTextField, randomAlphaOfLength(10))); } + client().admin().indices().prepareRefresh(INDEX_NAME).get(); - QueryBuilder query = new MatchQueryBuilder(semanticTextFieldName, randomAlphaOfLength(10)); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); - SearchRequest searchRequest = new SearchRequest(indices.toArray(new String[0]), searchSourceBuilder); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(boolQuery).size(clauseCount); + SearchRequest searchRequest = new SearchRequest(new String[] { INDEX_NAME }, searchSourceBuilder); assertResponse(client().search(searchRequest), response -> { assertThat(response.getSuccessfulShards(), equalTo(response.getTotalShards())); - assertThat(response.getHits().getTotalHits().value(), equalTo(0L)); + assertThat(response.getHits().getTotalHits().value(), equalTo((long) clauseCount)); }); } - private static XContentBuilder generateMapping(String semanticTextFieldName, String inferenceId) throws IOException { - return XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject(semanticTextFieldName) - .field("type", SemanticTextFieldMapper.CONTENT_TYPE) - .field("inference_id", inferenceId) - .endObject() - .endObject() - .endObject(); + private void createInferenceEndpoint(TaskType taskType, String inferenceId, Map serviceSettings) throws IOException { + IntegrationTestUtils.createInferenceEndpoint(client(), taskType, inferenceId, serviceSettings); + inferenceIds.put(inferenceId, taskType); + } + + private static Map getServiceSettings(TaskType taskType) { + return switch (taskType) { + case SPARSE_EMBEDDING -> SPARSE_EMBEDDING_SERVICE_SETTINGS; + case TEXT_EMBEDDING -> TEXT_EMBEDDING_SERVICE_SETTINGS; + default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]"); + }; } public static class FakeMlPlugin extends Plugin implements ActionPlugin, SearchPlugin { diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexOptionsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexOptionsIT.java index 54055a9bc20da..1bd79aab95a4f 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexOptionsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexOptionsIT.java @@ -150,7 +150,7 @@ public void testSetDefaultBBQIndexOptionsWithBasicLicense() throws Exception { } private void createInferenceEndpoint(TaskType taskType, String inferenceId, Map serviceSettings) throws IOException { - InferenceUtils.createInferenceEndpoint(client(), taskType, inferenceId, serviceSettings); + IntegrationTestUtils.createInferenceEndpoint(client(), taskType, inferenceId, serviceSettings); inferenceIds.put(inferenceId, taskType); } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java index b7f940163c01b..8403c13014f92 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java @@ -10,7 +10,6 @@ import org.elasticsearch.action.DocWriteResponse; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.document.DocumentField; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -31,11 +30,8 @@ import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.index.IndexVersionUtils; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; -import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; import org.junit.After; @@ -89,16 +85,9 @@ protected boolean forbidPrivateIndexSettings() { @After public void cleanUp() { - deleteIndex(indexName); + IntegrationTestUtils.deleteIndex(client(), indexName); for (var entry : inferenceIds.entrySet()) { - assertAcked( - safeGet( - client().execute( - DeleteInferenceEndpointAction.INSTANCE, - new DeleteInferenceEndpointAction.Request(entry.getKey(), entry.getValue(), true, false) - ) - ) - ); + IntegrationTestUtils.deleteInferenceEndpoint(client(), entry.getValue(), entry.getKey()); } } @@ -127,7 +116,7 @@ private void excludeInferenceFieldsFromSourceTestCase(IndexVersion minIndexVersi for (int i = 0; i < iterations; i++) { final IndexVersion indexVersion = IndexVersionUtils.randomVersionBetween(random(), minIndexVersion, maxIndexVersion); final Settings indexSettings = generateIndexSettings(indexVersion); - XContentBuilder mappings = generateMapping( + XContentBuilder mappings = IntegrationTestUtils.generateMapping( Map.of(sparseEmbeddingField, sparseEmbeddingInferenceId, textEmbeddingField, textEmbeddingInferenceId) ); assertAcked(prepareCreate(indexName).setSettings(indexSettings).setMapping(mappings)); @@ -158,12 +147,12 @@ private void excludeInferenceFieldsFromSourceTestCase(IndexVersion minIndexVersi } }); - deleteIndex(indexName); + IntegrationTestUtils.deleteIndex(client(), indexName); } } private void createInferenceEndpoint(TaskType taskType, String inferenceId, Map serviceSettings) throws IOException { - InferenceUtils.createInferenceEndpoint(client(), taskType, inferenceId, serviceSettings); + IntegrationTestUtils.createInferenceEndpoint(client(), taskType, inferenceId, serviceSettings); inferenceIds.put(inferenceId, taskType); } @@ -292,33 +281,6 @@ private static FetchSourceContext generateRandomFetchSourceContext() { return fetchSourceContext; } - private static XContentBuilder generateMapping(Map semanticTextFields) throws IOException { - XContentBuilder mapping = XContentFactory.jsonBuilder().startObject().startObject("properties"); - for (var entry : semanticTextFields.entrySet()) { - mapping.startObject(entry.getKey()); - mapping.field("type", SemanticTextFieldMapper.CONTENT_TYPE); - mapping.field("inference_id", entry.getValue()); - mapping.endObject(); - } - mapping.endObject().endObject(); - - return mapping; - } - - private static void deleteIndex(String indexName) { - assertAcked( - safeGet( - client().admin() - .indices() - .prepareDelete(indexName) - .setIndicesOptions( - IndicesOptions.builder().concreteTargetOptions(new IndicesOptions.ConcreteTargetOptions(true)).build() - ) - .execute() - ) - ); - } - private enum ExpectedSource { NONE, INFERENCE_FIELDS_EXCLUDED, From 4a6ccff1c81541240fc4a81f51798d0d4e688084 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Thu, 16 Oct 2025 15:22:13 -0400 Subject: [PATCH 08/18] Refactoring --- ...actSemanticCrossClusterSearchTestCase.java | 31 +------------ .../integration/SemanticQueryInferenceIT.java | 23 +--------- .../SemanticTextIndexVersionIT.java | 10 +---- .../SemanticTextInferenceFieldsIT.java | 10 +---- .../xpack/inference/FakeMlPlugin.java | 43 +++++++++++++++++++ .../queries/SemanticQueryBuilderTests.java | 10 +---- 6 files changed, 48 insertions(+), 79 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/FakeMlPlugin.java diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java index f4d9d45d00c87..6c268a318549b 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java @@ -20,7 +20,6 @@ import org.elasticsearch.action.support.broadcast.BroadcastResponse; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; @@ -31,21 +30,16 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.license.LicenseSettings; -import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.AbstractMultiClustersTestCase; import org.elasticsearch.transport.RemoteConnectionInfo; -import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; -import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; -import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder; +import org.elasticsearch.xpack.inference.FakeMlPlugin; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; -import org.elasticsearch.xpack.ml.action.TransportCoordinatedInferenceAction; import java.io.IOException; import java.util.Collection; @@ -273,29 +267,6 @@ protected static String[] convertToArray(List indices) { return indices.stream().map(IndexWithBoost::index).toArray(String[]::new); } - public static class FakeMlPlugin extends Plugin implements ActionPlugin, SearchPlugin { - @Override - public List getNamedWriteables() { - return new MlInferenceNamedXContentProvider().getNamedWriteables(); - } - - @Override - public List> getQueryVectorBuilders() { - return List.of( - new QueryVectorBuilderSpec<>( - TextEmbeddingQueryVectorBuilder.NAME, - TextEmbeddingQueryVectorBuilder::new, - TextEmbeddingQueryVectorBuilder.PARSER - ) - ); - } - - @Override - public Collection getActions() { - return List.of(new ActionHandler(CoordinatedInferenceAction.INSTANCE, TransportCoordinatedInferenceAction.class)); - } - } - protected record TestIndexInfo( String name, Map inferenceEndpoints, diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticQueryInferenceIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticQueryInferenceIT.java index 52dbc8eeff21b..e081ec73f0953 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticQueryInferenceIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticQueryInferenceIT.java @@ -9,7 +9,6 @@ import org.elasticsearch.action.DocWriteResponse; import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.MatchQueryBuilder; @@ -17,17 +16,15 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.inference.TaskType; import org.elasticsearch.license.LicenseSettings; -import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.reindex.ReindexPlugin; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder; +import org.elasticsearch.xpack.inference.FakeMlPlugin; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; @@ -150,22 +147,4 @@ private static Map getServiceSettings(TaskType taskType) { default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]"); }; } - - public static class FakeMlPlugin extends Plugin implements ActionPlugin, SearchPlugin { - @Override - public List getNamedWriteables() { - return new MlInferenceNamedXContentProvider().getNamedWriteables(); - } - - @Override - public List> getQueryVectorBuilders() { - return List.of( - new QueryVectorBuilderSpec<>( - TextEmbeddingQueryVectorBuilder.NAME, - TextEmbeddingQueryVectorBuilder::new, - TextEmbeddingQueryVectorBuilder.PARSER - ) - ); - } - } } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexVersionIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexVersionIT.java index ba63402158e6e..caac505bacf9d 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexVersionIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexVersionIT.java @@ -11,7 +11,6 @@ import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.IndexVersion; @@ -28,7 +27,7 @@ import org.elasticsearch.test.index.IndexVersionUtils; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.inference.FakeMlPlugin; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.Utils; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; @@ -194,11 +193,4 @@ public void testSemanticText() throws Exception { assertAcked(client().admin().indices().prepareDelete(indexName)); } } - - public static class FakeMlPlugin extends Plugin { - @Override - public List getNamedWriteables() { - return new MlInferenceNamedXContentProvider().getNamedWriteables(); - } - } } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java index 8403c13014f92..93dc8fb3e89ca 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java @@ -12,7 +12,6 @@ import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.document.DocumentField; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.index.IndexVersion; @@ -30,7 +29,7 @@ import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.index.IndexVersionUtils; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.inference.FakeMlPlugin; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; @@ -286,11 +285,4 @@ private enum ExpectedSource { INFERENCE_FIELDS_EXCLUDED, INFERENCE_FIELDS_INCLUDED } - - public static class FakeMlPlugin extends Plugin { - @Override - public List getNamedWriteables() { - return new MlInferenceNamedXContentProvider().getNamedWriteables(); - } - } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/FakeMlPlugin.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/FakeMlPlugin.java new file mode 100644 index 0000000000000..2a521bfaaac41 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/FakeMlPlugin.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.plugins.ActionPlugin; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.SearchPlugin; +import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder; +import org.elasticsearch.xpack.ml.action.TransportCoordinatedInferenceAction; + +import java.util.Collection; +import java.util.List; + +public class FakeMlPlugin extends Plugin implements ActionPlugin, SearchPlugin { + @Override + public List getNamedWriteables() { + return new MlInferenceNamedXContentProvider().getNamedWriteables(); + } + + @Override + public List> getQueryVectorBuilders() { + return List.of( + new QueryVectorBuilderSpec<>( + TextEmbeddingQueryVectorBuilder.NAME, + TextEmbeddingQueryVectorBuilder::new, + TextEmbeddingQueryVectorBuilder.PARSER + ) + ); + } + + @Override + public Collection getActions() { + return List.of(new ActionHandler(CoordinatedInferenceAction.INSTANCE, TransportCoordinatedInferenceAction.class)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index b2d7218720a57..bb4ebbb434025 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -33,7 +33,6 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.compress.CompressedXContent; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; @@ -69,9 +68,9 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.inference.FakeMlPlugin; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.registry.ModelRegistry; @@ -607,13 +606,6 @@ private static MinimalServiceSettings getModelSettingsForInferenceResultType( }; } - public static class FakeMlPlugin extends Plugin { - @Override - public List getNamedWriteables() { - return new MlInferenceNamedXContentProvider().getNamedWriteables(); - } - } - private static TestThreadPool threadPool; private static ModelRegistry modelRegistry; From a9e79a31bf627e19362e86842940926a921c57f6 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Thu, 16 Oct 2025 15:23:48 -0400 Subject: [PATCH 09/18] Renamed test class --- ...icQueryInferenceIT.java => ManyInferenceQueryClausesIT.java} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/{SemanticQueryInferenceIT.java => ManyInferenceQueryClausesIT.java} (99%) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticQueryInferenceIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ManyInferenceQueryClausesIT.java similarity index 99% rename from x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticQueryInferenceIT.java rename to x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ManyInferenceQueryClausesIT.java index e081ec73f0953..af05276a9ea14 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticQueryInferenceIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ManyInferenceQueryClausesIT.java @@ -43,7 +43,7 @@ import static org.hamcrest.CoreMatchers.is; @ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE, numDataNodes = 1) -public class SemanticQueryInferenceIT extends ESIntegTestCase { +public class ManyInferenceQueryClausesIT extends ESIntegTestCase { private static final String INDEX_NAME = "test_index"; private static final Map SPARSE_EMBEDDING_SERVICE_SETTINGS = Map.of("model", "my_model", "api_key", "my_api_key"); From 23c67f7c630acf31e7992afad199e97048e4a162 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Thu, 16 Oct 2025 16:01:40 -0400 Subject: [PATCH 10/18] Update docs/changelog/136720.yaml --- docs/changelog/136720.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 docs/changelog/136720.yaml diff --git a/docs/changelog/136720.yaml b/docs/changelog/136720.yaml new file mode 100644 index 0000000000000..f5688b4a14792 --- /dev/null +++ b/docs/changelog/136720.yaml @@ -0,0 +1,6 @@ +pr: 136720 +summary: Use Suppliers To Get Inference Results In Semantic Queries +area: Relevance +type: bug +issues: + - 136621 From 33ebe3b41f703ecb808820b6eec47a2a57ef1a47 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Thu, 16 Oct 2025 16:08:44 -0400 Subject: [PATCH 11/18] Updated changelog --- docs/changelog/136720.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changelog/136720.yaml b/docs/changelog/136720.yaml index f5688b4a14792..72f395f8153a0 100644 --- a/docs/changelog/136720.yaml +++ b/docs/changelog/136720.yaml @@ -1,6 +1,6 @@ pr: 136720 summary: Use Suppliers To Get Inference Results In Semantic Queries -area: Relevance +area: Vector Search type: bug issues: - 136621 From c6ac186a258b890d730c966723c37a41b0093fd2 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 17 Oct 2025 08:53:58 -0400 Subject: [PATCH 12/18] Factor out common logic --- .../InterceptedInferenceQueryBuilder.java | 20 ++++------ .../queries/SemanticQueryBuilder.java | 38 +++++++++++++------ 2 files changed, 33 insertions(+), 25 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java index d23de8ae615f1..485bccc9ebb4e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java @@ -44,7 +44,7 @@ import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.SEMANTIC_SEARCH_CCS_SUPPORT; import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.convertFromBwcInferenceResultsMap; import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.getInferenceResults; -import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.mergeInferenceResultsMaps; +import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.getNewInferenceResultsFromSupplier; /** *

@@ -344,18 +344,12 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri if (inferenceResultsMapSupplier != null) { // Additional inference results have already been requested, and we are waiting for them to continue the rewrite process - QueryBuilder rewritten = this; - - Map newInferenceResultsMap = inferenceResultsMapSupplier.get(); - if (newInferenceResultsMap != null) { - Map mergedInferenceResultsMap = mergeInferenceResultsMaps( - inferenceResultsMap, - newInferenceResultsMap - ); - rewritten = copy(mergedInferenceResultsMap, null, ccsRequest); - } - - return rewritten; + return getNewInferenceResultsFromSupplier( + inferenceResultsMapSupplier, + this, + inferenceResultsMap, + m -> copy(m, null, ccsRequest) + ); } FullyQualifiedInferenceId inferenceIdOverride = getInferenceIdOverride(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index dfb00899b1fb6..626cdeea3cdca 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -52,6 +52,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -387,6 +388,25 @@ static void registerInferenceAsyncActions( }); } + static T getNewInferenceResultsFromSupplier( + Supplier> supplier, + T currentQueryBuilder, + Map currentInferenceResultsMap, + Function, T> copyGenerator + ) { + T rewritten = currentQueryBuilder; + Map newInferenceResultsMap = supplier.get(); + if (newInferenceResultsMap != null) { + Map mergedInferenceResultsMap = mergeInferenceResultsMaps( + currentInferenceResultsMap, + newInferenceResultsMap + ); + rewritten = copyGenerator.apply(mergedInferenceResultsMap); + } + + return rewritten; + } + private static GroupedActionListener> createGroupedActionListener( InferenceResultsMapSupplier inferenceResultsMapSupplier, int inferenceRequestCount, @@ -515,18 +535,12 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu if (inferenceResultsMapSupplier != null) { // Additional inference results have already been requested, and we are waiting for them to continue the rewrite process - SemanticQueryBuilder rewritten = this; - - Map newInferenceResultsMap = inferenceResultsMapSupplier.get(); - if (newInferenceResultsMap != null) { - Map mergedInferenceResultsMap = mergeInferenceResultsMaps( - inferenceResultsMap, - newInferenceResultsMap - ); - rewritten = new SemanticQueryBuilder(this, mergedInferenceResultsMap, null, ccsRequest); - } - - return rewritten; + return getNewInferenceResultsFromSupplier( + inferenceResultsMapSupplier, + this, + inferenceResultsMap, + m -> new SemanticQueryBuilder(this, m, null, ccsRequest) + ); } Set fullyQualifiedInferenceIds = getInferenceIdsForField( From 27bbd55d52c1442dc4992020417c16d6ce9b97c8 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 17 Oct 2025 08:54:32 -0400 Subject: [PATCH 13/18] Make inference results map volatile --- .../xpack/inference/queries/InferenceResultsMapSupplier.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceResultsMapSupplier.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceResultsMapSupplier.java index 16558e9df579e..68f6923e0c3db 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceResultsMapSupplier.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceResultsMapSupplier.java @@ -14,7 +14,7 @@ import java.util.function.Supplier; class InferenceResultsMapSupplier implements Supplier> { - private Map inferenceResultsMap = null; + private volatile Map inferenceResultsMap = null; void set(Map inferenceResultsMap) { this.inferenceResultsMap = inferenceResultsMap; From 868dd8a3a3248c43b5d25d9d8709a922664fb34b Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Fri, 17 Oct 2025 08:55:51 -0400 Subject: [PATCH 14/18] Rename method --- .../xpack/inference/integration/IntegrationTestUtils.java | 2 +- .../inference/integration/ManyInferenceQueryClausesIT.java | 2 +- .../inference/integration/SemanticTextInferenceFieldsIT.java | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/IntegrationTestUtils.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/IntegrationTestUtils.java index 14f48526824e3..70a7ad9077102 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/IntegrationTestUtils.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/IntegrationTestUtils.java @@ -86,7 +86,7 @@ public static void deleteIndex(Client client, String indexName) { ); } - public static XContentBuilder generateMapping(Map semanticTextFields) throws IOException { + public static XContentBuilder generateSemanticTextMapping(Map semanticTextFields) throws IOException { XContentBuilder mapping = XContentFactory.jsonBuilder().startObject().startObject("properties"); for (var entry : semanticTextFields.entrySet()) { mapping.startObject(entry.getKey()); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ManyInferenceQueryClausesIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ManyInferenceQueryClausesIT.java index af05276a9ea14..f9849076c7a47 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ManyInferenceQueryClausesIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ManyInferenceQueryClausesIT.java @@ -114,7 +114,7 @@ private void manyQueryClausesTestCase( semanticTextFields.put(fieldName, inferenceId); } - XContentBuilder mapping = IntegrationTestUtils.generateMapping(semanticTextFields); + XContentBuilder mapping = IntegrationTestUtils.generateSemanticTextMapping(semanticTextFields); assertAcked(prepareCreate(INDEX_NAME).setMapping(mapping)); BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java index 93dc8fb3e89ca..15a524ffa68ed 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java @@ -115,7 +115,7 @@ private void excludeInferenceFieldsFromSourceTestCase(IndexVersion minIndexVersi for (int i = 0; i < iterations; i++) { final IndexVersion indexVersion = IndexVersionUtils.randomVersionBetween(random(), minIndexVersion, maxIndexVersion); final Settings indexSettings = generateIndexSettings(indexVersion); - XContentBuilder mappings = IntegrationTestUtils.generateMapping( + XContentBuilder mappings = IntegrationTestUtils.generateSemanticTextMapping( Map.of(sparseEmbeddingField, sparseEmbeddingInferenceId, textEmbeddingField, textEmbeddingInferenceId) ); assertAcked(prepareCreate(indexName).setSettings(indexSettings).setMapping(mappings)); From 9fd4e57c9cba5eeb3fb38d74f9ae29d3cf9b695d Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 20 Oct 2025 13:21:28 -0400 Subject: [PATCH 15/18] Use SetOnce --- ...rceptedInferenceKnnVectorQueryBuilder.java | 6 +++--- ...InterceptedInferenceMatchQueryBuilder.java | 6 +++--- .../InterceptedInferenceQueryBuilder.java | 10 +++++----- ...ptedInferenceSparseVectorQueryBuilder.java | 6 +++--- .../queries/SemanticQueryBuilder.java | 20 +++++++++---------- 5 files changed, 24 insertions(+), 24 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java index bfdbd71989198..808afeb6b3c33 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.queries; import org.apache.lucene.search.join.ScoreMode; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ResolvedIndices; import org.elasticsearch.cluster.metadata.IndexMetadata; @@ -33,7 +34,6 @@ import java.io.IOException; import java.util.Collection; import java.util.Map; -import java.util.function.Supplier; import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; @@ -63,7 +63,7 @@ public InterceptedInferenceKnnVectorQueryBuilder(StreamInput in) throws IOExcept private InterceptedInferenceKnnVectorQueryBuilder( InterceptedInferenceQueryBuilder other, Map inferenceResultsMap, - Supplier> inferenceResultsMapSupplier, + SetOnce> inferenceResultsMapSupplier, boolean ccsRequest ) { super(other, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); @@ -131,7 +131,7 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) { @Override protected QueryBuilder copy( Map inferenceResultsMap, - Supplier> inferenceResultsMapSupplier, + SetOnce> inferenceResultsMapSupplier, boolean ccsRequest ) { return new InterceptedInferenceKnnVectorQueryBuilder(this, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java index 8388ad44f53a0..018fdca7fabdb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.queries; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.TransportVersion; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.index.mapper.MappedFieldType; @@ -20,7 +21,6 @@ import java.io.IOException; import java.util.Map; -import java.util.function.Supplier; public class InterceptedInferenceMatchQueryBuilder extends InterceptedInferenceQueryBuilder { public static final String NAME = "intercepted_inference_match"; @@ -48,7 +48,7 @@ public InterceptedInferenceMatchQueryBuilder(StreamInput in) throws IOException private InterceptedInferenceMatchQueryBuilder( InterceptedInferenceQueryBuilder other, Map inferenceResultsMap, - Supplier> inferenceResultsMapSupplier, + SetOnce> inferenceResultsMapSupplier, boolean ccsRequest ) { super(other, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); @@ -77,7 +77,7 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) { @Override protected QueryBuilder copy( Map inferenceResultsMap, - Supplier> inferenceResultsMapSupplier, + SetOnce> inferenceResultsMapSupplier, boolean ccsRequest ) { return new InterceptedInferenceMatchQueryBuilder(this, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java index 485bccc9ebb4e..02cb81d04e811 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.queries; import org.apache.lucene.search.Query; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ResolvedIndices; import org.elasticsearch.cluster.metadata.IndexMetadata; @@ -37,7 +38,6 @@ import java.util.Map; import java.util.Objects; import java.util.Set; -import java.util.function.Supplier; import static org.elasticsearch.index.IndexSettings.DEFAULT_FIELD_SETTING; import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; @@ -70,7 +70,7 @@ public abstract class InterceptedInferenceQueryBuilder inferenceResultsMap; - protected final Supplier> inferenceResultsMapSupplier; + protected final SetOnce> inferenceResultsMapSupplier; protected final boolean ccsRequest; protected InterceptedInferenceQueryBuilder(T originalQuery) { @@ -110,7 +110,7 @@ protected InterceptedInferenceQueryBuilder(StreamInput in) throws IOException { protected InterceptedInferenceQueryBuilder( InterceptedInferenceQueryBuilder other, Map inferenceResultsMap, - Supplier> inferenceResultsMapSupplier, + SetOnce> inferenceResultsMapSupplier, boolean ccsRequest ) { this.originalQuery = other.originalQuery; @@ -162,7 +162,7 @@ protected InterceptedInferenceQueryBuilder( */ protected abstract QueryBuilder copy( Map inferenceResultsMap, - Supplier> inferenceResultsMapSupplier, + SetOnce> inferenceResultsMapSupplier, boolean ccsRequest ); @@ -357,7 +357,7 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri inferenceIds = Set.of(inferenceIdOverride); } - Supplier> newInferenceResultsMapSupplier = getInferenceResults( + SetOnce> newInferenceResultsMapSupplier = getInferenceResults( queryRewriteContext, inferenceIds, inferenceResultsMap, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java index 809250a8410c3..48a9d3910b01e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.queries; import org.apache.lucene.search.join.ScoreMode; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ResolvedIndices; import org.elasticsearch.cluster.metadata.IndexMetadata; @@ -32,7 +33,6 @@ import java.util.Collection; import java.util.List; import java.util.Map; -import java.util.function.Supplier; import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; @@ -62,7 +62,7 @@ public InterceptedInferenceSparseVectorQueryBuilder(StreamInput in) throws IOExc private InterceptedInferenceSparseVectorQueryBuilder( InterceptedInferenceQueryBuilder other, Map inferenceResultsMap, - Supplier> inferenceResultsMapSupplier, + SetOnce> inferenceResultsMapSupplier, boolean ccsRequest ) { super(other, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); @@ -118,7 +118,7 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) { @Override protected QueryBuilder copy( Map inferenceResultsMap, - Supplier> inferenceResultsMapSupplier, + SetOnce> inferenceResultsMapSupplier, boolean ccsRequest ) { return new InterceptedInferenceSparseVectorQueryBuilder(this, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index b53956a1096ad..5b0dad0a8d6cc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.queries; import org.apache.lucene.search.Query; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; @@ -53,7 +54,6 @@ import java.util.Objects; import java.util.Set; import java.util.function.Function; -import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -100,7 +100,7 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder inferenceResultsMap; - private final Supplier> inferenceResultsMapSupplier; + private final SetOnce> inferenceResultsMapSupplier; private final Boolean lenient; // ccsRequest is only used on the local cluster coordinator node to detect when: @@ -236,7 +236,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { private SemanticQueryBuilder( SemanticQueryBuilder other, Map inferenceResultsMap, - Supplier> inferenceResultsMapSupplier, + SetOnce> inferenceResultsMapSupplier, boolean ccsRequest ) { this.fieldName = other.fieldName; @@ -287,7 +287,7 @@ public static SemanticQueryBuilder fromXContent(XContentParser parser) throws IO * @param query The query to generate inference results for * @return An inference results map supplier */ - static Supplier> getInferenceResults( + static SetOnce> getInferenceResults( QueryRewriteContext queryRewriteContext, Set fullyQualifiedInferenceIds, @Nullable Map inferenceResultsMap, @@ -313,9 +313,9 @@ static Supplier> getInferenceRe } } - InferenceResultsMapSupplier inferenceResultsMapSupplier = null; + SetOnce> inferenceResultsMapSupplier = null; if (inferenceIds.isEmpty() == false) { - inferenceResultsMapSupplier = new InferenceResultsMapSupplier(); + inferenceResultsMapSupplier = new SetOnce<>(); registerInferenceAsyncActions(queryRewriteContext, inferenceResultsMapSupplier, query, inferenceIds); } @@ -339,7 +339,7 @@ static Map mergeInferenceResultsMap static void registerInferenceAsyncActions( QueryRewriteContext queryRewriteContext, - InferenceResultsMapSupplier inferenceResultsMapSupplier, + SetOnce> inferenceResultsMapSupplier, String query, List inferenceIds ) { @@ -389,7 +389,7 @@ static void registerInferenceAsyncActions( } static T getNewInferenceResultsFromSupplier( - Supplier> supplier, + SetOnce> supplier, T currentQueryBuilder, Map currentInferenceResultsMap, Function, T> copyGenerator @@ -408,7 +408,7 @@ static T getNewInferenceResultsFromSupplier( } private static GroupedActionListener> createGroupedActionListener( - InferenceResultsMapSupplier inferenceResultsMapSupplier, + SetOnce> inferenceResultsMapSupplier, int inferenceRequestCount, ActionListener listener ) { @@ -548,7 +548,7 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu queryRewriteContext.getLocalClusterAlias(), fieldName ); - Supplier> newInferenceResultsMapSupplier = getInferenceResults( + SetOnce> newInferenceResultsMapSupplier = getInferenceResults( queryRewriteContext, fullyQualifiedInferenceIds, inferenceResultsMap, From 13f80cfd77fa5067d61106c6d496b6a73606b6a3 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 20 Oct 2025 13:22:23 -0400 Subject: [PATCH 16/18] Remove InferenceResultsMapSupplier --- .../queries/InferenceResultsMapSupplier.java | 40 ------------------- 1 file changed, 40 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceResultsMapSupplier.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceResultsMapSupplier.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceResultsMapSupplier.java deleted file mode 100644 index 68f6923e0c3db..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InferenceResultsMapSupplier.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.queries; - -import org.elasticsearch.inference.InferenceResults; - -import java.util.Map; -import java.util.Objects; -import java.util.function.Supplier; - -class InferenceResultsMapSupplier implements Supplier> { - private volatile Map inferenceResultsMap = null; - - void set(Map inferenceResultsMap) { - this.inferenceResultsMap = inferenceResultsMap; - } - - @Override - public Map get() { - return inferenceResultsMap; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceResultsMapSupplier that = (InferenceResultsMapSupplier) o; - return Objects.equals(inferenceResultsMap, that.inferenceResultsMap); - } - - @Override - public int hashCode() { - return Objects.hashCode(inferenceResultsMap); - } -} From 05d0a85f1b13525a523c53b18eaa6fe027de2258 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 20 Oct 2025 13:27:07 -0400 Subject: [PATCH 17/18] Use hash map --- .../xpack/inference/queries/SemanticQueryBuilder.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 5b0dad0a8d6cc..46467b93d71d8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -16,7 +16,6 @@ import org.elasticsearch.action.support.GroupedActionListener; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; -import org.elasticsearch.common.collect.ImmutableOpenMap; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.Nullable; @@ -48,6 +47,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -413,9 +413,9 @@ private static GroupedActionListener listener ) { return new GroupedActionListener<>(inferenceRequestCount, listener.delegateFailureAndWrap((l, responses) -> { - ImmutableOpenMap.Builder mapBuilder = ImmutableOpenMap.builder(); - responses.forEach(r -> mapBuilder.put(r.v1(), r.v2())); - inferenceResultsMapSupplier.set(mapBuilder.build()); + Map inferenceResultsMap = new HashMap<>(responses.size()); + responses.forEach(r -> inferenceResultsMap.put(r.v1(), r.v2())); + inferenceResultsMapSupplier.set(inferenceResultsMap); l.onResponse(null); })); } From e8be765e48e2da6371201baefd58248e4e12be06 Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 20 Oct 2025 13:53:12 -0400 Subject: [PATCH 18/18] Don't merge inference results maps --- .../InterceptedInferenceQueryBuilder.java | 7 +--- .../queries/SemanticQueryBuilder.java | 34 ++++--------------- 2 files changed, 7 insertions(+), 34 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java index 02cb81d04e811..89fbf94f2f0de 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java @@ -344,12 +344,7 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri if (inferenceResultsMapSupplier != null) { // Additional inference results have already been requested, and we are waiting for them to continue the rewrite process - return getNewInferenceResultsFromSupplier( - inferenceResultsMapSupplier, - this, - inferenceResultsMap, - m -> copy(m, null, ccsRequest) - ); + return getNewInferenceResultsFromSupplier(inferenceResultsMapSupplier, this, m -> copy(m, null, ccsRequest)); } FullyQualifiedInferenceId inferenceIdOverride = getInferenceIdOverride(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 46467b93d71d8..4060d1c6bc4a9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -55,7 +55,6 @@ import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; -import java.util.stream.Stream; import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; @@ -322,21 +321,6 @@ static SetOnce> getInferenceRes return inferenceResultsMapSupplier; } - static Map mergeInferenceResultsMaps( - @Nullable Map originalInferenceResultsMap, - Map newInferenceResultsMap - ) { - Map mergedInferenceResultsMap = newInferenceResultsMap; - if (originalInferenceResultsMap != null && originalInferenceResultsMap.isEmpty() == false) { - mergedInferenceResultsMap = Stream.concat( - originalInferenceResultsMap.entrySet().stream(), - newInferenceResultsMap.entrySet().stream() - ).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - } - - return mergedInferenceResultsMap; - } - static void registerInferenceAsyncActions( QueryRewriteContext queryRewriteContext, SetOnce> inferenceResultsMapSupplier, @@ -391,20 +375,15 @@ static void registerInferenceAsyncActions( static T getNewInferenceResultsFromSupplier( SetOnce> supplier, T currentQueryBuilder, - Map currentInferenceResultsMap, Function, T> copyGenerator ) { - T rewritten = currentQueryBuilder; Map newInferenceResultsMap = supplier.get(); - if (newInferenceResultsMap != null) { - Map mergedInferenceResultsMap = mergeInferenceResultsMaps( - currentInferenceResultsMap, - newInferenceResultsMap - ); - rewritten = copyGenerator.apply(mergedInferenceResultsMap); - } - - return rewritten; + // It's safe to use only the new inference results map (once set) because we can enumerate the scenarios where we need to get + // inference results: + // - On the local coordinating node, getting inference results for the first time. The previous inference results map is null. + // - On the remote coordinating node, getting inference results for remote cluster inference IDs. In this case, we can guarantee + // that only remote cluster inference results are required to handle the query. + return newInferenceResultsMap != null ? copyGenerator.apply(newInferenceResultsMap) : currentQueryBuilder; } private static GroupedActionListener> createGroupedActionListener( @@ -538,7 +517,6 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu return getNewInferenceResultsFromSupplier( inferenceResultsMapSupplier, this, - inferenceResultsMap, m -> new SemanticQueryBuilder(this, m, null, ccsRequest) ); }