diff --git a/docs/changelog/136720.yaml b/docs/changelog/136720.yaml new file mode 100644 index 0000000000000..72f395f8153a0 --- /dev/null +++ b/docs/changelog/136720.yaml @@ -0,0 +1,6 @@ +pr: 136720 +summary: Use Suppliers To Get Inference Results In Semantic Queries +area: Vector Search +type: bug +issues: + - 136621 diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java index e1f72ddec5aab..6c268a318549b 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/search/ccs/AbstractSemanticCrossClusterSearchTestCase.java @@ -20,7 +20,6 @@ import org.elasticsearch.action.support.broadcast.BroadcastResponse; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; @@ -31,27 +30,16 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.license.LicenseSettings; -import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.AbstractMultiClustersTestCase; import org.elasticsearch.transport.RemoteConnectionInfo; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; -import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; -import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; -import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder; +import org.elasticsearch.xpack.inference.FakeMlPlugin; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; -import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; -import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; -import org.elasticsearch.xpack.ml.action.TransportCoordinatedInferenceAction; import java.io.IOException; import java.util.Collection; @@ -66,6 +54,7 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; +import static org.elasticsearch.xpack.inference.integration.IntegrationTestUtils.createInferenceEndpoint; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; @@ -165,35 +154,6 @@ protected BytesReference openPointInTime(String[] indices, TimeValue keepAlive) return response.getPointInTimeId(); } - protected static void createInferenceEndpoint(Client client, TaskType taskType, String inferenceId, Map serviceSettings) - throws IOException { - final String service = switch (taskType) { - case TEXT_EMBEDDING -> TestDenseInferenceServiceExtension.TestInferenceService.NAME; - case SPARSE_EMBEDDING -> TestSparseInferenceServiceExtension.TestInferenceService.NAME; - default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]"); - }; - - final BytesReference content; - try (XContentBuilder builder = XContentFactory.jsonBuilder()) { - builder.startObject(); - builder.field("service", service); - builder.field("service_settings", serviceSettings); - builder.endObject(); - - content = BytesReference.bytes(builder); - } - - PutInferenceModelAction.Request request = new PutInferenceModelAction.Request( - taskType, - inferenceId, - content, - XContentType.JSON, - TEST_REQUEST_TIMEOUT - ); - var responseFuture = client.execute(PutInferenceModelAction.INSTANCE, request); - assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId)); - } - protected void assertSearchResponse(QueryBuilder queryBuilder, List indices, List expectedSearchResults) throws Exception { assertSearchResponse(queryBuilder, indices, expectedSearchResults, null, null); @@ -307,29 +267,6 @@ protected static String[] convertToArray(List indices) { return indices.stream().map(IndexWithBoost::index).toArray(String[]::new); } - public static class FakeMlPlugin extends Plugin implements ActionPlugin, SearchPlugin { - @Override - public List getNamedWriteables() { - return new MlInferenceNamedXContentProvider().getNamedWriteables(); - } - - @Override - public List> getQueryVectorBuilders() { - return List.of( - new QueryVectorBuilderSpec<>( - TextEmbeddingQueryVectorBuilder.NAME, - TextEmbeddingQueryVectorBuilder::new, - TextEmbeddingQueryVectorBuilder.PARSER - ) - ); - } - - @Override - public Collection getActions() { - return List.of(new ActionHandler(CoordinatedInferenceAction.INSTANCE, TransportCoordinatedInferenceAction.class)); - } - } - protected record TestIndexInfo( String name, Map inferenceEndpoints, diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/IntegrationTestUtils.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/IntegrationTestUtils.java new file mode 100644 index 0000000000000..70a7ad9077102 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/IntegrationTestUtils.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction; +import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; +import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; + +import java.io.IOException; +import java.util.Map; + +import static org.elasticsearch.test.ESTestCase.TEST_REQUEST_TIMEOUT; +import static org.elasticsearch.test.ESTestCase.safeGet; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; + +public class IntegrationTestUtils { + private IntegrationTestUtils() {} + + public static void createInferenceEndpoint(Client client, TaskType taskType, String inferenceId, Map serviceSettings) + throws IOException { + final String service = switch (taskType) { + case TEXT_EMBEDDING -> TestDenseInferenceServiceExtension.TestInferenceService.NAME; + case SPARSE_EMBEDDING -> TestSparseInferenceServiceExtension.TestInferenceService.NAME; + default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]"); + }; + + final BytesReference content; + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + builder.startObject(); + builder.field("service", service); + builder.field("service_settings", serviceSettings); + builder.endObject(); + + content = BytesReference.bytes(builder); + } + + PutInferenceModelAction.Request request = new PutInferenceModelAction.Request( + taskType, + inferenceId, + content, + XContentType.JSON, + TEST_REQUEST_TIMEOUT + ); + var responseFuture = client.execute(PutInferenceModelAction.INSTANCE, request); + assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId)); + } + + public static void deleteInferenceEndpoint(Client client, TaskType taskType, String inferenceId) { + assertAcked( + safeGet( + client.execute( + DeleteInferenceEndpointAction.INSTANCE, + new DeleteInferenceEndpointAction.Request(inferenceId, taskType, true, false) + ) + ) + ); + } + + public static void deleteIndex(Client client, String indexName) { + assertAcked( + safeGet( + client.admin() + .indices() + .prepareDelete(indexName) + .setIndicesOptions( + IndicesOptions.builder().concreteTargetOptions(new IndicesOptions.ConcreteTargetOptions(true)).build() + ) + .execute() + ) + ); + } + + public static XContentBuilder generateSemanticTextMapping(Map semanticTextFields) throws IOException { + XContentBuilder mapping = XContentFactory.jsonBuilder().startObject().startObject("properties"); + for (var entry : semanticTextFields.entrySet()) { + mapping.startObject(entry.getKey()); + mapping.field("type", SemanticTextFieldMapper.CONTENT_TYPE); + mapping.field("inference_id", entry.getValue()); + mapping.endObject(); + } + mapping.endObject().endObject(); + + return mapping; + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ManyInferenceQueryClausesIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ManyInferenceQueryClausesIT.java new file mode 100644 index 0000000000000..f9849076c7a47 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ManyInferenceQueryClausesIT.java @@ -0,0 +1,150 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.MatchQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.license.LicenseSettings; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.reindex.ReindexPlugin; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; +import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder; +import org.elasticsearch.xpack.inference.FakeMlPlugin; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; +import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; +import org.junit.After; + +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; + +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE, numDataNodes = 1) +public class ManyInferenceQueryClausesIT extends ESIntegTestCase { + private static final String INDEX_NAME = "test_index"; + + private static final Map SPARSE_EMBEDDING_SERVICE_SETTINGS = Map.of("model", "my_model", "api_key", "my_api_key"); + private static final Map TEXT_EMBEDDING_SERVICE_SETTINGS = Map.of( + "model", + "my_model", + "dimensions", + 256, + "similarity", + "cosine", + "api_key", + "my_api_key" + ); + + private final Map inferenceIds = new HashMap<>(); + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); + } + + @Override + protected Collection> nodePlugins() { + return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class, ReindexPlugin.class, FakeMlPlugin.class); + } + + @After + public void cleanUp() { + IntegrationTestUtils.deleteIndex(client(), INDEX_NAME); + for (var entry : inferenceIds.entrySet()) { + IntegrationTestUtils.deleteInferenceEndpoint(client(), entry.getValue(), entry.getKey()); + } + } + + public void testManySemanticQueryClauses() throws Exception { + manyQueryClausesTestCase(randomIntBetween(18, 24), SemanticQueryBuilder::new, TaskType.SPARSE_EMBEDDING); + } + + public void testManyMatchQueryClauses() throws Exception { + manyQueryClausesTestCase(randomIntBetween(18, 24), MatchQueryBuilder::new, TaskType.SPARSE_EMBEDDING); + } + + public void testManySparseVectorQueryClauses() throws Exception { + manyQueryClausesTestCase(randomIntBetween(18, 24), (f, q) -> new SparseVectorQueryBuilder(f, null, q), TaskType.SPARSE_EMBEDDING); + } + + public void testManyKnnQueryClauses() throws Exception { + int clauseCount = randomIntBetween(18, 24); + manyQueryClausesTestCase( + clauseCount, + (f, q) -> new KnnVectorQueryBuilder(f, new TextEmbeddingQueryVectorBuilder(null, q), clauseCount, clauseCount * 10, null, null), + TaskType.TEXT_EMBEDDING + ); + } + + private void manyQueryClausesTestCase( + int clauseCount, + BiFunction clauseGenerator, + TaskType semanticTextFieldTaskType + ) throws Exception { + Map inferenceEndpointServiceSettings = getServiceSettings(semanticTextFieldTaskType); + Map semanticTextFields = new HashMap<>(clauseCount); + for (int i = 0; i < clauseCount; i++) { + String fieldName = randomAlphaOfLength(10); + String inferenceId = randomIdentifier(); + + createInferenceEndpoint(semanticTextFieldTaskType, inferenceId, inferenceEndpointServiceSettings); + semanticTextFields.put(fieldName, inferenceId); + } + + XContentBuilder mapping = IntegrationTestUtils.generateSemanticTextMapping(semanticTextFields); + assertAcked(prepareCreate(INDEX_NAME).setMapping(mapping)); + + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); + for (String semanticTextField : semanticTextFields.keySet()) { + Map source = Map.of(semanticTextField, randomAlphaOfLength(10)); + DocWriteResponse docWriteResponse = client().prepareIndex(INDEX_NAME).setSource(source).get(TEST_REQUEST_TIMEOUT); + assertThat(docWriteResponse.getResult(), is(DocWriteResponse.Result.CREATED)); + + boolQuery.should(clauseGenerator.apply(semanticTextField, randomAlphaOfLength(10))); + } + client().admin().indices().prepareRefresh(INDEX_NAME).get(); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(boolQuery).size(clauseCount); + SearchRequest searchRequest = new SearchRequest(new String[] { INDEX_NAME }, searchSourceBuilder); + assertResponse(client().search(searchRequest), response -> { + assertThat(response.getSuccessfulShards(), equalTo(response.getTotalShards())); + assertThat(response.getHits().getTotalHits().value(), equalTo((long) clauseCount)); + }); + } + + private void createInferenceEndpoint(TaskType taskType, String inferenceId, Map serviceSettings) throws IOException { + IntegrationTestUtils.createInferenceEndpoint(client(), taskType, inferenceId, serviceSettings); + inferenceIds.put(inferenceId, taskType); + } + + private static Map getServiceSettings(TaskType taskType) { + return switch (taskType) { + case SPARSE_EMBEDDING -> SPARSE_EMBEDDING_SERVICE_SETTINGS; + case TEXT_EMBEDDING -> TEXT_EMBEDDING_SERVICE_SETTINGS; + default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]"); + }; + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexOptionsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexOptionsIT.java index 08d3f0a6a9b9f..1bd79aab95a4f 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexOptionsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexOptionsIT.java @@ -37,13 +37,10 @@ import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction; -import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; import org.elasticsearch.xpack.inference.InferenceIndex; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; -import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; -import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; import org.junit.After; import org.junit.Before; @@ -153,32 +150,7 @@ public void testSetDefaultBBQIndexOptionsWithBasicLicense() throws Exception { } private void createInferenceEndpoint(TaskType taskType, String inferenceId, Map serviceSettings) throws IOException { - final String service = switch (taskType) { - case TEXT_EMBEDDING -> TestDenseInferenceServiceExtension.TestInferenceService.NAME; - case SPARSE_EMBEDDING -> TestSparseInferenceServiceExtension.TestInferenceService.NAME; - default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]"); - }; - - final BytesReference content; - try (XContentBuilder builder = XContentFactory.jsonBuilder()) { - builder.startObject(); - builder.field("service", service); - builder.field("service_settings", serviceSettings); - builder.endObject(); - - content = BytesReference.bytes(builder); - } - - PutInferenceModelAction.Request request = new PutInferenceModelAction.Request( - taskType, - inferenceId, - content, - XContentType.JSON, - TEST_REQUEST_TIMEOUT - ); - var responseFuture = client().execute(PutInferenceModelAction.INSTANCE, request); - assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId)); - + IntegrationTestUtils.createInferenceEndpoint(client(), taskType, inferenceId, serviceSettings); inferenceIds.put(inferenceId, taskType); } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexVersionIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexVersionIT.java index ba63402158e6e..caac505bacf9d 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexVersionIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextIndexVersionIT.java @@ -11,7 +11,6 @@ import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.IndexVersion; @@ -28,7 +27,7 @@ import org.elasticsearch.test.index.IndexVersionUtils; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.inference.FakeMlPlugin; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.Utils; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; @@ -194,11 +193,4 @@ public void testSemanticText() throws Exception { assertAcked(client().admin().indices().prepareDelete(indexName)); } } - - public static class FakeMlPlugin extends Plugin { - @Override - public List getNamedWriteables() { - return new MlInferenceNamedXContentProvider().getNamedWriteables(); - } - } } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java index bb26153965499..15a524ffa68ed 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/SemanticTextInferenceFieldsIT.java @@ -10,11 +10,8 @@ import org.elasticsearch.action.DocWriteResponse; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.document.DocumentField; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.index.IndexVersion; @@ -32,16 +29,9 @@ import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.index.IndexVersionUtils; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction; -import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; -import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.inference.FakeMlPlugin; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; -import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; -import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin; -import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; import org.junit.After; @@ -94,16 +84,9 @@ protected boolean forbidPrivateIndexSettings() { @After public void cleanUp() { - deleteIndex(indexName); + IntegrationTestUtils.deleteIndex(client(), indexName); for (var entry : inferenceIds.entrySet()) { - assertAcked( - safeGet( - client().execute( - DeleteInferenceEndpointAction.INSTANCE, - new DeleteInferenceEndpointAction.Request(entry.getKey(), entry.getValue(), true, false) - ) - ) - ); + IntegrationTestUtils.deleteInferenceEndpoint(client(), entry.getValue(), entry.getKey()); } } @@ -132,7 +115,7 @@ private void excludeInferenceFieldsFromSourceTestCase(IndexVersion minIndexVersi for (int i = 0; i < iterations; i++) { final IndexVersion indexVersion = IndexVersionUtils.randomVersionBetween(random(), minIndexVersion, maxIndexVersion); final Settings indexSettings = generateIndexSettings(indexVersion); - XContentBuilder mappings = generateMapping( + XContentBuilder mappings = IntegrationTestUtils.generateSemanticTextMapping( Map.of(sparseEmbeddingField, sparseEmbeddingInferenceId, textEmbeddingField, textEmbeddingInferenceId) ); assertAcked(prepareCreate(indexName).setSettings(indexSettings).setMapping(mappings)); @@ -163,37 +146,12 @@ private void excludeInferenceFieldsFromSourceTestCase(IndexVersion minIndexVersi } }); - deleteIndex(indexName); + IntegrationTestUtils.deleteIndex(client(), indexName); } } private void createInferenceEndpoint(TaskType taskType, String inferenceId, Map serviceSettings) throws IOException { - final String service = switch (taskType) { - case TEXT_EMBEDDING -> TestDenseInferenceServiceExtension.TestInferenceService.NAME; - case SPARSE_EMBEDDING -> TestSparseInferenceServiceExtension.TestInferenceService.NAME; - default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]"); - }; - - final BytesReference content; - try (XContentBuilder builder = XContentFactory.jsonBuilder()) { - builder.startObject(); - builder.field("service", service); - builder.field("service_settings", serviceSettings); - builder.endObject(); - - content = BytesReference.bytes(builder); - } - - PutInferenceModelAction.Request request = new PutInferenceModelAction.Request( - taskType, - inferenceId, - content, - XContentType.JSON, - TEST_REQUEST_TIMEOUT - ); - var responseFuture = client().execute(PutInferenceModelAction.INSTANCE, request); - assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId)); - + IntegrationTestUtils.createInferenceEndpoint(client(), taskType, inferenceId, serviceSettings); inferenceIds.put(inferenceId, taskType); } @@ -322,43 +280,9 @@ private static FetchSourceContext generateRandomFetchSourceContext() { return fetchSourceContext; } - private static XContentBuilder generateMapping(Map semanticTextFields) throws IOException { - XContentBuilder mapping = XContentFactory.jsonBuilder().startObject().startObject("properties"); - for (var entry : semanticTextFields.entrySet()) { - mapping.startObject(entry.getKey()); - mapping.field("type", SemanticTextFieldMapper.CONTENT_TYPE); - mapping.field("inference_id", entry.getValue()); - mapping.endObject(); - } - mapping.endObject().endObject(); - - return mapping; - } - - private static void deleteIndex(String indexName) { - assertAcked( - safeGet( - client().admin() - .indices() - .prepareDelete(indexName) - .setIndicesOptions( - IndicesOptions.builder().concreteTargetOptions(new IndicesOptions.ConcreteTargetOptions(true)).build() - ) - .execute() - ) - ); - } - private enum ExpectedSource { NONE, INFERENCE_FIELDS_EXCLUDED, INFERENCE_FIELDS_INCLUDED } - - public static class FakeMlPlugin extends Plugin { - @Override - public List getNamedWriteables() { - return new MlInferenceNamedXContentProvider().getNamedWriteables(); - } - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java index cd590b8410234..808afeb6b3c33 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.queries; import org.apache.lucene.search.join.ScoreMode; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ResolvedIndices; import org.elasticsearch.cluster.metadata.IndexMetadata; @@ -62,9 +63,10 @@ public InterceptedInferenceKnnVectorQueryBuilder(StreamInput in) throws IOExcept private InterceptedInferenceKnnVectorQueryBuilder( InterceptedInferenceQueryBuilder other, Map inferenceResultsMap, + SetOnce> inferenceResultsMapSupplier, boolean ccsRequest ) { - super(other, inferenceResultsMap, ccsRequest); + super(other, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); } @Override @@ -127,8 +129,12 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) { } @Override - protected QueryBuilder copy(Map inferenceResultsMap, boolean ccsRequest) { - return new InterceptedInferenceKnnVectorQueryBuilder(this, inferenceResultsMap, ccsRequest); + protected QueryBuilder copy( + Map inferenceResultsMap, + SetOnce> inferenceResultsMapSupplier, + boolean ccsRequest + ) { + return new InterceptedInferenceKnnVectorQueryBuilder(this, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java index a689c8a3a7f1e..018fdca7fabdb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceMatchQueryBuilder.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.queries; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.TransportVersion; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.index.mapper.MappedFieldType; @@ -47,9 +48,10 @@ public InterceptedInferenceMatchQueryBuilder(StreamInput in) throws IOException private InterceptedInferenceMatchQueryBuilder( InterceptedInferenceQueryBuilder other, Map inferenceResultsMap, + SetOnce> inferenceResultsMapSupplier, boolean ccsRequest ) { - super(other, inferenceResultsMap, ccsRequest); + super(other, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); } @Override @@ -73,8 +75,12 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) { } @Override - protected QueryBuilder copy(Map inferenceResultsMap, boolean ccsRequest) { - return new InterceptedInferenceMatchQueryBuilder(this, inferenceResultsMap, ccsRequest); + protected QueryBuilder copy( + Map inferenceResultsMap, + SetOnce> inferenceResultsMapSupplier, + boolean ccsRequest + ) { + return new InterceptedInferenceMatchQueryBuilder(this, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java index 8774d35f17ade..89fbf94f2f0de 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceQueryBuilder.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.queries; import org.apache.lucene.search.Query; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ResolvedIndices; import org.elasticsearch.cluster.metadata.IndexMetadata; @@ -42,6 +43,8 @@ import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.SEMANTIC_SEARCH_CCS_SUPPORT; import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.convertFromBwcInferenceResultsMap; +import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.getInferenceResults; +import static org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder.getNewInferenceResultsFromSupplier; /** *

@@ -67,6 +70,7 @@ public abstract class InterceptedInferenceQueryBuilder inferenceResultsMap; + protected final SetOnce> inferenceResultsMapSupplier; protected final boolean ccsRequest; protected InterceptedInferenceQueryBuilder(T originalQuery) { @@ -77,6 +81,7 @@ protected InterceptedInferenceQueryBuilder(T originalQuery, Map other, Map inferenceResultsMap, + SetOnce> inferenceResultsMapSupplier, boolean ccsRequest ) { this.originalQuery = other.originalQuery; this.inferenceResultsMap = inferenceResultsMap; + this.inferenceResultsMapSupplier = inferenceResultsMapSupplier; this.ccsRequest = ccsRequest; } @@ -144,13 +153,18 @@ protected InterceptedInferenceQueryBuilder( protected abstract QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext); /** - * Generate a copy of {@code this} using the provided inference results map. + * Generate a copy of {@code this}. * * @param inferenceResultsMap The inference results map + * @param inferenceResultsMapSupplier The inference results map supplier * @param ccsRequest Flag indicating if this is a CCS request * @return A copy of {@code this} with the provided inference results map */ - protected abstract QueryBuilder copy(Map inferenceResultsMap, boolean ccsRequest); + protected abstract QueryBuilder copy( + Map inferenceResultsMap, + SetOnce> inferenceResultsMapSupplier, + boolean ccsRequest + ); /** * Rewrite to a {@link QueryBuilder} appropriate for a specific index's mappings. The implementation can use @@ -195,6 +209,12 @@ protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {} @Override protected void doWriteTo(StreamOutput out) throws IOException { + if (inferenceResultsMapSupplier != null) { + throw new IllegalStateException( + "inferenceResultsMapSupplier must be null, can't serialize suppliers, missing a rewriteAndFetch?" + ); + } + out.writeNamedWriteable(originalQuery); if (out.getTransportVersion().supports(INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS)) { out.writeOptional( @@ -238,12 +258,13 @@ protected Query doToQuery(SearchExecutionContext context) { protected boolean doEquals(InterceptedInferenceQueryBuilder other) { return Objects.equals(originalQuery, other.originalQuery) && Objects.equals(inferenceResultsMap, other.inferenceResultsMap) + && Objects.equals(inferenceResultsMapSupplier, other.inferenceResultsMapSupplier) && Objects.equals(ccsRequest, other.ccsRequest); } @Override protected int doHashCode() { - return Objects.hash(originalQuery, inferenceResultsMap, ccsRequest); + return Objects.hash(originalQuery, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); } @Override @@ -321,29 +342,37 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri ); } + if (inferenceResultsMapSupplier != null) { + // Additional inference results have already been requested, and we are waiting for them to continue the rewrite process + return getNewInferenceResultsFromSupplier(inferenceResultsMapSupplier, this, m -> copy(m, null, ccsRequest)); + } + FullyQualifiedInferenceId inferenceIdOverride = getInferenceIdOverride(); if (inferenceIdOverride != null) { inferenceIds = Set.of(inferenceIdOverride); } - QueryBuilder rewritten = this; - if (queryRewriteContext.hasAsyncActions() == false) { - // If the query is null, there's nothing to generate inference results for. This can happen if pre-computed inference results - // are provided by the user. Ensure that we set an empty inference results map in this case so that it is always non-null after - // coordinator node rewrite. - Map modifiedInferenceResultsMap = SemanticQueryBuilder.getInferenceResults( - queryRewriteContext, - inferenceIds, - this.inferenceResultsMap, - getQuery() - ); + SetOnce> newInferenceResultsMapSupplier = getInferenceResults( + queryRewriteContext, + inferenceIds, + inferenceResultsMap, + getQuery() + ); - if (modifiedInferenceResultsMap == this.inferenceResultsMap) { + QueryBuilder rewritten = this; + if (newInferenceResultsMapSupplier == null) { + // No additional inference results are required + if (inferenceResultsMap != null) { // The inference results map is fully populated, so we can perform error checking - inferenceResultsErrorCheck(modifiedInferenceResultsMap); + inferenceResultsErrorCheck(inferenceResultsMap); } else { - rewritten = copy(modifiedInferenceResultsMap, ccsRequest); + // No inference results have been collected yet, indicating we don't need any to rewrite this query. + // This can happen when pre-computed inference results are provided by the user. + // Set an empty inference results map so that rewriting can continue. + rewritten = copy(Map.of(), null, ccsRequest); } + } else { + rewritten = copy(inferenceResultsMap, newInferenceResultsMapSupplier, ccsRequest); } return rewritten; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java index d2f94593f4340..48a9d3910b01e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceSparseVectorQueryBuilder.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.queries; import org.apache.lucene.search.join.ScoreMode; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ResolvedIndices; import org.elasticsearch.cluster.metadata.IndexMetadata; @@ -61,9 +62,10 @@ public InterceptedInferenceSparseVectorQueryBuilder(StreamInput in) throws IOExc private InterceptedInferenceSparseVectorQueryBuilder( InterceptedInferenceQueryBuilder other, Map inferenceResultsMap, + SetOnce> inferenceResultsMapSupplier, boolean ccsRequest ) { - super(other, inferenceResultsMap, ccsRequest); + super(other, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); } @Override @@ -114,8 +116,12 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) { } @Override - protected QueryBuilder copy(Map inferenceResultsMap, boolean ccsRequest) { - return new InterceptedInferenceSparseVectorQueryBuilder(this, inferenceResultsMap, ccsRequest); + protected QueryBuilder copy( + Map inferenceResultsMap, + SetOnce> inferenceResultsMapSupplier, + boolean ccsRequest + ) { + return new InterceptedInferenceSparseVectorQueryBuilder(this, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 8ebdd3d119733..4060d1c6bc4a9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -8,14 +8,18 @@ package org.elasticsearch.xpack.inference.queries; import org.apache.lucene.search.Query; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ResolvedIndices; +import org.elasticsearch.action.support.GroupedActionListener; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Tuple; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.query.AbstractQueryBuilder; @@ -40,14 +44,16 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; import java.util.stream.Collectors; import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; @@ -93,6 +99,7 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder inferenceResultsMap; + private final SetOnce> inferenceResultsMapSupplier; private final Boolean lenient; // ccsRequest is only used on the local cluster coordinator node to detect when: @@ -135,6 +142,7 @@ protected SemanticQueryBuilder( this.fieldName = fieldName; this.query = query; this.inferenceResultsMap = inferenceResultsMap != null ? Map.copyOf(inferenceResultsMap) : null; + this.inferenceResultsMapSupplier = null; this.lenient = lenient; this.ccsRequest = ccsRequest; } @@ -143,6 +151,7 @@ public SemanticQueryBuilder(StreamInput in) throws IOException { super(in); this.fieldName = in.readString(); this.query = in.readString(); + if (in.getTransportVersion().supports(INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS)) { this.inferenceResultsMap = in.readOptional( i1 -> i1.readImmutableMap(FullyQualifiedInferenceId::new, i2 -> i2.readNamedWriteable(InferenceResults.class)) @@ -156,20 +165,30 @@ public SemanticQueryBuilder(StreamInput in) throws IOException { this.inferenceResultsMap = inferenceResults != null ? buildSingleResultInferenceResultsMap(inferenceResults) : null; in.readBoolean(); // Discard noInferenceResults, it is no longer necessary } + if (in.getTransportVersion().supports(TransportVersions.V_8_18_0)) { this.lenient = in.readOptionalBoolean(); } else { this.lenient = null; } + if (in.getTransportVersion().supports(SEMANTIC_SEARCH_CCS_SUPPORT)) { this.ccsRequest = in.readBoolean(); } else { this.ccsRequest = false; } + + this.inferenceResultsMapSupplier = null; } @Override protected void doWriteTo(StreamOutput out) throws IOException { + if (inferenceResultsMapSupplier != null) { + throw new IllegalStateException( + "inferenceResultsMapSupplier must be null, can't serialize suppliers, missing a rewriteAndFetch?" + ); + } + out.writeString(fieldName); out.writeString(query); if (out.getTransportVersion().supports(INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS)) { @@ -216,6 +235,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { private SemanticQueryBuilder( SemanticQueryBuilder other, Map inferenceResultsMap, + SetOnce> inferenceResultsMapSupplier, boolean ccsRequest ) { this.fieldName = other.fieldName; @@ -224,6 +244,7 @@ private SemanticQueryBuilder( this.queryName = other.queryName; // No need to copy the map here since this is only called internally. We can safely assume that the caller will not modify the map. this.inferenceResultsMap = inferenceResultsMap; + this.inferenceResultsMapSupplier = inferenceResultsMapSupplier; this.lenient = other.lenient; this.ccsRequest = ccsRequest; } @@ -255,31 +276,26 @@ public static SemanticQueryBuilder fromXContent(XContentParser parser) throws IO * Get inference results for the provided query using the provided fully qualified inference IDs. *

*

- * This method will return an inference results map that will be asynchronously populated with inference results. If the provided - * inference results map already contains all required inference results, the same map instance will be returned. Otherwise, a new map - * instance will be returned. It is guaranteed that a non-null map instance will be returned. + * This method will return an inference results map supplier that will provide a complete map of additional inference results required. + * If the provided inference results map already contains all required inference results, a null supplier will be returned. *

* * @param queryRewriteContext The query rewrite context * @param fullyQualifiedInferenceIds The fully qualified inference IDs to use to generate inference results * @param inferenceResultsMap The initial inference results map * @param query The query to generate inference results for - * @return An inference results map + * @return An inference results map supplier */ - static Map getInferenceResults( + static SetOnce> getInferenceResults( QueryRewriteContext queryRewriteContext, Set fullyQualifiedInferenceIds, @Nullable Map inferenceResultsMap, @Nullable String query ) { - boolean modifiedInferenceResultsMap = false; - Map currentInferenceResultsMap = inferenceResultsMap != null - ? inferenceResultsMap - : Map.of(); - + List inferenceIds = new ArrayList<>(fullyQualifiedInferenceIds.size()); if (query != null) { for (FullyQualifiedInferenceId fullyQualifiedInferenceId : fullyQualifiedInferenceIds) { - if (currentInferenceResultsMap.containsKey(fullyQualifiedInferenceId) == false) { + if (inferenceResultsMap == null || inferenceResultsMap.containsKey(fullyQualifiedInferenceId) == false) { if (fullyQualifiedInferenceId.clusterAlias().equals(queryRewriteContext.getLocalClusterAlias()) == false) { // Catch if we are missing inference results that should have been generated on another cluster throw new IllegalStateException( @@ -291,59 +307,96 @@ static Map getInferenceResults( ); } - if (modifiedInferenceResultsMap == false) { - // Copy the inference results map to ensure it is mutable and thread safe - currentInferenceResultsMap = new ConcurrentHashMap<>(currentInferenceResultsMap); - modifiedInferenceResultsMap = true; - } - - registerInferenceAsyncAction( - queryRewriteContext, - ((ConcurrentHashMap) currentInferenceResultsMap), - query, - fullyQualifiedInferenceId.inferenceId() - ); + inferenceIds.add(fullyQualifiedInferenceId.inferenceId()); } } } - return currentInferenceResultsMap; + SetOnce> inferenceResultsMapSupplier = null; + if (inferenceIds.isEmpty() == false) { + inferenceResultsMapSupplier = new SetOnce<>(); + registerInferenceAsyncActions(queryRewriteContext, inferenceResultsMapSupplier, query, inferenceIds); + } + + return inferenceResultsMapSupplier; } - static void registerInferenceAsyncAction( + static void registerInferenceAsyncActions( QueryRewriteContext queryRewriteContext, - ConcurrentHashMap inferenceResultsMap, + SetOnce> inferenceResultsMapSupplier, String query, - String inferenceId + List inferenceIds ) { - InferenceAction.Request inferenceRequest = new InferenceAction.Request( - TaskType.ANY, - inferenceId, - null, - null, - null, - List.of(query), - Map.of(), - InputType.INTERNAL_SEARCH, - null, - false - ); - - queryRewriteContext.registerAsyncAction( - (client, listener) -> executeAsyncWithOrigin( - client, - ML_ORIGIN, - InferenceAction.INSTANCE, - inferenceRequest, - listener.delegateFailureAndWrap((l, inferenceResponse) -> { - inferenceResultsMap.put( - new FullyQualifiedInferenceId(queryRewriteContext.getLocalClusterAlias(), inferenceId), - validateAndConvertInferenceResults(inferenceResponse.getResults(), inferenceId) - ); - l.onResponse(null); - }) + List inferenceRequests = inferenceIds.stream() + .map( + i -> new InferenceAction.Request( + TaskType.ANY, + i, + null, + null, + null, + List.of(query), + Map.of(), + InputType.INTERNAL_SEARCH, + null, + false + ) ) - ); + .toList(); + + queryRewriteContext.registerAsyncAction((client, listener) -> { + GroupedActionListener> gal = createGroupedActionListener( + inferenceResultsMapSupplier, + inferenceRequests.size(), + listener + ); + for (InferenceAction.Request inferenceRequest : inferenceRequests) { + FullyQualifiedInferenceId fullyQualifiedInferenceId = new FullyQualifiedInferenceId( + queryRewriteContext.getLocalClusterAlias(), + inferenceRequest.getInferenceEntityId() + ); + executeAsyncWithOrigin( + client, + ML_ORIGIN, + InferenceAction.INSTANCE, + inferenceRequest, + gal.delegateFailureAndWrap((l, inferenceResponse) -> { + InferenceResults inferenceResults = validateAndConvertInferenceResults( + inferenceResponse.getResults(), + fullyQualifiedInferenceId.inferenceId() + ); + l.onResponse(Tuple.tuple(fullyQualifiedInferenceId, inferenceResults)); + }) + ); + } + }); + } + + static T getNewInferenceResultsFromSupplier( + SetOnce> supplier, + T currentQueryBuilder, + Function, T> copyGenerator + ) { + Map newInferenceResultsMap = supplier.get(); + // It's safe to use only the new inference results map (once set) because we can enumerate the scenarios where we need to get + // inference results: + // - On the local coordinating node, getting inference results for the first time. The previous inference results map is null. + // - On the remote coordinating node, getting inference results for remote cluster inference IDs. In this case, we can guarantee + // that only remote cluster inference results are required to handle the query. + return newInferenceResultsMap != null ? copyGenerator.apply(newInferenceResultsMap) : currentQueryBuilder; + } + + private static GroupedActionListener> createGroupedActionListener( + SetOnce> inferenceResultsMapSupplier, + int inferenceRequestCount, + ActionListener listener + ) { + return new GroupedActionListener<>(inferenceRequestCount, listener.delegateFailureAndWrap((l, responses) -> { + Map inferenceResultsMap = new HashMap<>(responses.size()); + responses.forEach(r -> inferenceResultsMap.put(r.v1(), r.v2())); + inferenceResultsMapSupplier.set(inferenceResultsMap); + l.onResponse(null); + })); } static Map convertFromBwcInferenceResultsMap( @@ -459,26 +512,41 @@ private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext qu ); } - SemanticQueryBuilder rewritten = this; - if (queryRewriteContext.hasAsyncActions() == false) { - Set fullyQualifiedInferenceIds = getInferenceIdsForField( - resolvedIndices.getConcreteLocalIndicesMetadata().values(), - queryRewriteContext.getLocalClusterAlias(), - fieldName - ); - Map modifiedInferenceResultsMap = getInferenceResults( - queryRewriteContext, - fullyQualifiedInferenceIds, - inferenceResultsMap, - query + if (inferenceResultsMapSupplier != null) { + // Additional inference results have already been requested, and we are waiting for them to continue the rewrite process + return getNewInferenceResultsFromSupplier( + inferenceResultsMapSupplier, + this, + m -> new SemanticQueryBuilder(this, m, null, ccsRequest) ); + } - if (modifiedInferenceResultsMap == inferenceResultsMap) { + Set fullyQualifiedInferenceIds = getInferenceIdsForField( + resolvedIndices.getConcreteLocalIndicesMetadata().values(), + queryRewriteContext.getLocalClusterAlias(), + fieldName + ); + SetOnce> newInferenceResultsMapSupplier = getInferenceResults( + queryRewriteContext, + fullyQualifiedInferenceIds, + inferenceResultsMap, + query + ); + + SemanticQueryBuilder rewritten = this; + if (newInferenceResultsMapSupplier == null) { + // No additional inference results are required + if (inferenceResultsMap != null) { // The inference results map is fully populated, so we can perform error checking - inferenceResultsErrorCheck(modifiedInferenceResultsMap); + inferenceResultsErrorCheck(inferenceResultsMap); } else { - rewritten = new SemanticQueryBuilder(this, modifiedInferenceResultsMap, ccsRequest); + // No inference results have been collected yet, indicating we don't need any to rewrite this query. + // This can happen when querying an unsupported field type or an unavailable index. Set an empty inference results map so + // that rewriting can continue. + rewritten = new SemanticQueryBuilder(this, Map.of(), null, ccsRequest); } + } else { + rewritten = new SemanticQueryBuilder(this, inferenceResultsMap, newInferenceResultsMapSupplier, ccsRequest); } return rewritten; @@ -577,11 +645,12 @@ protected boolean doEquals(SemanticQueryBuilder other) { return Objects.equals(fieldName, other.fieldName) && Objects.equals(query, other.query) && Objects.equals(inferenceResultsMap, other.inferenceResultsMap) + && Objects.equals(inferenceResultsMapSupplier, other.inferenceResultsMapSupplier) && Objects.equals(ccsRequest, other.ccsRequest); } @Override protected int doHashCode() { - return Objects.hash(fieldName, query, inferenceResultsMap, ccsRequest); + return Objects.hash(fieldName, query, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/FakeMlPlugin.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/FakeMlPlugin.java new file mode 100644 index 0000000000000..2a521bfaaac41 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/FakeMlPlugin.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.plugins.ActionPlugin; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.SearchPlugin; +import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder; +import org.elasticsearch.xpack.ml.action.TransportCoordinatedInferenceAction; + +import java.util.Collection; +import java.util.List; + +public class FakeMlPlugin extends Plugin implements ActionPlugin, SearchPlugin { + @Override + public List getNamedWriteables() { + return new MlInferenceNamedXContentProvider().getNamedWriteables(); + } + + @Override + public List> getQueryVectorBuilders() { + return List.of( + new QueryVectorBuilderSpec<>( + TextEmbeddingQueryVectorBuilder.NAME, + TextEmbeddingQueryVectorBuilder::new, + TextEmbeddingQueryVectorBuilder.PARSER + ) + ); + } + + @Override + public Collection getActions() { + return List.of(new ActionHandler(CoordinatedInferenceAction.INSTANCE, TransportCoordinatedInferenceAction.class)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index 160261eb1b1cc..cb5d1d40e2c2a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -33,7 +33,6 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.compress.CompressedXContent; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; @@ -69,9 +68,9 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.DenseEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.results.MlDenseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.inference.FakeMlPlugin; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.registry.ModelRegistry; @@ -607,13 +606,6 @@ private static MinimalServiceSettings getModelSettingsForInferenceResultType( }; } - public static class FakeMlPlugin extends Plugin { - @Override - public List getNamedWriteables() { - return new MlInferenceNamedXContentProvider().getNamedWriteables(); - } - } - private static TestThreadPool threadPool; private static ModelRegistry modelRegistry;