Skip to content

Commit 0682b46

Browse files
committed
HSEARCH-5010 Use scorer suppliers in custom weights
1 parent ebde482 commit 0682b46

File tree

5 files changed

+100
-14
lines changed

5 files changed

+100
-14
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Hibernate Search, full-text search for your domain model
3+
*
4+
* License: GNU Lesser General Public License (LGPL), version 2.1 or later
5+
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
6+
*/
7+
package org.hibernate.search.backend.lucene.lowlevel.query.impl;
8+
9+
import java.io.IOException;
10+
11+
import org.apache.lucene.search.ConstantScoreScorer;
12+
import org.apache.lucene.search.DocIdSetIterator;
13+
import org.apache.lucene.search.ScoreMode;
14+
import org.apache.lucene.search.Scorer;
15+
import org.apache.lucene.search.ScorerSupplier;
16+
import org.apache.lucene.search.Weight;
17+
18+
class ConstantScorerSupplier extends ScorerSupplier {
19+
private final Weight weight;
20+
private final float score;
21+
private final ScoreMode scoreMode;
22+
private final DocIdSetIterator matchingDocs;
23+
24+
public ConstantScorerSupplier(Weight weight, float score, ScoreMode scoreMode, DocIdSetIterator matchingDocs) {
25+
this.weight = weight;
26+
this.score = score;
27+
this.scoreMode = scoreMode;
28+
this.matchingDocs = matchingDocs;
29+
}
30+
31+
@Override
32+
public Scorer get(long leadCost) throws IOException {
33+
return new ConstantScoreScorer( weight, score, scoreMode, matchingDocs );
34+
}
35+
36+
@Override
37+
public long cost() {
38+
return matchingDocs.cost();
39+
}
40+
}

backend/lucene/src/main/java/org/hibernate/search/backend/lucene/lowlevel/query/impl/ExplicitDocIdsQuery.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@
99
import java.util.Arrays;
1010

1111
import org.apache.lucene.index.LeafReaderContext;
12-
import org.apache.lucene.search.ConstantScoreScorer;
1312
import org.apache.lucene.search.ConstantScoreWeight;
1413
import org.apache.lucene.search.DocIdSetIterator;
1514
import org.apache.lucene.search.IndexSearcher;
1615
import org.apache.lucene.search.Query;
1716
import org.apache.lucene.search.QueryVisitor;
1817
import org.apache.lucene.search.ScoreDoc;
1918
import org.apache.lucene.search.ScoreMode;
20-
import org.apache.lucene.search.Scorer;
19+
import org.apache.lucene.search.ScorerSupplier;
2120
import org.apache.lucene.search.Weight;
2221

