Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ProfileCollectorManager in DfsPhase #96689

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9a410bc
WIP introduce MultiProfileCollectorManager
cbuescher Jun 8, 2023
4b94099
Iter, adding tests
cbuescher Jun 8, 2023
14797fd
iter
cbuescher Jun 12, 2023
c34a877
fix docs test
cbuescher Jun 12, 2023
82f8242
Merge branch 'main' into dfs-phase-multithreadedcollectors-wip
cbuescher Jun 12, 2023
7c66bdc
Update docs/changelog/96689.yaml
cbuescher Jun 13, 2023
55a172c
Merge branch 'main' into dfs-phase-multithreadedcollectors-wip
cbuescher Jun 15, 2023
c77664c
wip dfsPhase test
cbuescher Jun 15, 2023
cf9d380
Merge branch 'main' into dfs-phase-multithreadedcollectors-wip
cbuescher Jun 21, 2023
98c25a9
Switching AbstractProfileBreakdown to manage multiple timers per timi…
cbuescher Jun 21, 2023
b988bf7
iter
cbuescher Jun 21, 2023
60cc3aa
fix more tests
cbuescher Jun 21, 2023
c5e5be4
fix more tests
cbuescher Jun 21, 2023
dd1b242
Merge branch 'main' into dfs-phase-multithreadedcollectors-wip
cbuescher Jun 26, 2023
b366795
multiple timers PR 97013 squashed
cbuescher Jun 26, 2023
20537e5
update skip version in yaml test
cbuescher Jun 27, 2023
22aae35
Merge branch 'main' into dfs-phase-multithreadedcollectors-wip
cbuescher Jun 27, 2023
971fb58
cleanup
cbuescher Jun 27, 2023
200e6e2
small test changes
cbuescher Jun 28, 2023
8fa91bf
Changes to CollectorResult output
cbuescher Jun 28, 2023
39ea2ee
spotless
cbuescher Jun 29, 2023
748dbb7
fixing tests
cbuescher Jun 29, 2023
33187c2
Don't report detailed collectors per manager
cbuescher Jun 30, 2023
5ee6598
Merge branch 'main' into dfs-phase-multithreadedcollectors-wip
cbuescher Jun 30, 2023
e141a97
make the profile collector manager typed
cbuescher Jun 30, 2023
0cdd332
assert newCollector has been called
cbuescher Jun 30, 2023
0f6958c
iter test
cbuescher Jul 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/96689.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 96689
summary: Use a collector manager in DfsPhase Knn Search
area: Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ public void testProfileDfs() throws Exception {
CollectorResult result = queryProfileShardResult.getCollectorResult();
assertThat(result.getName(), is(not(emptyOrNullString())));
assertThat(result.getTime(), greaterThan(0L));
assertThat(result.getTime(), greaterThan(0L));
}
ProfileResult statsResult = searchProfileDfsPhaseResult.getDfsShardResult();
assertThat(statsResult.getQueryName(), equalTo("statistics"));
Expand Down
43 changes: 23 additions & 20 deletions server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,20 @@
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopScoreDocCollector;
import org.elasticsearch.index.query.ParsedQuery;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.ContextIndexSearcher;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.profile.Profilers;
import org.elasticsearch.search.profile.Timer;
import org.elasticsearch.search.profile.dfs.DfsProfiler;
import org.elasticsearch.search.profile.dfs.DfsTimingType;
import org.elasticsearch.search.profile.query.CollectorResult;
import org.elasticsearch.search.profile.query.InternalProfileCollector;
import org.elasticsearch.search.profile.query.InternalProfileCollectorManager;
import org.elasticsearch.search.profile.query.ProfileCollectorManager;
import org.elasticsearch.search.profile.query.QueryProfiler;
import org.elasticsearch.search.query.SingleThreadCollectorManager;
import org.elasticsearch.search.rescore.RescoreContext;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
Expand All @@ -43,7 +44,6 @@
/**
* DFS phase of a search request, used to make scoring 100% accurate by collecting additional info from each shard before the query phase.
* The additional information is used to better compare the scores coming from all the shards, which depend on local factors (e.g. idf).
*
* When a kNN search is provided alongside the query, the DFS phase is also used to gather the top k candidates from each shard. Then the
* global top k hits are passed on to the query phase.
*/
Expand Down Expand Up @@ -189,24 +189,27 @@ private void executeKnnVectorQuery(SearchContext context) throws IOException {
List<DfsKnnResults> knnResults = new ArrayList<>(knnVectorQueryBuilders.size());
for (int i = 0; i < knnSearch.size(); i++) {
Query knnQuery = searchExecutionContext.toQuery(knnVectorQueryBuilders.get(i)).query();
TopScoreDocCollector topScoreDocCollector = TopScoreDocCollector.create(knnSearch.get(i).k(), Integer.MAX_VALUE);
CollectorManager<Collector, Void> collectorManager = new SingleThreadCollectorManager(topScoreDocCollector);
if (context.getProfilers() != null) {
InternalProfileCollectorManager ipcm = new InternalProfileCollectorManager(
new InternalProfileCollector(collectorManager.newCollector(), CollectorResult.REASON_SEARCH_TOP_HITS)
);
QueryProfiler knnProfiler = context.getProfilers().getDfsProfiler().addQueryProfiler(ipcm);
collectorManager = ipcm;
// Set the current searcher profiler to gather query profiling information for gathering top K docs
context.searcher().setProfiler(knnProfiler);
}
context.searcher().search(knnQuery, collectorManager);
knnResults.add(new DfsKnnResults(topScoreDocCollector.topDocs().scoreDocs));
knnResults.add(singleKnnSearch(knnQuery, knnSearch.get(i).k(), context.getProfilers(), context.searcher()));
}
context.dfsResult().knnResults(knnResults);
}

