Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
eaa3054
Display error for text_expansion if the queried field does not have t…
saikatsarkar056 Feb 15, 2024
496aa0c
Display error for text_expansion if the queried field does not have t…
saikatsarkar056 Feb 15, 2024
3601fa1
Display error for text_expansion if the queried field does not have t…
saikatsarkar056 Feb 16, 2024
d287e38
Display error for text_expansion if the queried field does not have t…
saikatsarkar056 Feb 16, 2024
c29a17e
Display error for text_expansion if the queried field does not have t…
saikatsarkar056 Feb 16, 2024
df7d7ed
Display error for text_expansion if the queried field does not have t…
saikatsarkar056 Feb 21, 2024
023a5c7
Display error for text_expansion if the queried field does not have t…
saikatsarkar056 Feb 21, 2024
fc6f437
Display error for text_expansion if the queried field does not have t…
saikatsarkar056 Feb 21, 2024
454f3c6
Display error for text_expansion if the queried field does not have t…
saikatsarkar056 Feb 21, 2024
b2f7b0d
Display error for text_expansion if the queried field does not have t…
saikatsarkar056 Feb 21, 2024
f7b821c
Display error for text_expansion if the queried field does not have t…
saikatsarkar056 Feb 21, 2024
d256853
Display error for text_expansion if the queried field does not have t…
saikatsarkar056 Feb 21, 2024
762409d
Display error for text_expansion if the queried field does not have t…
saikatsarkar056 Feb 21, 2024
f4c2c33
Display error for text_expansion if the queried field does not have t…
saikatsarkar056 Feb 21, 2024
91f0a73
Display error for text_expansion if the queried field does not have t…
saikatsarkar056 Feb 21, 2024
6113c8c
Clean up the code
saikatsarkar056 Feb 22, 2024
d536e07
Optimize the code for pruning config
saikatsarkar056 Feb 22, 2024
3fb95e7
Optimize the code for pruning config
saikatsarkar056 Feb 22, 2024
cb071b6
Write findBestWeightFor for clear code
saikatsarkar056 Feb 22, 2024
ee3233b
Run Spotless
saikatsarkar056 Feb 22, 2024
8969c31
Remove findBestWeightFor method
saikatsarkar056 Feb 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.xcontent.ParseField;
Expand All @@ -32,8 +31,10 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
Expand All @@ -51,6 +52,34 @@ public class TextExpansionQueryBuilder extends AbstractQueryBuilder<TextExpansio
private SetOnce<TextExpansionResults> weightedTokensSupplier;
private final TokenPruningConfig tokenPruningConfig;

public enum AllowedFieldType {
RANK_FEATURES("rank_features"),
SPARSE_VECTOR("sparse_vector");

private final String typeName;

AllowedFieldType(String typeName) {
this.typeName = typeName;
}

public String getTypeName() {
return typeName;
}

public static boolean isFieldTypeAllowed(String typeName) {
for (AllowedFieldType fieldType : values()) {
if (fieldType.getTypeName().equals(typeName)) {
return true;
}
}
return false;
}

public static String getAllowedFieldTypesAsString() {
return Arrays.stream(values()).map(value -> value.typeName).collect(Collectors.joining(", "));
}
}

public TextExpansionQueryBuilder(String fieldName, String modelText, String modelId) {
this(fieldName, modelText, modelId, null);
}
Expand Down Expand Up @@ -198,24 +227,14 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
}