2322
public final class ExplicitDocIdsQuery extends Query {
@@ -58,15 +57,16 @@ public int hashCode() {
5857
@Override
5958
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) {
6059
return new ConstantScoreWeight( this, 1.0f ) {
60+
6161
@Override
62-
public Scorer scorer(LeafReaderContext context) {
62+
public ScorerSupplier scorerSupplier(LeafReaderContext context) {
6363
DocIdSetIterator matchingDocs = ExplicitDocIdSetIterator.of(
6464
sortedDocIds, context.docBase, context.reader().maxDoc()
6565
);
6666
if ( matchingDocs == null ) {
6767
return null; // Skip this leaf
6868
}
69-
return new ConstantScoreScorer( this, this.score(), scoreMode, matchingDocs );
69+
return new ConstantScorerSupplier( this, this.score(), scoreMode, matchingDocs );
7070
}
7171

7272
@Override

backend/lucene/src/main/java/org/hibernate/search/backend/lucene/lowlevel/query/impl/MappedTypeNameQuery.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
*/
77
package org.hibernate.search.backend.lucene.lowlevel.query.impl;
88

9+
import java.io.IOException;
10+
911
import org.hibernate.search.backend.lucene.lowlevel.reader.impl.IndexReaderMetadataResolver;
1012

1113
import org.apache.lucene.index.LeafReaderContext;
@@ -17,6 +19,7 @@
1719
import org.apache.lucene.search.QueryVisitor;
1820
import org.apache.lucene.search.ScoreMode;
1921
import org.apache.lucene.search.Scorer;
22+
import org.apache.lucene.search.ScorerSupplier;
2023
import org.apache.lucene.search.Weight;
2124

2225
public final class MappedTypeNameQuery extends Query {
@@ -53,8 +56,9 @@ public int hashCode() {
5356
@Override
5457
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) {
5558
return new ConstantScoreWeight( this, 1.0f ) {
59+
5660
@Override
57-
public Scorer scorer(LeafReaderContext context) {
61+
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
5862
String leafMappedTypeName = metadataResolver.resolveMappedTypeName( context );
5963
DocIdSetIterator matchingDocs;
6064
if ( mappedTypeName.equals( leafMappedTypeName ) ) {
@@ -63,7 +67,8 @@ public Scorer scorer(LeafReaderContext context) {
6367
else {
6468
matchingDocs = DocIdSetIterator.empty();
6569
}
66-
return new ConstantScoreScorer( this, this.score(), scoreMode, matchingDocs );
70+
71+
return new ConstantScorerSupplier( this, this.score(), scoreMode, matchingDocs );
6772
}
6873

6974
@Override

backend/lucene/src/main/java/org/hibernate/search/backend/lucene/lowlevel/query/impl/VectorSimilarityFilterQuery.java

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.lucene.search.QueryVisitor;
2424
import org.apache.lucene.search.ScoreMode;
2525
import org.apache.lucene.search.Scorer;
26+
import org.apache.lucene.search.ScorerSupplier;
2627
import org.apache.lucene.search.TwoPhaseIterator;
2728
import org.apache.lucene.search.Weight;
2829
import org.apache.lucene.util.VectorUtil;
@@ -148,12 +149,39 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
148149
}
149150

150151
@Override
151-
public Scorer scorer(LeafReaderContext context) throws IOException {
152-
Scorer scorer = super.scorer( context );
152+
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
153+
ScorerSupplier scorerSupplier = super.scorerSupplier( context );
154+
if ( scorerSupplier == null ) {
155+
return null;
156+
}
157+
return new MinScoreScorerSupplier( this, scorerSupplier, similarityAsScore );
158+
}
159+
}
160+
161+
private static class MinScoreScorerSupplier extends ScorerSupplier {
162+
163+
private final Weight weight;
164+
private final ScorerSupplier delegate;
165+
private final float similarityAsScore;
166+
167+
private MinScoreScorerSupplier(Weight weight, ScorerSupplier delegate, float similarityAsScore) {
168+
this.weight = weight;
169+
this.delegate = delegate;
170+
this.similarityAsScore = similarityAsScore;
171+
}
172+
173+
@Override
174+
public Scorer get(long leadCost) throws IOException {
175+
Scorer scorer = delegate.get( leadCost );
153176
if ( scorer == null ) {
154177
return null;
155178
}
156-
return new MinScoreScorer( this, scorer, similarityAsScore );
179+
return new MinScoreScorer( weight, scorer, similarityAsScore );
180+
}
181+
182+
@Override
183+
public long cost() {
184+
return delegate.cost();
157185
}
158186
}
159187

util/internal/integrationtest/backend/lucene/src/main/java/org/hibernate/search/util/impl/integrationtest/backend/lucene/query/SlowQuery.java

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
*/
77
package org.hibernate.search.util.impl.integrationtest.backend.lucene.query;
88

9+
import java.io.IOException;
10+
911
import org.apache.lucene.index.LeafReaderContext;
1012
import org.apache.lucene.search.ConstantScoreScorer;
1113
import org.apache.lucene.search.ConstantScoreWeight;
@@ -16,6 +18,7 @@
1618
import org.apache.lucene.search.QueryVisitor;
1719
import org.apache.lucene.search.ScoreMode;
1820
import org.apache.lucene.search.Scorer;
21+
import org.apache.lucene.search.ScorerSupplier;
1922
import org.apache.lucene.search.Weight;
2023

2124
/**
@@ -55,11 +58,21 @@ public String toString() {
5558
}
5659

5760
@Override
58-
public Scorer scorer(LeafReaderContext context) {
59-
return new ConstantScoreScorer(
60-
this, score(), scoreMode,
61-
new SlowDocIdSetIterator( DocIdSetIterator.all( context.reader().maxDoc() ) )
62-
);
61+
public ScorerSupplier scorerSupplier(LeafReaderContext context) {
62+
Weight weight = this;
63+
float score = score();
64+
SlowDocIdSetIterator iterator = new SlowDocIdSetIterator( DocIdSetIterator.all( context.reader().maxDoc() ) );
65+
return new ScorerSupplier() {
66+
@Override
67+
public Scorer get(long leadCost) throws IOException {
68+
return new ConstantScoreScorer( weight, score, scoreMode, iterator );
69+
}
70+
71+
@Override
72+
public long cost() {
73+
return iterator.cost();
74+
}
75+
};
6376
}
6477

6578
@Override

0 commit comments

Comments
 (0)