Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/136720.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 136720
summary: Use Suppliers To Get Inference Results In Semantic Queries
area: Vector Search
type: bug
issues:
- 136621
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import org.elasticsearch.action.support.broadcast.BroadcastResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
Expand All @@ -31,27 +30,16 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.license.LicenseSettings;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.AbstractMultiClustersTestCase;
import org.elasticsearch.transport.RemoteConnectionInfo;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder;
import org.elasticsearch.xpack.inference.FakeMlPlugin;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin;
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
import org.elasticsearch.xpack.ml.action.TransportCoordinatedInferenceAction;

import java.io.IOException;
import java.util.Collection;
Expand All @@ -66,6 +54,7 @@

import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
import static org.elasticsearch.xpack.inference.integration.IntegrationTestUtils.createInferenceEndpoint;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;

Expand Down Expand Up @@ -165,35 +154,6 @@ protected BytesReference openPointInTime(String[] indices, TimeValue keepAlive)
return response.getPointInTimeId();
}

protected static void createInferenceEndpoint(Client client, TaskType taskType, String inferenceId, Map<String, Object> serviceSettings)
throws IOException {
final String service = switch (taskType) {
case TEXT_EMBEDDING -> TestDenseInferenceServiceExtension.TestInferenceService.NAME;
case SPARSE_EMBEDDING -> TestSparseInferenceServiceExtension.TestInferenceService.NAME;
default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]");
};

final BytesReference content;
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
builder.startObject();
builder.field("service", service);
builder.field("service_settings", serviceSettings);
builder.endObject();

content = BytesReference.bytes(builder);
}

PutInferenceModelAction.Request request = new PutInferenceModelAction.Request(
taskType,
inferenceId,
content,
XContentType.JSON,
TEST_REQUEST_TIMEOUT
);
var responseFuture = client.execute(PutInferenceModelAction.INSTANCE, request);
assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId));
}

protected void assertSearchResponse(QueryBuilder queryBuilder, List<IndexWithBoost> indices, List<SearchResult> expectedSearchResults)
throws Exception {
assertSearchResponse(queryBuilder, indices, expectedSearchResults, null, null);
Expand Down Expand Up @@ -307,29 +267,6 @@ protected static String[] convertToArray(List<IndexWithBoost> indices) {
return indices.stream().map(IndexWithBoost::index).toArray(String[]::new);
}

public static class FakeMlPlugin extends Plugin implements ActionPlugin, SearchPlugin {
@Override
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
return new MlInferenceNamedXContentProvider().getNamedWriteables();
}

@Override
public List<QueryVectorBuilderSpec<?>> getQueryVectorBuilders() {
return List.of(
new QueryVectorBuilderSpec<>(
TextEmbeddingQueryVectorBuilder.NAME,
TextEmbeddingQueryVectorBuilder::new,
TextEmbeddingQueryVectorBuilder.PARSER
)
);
}

@Override
public Collection<ActionHandler> getActions() {
return List.of(new ActionHandler(CoordinatedInferenceAction.INSTANCE, TransportCoordinatedInferenceAction.class));
}
}

protected record TestIndexInfo(
String name,
Map<String, MinimalServiceSettings> inferenceEndpoints,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.integration;

import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.DeleteInferenceEndpointAction;
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;

import java.io.IOException;
import java.util.Map;

import static org.elasticsearch.test.ESTestCase.TEST_REQUEST_TIMEOUT;
import static org.elasticsearch.test.ESTestCase.safeGet;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;

public class IntegrationTestUtils {
private IntegrationTestUtils() {}

public static void createInferenceEndpoint(Client client, TaskType taskType, String inferenceId, Map<String, Object> serviceSettings)
throws IOException {
final String service = switch (taskType) {
case TEXT_EMBEDDING -> TestDenseInferenceServiceExtension.TestInferenceService.NAME;
case SPARSE_EMBEDDING -> TestSparseInferenceServiceExtension.TestInferenceService.NAME;
default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]");
};

final BytesReference content;
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
builder.startObject();
builder.field("service", service);
builder.field("service_settings", serviceSettings);
builder.endObject();

content = BytesReference.bytes(builder);
}

PutInferenceModelAction.Request request = new PutInferenceModelAction.Request(
taskType,
inferenceId,
content,
XContentType.JSON,
TEST_REQUEST_TIMEOUT
);
var responseFuture = client.execute(PutInferenceModelAction.INSTANCE, request);
assertThat(responseFuture.actionGet(TEST_REQUEST_TIMEOUT).getModel().getInferenceEntityId(), equalTo(inferenceId));
}

public static void deleteInferenceEndpoint(Client client, TaskType taskType, String inferenceId) {
assertAcked(
safeGet(
client.execute(
DeleteInferenceEndpointAction.INSTANCE,
new DeleteInferenceEndpointAction.Request(inferenceId, taskType, true, false)
)
)
);
}

public static void deleteIndex(Client client, String indexName) {
assertAcked(
safeGet(
client.admin()
.indices()
.prepareDelete(indexName)
.setIndicesOptions(
IndicesOptions.builder().concreteTargetOptions(new IndicesOptions.ConcreteTargetOptions(true)).build()
)
.execute()
)
);
}