private QueryBuilder weightedTokensToQuery(String fieldName, TextExpansionResults textExpansionResults) {
if (tokenPruningConfig != null) {
WeightedTokensQueryBuilder weightedTokensQueryBuilder = new WeightedTokensQueryBuilder(
fieldName,
textExpansionResults.getWeightedTokens(),
tokenPruningConfig
);
weightedTokensQueryBuilder.queryName(queryName);
weightedTokensQueryBuilder.boost(boost);
return weightedTokensQueryBuilder;
}
var boolQuery = QueryBuilders.boolQuery();
for (var weightedToken : textExpansionResults.getWeightedTokens()) {
boolQuery.should(QueryBuilders.termQuery(fieldName, weightedToken.token()).boost(weightedToken.weight()));
}
boolQuery.minimumShouldMatch(1);
boolQuery.boost(this.boost);
boolQuery.queryName(this.queryName);
return boolQuery;
WeightedTokensQueryBuilder weightedTokensQueryBuilder = new WeightedTokensQueryBuilder(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice and clean. I think we could think about another optimization here - Right now WeightedTokensQueryBuilder.toToQuery pulls the field document count and calculates the token frequency ratio for every query. Since token pruning is an opt in feature for text expansion queries, we might want to update that method in the WeightedTokenBuilder to short-circuit this and only get these values if we have a non-null token pruning configuration. WDYT?

Copy link
Contributor Author

@saikatsarkar056 saikatsarkar056 Feb 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kderusso I think you mean doToQuery method here. My understanding is that we should only calculate the token frequency ratio and return BooleanQuery if we have non-null token pruning configuration. Am I right? So, we should go to the following direction:

if (this.tokenPruningConfig == null) {
   return new MatchNoDocsQuery("The \"" + getName() + "\" query does not have any pruning configuration");
}

var qb = new BooleanQuery.Builder();
int fieldDocCount = context.getIndexReader().getDocCount(fieldName);
...

Please let me know if my understanding is correct about token pruning.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I had a typo - doToQuery - here's the link to the method I'm talking about.

Your suggestion to return a MatchNoDocsQuery is not what we want here. If we did that, every time we sent in a text expansion query without a pruning configuration, no documents would ever be returned. Since pruning configuration is opt-in and optional this is very undesireable behavior.

No, what I'm suggesting is altering how we determine whether we want to keep tokens in the WeightedTokensQueryBuilder.

Today, we do the following:

  1. Get the document count for the field name
  2. Calculate the best weight for each token
  3. Get the average token frequency ratio
  4. Start building a boolean query, determining if we should keep each token.

If the pruning configuration is null, two things are true:

  1. We want to keep every token, because we don't want any tokens to be pruned
  2. Because we want to keep every token there is no need to calculate the counts or frequency ratios.

So I propose we short-circuit this and only calculate those ratios if there exists a pruning configuration.

Does this make sense to you?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the explanation. Now, I got the idea about this optimization. I will change the code and notify you for another review.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kderusso I did some optimization and code clean-up around pruning configuration. Can you please review the changes again? Thank you.

fieldName,
textExpansionResults.getWeightedTokens(),
tokenPruningConfig
);
weightedTokensQueryBuilder.queryName(queryName);
weightedTokensQueryBuilder.boost(boost);
return weightedTokensQueryBuilder;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.xpack.ml.queries.TextExpansionQueryBuilder.AllowedFieldType;
import static org.elasticsearch.xpack.ml.queries.TextExpansionQueryBuilder.PRUNING_CONFIG;

public class WeightedTokensQueryBuilder extends AbstractQueryBuilder<WeightedTokensQueryBuilder> {
Expand Down Expand Up @@ -152,27 +153,53 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
if (ft == null) {
return new MatchNoDocsQuery("The \"" + getName() + "\" query is against a field that does not exist");
}

final String fieldTypeName = ft.typeName();
if (AllowedFieldType.isFieldTypeAllowed(fieldTypeName) == false) {
throw new ElasticsearchParseException(
"["
+ fieldTypeName
+ "]"
+ " is not an appropriate field type for this query. "
+ "Allowed field types are ["
+ AllowedFieldType.getAllowedFieldTypesAsString()
+ "]."
);
}

return (this.tokenPruningConfig == null)
? queryBuilderWithAllTokens(tokens, ft, context)
: queryBuilderWithPrunedTokens(tokens, ft, context);
}

private Query queryBuilderWithAllTokens(List<WeightedToken> tokens, MappedFieldType ft, SearchExecutionContext context) {
var qb = new BooleanQuery.Builder();
int fieldDocCount = context.getIndexReader().getDocCount(fieldName);
float bestWeight = 0f;
for (var t : tokens) {
bestWeight = Math.max(t.weight(), bestWeight);

for (var token : tokens) {
qb.add(new BoostQuery(ft.termQuery(token.token(), context), token.weight()), BooleanClause.Occur.SHOULD);
}
return qb.setMinimumNumberShouldMatch(1).build();
}

private Query queryBuilderWithPrunedTokens(List<WeightedToken> tokens, MappedFieldType ft, SearchExecutionContext context)
throws IOException {
var qb = new BooleanQuery.Builder();
int fieldDocCount = context.getIndexReader().getDocCount(fieldName);
float bestWeight = tokens.stream().map(WeightedToken::weight).reduce(0f, Math::max);
float averageTokenFreqRatio = getAverageTokenFreqRatio(context.getIndexReader(), fieldDocCount);
if (averageTokenFreqRatio == 0) {
return new MatchNoDocsQuery("The \"" + getName() + "\" query is against an empty field");
}

for (var token : tokens) {
boolean keep = shouldKeepToken(context.getIndexReader(), token, fieldDocCount, averageTokenFreqRatio, bestWeight);
if (this.tokenPruningConfig != null) {
keep ^= this.tokenPruningConfig.isOnlyScorePrunedTokens();
}
keep ^= this.tokenPruningConfig.isOnlyScorePrunedTokens();
if (keep) {
qb.add(new BoostQuery(ft.termQuery(token.token(), context), token.weight()), BooleanClause.Occur.SHOULD);
}
}
qb.setMinimumNumberShouldMatch(1);
return qb.build();

return qb.setMinimumNumberShouldMatch(1).build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.elasticsearch.common.compress.CompressedXContent;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.extras.MapperExtrasPlugin;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.plugins.Plugin;
Expand Down Expand Up @@ -260,10 +259,6 @@ public void testThatTokensAreCorrectlyPruned() {
SearchExecutionContext searchExecutionContext = createSearchExecutionContext();
TextExpansionQueryBuilder queryBuilder = createTestQueryBuilder();
QueryBuilder rewrittenQueryBuilder = rewriteAndFetch(queryBuilder, searchExecutionContext);
if (queryBuilder.getTokenPruningConfig() == null) {
assertTrue(rewrittenQueryBuilder instanceof BoolQueryBuilder);
} else {
assertTrue(rewrittenQueryBuilder instanceof WeightedTokensQueryBuilder);
}
assertTrue(rewrittenQueryBuilder instanceof WeightedTokensQueryBuilder);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,20 @@ setup:
tokens_weight_threshold: 0.4
only_score_pruned_tokens: true
- match: { hits.total.value: 0 }

---
"Test text-expansion that displays error for invalid queried field type":
- skip:
version: " - 8.13.99"
reason: "validation for invalid field type introduced in 8.14.0"

- do:
catch: /\[keyword\] is not an appropriate field type for this query/
search:
index: index-with-rank-features
body:
query:
text_expansion:
source_text:
model_id: text_expansion_model
model_text: "octopus comforter smells"