diff --git a/docs/changelog/96689.yaml b/docs/changelog/96689.yaml new file mode 100644 index 0000000000000..220624b8c1eca --- /dev/null +++ b/docs/changelog/96689.yaml @@ -0,0 +1,5 @@ +pr: 96689 +summary: Use a collector manager in DfsPhase Knn Search +area: Search +type: enhancement +issues: [] diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java index 46ce6c21474e7..fce19a316b34a 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java @@ -111,7 +111,6 @@ public void testProfileDfs() throws Exception { CollectorResult result = queryProfileShardResult.getCollectorResult(); assertThat(result.getName(), is(not(emptyOrNullString()))); assertThat(result.getTime(), greaterThan(0L)); - assertThat(result.getTime(), greaterThan(0L)); } ProfileResult statsResult = searchProfileDfsPhaseResult.getDfsShardResult(); assertThat(statsResult.getQueryName(), equalTo("statistics")); diff --git a/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java b/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java index 78d2dd5fc9b44..84794b82ad00a 100644 --- a/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java +++ b/server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java @@ -16,19 +16,20 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TermStatistics; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopScoreDocCollector; import org.elasticsearch.index.query.ParsedQuery; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.internal.ContextIndexSearcher; import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.search.profile.Profilers; import org.elasticsearch.search.profile.Timer; import org.elasticsearch.search.profile.dfs.DfsProfiler; import org.elasticsearch.search.profile.dfs.DfsTimingType; import org.elasticsearch.search.profile.query.CollectorResult; -import org.elasticsearch.search.profile.query.InternalProfileCollector; -import org.elasticsearch.search.profile.query.InternalProfileCollectorManager; +import org.elasticsearch.search.profile.query.ProfileCollectorManager; import org.elasticsearch.search.profile.query.QueryProfiler; -import org.elasticsearch.search.query.SingleThreadCollectorManager; import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.vectors.KnnSearchBuilder; import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; @@ -43,7 +44,6 @@ /** * DFS phase of a search request, used to make scoring 100% accurate by collecting additional info from each shard before the query phase. * The additional information is used to better compare the scores coming from all the shards, which depend on local factors (e.g. idf). - * * When a kNN search is provided alongside the query, the DFS phase is also used to gather the top k candidates from each shard. Then the * global top k hits are passed on to the query phase. */ @@ -189,24 +189,27 @@ private void executeKnnVectorQuery(SearchContext context) throws IOException { List knnResults = new ArrayList<>(knnVectorQueryBuilders.size()); for (int i = 0; i < knnSearch.size(); i++) { Query knnQuery = searchExecutionContext.toQuery(knnVectorQueryBuilders.get(i)).query(); - TopScoreDocCollector topScoreDocCollector = TopScoreDocCollector.create(knnSearch.get(i).k(), Integer.MAX_VALUE); - CollectorManager collectorManager = new SingleThreadCollectorManager(topScoreDocCollector); - if (context.getProfilers() != null) { - InternalProfileCollectorManager ipcm = new InternalProfileCollectorManager( - new InternalProfileCollector(collectorManager.newCollector(), CollectorResult.REASON_SEARCH_TOP_HITS) - ); - QueryProfiler knnProfiler = context.getProfilers().getDfsProfiler().addQueryProfiler(ipcm); - collectorManager = ipcm; - // Set the current searcher profiler to gather query profiling information for gathering top K docs - context.searcher().setProfiler(knnProfiler); - } - context.searcher().search(knnQuery, collectorManager); - knnResults.add(new DfsKnnResults(topScoreDocCollector.topDocs().scoreDocs)); + knnResults.add(singleKnnSearch(knnQuery, knnSearch.get(i).k(), context.getProfilers(), context.searcher())); } + context.dfsResult().knnResults(knnResults); + } + + static DfsKnnResults singleKnnSearch(Query knnQuery, int k, Profilers profilers, ContextIndexSearcher searcher) throws IOException { + CollectorManager cm = TopScoreDocCollector.createSharedManager(k, null, Integer.MAX_VALUE); + + if (profilers != null) { + ProfileCollectorManager ipcm = new ProfileCollectorManager<>(cm, CollectorResult.REASON_SEARCH_TOP_HITS); + QueryProfiler knnProfiler = profilers.getDfsProfiler().addQueryProfiler(ipcm); + cm = ipcm; + // Set the current searcher profiler to gather query profiling information for gathering top K docs + searcher.setProfiler(knnProfiler); + } + TopDocs topDocs = searcher.search(knnQuery, cm); + // Set profiler back after running KNN searches - if (context.getProfilers() != null) { - context.searcher().setProfiler(context.getProfilers().getCurrentQueryProfiler()); + if (profilers != null) { + searcher.setProfiler(profilers.getCurrentQueryProfiler()); } - context.dfsResult().knnResults(knnResults); + return new DfsKnnResults(topDocs.scoreDocs); } } diff --git a/server/src/main/java/org/elasticsearch/search/profile/dfs/DfsProfiler.java b/server/src/main/java/org/elasticsearch/search/profile/dfs/DfsProfiler.java index 4c0c56327813f..c04336fe53096 100644 --- a/server/src/main/java/org/elasticsearch/search/profile/dfs/DfsProfiler.java +++ b/server/src/main/java/org/elasticsearch/search/profile/dfs/DfsProfiler.java @@ -12,7 +12,7 @@ import org.elasticsearch.search.profile.ProfileResult; import org.elasticsearch.search.profile.SearchProfileDfsPhaseResult; import org.elasticsearch.search.profile.Timer; -import org.elasticsearch.search.profile.query.InternalProfileCollectorManager; +import org.elasticsearch.search.profile.query.ProfileCollectorManager; import org.elasticsearch.search.profile.query.QueryProfileShardResult; import org.elasticsearch.search.profile.query.QueryProfiler; @@ -51,7 +51,7 @@ public Timer startTimer(DfsTimingType dfsTimingType) { return newTimer; } - public QueryProfiler addQueryProfiler(InternalProfileCollectorManager collectorManager) { + public QueryProfiler addQueryProfiler(ProfileCollectorManager collectorManager) { QueryProfiler queryProfiler = new QueryProfiler(); queryProfiler.setCollectorManager(collectorManager::getCollectorTree); knnQueryProfilers.add(queryProfiler); diff --git a/server/src/main/java/org/elasticsearch/search/profile/query/ProfileCollectorManager.java b/server/src/main/java/org/elasticsearch/search/profile/query/ProfileCollectorManager.java index 548016a01f9c7..b9f5a36491cd6 100644 --- a/server/src/main/java/org/elasticsearch/search/profile/query/ProfileCollectorManager.java +++ b/server/src/main/java/org/elasticsearch/search/profile/query/ProfileCollectorManager.java @@ -13,6 +13,7 @@ import java.io.IOException; import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.stream.Collectors; @@ -21,15 +22,15 @@ * in an {@link InternalProfileCollector}. It delegates all the profiling to the generated collectors via {@link #getCollectorTree()} * and joins them up when its {@link #reduce} method is called. The profile result can */ -public final class ProfileCollectorManager implements CollectorManager { +public final class ProfileCollectorManager implements CollectorManager { - private final CollectorManager collectorManager; + private final CollectorManager collectorManager; private final String reason; private CollectorResult collectorTree; @SuppressWarnings("unchecked") - public ProfileCollectorManager(CollectorManager collectorManager, String reason) { - this.collectorManager = (CollectorManager) collectorManager; + public ProfileCollectorManager(CollectorManager collectorManager, String reason) { + this.collectorManager = (CollectorManager) collectorManager; this.reason = reason; } @@ -38,22 +39,26 @@ public InternalProfileCollector newCollector() throws IOException { return new InternalProfileCollector(collectorManager.newCollector(), reason); } - public Void reduce(Collection profileCollectors) throws IOException { + public T reduce(Collection profileCollectors) throws IOException { + assert profileCollectors.size() > 0 : "at least one collector expected"; List unwrapped = profileCollectors.stream() .map(InternalProfileCollector::getWrappedCollector) .collect(Collectors.toList()); - collectorManager.reduce(unwrapped); + T returnValue = collectorManager.reduce(unwrapped); List resultsPerProfiler = profileCollectors.stream() .map(ipc -> ipc.getCollectorTree()) .collect(Collectors.toList()); - this.collectorTree = new CollectorResult(this.getClass().getSimpleName(), "segment_search", 0, resultsPerProfiler); - return null; + + long totalTime = resultsPerProfiler.stream().map(CollectorResult::getTime).reduce(0L, Long::sum); + String collectorName = resultsPerProfiler.get(0).getName(); + this.collectorTree = new CollectorResult(collectorName, reason, totalTime, Collections.emptyList()); + return returnValue; } public CollectorResult getCollectorTree() { if (this.collectorTree == null) { - throw new IllegalStateException("A collectorTree hasn't been set yet, call reduce() before attempting to retrieve it"); + throw new IllegalStateException("A collectorTree hasn't been set yet. Call reduce() before attempting to retrieve it"); } return this.collectorTree; } diff --git a/server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java b/server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java new file mode 100644 index 0000000000000..19f1d251d7a9a --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java @@ -0,0 +1,117 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.dfs; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.search.internal.ContextIndexSearcher; +import org.elasticsearch.search.profile.Profilers; +import org.elasticsearch.search.profile.SearchProfileDfsPhaseResult; +import org.elasticsearch.search.profile.query.CollectorResult; +import org.elasticsearch.search.profile.query.QueryProfileShardResult; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.ThreadPoolExecutor; + +public class DfsPhaseTests extends ESTestCase { + + ThreadPoolExecutor threadPoolExecutor; + private TestThreadPool threadPool; + + @Before + public final void init() { + int numThreads = randomIntBetween(2, 4); + threadPool = new TestThreadPool(DfsPhaseTests.class.getName()); + threadPoolExecutor = EsExecutors.newFixed( + "test", + numThreads, + 10, + EsExecutors.daemonThreadFactory("test"), + threadPool.getThreadContext(), + randomBoolean() + ); + } + + @After + public void cleanup() { + threadPoolExecutor.shutdown(); + terminate(threadPool); + } + + public void testSingleKnnSearch() throws IOException { + try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig())) { + int numDocs = randomIntBetween(900, 1000); + for (int i = 0; i < numDocs; i++) { + Document d = new Document(); + d.add(new KnnFloatVectorField("float_vector", new float[] { i, 0, 0 })); + w.addDocument(d); + } + w.flush(); + + IndexReader reader = w.getReader(); + ContextIndexSearcher searcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + randomBoolean(), + this.threadPoolExecutor + ) { + @Override + protected LeafSlice[] slices(List leaves) { + // get a thread per segment + return slices(leaves, 1, 1); + } + }; + + Query query = new KnnFloatVectorQuery("float_vector", new float[] { 0, 0, 0 }, numDocs, null); + + int k = 10; + // run without profiling enabled + DfsKnnResults dfsKnnResults = DfsPhase.singleKnnSearch(query, k, null, searcher); + assertEquals(k, dfsKnnResults.scoreDocs().length); + + // run with profiling enabled + Profilers profilers = new Profilers(searcher); + dfsKnnResults = DfsPhase.singleKnnSearch(query, k, profilers, searcher); + assertEquals(k, dfsKnnResults.scoreDocs().length); + SearchProfileDfsPhaseResult searchProfileDfsPhaseResult = profilers.getDfsProfiler().buildDfsPhaseResults(); + List queryProfileShardResult = searchProfileDfsPhaseResult.getQueryProfileShardResult(); + assertNotNull(queryProfileShardResult); + CollectorResult collectorResult = queryProfileShardResult.get(0).getCollectorResult(); + assertEquals("SimpleTopScoreDocCollector", (collectorResult.getName())); + assertEquals("search_top_hits", (collectorResult.getReason())); + assertTrue(collectorResult.getTime() > 0); + List children = collectorResult.getCollectorResults(); + if (children.size() > 0) { + long totalTime = 0L; + for (CollectorResult child : children) { + assertEquals("SimpleTopScoreDocCollector", (child.getName())); + assertEquals("search_top_hits", (child.getReason())); + totalTime += child.getTime(); + } + assertEquals(totalTime, collectorResult.getTime()); + } + reader.close(); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/search/profile/query/ProfileCollectorManagerTests.java b/server/src/test/java/org/elasticsearch/search/profile/query/ProfileCollectorManagerTests.java index b018f16442df6..209efe59333c8 100644 --- a/server/src/test/java/org/elasticsearch/search/profile/query/ProfileCollectorManagerTests.java +++ b/server/src/test/java/org/elasticsearch/search/profile/query/ProfileCollectorManagerTests.java @@ -12,8 +12,6 @@ import org.apache.lucene.document.Field; import org.apache.lucene.document.StringField; import org.apache.lucene.index.IndexReader; -import org.apache.lucene.sandbox.search.ProfilerCollectorResult; -import org.apache.lucene.search.Collector; import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; @@ -27,6 +25,7 @@ import org.elasticsearch.test.ESTestCase; import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; @@ -48,29 +47,41 @@ private static class TestCollector extends DummyTotalHitCountCollector { */ public void testBasic() throws IOException { final SetOnce reduceCalled = new SetOnce<>(); - ProfileCollectorManager pcm = new ProfileCollectorManager(new CollectorManager<>() { + ProfileCollectorManager pcm = new ProfileCollectorManager<>(new CollectorManager() { - private static int counter = 0; + private int counter = 0; @Override - public Collector newCollector() { + public TestCollector newCollector() { return new TestCollector(counter++); } @Override - public Void reduce(Collection collectors) { + public Integer reduce(Collection collectors) { reduceCalled.set(true); - return null; + return counter; } }, CollectorResult.REASON_SEARCH_TOP_HITS); - for (int i = 0; i < randomIntBetween(5, 10); i++) { - InternalProfileCollector internalProfileCollector = pcm.newCollector(); - assertEquals(i, ((TestCollector) internalProfileCollector.getWrappedCollector()).id); + int runs = randomIntBetween(5, 10); + List collectors = new ArrayList<>(); + for (int i = 0; i < runs; i++) { + collectors.add(pcm.newCollector()); + assertEquals(i, ((TestCollector) collectors.get(i).getWrappedCollector()).id); } - pcm.reduce(Collections.emptyList()); + Integer returnValue = pcm.reduce(collectors); + assertEquals(runs, returnValue.intValue()); assertTrue(reduceCalled.get()); } + public void testReduceEmpty() { + ProfileCollectorManager pcm = new ProfileCollectorManager<>( + TopScoreDocCollector.createSharedManager(10, null, 1000), + CollectorResult.REASON_SEARCH_TOP_HITS + ); + AssertionError ae = expectThrows(AssertionError.class, () -> pcm.reduce(Collections.emptyList())); + assertEquals("at least one collector expected", ae.getMessage()); + } + /** * This test checks functionality with potentially more than one slice on a real searcher, * wrapping a {@link TopScoreDocCollector} into {@link ProfileCollectorManager} and checking the @@ -88,7 +99,6 @@ public void testManagerWithSearcher() throws IOException { writer.flush(); IndexReader reader = writer.getReader(); IndexSearcher searcher = newSearcher(reader); - int numSlices = searcher.getSlices() == null ? 1 : searcher.getSlices().length; searcher.setSimilarity(new BM25Similarity()); CollectorManager topDocsManager = TopScoreDocCollector.createSharedManager(10, null, 1000); @@ -96,21 +106,15 @@ public void testManagerWithSearcher() throws IOException { assertEquals(numDocs, topDocs.totalHits.value); String profileReason = "profiler_reason"; - ProfileCollectorManager profileCollectorManager = new ProfileCollectorManager(topDocsManager, profileReason); + ProfileCollectorManager profileCollectorManager = new ProfileCollectorManager<>(topDocsManager, profileReason); searcher.search(new MatchAllDocsQuery(), profileCollectorManager); - CollectorResult parent = profileCollectorManager.getCollectorTree(); - assertEquals("ProfileCollectorManager", parent.getName()); - assertEquals("segment_search", parent.getReason()); - assertEquals(0, parent.getTime()); - List delegateCollectorResults = parent.getProfiledChildren(); - assertEquals(numSlices, delegateCollectorResults.size()); - for (ProfilerCollectorResult pcr : delegateCollectorResults) { - assertEquals("SimpleTopScoreDocCollector", pcr.getName()); - assertEquals(profileReason, pcr.getReason()); - assertTrue(pcr.getTime() > 0); - } + CollectorResult result = profileCollectorManager.getCollectorTree(); + assertEquals("profiler_reason", result.getReason()); + assertEquals("SimpleTopScoreDocCollector", result.getName()); + assertTrue(result.getTime() > 0); + reader.close(); } directory.close();