public static XContentBuilder generateSemanticTextMapping(Map<String, String> semanticTextFields) throws IOException {
XContentBuilder mapping = XContentFactory.jsonBuilder().startObject().startObject("properties");
for (var entry : semanticTextFields.entrySet()) {
mapping.startObject(entry.getKey());
mapping.field("type", SemanticTextFieldMapper.CONTENT_TYPE);
mapping.field("inference_id", entry.getValue());
mapping.endObject();
}
mapping.endObject().endObject();

return mapping;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.integration;

import org.elasticsearch.action.DocWriteResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.MatchQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.license.LicenseSettings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.reindex.ReindexPlugin;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder;
import org.elasticsearch.xpack.inference.FakeMlPlugin;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin;
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
import org.junit.After;

import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;

import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;

@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE, numDataNodes = 1)
public class ManyInferenceQueryClausesIT extends ESIntegTestCase {
private static final String INDEX_NAME = "test_index";

private static final Map<String, Object> SPARSE_EMBEDDING_SERVICE_SETTINGS = Map.of("model", "my_model", "api_key", "my_api_key");
private static final Map<String, Object> TEXT_EMBEDDING_SERVICE_SETTINGS = Map.of(
"model",
"my_model",
"dimensions",
256,
"similarity",
"cosine",
"api_key",
"my_api_key"
);

private final Map<String, TaskType> inferenceIds = new HashMap<>();

@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build();
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class, ReindexPlugin.class, FakeMlPlugin.class);
}

@After
public void cleanUp() {
IntegrationTestUtils.deleteIndex(client(), INDEX_NAME);
for (var entry : inferenceIds.entrySet()) {
IntegrationTestUtils.deleteInferenceEndpoint(client(), entry.getValue(), entry.getKey());
}
}

public void testManySemanticQueryClauses() throws Exception {
manyQueryClausesTestCase(randomIntBetween(18, 24), SemanticQueryBuilder::new, TaskType.SPARSE_EMBEDDING);
}

public void testManyMatchQueryClauses() throws Exception {
manyQueryClausesTestCase(randomIntBetween(18, 24), MatchQueryBuilder::new, TaskType.SPARSE_EMBEDDING);
}

public void testManySparseVectorQueryClauses() throws Exception {
manyQueryClausesTestCase(randomIntBetween(18, 24), (f, q) -> new SparseVectorQueryBuilder(f, null, q), TaskType.SPARSE_EMBEDDING);
}

public void testManyKnnQueryClauses() throws Exception {
int clauseCount = randomIntBetween(18, 24);
manyQueryClausesTestCase(
clauseCount,
(f, q) -> new KnnVectorQueryBuilder(f, new TextEmbeddingQueryVectorBuilder(null, q), clauseCount, clauseCount * 10, null, null),
TaskType.TEXT_EMBEDDING
);
}

private void manyQueryClausesTestCase(
int clauseCount,
BiFunction<String, String, QueryBuilder> clauseGenerator,
TaskType semanticTextFieldTaskType
) throws Exception {
Map<String, Object> inferenceEndpointServiceSettings = getServiceSettings(semanticTextFieldTaskType);
Map<String, String> semanticTextFields = new HashMap<>(clauseCount);
for (int i = 0; i < clauseCount; i++) {
String fieldName = randomAlphaOfLength(10);
String inferenceId = randomIdentifier();

createInferenceEndpoint(semanticTextFieldTaskType, inferenceId, inferenceEndpointServiceSettings);
semanticTextFields.put(fieldName, inferenceId);
}

XContentBuilder mapping = IntegrationTestUtils.generateSemanticTextMapping(semanticTextFields);
assertAcked(prepareCreate(INDEX_NAME).setMapping(mapping));

BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
for (String semanticTextField : semanticTextFields.keySet()) {
Map<String, Object> source = Map.of(semanticTextField, randomAlphaOfLength(10));
DocWriteResponse docWriteResponse = client().prepareIndex(INDEX_NAME).setSource(source).get(TEST_REQUEST_TIMEOUT);
assertThat(docWriteResponse.getResult(), is(DocWriteResponse.Result.CREATED));

boolQuery.should(clauseGenerator.apply(semanticTextField, randomAlphaOfLength(10)));
}
client().admin().indices().prepareRefresh(INDEX_NAME).get();

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(boolQuery).size(clauseCount);
SearchRequest searchRequest = new SearchRequest(new String[] { INDEX_NAME }, searchSourceBuilder);
assertResponse(client().search(searchRequest), response -> {
assertThat(response.getSuccessfulShards(), equalTo(response.getTotalShards()));
assertThat(response.getHits().getTotalHits().value(), equalTo((long) clauseCount));
});
}

private void createInferenceEndpoint(TaskType taskType, String inferenceId, Map<String, Object> serviceSettings) throws IOException {
IntegrationTestUtils.createInferenceEndpoint(client(), taskType, inferenceId, serviceSettings);
inferenceIds.put(inferenceId, taskType);
}

private static Map<String, Object> getServiceSettings(TaskType taskType) {
return switch (taskType) {
case SPARSE_EMBEDDING -> SPARSE_EMBEDDING_SERVICE_SETTINGS;
case TEXT_EMBEDDING -> TEXT_EMBEDDING_SERVICE_SETTINGS;
default -> throw new IllegalArgumentException("Unhandled task type [" + taskType + "]");
};
}
}
Loading