Skip to content

Commit

Permalink
Test serializing the query when no inference ID could be resolved
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikep86 committed May 17, 2024
1 parent bceeb84 commit 27dff4b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ private static class ServiceHolder implements Closeable {
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
entries.addAll(IndicesModule.getNamedWriteables());
entries.addAll(searchModule.getNamedWriteables());
pluginsService.forEach(plugin -> entries.addAll(plugin.getNamedWriteables()));
namedWriteableRegistry = new NamedWriteableRegistry(entries);
parserConfiguration = XContentParserConfiguration.EMPTY.withRegistry(
new NamedXContentRegistry(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,8 @@ protected QueryBuilder rewriteQuery(
// The first rewriteAndFetch call simulates rewriting on the coordinator node
// The second rewriteAndFetch call simulates rewriting on the shard
QueryBuilder rewritten = rewriteAndFetch(queryBuilder, coordinatorRewriteContext);
// extra safety to fail fast - serialize the rewritten version to ensure it's serializable.
assertSerialization(rewritten);
rewritten = rewriteAndFetch(rewritten, shardRewriteContext);
// extra safety to fail fast - serialize the rewritten version to ensure it's serializable.
assertSerialization(rewritten);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.compress.CompressedXContent;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.ParsedDocument;
import org.elasticsearch.index.mapper.SourceToParse;
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.index.search.ESToParentBlockJoinQuery;
import org.elasticsearch.inference.InputType;
Expand All @@ -42,6 +45,7 @@
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.inference.InferencePlugin;
Expand Down Expand Up @@ -98,7 +102,7 @@ public void setUp() throws Exception {

@Override
protected Collection<Class<? extends Plugin>> getPlugins() {
return List.of(InferencePlugin.class);
return List.of(InferencePlugin.class, FakeMlPlugin.class);
}

@Override
Expand Down Expand Up @@ -292,6 +296,20 @@ public void testToXContent() throws IOException {
}""", queryBuilder);
}

public void testSerializingQueryWhenNoInferenceId() throws IOException {
// Test serializing the query after rewriting on the coordinator node when no inference ID could be resolved for the field
SemanticQueryBuilder builder = new SemanticQueryBuilder(SEMANTIC_TEXT_FIELD + "_missing", "query text");

QueryRewriteContext queryRewriteContext = createQueryRewriteContext();
queryRewriteContext.setAllowUnmappedFields(true);

SearchExecutionContext searchExecutionContext = createSearchExecutionContext();
searchExecutionContext.setAllowUnmappedFields(true);

QueryBuilder rewritten = rewriteQuery(builder, queryRewriteContext, searchExecutionContext);
assertThat(rewritten, instanceOf(MatchNoneQueryBuilder.class));
}

private static SourceToParse buildSemanticTextFieldWithInferenceResults(InferenceResultType inferenceResultType) throws IOException {
SemanticTextField.ModelSettings modelSettings = switch (inferenceResultType) {
case NONE -> null;
Expand Down Expand Up @@ -321,4 +339,11 @@ private static SourceToParse buildSemanticTextFieldWithInferenceResults(Inferenc

return sourceToParse;
}

public static class FakeMlPlugin extends Plugin {
@Override
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
return new MlInferenceNamedXContentProvider().getNamedWriteables();
}
}
}

0 comments on commit 27dff4b

Please sign in to comment.