diff --git a/sql/src/main/java/io/crate/operation/collect/LuceneDocCollector.java b/sql/src/main/java/io/crate/operation/collect/LuceneDocCollector.java index c2bff370c693..8a538d168093 100644 --- a/sql/src/main/java/io/crate/operation/collect/LuceneDocCollector.java +++ b/sql/src/main/java/io/crate/operation/collect/LuceneDocCollector.java @@ -21,6 +21,8 @@ package io.crate.operation.collect; +import com.google.common.base.Predicates; +import com.google.common.collect.Iterables; import io.crate.Constants; import io.crate.action.sql.query.CrateSearchContext; import io.crate.action.sql.query.LuceneSortGenerator; @@ -30,10 +32,7 @@ import io.crate.lucene.QueryBuilderHelper; import io.crate.metadata.Functions; import io.crate.operation.*; -import io.crate.operation.reference.doc.lucene.CollectorContext; -import io.crate.operation.reference.doc.lucene.LuceneCollectorExpression; -import io.crate.operation.reference.doc.lucene.LuceneDocLevelReferenceResolver; -import io.crate.operation.reference.doc.lucene.OrderByCollectorExpression; +import io.crate.operation.reference.doc.lucene.*; import io.crate.planner.node.dql.CollectNode; import io.crate.planner.symbol.Reference; import io.crate.planner.symbol.Symbol; @@ -42,9 +41,11 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.index.fieldvisitor.FieldsVisitor; import org.elasticsearch.index.mapper.internal.SourceFieldMapper; +import org.elasticsearch.search.internal.ContextIndexSearcher; import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.concurrent.CancellationException; @@ -102,7 +103,6 @@ public void required(boolean required) { private RamAccountingContext ramAccountingContext; private boolean producedRows = false; private boolean failed = false; - private Scorer scorer; private int rowCount = 0; private int pageSize; @@ -131,7 +131,6 @@ public LuceneDocCollector(List> inputs, @Override public void setScorer(Scorer scorer) throws IOException { - this.scorer = scorer; for (LuceneCollectorExpression expr : collectorExpressions) { expr.setScorer(scorer); } @@ -142,19 +141,14 @@ public void collect(int doc) throws IOException { if (shardContext.isKilled()) { throw new CancellationException(); } - - rowCount++; if (ramAccountingContext != null && ramAccountingContext.trippedBreaker()) { // stop collecting because breaker limit was reached throw new UnexpectedCollectionTerminatedException( CrateCircuitBreakerService.breakingExceptionMessage(ramAccountingContext.contextId(), ramAccountingContext.limit())); } - // validate minimum score - if (searchContext.minimumScore() != null && scorer.score() < searchContext.minimumScore()) { - return; - } + rowCount++; producedRows = true; if (visitorEnabled) { fieldsVisitor.reset(); @@ -202,12 +196,13 @@ public void doCollect(JobCollectContext jobCollectContext) { } visitorEnabled = fieldsVisitor.required(); shardContext.acquireContext(); + searchContext.searcher().inStage(ContextIndexSearcher.Stage.MAIN_QUERY); Query query = searchContext.query(); try { assert query != null : "query must not be null"; - if( orderBy != null) { + if(orderBy != null) { searchWithOrderBy(jobCollectContext, query); } else { searchContext.searcher().search(query, this); @@ -219,6 +214,7 @@ public void doCollect(JobCollectContext jobCollectContext) { failed = true; downstream.fail(shardContext.isKilled() ? new CancellationException() : e); } finally { + searchContext().searcher().finishStage(ContextIndexSearcher.Stage.MAIN_QUERY); shardContext.releaseContext(); shardContext.close(); } @@ -229,7 +225,9 @@ private void searchWithOrderBy(JobCollectContext jobCollectContext, Query query) Sort sort = LuceneSortGenerator.generateLuceneSort(searchContext, orderBy, inputSymbolVisitor); TopFieldDocs topFieldDocs = searchContext.searcher().search(query, batchSize, sort); int collected = topFieldDocs.scoreDocs.length; - ScoreDoc lastCollected = collectTopFields(topFieldDocs); + + Collection scoreExpressions = getScoreExpressions(); + ScoreDoc lastCollected = collectTopFields(topFieldDocs, scoreExpressions); while ((limit == null || collected < limit) && topFieldDocs.scoreDocs.length >= batchSize && lastCollected != null) { jobCollectContext.interruptIfKilled(); @@ -244,10 +242,20 @@ private void searchWithOrderBy(JobCollectContext jobCollectContext, Query query) topFieldDocs = (TopFieldDocs)searchContext.searcher().searchAfter(lastCollected, query, batchSize, sort); } collected += topFieldDocs.scoreDocs.length; - lastCollected = collectTopFields(topFieldDocs); + lastCollected = collectTopFields(topFieldDocs, scoreExpressions); } } + private Collection getScoreExpressions() { + List scoreCollectorExpressions = new ArrayList<>(); + for (LuceneCollectorExpression expression : collectorExpressions) { + if (expression instanceof ScoreCollectorExpression) { + scoreCollectorExpressions.add((ScoreCollectorExpression) expression); + } + } + return scoreCollectorExpressions; + } + public CrateSearchContext searchContext() { return searchContext; } @@ -264,7 +272,7 @@ public void pageSize(int pageSize) { this.pageSize = pageSize; } - private ScoreDoc collectTopFields(TopFieldDocs topFieldDocs) throws IOException{ + private ScoreDoc collectTopFields(TopFieldDocs topFieldDocs, Collection scoreExpressions) throws IOException{ IndexReaderContext indexReaderContext = searchContext.searcher().getTopReaderContext(); ScoreDoc lastDoc = null; if(!indexReaderContext.leaves().isEmpty()) { @@ -274,6 +282,9 @@ private ScoreDoc collectTopFields(TopFieldDocs topFieldDocs) throws IOException{ int subDoc = scoreDoc.doc - subReaderContext.docBase; setNextReader(subReaderContext); setNextOrderByValues(scoreDoc); + for (LuceneCollectorExpression scoreExpression : scoreExpressions) { + ((ScoreCollectorExpression) scoreExpression).score(scoreDoc.score); + } collect(subDoc); lastDoc = scoreDoc; } diff --git a/sql/src/main/java/io/crate/operation/reference/doc/lucene/ScoreCollectorExpression.java b/sql/src/main/java/io/crate/operation/reference/doc/lucene/ScoreCollectorExpression.java index 483121bd0dd4..742ee29fd83c 100644 --- a/sql/src/main/java/io/crate/operation/reference/doc/lucene/ScoreCollectorExpression.java +++ b/sql/src/main/java/io/crate/operation/reference/doc/lucene/ScoreCollectorExpression.java @@ -45,8 +45,15 @@ public Float value() { return score; } + public void score(float score) { + this.score = score; + } + @Override public void setNextDocId(int doc) { + if (scorer == null) { + return; + } try { score = scorer.score(); } catch (IOException e) { diff --git a/sql/src/test/java/io/crate/integrationtests/TransportSQLActionTest.java b/sql/src/test/java/io/crate/integrationtests/TransportSQLActionTest.java index 02cd8537653f..025d23ab323a 100644 --- a/sql/src/test/java/io/crate/integrationtests/TransportSQLActionTest.java +++ b/sql/src/test/java/io/crate/integrationtests/TransportSQLActionTest.java @@ -946,8 +946,6 @@ public void testSelectWhereScore() throws Exception { execute("create table quotes (quote string, " + "index quote_ft using fulltext(quote)) with (number_of_replicas = 0)"); ensureYellow(); - assertTrue(client().admin().indices().exists(new IndicesExistsRequest("quotes")) - .actionGet().isExists()); execute("insert into quotes values (?), (?)", new Object[]{"Would it save you a lot of time if I just gave up and went mad now?", @@ -959,6 +957,11 @@ public void testSelectWhereScore() throws Exception { "and \"_score\" >= 0.98"); assertEquals(1L, response.rowCount()); assertThat((Float) response.rows()[0][1], greaterThanOrEqualTo(0.98f)); + + execute("select quote, \"_score\" from quotes where match(quote_ft, 'time') " + + "and \"_score\" >= 0.98 order by quote "); + assertEquals(1L, response.rowCount()); + assertThat((Float) response.rows()[0][1], greaterThanOrEqualTo(0.98f)); } @Test