Skip to content

Commit

Permalink
Address comments from Froh and Ankit
Browse files Browse the repository at this point in the history
Signed-off-by: Louis Chu <clingzhi@amazon.com>
  • Loading branch information
noCharger committed Aug 4, 2023
1 parent 421dabc commit 854ca55
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 159 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.action.index.IndexRequestBuilder;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchType;
Expand Down Expand Up @@ -93,15 +94,15 @@ public String getType() {
public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context, Map<String, String> params) {
assert scriptSource.equals("explainable_script");
assert context == ScoreScript.CONTEXT;
ScoreScript.Factory factory = (params1, lookup) -> new ScoreScript.LeafFactory() {
ScoreScript.Factory factory = (params1, lookup, indexSearcher) -> new ScoreScript.LeafFactory() {
@Override
public boolean needs_score() {
return false;
}

@Override
public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
return new MyScript(params1, lookup, ctx);
return new MyScript(params1, lookup, indexSearcher, ctx);
}
};
return context.factoryClazz.cast(factory);
Expand All @@ -117,8 +118,8 @@ public Set<ScriptContext<?>> getSupportedContexts() {

static class MyScript extends ScoreScript implements ExplainableScoreScript {

MyScript(Map<String, Object> params, SearchLookup lookup, LeafReaderContext leafContext) {
super(params, lookup, leafContext);
MyScript(Map<String, Object> params, SearchLookup lookup, IndexSearcher indexSearcher, LeafReaderContext leafContext) {
super(params, lookup, indexSearcher, leafContext);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,145 +8,15 @@

package org.opensearch.index.query.functionscore;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.queries.function.valuesource.SumTotalTermFreqValueSource;
import org.apache.lucene.queries.function.valuesource.TFValueSource;
import org.apache.lucene.queries.function.valuesource.TermFreqValueSource;
import org.apache.lucene.queries.function.valuesource.TotalTermFreqValueSource;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.common.lucene.BytesRefs;

import java.io.IOException;
import java.util.Map;

/**
* Abstract class representing a term frequency function.
* An interface representing a term frequency function used to compute document scores
* based on specific term frequency calculations. Implementations of this interface should
* provide a way to execute the term frequency function for a given document ID.
*
* @opensearch.internal
*/
public abstract class TermFrequencyFunction {

protected final String field;
protected final String term;
protected final int docId;
protected Map<Object, Object> context;

public TermFrequencyFunction(String field, String term, int docId, Map<Object, Object> context) {
this.field = field;
this.term = term;
this.docId = docId;
this.context = context;
}

public abstract Object execute(LeafReaderContext readerContext) throws IOException;

/**
* Factory class to create term frequency functions.
*/
public static class TermFrequencyFunctionFactory {
public static TermFrequencyFunction createFunction(
TermFrequencyFunctionNamesEnum functionName,
String field,
String term,
int docId,
Map<Object, Object> context
) {
switch (functionName) {
case TERM_FREQ:
return new TermFreqFunction(field, term, docId, context);
case TF:
return new TFFunction(field, term, docId, context);
case TOTAL_TERM_FREQ:
return new TotalTermFreq(field, term, docId, context);
case SUM_TOTAL_TERM_FREQ:
return new SumTotalTermFreq(field, term, docId, context);
default:
throw new IllegalArgumentException("Unsupported function: " + functionName);
}
}
}

/**
* TermFreqFunction computes the term frequency in a field.
*/
public static class TermFreqFunction extends TermFrequencyFunction {

public TermFreqFunction(String field, String term, int docId, Map<Object, Object> context) {
super(field, term, docId, context);
}

@Override
public Integer execute(LeafReaderContext readerContext) throws IOException {
TermFreqValueSource valueSource = new TermFreqValueSource(field, term, field, BytesRefs.toBytesRef(term));
return valueSource.getValues(null, readerContext).intVal(docId);
}
}

/**
* TFFunction computes the term frequency-inverse document frequency (tf-idf) in a field.
*/
public static class TFFunction extends TermFrequencyFunction {

public TFFunction(String field, String term, int docId, Map<Object, Object> context) {
super(field, term, docId, context);
}

@Override
public Float execute(LeafReaderContext readerContext) throws IOException {
TFValueSource valueSource = new TFValueSource(field, term, field, BytesRefs.toBytesRef(term));
return valueSource.getValues(context, readerContext).floatVal(docId);
}
}

/**
* TotalTermFreq computes the total term frequency in a field.
*/
public static class TotalTermFreq extends TermFrequencyFunction {

public TotalTermFreq(String field, String term, int docId, Map<Object, Object> context) {
super(field, term, docId, context);
}

@Override
public Long execute(LeafReaderContext readerContext) throws IOException {
TotalTermFreqValueSource valueSource = new TotalTermFreqValueSource(field, term, field, BytesRefs.toBytesRef(term));
valueSource.createWeight(context, (IndexSearcher) context.get("searcher"));
return valueSource.getValues(context, readerContext).longVal(docId);
}
}

/**
* SumTotalTermFreq computes the sum of total term frequencies within a field.
*/
public static class SumTotalTermFreq extends TermFrequencyFunction {

public SumTotalTermFreq(String field, String term, int docId, Map<Object, Object> context) {
super(field, term, docId, context);
}

@Override
public Long execute(LeafReaderContext readerContext) throws IOException {
SumTotalTermFreqValueSource valueSource = new SumTotalTermFreqValueSource(field);
valueSource.createWeight(context, (IndexSearcher) context.get("searcher"));
return valueSource.getValues(context, readerContext).longVal(docId);
}
}

/**
* Enum representing the names of term frequency functions.
*/
public enum TermFrequencyFunctionNamesEnum {
TERM_FREQ("termFreq"),
TF("tf"),
TOTAL_TERM_FREQ("totalTermFreq"),
SUM_TOTAL_TERM_FREQ("sumTotalTermFreq");

private final String termFrequencyFunctionName;

private TermFrequencyFunctionNamesEnum(String termFrequencyFunctionName) {
this.termFrequencyFunctionName = termFrequencyFunctionName;
}

public String getTermFrequencyFunctionName() {
return termFrequencyFunctionName;
}
}
public interface TermFrequencyFunction {
Object execute(int docId) throws IOException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.index.query.functionscore;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.queries.function.valuesource.SumTotalTermFreqValueSource;
import org.apache.lucene.queries.function.valuesource.TFValueSource;
import org.apache.lucene.queries.function.valuesource.TermFreqValueSource;
import org.apache.lucene.queries.function.valuesource.TotalTermFreqValueSource;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.common.lucene.BytesRefs;

import java.io.IOException;
import java.util.Map;

/**
* A factory class for creating instances of {@link TermFrequencyFunction}.
* This class provides methods for creating different term frequency functions based on
* the specified function name, field, and term. Each term frequency function is designed
* to compute document scores based on specific term frequency calculations.
*
* @opensearch.internal
*/
public class TermFrequencyFunctionFactory {

public static TermFrequencyFunction createFunction(
TermFrequencyFunctionName functionName,
Map<Object, Object> context,
String field,
String term,
LeafReaderContext readerContext
) throws IOException {
switch (functionName) {
case TERM_FREQ:
TermFreqValueSource termFreqValueSource = new TermFreqValueSource(field, term, field, BytesRefs.toBytesRef(term));
return docId -> termFreqValueSource.getValues(null, readerContext).intVal(docId);
case TF:
TFValueSource tfValueSource = new TFValueSource(field, term, field, BytesRefs.toBytesRef(term));
return docId -> tfValueSource.getValues(context, readerContext).floatVal(docId);
case TOTAL_TERM_FREQ:
TotalTermFreqValueSource totalTermFreqValueSource = new TotalTermFreqValueSource(
field,
term,
field,
BytesRefs.toBytesRef(term)
);
totalTermFreqValueSource.createWeight(context, (IndexSearcher) context.get("searcher"));
return docId -> totalTermFreqValueSource.getValues(context, readerContext).longVal(docId);
case SUM_TOTAL_TERM_FREQ:
SumTotalTermFreqValueSource sumTotalTermFreqValueSource = new SumTotalTermFreqValueSource(field);
sumTotalTermFreqValueSource.createWeight(context, (IndexSearcher) context.get("searcher"));
return docId -> sumTotalTermFreqValueSource.getValues(context, readerContext).longVal(docId);
default:
throw new IllegalArgumentException("Unsupported function: " + functionName);
}
}

/**
* An enumeration representing the names of supported term frequency functions.
*/
public enum TermFrequencyFunctionName {
TERM_FREQ("termFreq"),
TF("tf"),
TOTAL_TERM_FREQ("totalTermFreq"),
SUM_TOTAL_TERM_FREQ("sumTotalTermFreq");

private final String termFrequencyFunctionName;

TermFrequencyFunctionName(String termFrequencyFunctionName) {
this.termFrequencyFunctionName = termFrequencyFunctionName;
}

public String getTermFrequencyFunctionName() {
return termFrequencyFunctionName;
}
}
}
31 changes: 19 additions & 12 deletions server/src/main/java/org/opensearch/script/ScoreScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
import org.opensearch.Version;
import org.opensearch.common.logging.DeprecationLogger;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum;
import org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionFactory;
import org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName;

import org.opensearch.search.lookup.LeafSearchLookup;
import org.opensearch.search.lookup.SearchLookup;
Expand All @@ -48,6 +47,7 @@
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.function.DoubleSupplier;
import java.util.function.Function;
Expand Down Expand Up @@ -121,6 +121,8 @@ public Explanation get(double score, Explanation subQueryExplanation) {

private final IndexSearcher indexSearcher;

private final Map<String, Object> termFreqCache = new HashMap<>();

public ScoreScript(Map<String, Object> params, SearchLookup lookup, IndexSearcher indexSearcher, LeafReaderContext leafContext) {
// null check needed b/c of expression engine subclass
if (lookup == null) {
Expand Down Expand Up @@ -152,16 +154,21 @@ public Map<String, ScriptDocValues<?>> getDoc() {
return leafLookup.doc();
}

public Object getTermFrequency(TermFrequencyFunctionNamesEnum functionName, String field, String val) throws IOException {
// Fetch data from local cache
Map<Object, Object> context = new HashMap<>() {
{
put("searcher", indexSearcher);
}
};
return leafLookup.executeTermFrequencyFunction(
TermFrequencyFunctionFactory.createFunction(functionName, field, val, docId, context)
);
public Object getTermFrequency(TermFrequencyFunctionName functionName, String field, String val) throws IOException {
String cacheKey = (val == null) ? String.format(Locale.ROOT, "%s-%s", functionName, field) : String.format(Locale.ROOT, "%s-%s-%s", functionName, field, val);

if (!termFreqCache.containsKey(cacheKey)) {
Map<Object, Object> context = new HashMap<>() {
{
put("searcher", indexSearcher);
}
};

Object termFrequency = leafLookup.getTermFrequency(functionName, context, field, val, docId);
termFreqCache.put(cacheKey, termFrequency);
}

return termFreqCache.get(cacheKey);
}

/** Set the current document to run the script on next. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@
import java.time.ZoneId;

import static org.opensearch.common.util.BitMixer.mix32;
import static org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum.SUM_TOTAL_TERM_FREQ;
import static org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum.TERM_FREQ;
import static org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum.TF;
import static org.opensearch.index.query.functionscore.TermFrequencyFunction.TermFrequencyFunctionNamesEnum.TOTAL_TERM_FREQ;
import static org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName.SUM_TOTAL_TERM_FREQ;
import static org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName.TERM_FREQ;
import static org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName.TF;
import static org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName.TOTAL_TERM_FREQ;

/**
* Utilities for scoring scripts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import org.apache.lucene.index.LeafReaderContext;
import org.opensearch.index.query.functionscore.TermFrequencyFunction;
import org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory;

import java.io.IOException;
import java.util.HashMap;
Expand Down Expand Up @@ -90,7 +91,15 @@ public void setDocument(int docId) {
fieldsLookup.setDocument(docId);
}

public Object executeTermFrequencyFunction(TermFrequencyFunction function) throws IOException {
return function.execute(ctx);
public Object getTermFrequency(
TermFrequencyFunctionFactory.TermFrequencyFunctionName functionName,
Map<Object, Object> context,
String field,
String val,
int docId
) throws IOException {
TermFrequencyFunction termFreqFunction = TermFrequencyFunctionFactory.createFunction(functionName, context, field, val, ctx);
// execute the function
return termFreqFunction.execute(docId);
}
}

0 comments on commit 854ca55

Please sign in to comment.