diff --git a/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java b/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java index 7f91d778105dd..c85499411be06 100644 --- a/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java +++ b/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java @@ -110,7 +110,24 @@ public QueryPhase() { @Override public void preProcess(SearchContext context) { - context.preProcess(true); + final Runnable cancellation; + if (context.lowLevelCancellation()) { + SearchShardTask task = context.getTask(); + cancellation = context.searcher().addQueryCancellation(() -> { + if (task.isCancelled()) { + throw new TaskCancelledException("cancelled"); + } + }); + } else { + cancellation = null; + } + try { + context.preProcess(true); + } finally { + if (cancellation != null) { + context.searcher().removeQueryCancellation(cancellation); + } + } } @Override diff --git a/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java b/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java index 4b70153261d3f..95330ea8fd04f 100644 --- a/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java @@ -53,6 +53,8 @@ import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.MultiTermQuery; +import org.apache.lucene.search.PrefixQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.Sort; @@ -76,6 +78,7 @@ import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.index.query.ParsedQuery; +import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.index.search.ESToParentBlockJoinQuery; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.IndexShardTestCase; @@ -84,6 +87,7 @@ import org.elasticsearch.search.internal.ScrollContext; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.sort.SortAndFormats; +import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.test.TestSearchContext; import java.io.IOException; @@ -834,7 +838,59 @@ public void testMinScore() throws Exception { reader.close(); dir.close(); + } + + public void testCancellationDuringPreprocess() throws IOException { + try (Directory dir = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig())) { + + for (int i = 0; i < 10; i++) { + Document doc = new Document(); + StringBuilder sb = new StringBuilder(); + for (int j = 0; j < i; j++) { + sb.append('a'); + } + doc.add(new StringField("foo", sb.toString(), Store.NO)); + w.addDocument(doc); + } + w.flush(); + w.close(); + + try (IndexReader reader = DirectoryReader.open(dir)) { + TestSearchContext context = new TestSearchContextWithRewriteAndCancellation( + null, indexShard, newContextSearcher(reader)); + PrefixQuery prefixQuery = new PrefixQuery(new Term("foo", "a")); + prefixQuery.setRewriteMethod(MultiTermQuery.SCORING_BOOLEAN_REWRITE); + context.parsedQuery(new ParsedQuery(prefixQuery)); + SearchShardTask task = mock(SearchShardTask.class); + when(task.isCancelled()).thenReturn(true); + context.setTask(task); + expectThrows(TaskCancelledException.class, () -> new QueryPhase().preProcess(context)); + } + } + } + + private static class TestSearchContextWithRewriteAndCancellation extends TestSearchContext { + private TestSearchContextWithRewriteAndCancellation(QueryShardContext queryShardContext, + IndexShard indexShard, + ContextIndexSearcher searcher) { + super(queryShardContext, indexShard, searcher); + } + + @Override + public void preProcess(boolean rewrite) { + try { + searcher().rewrite(query()); + } catch (IOException e) { + fail("IOException shouldn't be thrown"); + } + } + + @Override + public boolean lowLevelCancellation() { + return true; + } } private static ContextIndexSearcher newContextSearcher(IndexReader reader) throws IOException {