static DfsKnnResults singleKnnSearch(Query knnQuery, int k, Profilers profilers, ContextIndexSearcher searcher) throws IOException {
CollectorManager<? extends Collector, TopDocs> cm = TopScoreDocCollector.createSharedManager(k, null, Integer.MAX_VALUE);

if (profilers != null) {
ProfileCollectorManager<TopDocs> ipcm = new ProfileCollectorManager<>(cm, CollectorResult.REASON_SEARCH_TOP_HITS);
QueryProfiler knnProfiler = profilers.getDfsProfiler().addQueryProfiler(ipcm);
cm = ipcm;
// Set the current searcher profiler to gather query profiling information for gathering top K docs
searcher.setProfiler(knnProfiler);
}
TopDocs topDocs = searcher.search(knnQuery, cm);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks better than before. I think that you may need to update existing usages of the profile collector manager to address warnings around their unchecked usage now that it's typed?


// Set profiler back after running KNN searches
if (context.getProfilers() != null) {
context.searcher().setProfiler(context.getProfilers().getCurrentQueryProfiler());
if (profilers != null) {
searcher.setProfiler(profilers.getCurrentQueryProfiler());
}
context.dfsResult().knnResults(knnResults);
return new DfsKnnResults(topDocs.scoreDocs);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.elasticsearch.search.profile.ProfileResult;
import org.elasticsearch.search.profile.SearchProfileDfsPhaseResult;
import org.elasticsearch.search.profile.Timer;
import org.elasticsearch.search.profile.query.InternalProfileCollectorManager;
import org.elasticsearch.search.profile.query.ProfileCollectorManager;
import org.elasticsearch.search.profile.query.QueryProfileShardResult;
import org.elasticsearch.search.profile.query.QueryProfiler;

Expand Down Expand Up @@ -51,7 +51,7 @@ public Timer startTimer(DfsTimingType dfsTimingType) {
return newTimer;
}

public QueryProfiler addQueryProfiler(InternalProfileCollectorManager collectorManager) {
public QueryProfiler addQueryProfiler(ProfileCollectorManager<?> collectorManager) {
QueryProfiler queryProfiler = new QueryProfiler();
queryProfiler.setCollectorManager(collectorManager::getCollectorTree);
knnQueryProfilers.add(queryProfiler);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

Expand All @@ -21,15 +22,15 @@
* in an {@link InternalProfileCollector}. It delegates all the profiling to the generated collectors via {@link #getCollectorTree()}
* and joins them up when its {@link #reduce} method is called. The profile result can
*/
public final class ProfileCollectorManager implements CollectorManager<InternalProfileCollector, Void> {
public final class ProfileCollectorManager<T> implements CollectorManager<InternalProfileCollector, T> {

private final CollectorManager<Collector, ?> collectorManager;
private final CollectorManager<Collector, T> collectorManager;
private final String reason;
private CollectorResult collectorTree;

@SuppressWarnings("unchecked")
public ProfileCollectorManager(CollectorManager<? extends Collector, ?> collectorManager, String reason) {
this.collectorManager = (CollectorManager<Collector, ?>) collectorManager;
public ProfileCollectorManager(CollectorManager<? extends Collector, T> collectorManager, String reason) {
this.collectorManager = (CollectorManager<Collector, T>) collectorManager;
this.reason = reason;
}

Expand All @@ -38,22 +39,26 @@ public InternalProfileCollector newCollector() throws IOException {
return new InternalProfileCollector(collectorManager.newCollector(), reason);
}

public Void reduce(Collection<InternalProfileCollector> profileCollectors) throws IOException {
public T reduce(Collection<InternalProfileCollector> profileCollectors) throws IOException {
assert profileCollectors.size() > 0 : "at least one collector expected";
List<Collector> unwrapped = profileCollectors.stream()
.map(InternalProfileCollector::getWrappedCollector)
.collect(Collectors.toList());
collectorManager.reduce(unwrapped);
T returnValue = collectorManager.reduce(unwrapped);

List<CollectorResult> resultsPerProfiler = profileCollectors.stream()
.map(ipc -> ipc.getCollectorTree())
.collect(Collectors.toList());
this.collectorTree = new CollectorResult(this.getClass().getSimpleName(), "segment_search", 0, resultsPerProfiler);
return null;

long totalTime = resultsPerProfiler.stream().map(CollectorResult::getTime).reduce(0L, Long::sum);
String collectorName = resultsPerProfiler.get(0).getName();
this.collectorTree = new CollectorResult(collectorName, reason, totalTime, Collections.emptyList());
return returnValue;
}

public CollectorResult getCollectorTree() {
if (this.collectorTree == null) {
throw new IllegalStateException("A collectorTree hasn't been set yet, call reduce() before attempting to retrieve it");
throw new IllegalStateException("A collectorTree hasn't been set yet. Call reduce() before attempting to retrieve it");
}
return this.collectorTree;
}
Expand Down
117 changes: 117 additions & 0 deletions server/src/test/java/org/elasticsearch/search/dfs/DfsPhaseTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.search.dfs;

import org.apache.lucene.document.Document;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.search.internal.ContextIndexSearcher;
import org.elasticsearch.search.profile.Profilers;
import org.elasticsearch.search.profile.SearchProfileDfsPhaseResult;
import org.elasticsearch.search.profile.query.CollectorResult;
import org.elasticsearch.search.profile.query.QueryProfileShardResult;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.junit.After;
import org.junit.Before;

import java.io.IOException;
import java.util.List;
import java.util.concurrent.ThreadPoolExecutor;

public class DfsPhaseTests extends ESTestCase {

ThreadPoolExecutor threadPoolExecutor;
private TestThreadPool threadPool;

@Before
public final void init() {
int numThreads = randomIntBetween(2, 4);
threadPool = new TestThreadPool(DfsPhaseTests.class.getName());
threadPoolExecutor = EsExecutors.newFixed(
"test",
numThreads,
10,
EsExecutors.daemonThreadFactory("test"),
threadPool.getThreadContext(),
randomBoolean()
);
}

@After
public void cleanup() {
threadPoolExecutor.shutdown();
terminate(threadPool);
}

public void testSingleKnnSearch() throws IOException {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this test! Do we need to test this specific method or could we instead test executeKnnVectorQuery directly? Maybe as a follow-up we can unit test also the terms statistics collection? we should have totally written these tests 10 years ago but it seems like we haven't :)

try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig())) {
int numDocs = randomIntBetween(900, 1000);
for (int i = 0; i < numDocs; i++) {
Document d = new Document();
d.add(new KnnFloatVectorField("float_vector", new float[] { i, 0, 0 }));
w.addDocument(d);
}
w.flush();

IndexReader reader = w.getReader();
ContextIndexSearcher searcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
randomBoolean(),
this.threadPoolExecutor
) {
@Override
protected LeafSlice[] slices(List<LeafReaderContext> leaves) {
// get a thread per segment
return slices(leaves, 1, 1);
}
};

Query query = new KnnFloatVectorQuery("float_vector", new float[] { 0, 0, 0 }, numDocs, null);

int k = 10;
// run without profiling enabled
DfsKnnResults dfsKnnResults = DfsPhase.singleKnnSearch(query, k, null, searcher);
assertEquals(k, dfsKnnResults.scoreDocs().length);

// run with profiling enabled
Profilers profilers = new Profilers(searcher);
dfsKnnResults = DfsPhase.singleKnnSearch(query, k, profilers, searcher);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we have two tests, one with profiling and one without?

assertEquals(k, dfsKnnResults.scoreDocs().length);
SearchProfileDfsPhaseResult searchProfileDfsPhaseResult = profilers.getDfsProfiler().buildDfsPhaseResults();
List<QueryProfileShardResult> queryProfileShardResult = searchProfileDfsPhaseResult.getQueryProfileShardResult();
assertNotNull(queryProfileShardResult);
CollectorResult collectorResult = queryProfileShardResult.get(0).getCollectorResult();
assertEquals("SimpleTopScoreDocCollector", (collectorResult.getName()));
assertEquals("search_top_hits", (collectorResult.getReason()));
assertTrue(collectorResult.getTime() > 0);
List<CollectorResult> children = collectorResult.getCollectorResults();
if (children.size() > 0) {
long totalTime = 0L;
for (CollectorResult child : children) {
assertEquals("SimpleTopScoreDocCollector", (child.getName()));
assertEquals("search_top_hits", (child.getReason()));
totalTime += child.getTime();
}
assertEquals(totalTime, collectorResult.getTime());
}
reader.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import org.apache.lucene.document.Field;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.sandbox.search.ProfilerCollectorResult;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
Expand All @@ -27,6 +25,7 @@
import org.elasticsearch.test.ESTestCase;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
Expand All @@ -48,29 +47,41 @@ private static class TestCollector extends DummyTotalHitCountCollector {
*/
public void testBasic() throws IOException {
final SetOnce<Boolean> reduceCalled = new SetOnce<>();
ProfileCollectorManager pcm = new ProfileCollectorManager(new CollectorManager<>() {
ProfileCollectorManager<Integer> pcm = new ProfileCollectorManager<>(new CollectorManager<TestCollector, Integer>() {

private static int counter = 0;
private int counter = 0;

@Override
public Collector newCollector() {
public TestCollector newCollector() {
return new TestCollector(counter++);
}

@Override
public Void reduce(Collection<Collector> collectors) {
public Integer reduce(Collection<TestCollector> collectors) {
reduceCalled.set(true);
return null;
return counter;
}
}, CollectorResult.REASON_SEARCH_TOP_HITS);
for (int i = 0; i < randomIntBetween(5, 10); i++) {
InternalProfileCollector internalProfileCollector = pcm.newCollector();
assertEquals(i, ((TestCollector) internalProfileCollector.getWrappedCollector()).id);
int runs = randomIntBetween(5, 10);
List<InternalProfileCollector> collectors = new ArrayList<>();
for (int i = 0; i < runs; i++) {
collectors.add(pcm.newCollector());
assertEquals(i, ((TestCollector) collectors.get(i).getWrappedCollector()).id);
}
pcm.reduce(Collections.emptyList());
Integer returnValue = pcm.reduce(collectors);
assertEquals(runs, returnValue.intValue());
assertTrue(reduceCalled.get());
}

public void testReduceEmpty() {
ProfileCollectorManager<TopDocs> pcm = new ProfileCollectorManager<>(
TopScoreDocCollector.createSharedManager(10, null, 1000),
CollectorResult.REASON_SEARCH_TOP_HITS
);
AssertionError ae = expectThrows(AssertionError.class, () -> pcm.reduce(Collections.emptyList()));
assertEquals("at least one collector expected", ae.getMessage());
}

/**
* This test checks functionality with potentially more than one slice on a real searcher,
* wrapping a {@link TopScoreDocCollector} into {@link ProfileCollectorManager} and checking the
Expand All @@ -88,29 +99,22 @@ public void testManagerWithSearcher() throws IOException {
writer.flush();
IndexReader reader = writer.getReader();
IndexSearcher searcher = newSearcher(reader);
int numSlices = searcher.getSlices() == null ? 1 : searcher.getSlices().length;
searcher.setSimilarity(new BM25Similarity());

CollectorManager<TopScoreDocCollector, TopDocs> topDocsManager = TopScoreDocCollector.createSharedManager(10, null, 1000);
TopDocs topDocs = searcher.search(new MatchAllDocsQuery(), topDocsManager);
assertEquals(numDocs, topDocs.totalHits.value);

String profileReason = "profiler_reason";
ProfileCollectorManager profileCollectorManager = new ProfileCollectorManager(topDocsManager, profileReason);
ProfileCollectorManager<TopDocs> profileCollectorManager = new ProfileCollectorManager<>(topDocsManager, profileReason);

searcher.search(new MatchAllDocsQuery(), profileCollectorManager);

CollectorResult parent = profileCollectorManager.getCollectorTree();
assertEquals("ProfileCollectorManager", parent.getName());
assertEquals("segment_search", parent.getReason());
assertEquals(0, parent.getTime());
List<ProfilerCollectorResult> delegateCollectorResults = parent.getProfiledChildren();
assertEquals(numSlices, delegateCollectorResults.size());
for (ProfilerCollectorResult pcr : delegateCollectorResults) {
assertEquals("SimpleTopScoreDocCollector", pcr.getName());
assertEquals(profileReason, pcr.getReason());
assertTrue(pcr.getTime() > 0);
}
CollectorResult result = profileCollectorManager.getCollectorTree();
assertEquals("profiler_reason", result.getReason());
assertEquals("SimpleTopScoreDocCollector", result.getName());
assertTrue(result.getTime() > 0);

reader.close();
}
directory.close();
Expand Down