Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Expose term frequency in Painless script score context #9081

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Make SearchTemplateRequest implement IndicesRequest.Replaceable ([#9122]()https://github.com/opensearch-project/OpenSearch/pull/9122)
- [BWC and API enforcement] Define the initial set of annotations, their meaning and relations between them ([#9223](https://github.com/opensearch-project/OpenSearch/pull/9223))
- [Segment Replication] Support realtime reads for GET requests ([#9212](https://github.com/opensearch-project/OpenSearch/pull/9212))
- [Feature] Expose term frequency in Painless script score context ([#9081](https://github.com/opensearch-project/OpenSearch/pull/9081))

### Dependencies
- Bump `org.apache.logging.log4j:log4j-core` from 2.17.1 to 2.20.0 ([#8307](https://github.com/opensearch-project/OpenSearch/pull/8307))
Expand Down Expand Up @@ -164,4 +165,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Security

[Unreleased 3.0]: https://github.com/opensearch-project/OpenSearch/compare/2.x...HEAD
[Unreleased 2.x]: https://github.com/opensearch-project/OpenSearch/compare/2.10...2.x
[Unreleased 2.x]: https://github.com/opensearch-project/OpenSearch/compare/2.10...2.x
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@

@Override
public ScoreScript newInstance(final LeafReaderContext leaf) throws IOException {
return new ScoreScript(null, null, null) {
return new ScoreScript(null, null, null, null) {

Check warning on line 69 in modules/lang-expression/src/main/java/org/opensearch/script/expression/ExpressionScoreScript.java

View check run for this annotation

Codecov / codecov/patch

modules/lang-expression/src/main/java/org/opensearch/script/expression/ExpressionScoreScript.java#L69

Added line #L69 was not covered by tests
// Fake the scorer until setScorer is called.
DoubleValues values = source.getValues(leaf, new DoubleValues() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.lucene.expressions.js.JavascriptCompiler;
import org.apache.lucene.expressions.js.VariableContext;
import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.SpecialPermission;
import org.opensearch.common.Nullable;
import org.opensearch.index.fielddata.IndexFieldData;
Expand Down Expand Up @@ -110,7 +111,7 @@ public FilterScript.LeafFactory newFactory(Map<String, Object> params, SearchLoo

contexts.put(ScoreScript.CONTEXT, (Expression expr) -> new ScoreScript.Factory() {
@Override
public ScoreScript.LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup) {
public ScoreScript.LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup, IndexSearcher indexSearcher) {
return newScoreScript(expr, lookup, params);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,11 @@ static Response innerShardOperation(Request request, ScriptService scriptService
} else if (scriptContext == ScoreScript.CONTEXT) {
return prepareRamIndex(request, (context, leafReaderContext) -> {
ScoreScript.Factory factory = scriptService.compile(request.script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory leafFactory = factory.newFactory(request.getScript().getParams(), context.lookup());
ScoreScript.LeafFactory leafFactory = factory.newFactory(
request.getScript().getParams(),
context.lookup(),
context.searcher()
);
ScoreScript scoreScript = leafFactory.newInstance(leafReaderContext);
scoreScript.setDocument(0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class org.opensearch.script.ScoreScript @no_import {
}

static_import {
int termFreq(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$TermFreq
float tf(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$TF
long totalTermFreq(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$TotalTermFreq
long sumTotalTermFreq(org.opensearch.script.ScoreScript, String) bound_to org.opensearch.script.ScoreScriptUtils$SumTotalTermFreq
double saturation(double, double) from_class org.opensearch.script.ScoreScriptUtils
double sigmoid(double, double, double) from_class org.opensearch.script.ScoreScriptUtils
double randomScore(org.opensearch.script.ScoreScript, int, String) bound_to org.opensearch.script.ScoreScriptUtils$RandomScoreField
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
---
setup:
- skip:
version: " - 2.9.99"
reason: "termFreq functions for script_score was introduced in 2.10.0"
- do:
indices.create:
index: test
body:
settings:
number_of_shards: 1
mappings:
properties:
f1:
type: keyword
f2:
type: text
- do:
bulk:
refresh: true
body:
- '{"index": {"_index": "test", "_id": "doc1"}}'
- '{"f1": "v0", "f2": "v1"}'
- '{"index": {"_index": "test", "_id": "doc2"}}'
- '{"f2": "v2"}'

---
"Script score function using the termFreq function":
- do:
search:
index: test
rest_total_hits_as_int: true
body:
query:
function_score:
query:
match_all: {}
script_score:
script:
source: "termFreq(params.field, params.term)"
params:
field: "f1"
term: "v0"
- match: { hits.total: 2 }
- match: { hits.hits.0._id: "doc1" }
- match: { hits.hits.1._id: "doc2" }
- match: { hits.hits.0._score: 1.0 }
- match: { hits.hits.1._score: 0.0 }

---
"Script score function using the totalTermFreq function":
- do:
search:
index: test
rest_total_hits_as_int: true
body:
query:
function_score:
query:
match_all: {}
script_score:
script:
source: "if (doc[params.field].size() == 0) return params.default_value; else { return totalTermFreq(params.field, params.term); }"
params:
default_value: 0.5
field: "f1"
term: "v0"
- match: { hits.total: 2 }
- match: { hits.hits.0._id: "doc1" }
- match: { hits.hits.1._id: "doc2" }
- match: { hits.hits.0._score: 1.0 }
- match: { hits.hits.1._score: 0.5 }

---
"Script score function using the sumTotalTermFreq function":
- do:
search:
index: test
rest_total_hits_as_int: true
body:
query:
function_score:
query:
match_all: {}
script_score:
script:
source: "if (doc[params.field].size() == 0) return params.default_value; else { return sumTotalTermFreq(params.field); }"
params:
default_value: 0.5
field: "f1"
- match: { hits.total: 2 }
- match: { hits.hits.0._id: "doc1" }
- match: { hits.hits.1._id: "doc2" }
- match: { hits.hits.0._score: 1.0 }
- match: { hits.hits.1._score: 0.5 }
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.common.settings.Settings;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.ScriptPlugin;
Expand Down Expand Up @@ -120,20 +121,22 @@
@Override
public LeafFactory newFactory(
Map<String, Object> params,
SearchLookup lookup
SearchLookup lookup,
IndexSearcher indexSearcher
) {
return new PureDfLeafFactory(params, lookup);
return new PureDfLeafFactory(params, lookup, indexSearcher);

Check warning on line 127 in plugins/examples/script-expert-scoring/src/main/java/org/opensearch/example/expertscript/ExpertScriptPlugin.java

View check run for this annotation

Codecov / codecov/patch

plugins/examples/script-expert-scoring/src/main/java/org/opensearch/example/expertscript/ExpertScriptPlugin.java#L127

Added line #L127 was not covered by tests
noCharger marked this conversation as resolved.
Show resolved Hide resolved
}
}

private static class PureDfLeafFactory implements LeafFactory {
private final Map<String, Object> params;
private final SearchLookup lookup;
private final IndexSearcher indexSearcher;
private final String field;
private final String term;

private PureDfLeafFactory(
Map<String, Object> params, SearchLookup lookup) {
Map<String, Object> params, SearchLookup lookup, IndexSearcher indexSearcher) {

Check warning on line 139 in plugins/examples/script-expert-scoring/src/main/java/org/opensearch/example/expertscript/ExpertScriptPlugin.java

View check run for this annotation

Codecov / codecov/patch

plugins/examples/script-expert-scoring/src/main/java/org/opensearch/example/expertscript/ExpertScriptPlugin.java#L139

Added line #L139 was not covered by tests
if (params.containsKey("field") == false) {
throw new IllegalArgumentException(
"Missing parameter [field]");
Expand All @@ -144,6 +147,7 @@
}
this.params = params;
this.lookup = lookup;
this.indexSearcher = indexSearcher;

Check warning on line 150 in plugins/examples/script-expert-scoring/src/main/java/org/opensearch/example/expertscript/ExpertScriptPlugin.java

View check run for this annotation

Codecov / codecov/patch

plugins/examples/script-expert-scoring/src/main/java/org/opensearch/example/expertscript/ExpertScriptPlugin.java#L150

Added line #L150 was not covered by tests
field = params.get("field").toString();
term = params.get("term").toString();
}
Expand All @@ -163,7 +167,7 @@
* the field and/or term don't exist in this segment,
* so always return 0
*/
return new ScoreScript(params, lookup, context) {
return new ScoreScript(params, lookup, indexSearcher, context) {

Check warning on line 170 in plugins/examples/script-expert-scoring/src/main/java/org/opensearch/example/expertscript/ExpertScriptPlugin.java

View check run for this annotation

Codecov / codecov/patch

plugins/examples/script-expert-scoring/src/main/java/org/opensearch/example/expertscript/ExpertScriptPlugin.java#L170

Added line #L170 was not covered by tests
@Override
public double execute(
ExplanationHolder explanation
Expand All @@ -172,7 +176,7 @@
}
};
}
return new ScoreScript(params, lookup, context) {
return new ScoreScript(params, lookup, indexSearcher, context) {

Check warning on line 179 in plugins/examples/script-expert-scoring/src/main/java/org/opensearch/example/expertscript/ExpertScriptPlugin.java

View check run for this annotation

Codecov / codecov/patch

plugins/examples/script-expert-scoring/src/main/java/org/opensearch/example/expertscript/ExpertScriptPlugin.java#L179

Added line #L179 was not covered by tests
int currentDocid = -1;
@Override
public void setDocument(int docid) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.action.index.IndexRequestBuilder;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchType;
Expand Down Expand Up @@ -93,15 +94,15 @@ public String getType() {
public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context, Map<String, String> params) {
assert scriptSource.equals("explainable_script");
assert context == ScoreScript.CONTEXT;
ScoreScript.Factory factory = (params1, lookup) -> new ScoreScript.LeafFactory() {
ScoreScript.Factory factory = (params1, lookup, indexSearcher) -> new ScoreScript.LeafFactory() {
@Override
public boolean needs_score() {
return false;
}

@Override
public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
return new MyScript(params1, lookup, ctx);
return new MyScript(params1, lookup, indexSearcher, ctx);
}
};
return context.factoryClazz.cast(factory);
Expand All @@ -117,8 +118,8 @@ public Set<ScriptContext<?>> getSupportedContexts() {

static class MyScript extends ScoreScript implements ExplainableScoreScript {

MyScript(Map<String, Object> params, SearchLookup lookup, LeafReaderContext leafContext) {
super(params, lookup, leafContext);
MyScript(Map<String, Object> params, SearchLookup lookup, IndexSearcher indexSearcher, LeafReaderContext leafContext) {
super(params, lookup, indexSearcher, leafContext);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ protected int doHashCode() {
protected ScoreFunction doToFunction(QueryShardContext context) {
try {
ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup());
ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup(), context.searcher());
return new ScriptScoreFunction(
script,
searchScript,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ protected Query doToQuery(QueryShardContext context) throws IOException {
);
}
ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory scoreScriptFactory = factory.newFactory(script.getParams(), context.lookup());
ScoreScript.LeafFactory scoreScriptFactory = factory.newFactory(script.getParams(), context.lookup(), context.searcher());
final QueryBuilder queryBuilder = this.query;
Query query = queryBuilder.toQuery(context);
return new ScriptScoreQuery(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.index.query.functionscore;

import java.io.IOException;

/**
* An interface representing a term frequency function used to compute document scores
* based on specific term frequency calculations. Implementations of this interface should
* provide a way to execute the term frequency function for a given document ID.
*
* @opensearch.internal
*/
public interface TermFrequencyFunction {
Object execute(int docId) throws IOException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.index.query.functionscore;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.valuesource.SumTotalTermFreqValueSource;
import org.apache.lucene.queries.function.valuesource.TFValueSource;
import org.apache.lucene.queries.function.valuesource.TermFreqValueSource;
import org.apache.lucene.queries.function.valuesource.TotalTermFreqValueSource;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.common.lucene.BytesRefs;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

/**
* A factory class for creating instances of {@link TermFrequencyFunction}.
* This class provides methods for creating different term frequency functions based on
* the specified function name, field, and term. Each term frequency function is designed
* to compute document scores based on specific term frequency calculations.
*
* @opensearch.internal
*/
public class TermFrequencyFunctionFactory {

Check warning on line 32 in server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java#L32

Added line #L32 was not covered by tests
public static TermFrequencyFunction createFunction(
TermFrequencyFunctionName functionName,
String field,
String term,
LeafReaderContext readerContext,
IndexSearcher indexSearcher
) throws IOException {
switch (functionName) {
case TERM_FREQ:
TermFreqValueSource termFreqValueSource = new TermFreqValueSource(field, term, field, BytesRefs.toBytesRef(term));
FunctionValues functionValues = termFreqValueSource.getValues(null, readerContext);
return docId -> functionValues.intVal(docId);

Check warning on line 44 in server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java#L42-L44

Added lines #L42 - L44 were not covered by tests
case TF:
TFValueSource tfValueSource = new TFValueSource(field, term, field, BytesRefs.toBytesRef(term));
Map<Object, Object> tfContext = new HashMap<>() {

Check warning on line 47 in server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java#L46-L47

Added lines #L46 - L47 were not covered by tests
noCharger marked this conversation as resolved.
Show resolved Hide resolved
{
put("searcher", indexSearcher);
}

Check warning on line 50 in server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java#L49-L50

Added lines #L49 - L50 were not covered by tests
};
functionValues = tfValueSource.getValues(tfContext, readerContext);
return docId -> functionValues.floatVal(docId);

Check warning on line 53 in server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java#L52-L53

Added lines #L52 - L53 were not covered by tests
case TOTAL_TERM_FREQ:
TotalTermFreqValueSource totalTermFreqValueSource = new TotalTermFreqValueSource(

Check warning on line 55 in server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java#L55

Added line #L55 was not covered by tests
field,
term,
field,
BytesRefs.toBytesRef(term)

Check warning on line 59 in server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java#L59

Added line #L59 was not covered by tests
);
Map<Object, Object> ttfContext = new HashMap<>();
totalTermFreqValueSource.createWeight(ttfContext, indexSearcher);
functionValues = totalTermFreqValueSource.getValues(ttfContext, readerContext);
return docId -> functionValues.longVal(docId);

Check warning on line 64 in server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java#L61-L64

Added lines #L61 - L64 were not covered by tests
case SUM_TOTAL_TERM_FREQ:
SumTotalTermFreqValueSource sumTotalTermFreqValueSource = new SumTotalTermFreqValueSource(field);
Map<Object, Object> sttfContext = new HashMap<>();
sumTotalTermFreqValueSource.createWeight(sttfContext, indexSearcher);
functionValues = sumTotalTermFreqValueSource.getValues(sttfContext, readerContext);
return docId -> functionValues.longVal(docId);

Check warning on line 70 in server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java#L66-L70

Added lines #L66 - L70 were not covered by tests
default:
throw new IllegalArgumentException("Unsupported function: " + functionName);

Check warning on line 72 in server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java#L72

Added line #L72 was not covered by tests
}
}

/**
* An enumeration representing the names of supported term frequency functions.
*/
public enum TermFrequencyFunctionName {
TERM_FREQ("termFreq"),
TF("tf"),
TOTAL_TERM_FREQ("totalTermFreq"),
SUM_TOTAL_TERM_FREQ("sumTotalTermFreq");

Check warning on line 83 in server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java#L79-L83

Added lines #L79 - L83 were not covered by tests

private final String termFrequencyFunctionName;

TermFrequencyFunctionName(String termFrequencyFunctionName) {
this.termFrequencyFunctionName = termFrequencyFunctionName;
}

Check warning on line 89 in server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java#L87-L89

Added lines #L87 - L89 were not covered by tests

public String getTermFrequencyFunctionName() {
return termFrequencyFunctionName;

Check warning on line 92 in server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/index/query/functionscore/TermFrequencyFunctionFactory.java#L92

Added line #L92 was not covered by tests
}
}
}