diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java index 7d5197b9e9ba0..675d062fdb3af 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilder.java @@ -18,7 +18,6 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.xcontent.ParseField; @@ -32,8 +31,10 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; import java.io.IOException; +import java.util.Arrays; import java.util.List; import java.util.Objects; +import java.util.stream.Collectors; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; @@ -51,6 +52,34 @@ public class TextExpansionQueryBuilder extends AbstractQueryBuilder weightedTokensSupplier; private final TokenPruningConfig tokenPruningConfig; + public enum AllowedFieldType { + RANK_FEATURES("rank_features"), + SPARSE_VECTOR("sparse_vector"); + + private final String typeName; + + AllowedFieldType(String typeName) { + this.typeName = typeName; + } + + public String getTypeName() { + return typeName; + } + + public static boolean isFieldTypeAllowed(String typeName) { + for (AllowedFieldType fieldType : values()) { + if (fieldType.getTypeName().equals(typeName)) { + return true; + } + } + return false; + } + + public static String getAllowedFieldTypesAsString() { + return Arrays.stream(values()).map(value -> value.typeName).collect(Collectors.joining(", ")); + } + } + public TextExpansionQueryBuilder(String fieldName, String modelText, String modelId) { this(fieldName, modelText, modelId, null); } @@ -198,24 +227,14 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws } private QueryBuilder weightedTokensToQuery(String fieldName, TextExpansionResults textExpansionResults) { - if (tokenPruningConfig != null) { - WeightedTokensQueryBuilder weightedTokensQueryBuilder = new WeightedTokensQueryBuilder( - fieldName, - textExpansionResults.getWeightedTokens(), - tokenPruningConfig - ); - weightedTokensQueryBuilder.queryName(queryName); - weightedTokensQueryBuilder.boost(boost); - return weightedTokensQueryBuilder; - } - var boolQuery = QueryBuilders.boolQuery(); - for (var weightedToken : textExpansionResults.getWeightedTokens()) { - boolQuery.should(QueryBuilders.termQuery(fieldName, weightedToken.token()).boost(weightedToken.weight())); - } - boolQuery.minimumShouldMatch(1); - boolQuery.boost(this.boost); - boolQuery.queryName(this.queryName); - return boolQuery; + WeightedTokensQueryBuilder weightedTokensQueryBuilder = new WeightedTokensQueryBuilder( + fieldName, + textExpansionResults.getWeightedTokens(), + tokenPruningConfig + ); + weightedTokensQueryBuilder.queryName(queryName); + weightedTokensQueryBuilder.boost(boost); + return weightedTokensQueryBuilder; } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java index a09bcadaacfc0..51139881fc2e4 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/queries/WeightedTokensQueryBuilder.java @@ -34,6 +34,7 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.xpack.ml.queries.TextExpansionQueryBuilder.AllowedFieldType; import static org.elasticsearch.xpack.ml.queries.TextExpansionQueryBuilder.PRUNING_CONFIG; public class WeightedTokensQueryBuilder extends AbstractQueryBuilder { @@ -152,27 +153,53 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { if (ft == null) { return new MatchNoDocsQuery("The \"" + getName() + "\" query is against a field that does not exist"); } + + final String fieldTypeName = ft.typeName(); + if (AllowedFieldType.isFieldTypeAllowed(fieldTypeName) == false) { + throw new ElasticsearchParseException( + "[" + + fieldTypeName + + "]" + + " is not an appropriate field type for this query. " + + "Allowed field types are [" + + AllowedFieldType.getAllowedFieldTypesAsString() + + "]." + ); + } + + return (this.tokenPruningConfig == null) + ? queryBuilderWithAllTokens(tokens, ft, context) + : queryBuilderWithPrunedTokens(tokens, ft, context); + } + + private Query queryBuilderWithAllTokens(List tokens, MappedFieldType ft, SearchExecutionContext context) { var qb = new BooleanQuery.Builder(); - int fieldDocCount = context.getIndexReader().getDocCount(fieldName); - float bestWeight = 0f; - for (var t : tokens) { - bestWeight = Math.max(t.weight(), bestWeight); + + for (var token : tokens) { + qb.add(new BoostQuery(ft.termQuery(token.token(), context), token.weight()), BooleanClause.Occur.SHOULD); } + return qb.setMinimumNumberShouldMatch(1).build(); + } + + private Query queryBuilderWithPrunedTokens(List tokens, MappedFieldType ft, SearchExecutionContext context) + throws IOException { + var qb = new BooleanQuery.Builder(); + int fieldDocCount = context.getIndexReader().getDocCount(fieldName); + float bestWeight = tokens.stream().map(WeightedToken::weight).reduce(0f, Math::max); float averageTokenFreqRatio = getAverageTokenFreqRatio(context.getIndexReader(), fieldDocCount); if (averageTokenFreqRatio == 0) { return new MatchNoDocsQuery("The \"" + getName() + "\" query is against an empty field"); } + for (var token : tokens) { boolean keep = shouldKeepToken(context.getIndexReader(), token, fieldDocCount, averageTokenFreqRatio, bestWeight); - if (this.tokenPruningConfig != null) { - keep ^= this.tokenPruningConfig.isOnlyScorePrunedTokens(); - } + keep ^= this.tokenPruningConfig.isOnlyScorePrunedTokens(); if (keep) { qb.add(new BoostQuery(ft.termQuery(token.token(), context), token.weight()), BooleanClause.Occur.SHOULD); } } - qb.setMinimumNumberShouldMatch(1); - return qb.build(); + + return qb.setMinimumNumberShouldMatch(1).build(); } @Override diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java index 13f12f3cdc1e1..50561d92f5d37 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/queries/TextExpansionQueryBuilderTests.java @@ -25,7 +25,6 @@ import org.elasticsearch.common.compress.CompressedXContent; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.extras.MapperExtrasPlugin; -import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.plugins.Plugin; @@ -260,10 +259,6 @@ public void testThatTokensAreCorrectlyPruned() { SearchExecutionContext searchExecutionContext = createSearchExecutionContext(); TextExpansionQueryBuilder queryBuilder = createTestQueryBuilder(); QueryBuilder rewrittenQueryBuilder = rewriteAndFetch(queryBuilder, searchExecutionContext); - if (queryBuilder.getTokenPruningConfig() == null) { - assertTrue(rewrittenQueryBuilder instanceof BoolQueryBuilder); - } else { - assertTrue(rewrittenQueryBuilder instanceof WeightedTokensQueryBuilder); - } + assertTrue(rewrittenQueryBuilder instanceof WeightedTokensQueryBuilder); } } diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/text_expansion_search.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/text_expansion_search.yml index 5e29d3cdf2ae6..dc4e1751ccdee 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/text_expansion_search.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/text_expansion_search.yml @@ -287,3 +287,20 @@ setup: tokens_weight_threshold: 0.4 only_score_pruned_tokens: true - match: { hits.total.value: 0 } + +--- +"Test text-expansion that displays error for invalid queried field type": + - skip: + version: " - 8.13.99" + reason: "validation for invalid field type introduced in 8.14.0" + + - do: + catch: /\[keyword\] is not an appropriate field type for this query/ + search: + index: index-with-rank-features + body: + query: + text_expansion: + source_text: + model_id: text_expansion_model + model_text: "octopus comforter smells"