Skip to content

Commit

Permalink
Unwrap TwoPhaseIterator in MinScoreScorer (#80116) (#82508)
Browse files Browse the repository at this point in the history
A ConjunctionScorer can add the approximation of the TwoPhaseIterator of
a MinScoreScorer to its TwoPhaseIterator list after the main
TwoPhaseIterator. This can lead to an undesired state, as the matches()
method is called after the score() method. For example, if the matches()
method of ToParentBlockJoinQuery is called after the score() method,
then we return a wrong result or over-read DocValues. Here, we wrap the
approximation to prevent it from unwrapping as a TwoPhaseIterator.

Closes #79658
  • Loading branch information
dnhatn committed Jan 13, 2022
1 parent 9f44f7e commit 58138ac
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.search.nested;

import org.apache.lucene.search.join.ScoreMode;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.ConstantScoreQueryBuilder;
import org.elasticsearch.index.query.MatchPhraseQueryBuilder;
import org.elasticsearch.index.query.NestedQueryBuilder;
import org.elasticsearch.index.query.RangeQueryBuilder;
import org.elasticsearch.index.query.functionscore.ScriptScoreQueryBuilder;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.script.MockScriptPlugin;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

public class NestedWithMinScoreIT extends ESIntegTestCase {

public static class ScriptTestPlugin extends MockScriptPlugin {
@Override
protected Map<String, Function<Map<String, Object>, Object>> pluginScripts() {
return Collections.singletonMap("score_script", params -> {
final Object scoreAccessor = params.get("_score");
if (scoreAccessor instanceof Number) {
return ((Number) scoreAccessor).doubleValue();
} else {
return null;
}
});
}
}

@Override
protected Collection<Class<? extends Plugin>> getMockPlugins() {
final List<Class<? extends Plugin>> plugins = new ArrayList<>(super.getMockPlugins());
plugins.add(ScriptTestPlugin.class);
return plugins;
}

public void testNestedWithMinScore() throws Exception {
XContentBuilder mapping = XContentFactory.jsonBuilder();
mapping.startObject();
mapping.startObject("properties");
{
mapping.startObject("toolTracks");
{
mapping.field("type", "nested");
mapping.startObject("properties");
{
mapping.startObject("data");
mapping.field("type", "text");
mapping.endObject();

mapping.startObject("confidence");
mapping.field("type", "double");
mapping.endObject();
}
mapping.endObject();
}
mapping.endObject();
}
mapping.endObject();
mapping.endObject();

client().admin().indices().prepareCreate("test").addMapping("_doc", mapping).get();

XContentBuilder doc = XContentFactory.jsonBuilder();
doc.startObject();
doc.startArray("toolTracks");
double[] confidence = new double[] { 0.3, 0.92, 0.7, 0.85, 0.2, 0.3, 0.75, 0.82, 0.1, 0.6, 0.3, 0.7 };
for (double v : confidence) {
doc.startObject();
doc.field("confidence", v);
doc.field("data", "cash dispenser, automated teller machine, automatic teller machine");
doc.endObject();
}
doc.endArray();
doc.endObject();

client().prepareIndex("test", "_doc").setId("d1").setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).setSource(doc).get();
final BoolQueryBuilder childQuery = new BoolQueryBuilder().filter(
new MatchPhraseQueryBuilder("toolTracks.data", "cash dispenser, automated teller machine, automatic teller machine")
).filter(new RangeQueryBuilder("toolTracks.confidence").from(0.8));

final ScriptScoreQueryBuilder scriptScoreQuery = new ScriptScoreQueryBuilder(
new NestedQueryBuilder("toolTracks", new ConstantScoreQueryBuilder(childQuery), ScoreMode.Total),
new Script(ScriptType.INLINE, MockScriptPlugin.NAME, "score_script", Collections.emptyMap())
);
scriptScoreQuery.setMinScore(1.0f);
SearchSourceBuilder source = new SearchSourceBuilder();
source.query(scriptScoreQuery);
source.profile(randomBoolean());
if (randomBoolean()) {
source.trackTotalHitsUpTo(randomBoolean() ? Integer.MAX_VALUE : randomIntBetween(1, 1000));
}
SearchRequest searchRequest = new SearchRequest("test").source(source);
final SearchResponse searchResponse = client().search(searchRequest).actionGet();
ElasticsearchAssertions.assertSearchHits(searchResponse, "d1");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,25 @@ public DocIdSetIterator iterator() {

@Override
public TwoPhaseIterator twoPhaseIterator() {
final TwoPhaseIterator inTwoPhase = this.in.twoPhaseIterator();
final DocIdSetIterator approximation = inTwoPhase == null ? in.iterator() : inTwoPhase.approximation();
TwoPhaseIterator inTwoPhase = in.twoPhaseIterator();
DocIdSetIterator approximation;
if (inTwoPhase == null) {
approximation = in.iterator();
if (TwoPhaseIterator.unwrap(approximation) != null) {
inTwoPhase = TwoPhaseIterator.unwrap(approximation);
approximation = inTwoPhase.approximation();
}
} else {
approximation = inTwoPhase.approximation();
}
final TwoPhaseIterator finalTwoPhase = inTwoPhase;
return new TwoPhaseIterator(approximation) {

@Override
public boolean matches() throws IOException {
// we need to check the two-phase iterator first
// otherwise calling score() is illegal
if (inTwoPhase != null && inTwoPhase.matches() == false) {
if (finalTwoPhase != null && finalTwoPhase.matches() == false) {
return false;
}
curScore = in.score();
Expand All @@ -79,7 +89,7 @@ public boolean matches() throws IOException {
@Override
public float matchCost() {
return 1000f // random constant for the score computation
+ (inTwoPhase == null ? 0 : inTwoPhase.matchCost());
+ (finalTwoPhase == null ? 0 : finalTwoPhase.matchCost());
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;
import org.elasticsearch.Version;
Expand Down Expand Up @@ -265,6 +266,11 @@ public DocIdSetIterator iterator() {
return subQueryScorer.iterator();
}

@Override
public TwoPhaseIterator twoPhaseIterator() {
return subQueryScorer.twoPhaseIterator();
}

@Override
public float getMaxScore(int upTo) {
return Float.MAX_VALUE; // TODO: what would be a good upper bound?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,36 @@

package org.elasticsearch.common.lucene.search.function;

import com.carrotsearch.randomizedtesting.generators.RandomPicks;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.AssertingScorer;
import org.apache.lucene.search.ConjunctionDISI;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.TestUtil;
import org.elasticsearch.test.ESTestCase;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.hamcrest.Matchers.equalTo;

public class MinScoreScorerTests extends LuceneTestCase {
public class MinScoreScorerTests extends ESTestCase {

private static DocIdSetIterator iterator(final int... docs) {
return new DocIdSetIterator() {
Expand All @@ -33,11 +46,8 @@ private static DocIdSetIterator iterator(final int... docs) {

@Override
public int nextDoc() throws IOException {
if (i + 1 == docs.length) {
return NO_MORE_DOCS;
} else {
return docs[++i];
}
++i;
return docID();
}

@Override
Expand Down Expand Up @@ -81,11 +91,41 @@ public boolean isCacheable(LeafReaderContext ctx) {
};
}

private static Scorer hideTwoPhaseIterator(Scorer in) {
return new Scorer(in.getWeight()) {
@Override
public DocIdSetIterator iterator() {
return TwoPhaseIterator.asDocIdSetIterator(in.twoPhaseIterator());
}

@Override
public TwoPhaseIterator twoPhaseIterator() {
return null;
}

@Override
public float getMaxScore(int upTo) throws IOException {
return in.getMaxScore(upTo);
}

@Override
public float score() throws IOException {
return in.score();
}

@Override
public int docID() {
return in.docID();
}
};
}

private static Scorer scorer(int maxDoc, final int[] docs, final float[] scores, final boolean twoPhase) {
final DocIdSetIterator iterator = twoPhase ? DocIdSetIterator.all(maxDoc) : iterator(docs);
return new Scorer(fakeWeight()) {
final Scorer scorer = new Scorer(fakeWeight()) {

int lastScoredDoc = -1;
final float matchCost = (random().nextBoolean() ? 1000 : 0) + random().nextInt(2000);

public DocIdSetIterator iterator() {
if (twoPhase) {
Expand All @@ -106,7 +146,7 @@ public boolean matches() throws IOException {

@Override
public float matchCost() {
return 10;
return matchCost;
}
};
} else {
Expand All @@ -132,23 +172,33 @@ public float getMaxScore(int upTo) throws IOException {
return Float.MAX_VALUE;
}
};
final ScoreMode scoreMode = RandomPicks.randomFrom(
random(),
new ScoreMode[] { ScoreMode.COMPLETE, ScoreMode.TOP_SCORES, ScoreMode.TOP_DOCS_WITH_SCORES }
);
final Scorer assertingScorer = AssertingScorer.wrap(random(), scorer, scoreMode);
if (twoPhase && randomBoolean()) {
return hideTwoPhaseIterator(assertingScorer);
} else {
return assertingScorer;
}
}

private static int[] randomDocs(int maxDoc, int numDocs) {
final List<Integer> docs = randomSubsetOf(numDocs, IntStream.range(0, maxDoc).boxed().collect(Collectors.toList()));
return docs.stream().mapToInt(n -> n).sorted().toArray();
}

public void doTestRandom(boolean twoPhase) throws IOException {
final int maxDoc = TestUtil.nextInt(random(), 10, 10000);
final int numDocs = TestUtil.nextInt(random(), 1, maxDoc / 2);
final int numDocs = TestUtil.nextInt(random(), 1, maxDoc);
final Set<Integer> uniqueDocs = new HashSet<>();
while (uniqueDocs.size() < numDocs) {
uniqueDocs.add(random().nextInt(maxDoc));
}
final int[] docs = new int[numDocs];
int i = 0;
for (int doc : uniqueDocs) {
docs[i++] = doc;
}
Arrays.sort(docs);
final int[] docs = randomDocs(maxDoc, numDocs);
final float[] scores = new float[numDocs];
for (i = 0; i < numDocs; ++i) {
for (int i = 0; i < numDocs; ++i) {
scores[i] = random().nextFloat();
}
Scorer scorer = scorer(maxDoc, docs, scores, twoPhase);
Expand Down Expand Up @@ -193,4 +243,48 @@ public void testTwoPhaseIterator() throws IOException {
doTestRandom(true);
}
}

public void testConjunction() throws Exception {
final int maxDoc = randomIntBetween(10, 10000);
final Map<Integer, Integer> matchedDocs = new HashMap<>();
final List<Scorer> scorers = new ArrayList<>();
final int numScorers = randomIntBetween(2, 10);
for (int s = 0; s < numScorers; s++) {
final int numDocs = randomIntBetween(2, maxDoc);
final int[] docs = randomDocs(maxDoc, numDocs);
final float[] scores = new float[numDocs];
for (int i = 0; i < numDocs; ++i) {
scores[i] = randomFloat();
}
final boolean useTwoPhase = randomBoolean();
final Scorer scorer = scorer(maxDoc, docs, scores, useTwoPhase);
final float minScore;
if (randomBoolean()) {
minScore = randomFloat();
MinScoreScorer minScoreScorer = new MinScoreScorer(scorer.getWeight(), scorer, minScore);
scorers.add(minScoreScorer);
} else {
scorers.add(scorer);
minScore = 0.0f;
}
for (int i = 0; i < numDocs; i++) {
if (scores[i] >= minScore) {
matchedDocs.compute(docs[i], (k, v) -> v == null ? 1 : v + 1);
}
}
}

final DocIdSetIterator disi = ConjunctionDISI.intersectScorers(scorers);
final List<Integer> actualDocs = new ArrayList<>();
while (disi.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
actualDocs.add(disi.docID());
}
final List<Integer> expectedDocs = matchedDocs.entrySet()
.stream()
.filter(v -> v.getValue() == numScorers)
.map(Map.Entry::getKey)
.sorted()
.collect(Collectors.toList());
assertThat(actualDocs, equalTo(expectedDocs));
}
}

2 comments on commit 58138ac

@marcreichman
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dnhatn Is this included in the released 7.17.0? I can't see it in the release notes.

@dnhatn
Copy link
Member Author

@dnhatn dnhatn commented on 58138ac Feb 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marcreichman Yes, it is. It seems the release notes is outdated. I will fix it. Thanks again for reporting the issue.

Please sign in to comment.