Skip to content

Commit

Permalink
Address comments from Andriy
Browse files Browse the repository at this point in the history
Signed-off-by: Louis Chu <clingzhi@amazon.com>
  • Loading branch information
noCharger committed Aug 11, 2023
1 parent 1c7f13e commit 68f2680
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.opensearch.index.query.functionscore;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.valuesource.SumTotalTermFreqValueSource;
import org.apache.lucene.queries.function.valuesource.TFValueSource;
import org.apache.lucene.queries.function.valuesource.TermFreqValueSource;
Expand All @@ -17,6 +18,7 @@
import org.opensearch.common.lucene.BytesRefs;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

/**
Expand All @@ -28,34 +30,44 @@
* @opensearch.internal
*/
public class TermFrequencyFunctionFactory {

public static TermFrequencyFunction createFunction(
TermFrequencyFunctionName functionName,
Map<Object, Object> context,
String field,
String term,
LeafReaderContext readerContext
LeafReaderContext readerContext,
IndexSearcher indexSearcher
) throws IOException {
switch (functionName) {
case TERM_FREQ:
TermFreqValueSource termFreqValueSource = new TermFreqValueSource(field, term, field, BytesRefs.toBytesRef(term));
return docId -> termFreqValueSource.getValues(null, readerContext).intVal(docId);
FunctionValues functionValues = termFreqValueSource.getValues(null, readerContext);
return docId -> functionValues.intVal(docId);
case TF:
TFValueSource tfValueSource = new TFValueSource(field, term, field, BytesRefs.toBytesRef(term));
return docId -> tfValueSource.getValues(context, readerContext).floatVal(docId);
Map<Object, Object> tfContext = new HashMap<>() {
{
put("searcher", indexSearcher);
}
};
functionValues = tfValueSource.getValues(tfContext, readerContext);
return docId -> functionValues.floatVal(docId);
case TOTAL_TERM_FREQ:
TotalTermFreqValueSource totalTermFreqValueSource = new TotalTermFreqValueSource(
field,
term,
field,
BytesRefs.toBytesRef(term)
);
totalTermFreqValueSource.createWeight(context, (IndexSearcher) context.get("searcher"));
return docId -> totalTermFreqValueSource.getValues(context, readerContext).longVal(docId);
Map<Object, Object> ttfContext = new HashMap<>();
totalTermFreqValueSource.createWeight(ttfContext, indexSearcher);
functionValues = totalTermFreqValueSource.getValues(ttfContext, readerContext);
return docId -> functionValues.longVal(docId);
case SUM_TOTAL_TERM_FREQ:
SumTotalTermFreqValueSource sumTotalTermFreqValueSource = new SumTotalTermFreqValueSource(field);
sumTotalTermFreqValueSource.createWeight(context, (IndexSearcher) context.get("searcher"));
return docId -> sumTotalTermFreqValueSource.getValues(context, readerContext).longVal(docId);
Map<Object, Object> sttfContext = new HashMap<>();
sumTotalTermFreqValueSource.createWeight(sttfContext, indexSearcher);
functionValues = sumTotalTermFreqValueSource.getValues(sttfContext, readerContext);
return docId -> functionValues.longVal(docId);
default:
throw new IllegalArgumentException("Unsupported function: " + functionName);
}
Expand Down
30 changes: 7 additions & 23 deletions server/src/main/java/org/opensearch/script/ScoreScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@
import org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName;

import org.opensearch.search.lookup.LeafSearchLookup;
import org.opensearch.search.lookup.LeafTermFrequencyLookup;
import org.opensearch.search.lookup.SearchLookup;
import org.opensearch.search.lookup.SourceLookup;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.function.DoubleSupplier;
import java.util.function.Function;
Expand Down Expand Up @@ -111,6 +111,9 @@ public Explanation get(double score, Explanation subQueryExplanation) {
/** A leaf lookup for the bound segment this script will operate on. */
private final LeafSearchLookup leafLookup;

/** A leaf term frequency lookup for the bound segment this script will operate on. */
private final LeafTermFrequencyLookup leafTermFrequencyLookup;

private DoubleSupplier scoreSupplier = () -> 0.0;

private final int docBase;
Expand All @@ -119,26 +122,22 @@ public Explanation get(double score, Explanation subQueryExplanation) {
private String indexName = null;
private Version indexVersion = null;

private final IndexSearcher indexSearcher;

private final Map<String, Object> termFreqCache = new HashMap<>();

public ScoreScript(Map<String, Object> params, SearchLookup lookup, IndexSearcher indexSearcher, LeafReaderContext leafContext) {
// null check needed b/c of expression engine subclass
if (lookup == null) {
assert params == null;
assert leafContext == null;
this.params = null;
this.leafLookup = null;
this.leafTermFrequencyLookup = null;
this.docBase = 0;
this.indexSearcher = null;
} else {
this.leafLookup = lookup.getLeafSearchLookup(leafContext);
this.leafTermFrequencyLookup = new LeafTermFrequencyLookup(indexSearcher, leafLookup);
params = new HashMap<>(params);
params.putAll(leafLookup.asMap());
this.params = new DynamicMap(params, PARAMS_FUNCTIONS);
this.docBase = leafContext.docBase;
this.indexSearcher = indexSearcher;
}
}

Expand All @@ -155,22 +154,7 @@ public Map<String, ScriptDocValues<?>> getDoc() {
}

public Object getTermFrequency(TermFrequencyFunctionName functionName, String field, String val) throws IOException {
String cacheKey = (val == null)
? String.format(Locale.ROOT, "%s-%s", functionName, field)
: String.format(Locale.ROOT, "%s-%s-%s", functionName, field, val);

if (!termFreqCache.containsKey(cacheKey)) {
Map<Object, Object> context = new HashMap<>() {
{
put("searcher", indexSearcher);
}
};

Object termFrequency = leafLookup.getTermFrequency(functionName, context, field, val, docId);
termFreqCache.put(cacheKey, termFrequency);
}

return termFreqCache.get(cacheKey);
return leafTermFrequencyLookup.getTermFrequency(functionName, field, val, docId);
}

/** Set the current document to run the script on next. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@
package org.opensearch.search.lookup;

import org.apache.lucene.index.LeafReaderContext;
import org.opensearch.index.query.functionscore.TermFrequencyFunction;
import org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

Expand Down Expand Up @@ -90,16 +86,4 @@ public void setDocument(int docId) {
sourceLookup.setSegmentAndDocument(ctx, docId);
fieldsLookup.setDocument(docId);
}

public Object getTermFrequency(
TermFrequencyFunctionFactory.TermFrequencyFunctionName functionName,
Map<Object, Object> context,
String field,
String val,
int docId
) throws IOException {
TermFrequencyFunction termFreqFunction = TermFrequencyFunctionFactory.createFunction(functionName, context, field, val, ctx);
// execute the function
return termFreqFunction.execute(docId);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.lookup;

import org.apache.lucene.search.IndexSearcher;
import org.opensearch.index.query.functionscore.TermFrequencyFunction;
import org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory;
import org.opensearch.index.query.functionscore.TermFrequencyFunctionFactory.TermFrequencyFunctionName;

import java.io.IOException;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;

/**
* Looks up term frequency per-segment
*
* @opensearch.internal
*/
public class LeafTermFrequencyLookup {

private final IndexSearcher indexSearcher;
private final LeafSearchLookup leafLookup;
private final Map<String, TermFrequencyFunction> termFreqCache;

public LeafTermFrequencyLookup(IndexSearcher indexSearcher, LeafSearchLookup leafLookup) {
this.indexSearcher = indexSearcher;
this.leafLookup = leafLookup;
this.termFreqCache = new HashMap<>();
}

public Object getTermFrequency(TermFrequencyFunctionName functionName, String field, String val, int docId) throws IOException {
String cacheKey = (val == null)
? String.format(Locale.ROOT, "%s-%s", functionName, field)
: String.format(Locale.ROOT, "%s-%s-%s", functionName, field, val);

if (!termFreqCache.containsKey(cacheKey)) {
TermFrequencyFunction termFrequencyFunction = TermFrequencyFunctionFactory.createFunction(
functionName,
field,
val,
leafLookup.ctx,
indexSearcher
);
termFreqCache.put(cacheKey, termFrequencyFunction);
}

return termFreqCache.get(cacheKey).execute(docId);
}
}

0 comments on commit 68f2680

Please sign in to comment.