Skip to content

Commit

Permalink
Add profiling information for knn vector queries (#90200)
Browse files Browse the repository at this point in the history
This adds timers to the dfs phase to profile a knn vector query and provide a breakdown of several 
parts of the query.
  • Loading branch information
jdconrad committed Sep 26, 2022
1 parent 0764ddc commit 94f05da
Show file tree
Hide file tree
Showing 14 changed files with 427 additions and 73 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/90200.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 90200
summary: Add profiling information for knn vector queries
area: Vector Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,119 @@ disabling stored fields removes fetch sub phases:
- match: { hits.hits.0._index: test }
- match: { profile.shards.0.fetch.debug.stored_fields: [] }
- is_false: profile.shards.0.fetch.children

---
dfs knn vector profiling:
- skip:
version: ' - 8.5.99'
reason: dfs profiling implemented in 8.6.0

- do:
indices.create:
index: images
body:
settings:
index.number_of_shards: 1
mappings:
properties:
image:
type: "dense_vector"
dims: 3
index: true
similarity: "l2_norm"

- do:
index:
index: images
id: "1"
refresh: true
body:
image: [1, 5, -20]

- do:
search:
index: images
body:
profile: true
knn:
field: "image"
query_vector: [-5, 9, -12]
k: 1
num_candidates: 100

- match: { hits.total.value: 1 }
- match: { profile.shards.0.dfs.knn.query.0.type: "DocAndScoreQuery" }
- match: { profile.shards.0.dfs.knn.query.0.description: "DocAndScore[100]" }
- gt: { profile.shards.0.dfs.knn.query.0.time_in_nanos: 0 }
- match: { profile.shards.0.dfs.knn.query.0.breakdown.set_min_competitive_score_count: 0 }
- match: { profile.shards.0.dfs.knn.query.0.breakdown.set_min_competitive_score: 0 }
- match: { profile.shards.0.dfs.knn.query.0.breakdown.match_count: 0 }
- match: { profile.shards.0.dfs.knn.query.0.breakdown.match: 0 }
- match: { profile.shards.0.dfs.knn.query.0.breakdown.shallow_advance_count: 0 }
- match: { profile.shards.0.dfs.knn.query.0.breakdown.shallow_advance: 0 }
- gt: { profile.shards.0.dfs.knn.query.0.breakdown.next_doc_count: 0 }
- gt: { profile.shards.0.dfs.knn.query.0.breakdown.next_doc: 0 }
- gt: { profile.shards.0.dfs.knn.query.0.breakdown.score_count: 0 }
- gt: { profile.shards.0.dfs.knn.query.0.breakdown.score: 0 }
- match: { profile.shards.0.dfs.knn.query.0.breakdown.compute_max_score_count: 0 }
- match: { profile.shards.0.dfs.knn.query.0.breakdown.compute_max_score: 0 }
- gt: { profile.shards.0.dfs.knn.query.0.breakdown.advance_count: 0 }
- gt: { profile.shards.0.dfs.knn.query.0.breakdown.advance: 0 }
- gt: { profile.shards.0.dfs.knn.query.0.breakdown.build_scorer_count: 0 }
- gt: { profile.shards.0.dfs.knn.query.0.breakdown.build_scorer: 0 }
- gt: { profile.shards.0.dfs.knn.query.0.breakdown.create_weight: 0 }
- gt: { profile.shards.0.dfs.knn.query.0.breakdown.create_weight_count: 0 }
- gt: { profile.shards.0.dfs.knn.rewrite_time: 0 }
- match: { profile.shards.0.dfs.knn.collector.0.name: "SimpleTopScoreDocCollector" }
- match: { profile.shards.0.dfs.knn.collector.0.reason: "search_top_hits" }
- gt: { profile.shards.0.dfs.knn.collector.0.time_in_nanos: 0 }

---
dfs without knn vector profiling:
- skip:
version: ' - 8.5.99'
reason: dfs profiling implemented in 8.6.0

- do:
indices.create:
index: keywords
body:
settings:
index.number_of_shards: 1
mappings:
properties:
keyword:
type: "keyword"
- do:
index:
index: keywords
id: "1"
refresh: true
body:
keyword: "a"

- do:
search:
index: keywords
search_type: dfs_query_then_fetch
body:
profile: true
query:
term:
keyword: "a"

- match: { hits.total.value: 1 }
- is_false: profile.shards.0.dfs

- do:
search:
index: keywords
search_type: query_then_fetch
body:
profile: true
query:
term:
keyword: "a"

- match: { hits.total.value: 1 }
- is_false: profile.shards.0.dfs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ public void run() throws IOException {
@Override
protected void innerOnResponse(QuerySearchResult response) {
try {
response.setSearchProfileDfsPhaseResult(dfsResult.searchProfileDfsPhaseResult());
counter.onResult(response);
} catch (Exception e) {
context.onPhaseFailure(DfsQueryPhase.this, "", e);
Expand Down
140 changes: 84 additions & 56 deletions server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,25 @@

import org.apache.lucene.index.Term;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.TopDocs;
import org.elasticsearch.index.query.ParsedQuery;
import org.apache.lucene.search.TopScoreDocCollector;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.profile.query.CollectorResult;
import org.elasticsearch.search.profile.query.InternalProfileCollector;
import org.elasticsearch.search.rescore.RescoreContext;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.tasks.TaskCancelledException;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
Expand All @@ -39,71 +42,96 @@ public class DfsPhase {

public void execute(SearchContext context) {
try {
Map<String, CollectionStatistics> fieldStatistics = new HashMap<>();
Map<Term, TermStatistics> stats = new HashMap<>();
IndexSearcher searcher = new IndexSearcher(context.searcher().getIndexReader()) {
@Override
public TermStatistics termStatistics(Term term, int docFreq, long totalTermFreq) throws IOException {
if (context.isCancelled()) {
throw new TaskCancelledException("cancelled");
}
TermStatistics ts = super.termStatistics(term, docFreq, totalTermFreq);
if (ts != null) {
stats.put(term, ts);
}
return ts;
}
collectStatistics(context);
executeKnnVectorQuery(context);
} catch (Exception e) {
throw new DfsPhaseExecutionException(context.shardTarget(), "Exception during dfs phase", e);
}
}

private void collectStatistics(SearchContext context) throws IOException {
Map<String, CollectionStatistics> fieldStatistics = new HashMap<>();
Map<Term, TermStatistics> stats = new HashMap<>();

@Override
public CollectionStatistics collectionStatistics(String field) throws IOException {
if (context.isCancelled()) {
throw new TaskCancelledException("cancelled");
}
CollectionStatistics cs = super.collectionStatistics(field);
if (cs != null) {
fieldStatistics.put(field, cs);
}
return cs;
IndexSearcher searcher = new IndexSearcher(context.searcher().getIndexReader()) {
@Override
public TermStatistics termStatistics(Term term, int docFreq, long totalTermFreq) throws IOException {
if (context.isCancelled()) {
throw new TaskCancelledException("cancelled");
}
TermStatistics ts = super.termStatistics(term, docFreq, totalTermFreq);
if (ts != null) {
stats.put(term, ts);
}
};
return ts;
}

searcher.createWeight(context.rewrittenQuery(), ScoreMode.COMPLETE, 1);
for (RescoreContext rescoreContext : context.rescore()) {
for (Query query : rescoreContext.getQueries()) {
searcher.createWeight(context.searcher().rewrite(query), ScoreMode.COMPLETE, 1);
@Override
public CollectionStatistics collectionStatistics(String field) throws IOException {
if (context.isCancelled()) {
throw new TaskCancelledException("cancelled");
}
CollectionStatistics cs = super.collectionStatistics(field);
if (cs != null) {
fieldStatistics.put(field, cs);
}
return cs;
}
};

Term[] terms = stats.keySet().toArray(new Term[0]);
TermStatistics[] termStatistics = new TermStatistics[terms.length];
for (int i = 0; i < terms.length; i++) {
termStatistics[i] = stats.get(terms[i]);
searcher.createWeight(context.rewrittenQuery(), ScoreMode.COMPLETE, 1);
for (RescoreContext rescoreContext : context.rescore()) {
for (Query query : rescoreContext.getQueries()) {
searcher.createWeight(context.searcher().rewrite(query), ScoreMode.COMPLETE, 1);
}
}

context.dfsResult()
.termsStatistics(terms, termStatistics)
.fieldStatistics(fieldStatistics)
.maxDoc(context.searcher().getIndexReader().maxDoc());
Term[] terms = stats.keySet().toArray(new Term[0]);
TermStatistics[] termStatistics = new TermStatistics[terms.length];
for (int i = 0; i < terms.length; i++) {
termStatistics[i] = stats.get(terms[i]);
}

// If kNN search is requested, perform kNN query and gather top docs
SearchSourceBuilder source = context.request().source();
if (source != null && source.knnSearch() != null) {
SearchExecutionContext searchExecutionContext = context.getSearchExecutionContext();
KnnSearchBuilder knnSearch = source.knnSearch();
context.dfsResult()
.termsStatistics(terms, termStatistics)
.fieldStatistics(fieldStatistics)
.maxDoc(context.searcher().getIndexReader().maxDoc());
}

KnnVectorQueryBuilder knnVectorQueryBuilder = knnSearch.toQueryBuilder();
if (context.request().getAliasFilter().getQueryBuilder() != null) {
knnVectorQueryBuilder.addFilterQuery(context.request().getAliasFilter().getQueryBuilder());
}
ParsedQuery query = searchExecutionContext.toQuery(knnVectorQueryBuilder);
private void executeKnnVectorQuery(SearchContext context) throws IOException {
SearchSourceBuilder source = context.request().source();
if (source == null || source.knnSearch() == null) {
return;
}

TopDocs topDocs = searcher.search(query.query(), knnSearch.k());
DfsKnnResults knnResults = new DfsKnnResults(topDocs.scoreDocs);
context.dfsResult().knnResults(knnResults);
}
} catch (Exception e) {
throw new DfsPhaseExecutionException(context.shardTarget(), "Exception during dfs phase", e);
SearchExecutionContext searchExecutionContext = context.getSearchExecutionContext();
KnnSearchBuilder knnSearch = context.request().source().knnSearch();
KnnVectorQueryBuilder knnVectorQueryBuilder = knnSearch.toQueryBuilder();

if (context.request().getAliasFilter().getQueryBuilder() != null) {
knnVectorQueryBuilder.addFilterQuery(context.request().getAliasFilter().getQueryBuilder());
}
}

Query query = searchExecutionContext.toQuery(knnVectorQueryBuilder).query();
TopScoreDocCollector topScoreDocCollector = TopScoreDocCollector.create(knnSearch.k(), Integer.MAX_VALUE);
Collector collector = topScoreDocCollector;

if (context.getProfilers() != null) {
InternalProfileCollector ipc = new InternalProfileCollector(
topScoreDocCollector,
CollectorResult.REASON_SEARCH_TOP_HITS,
List.of()
);
context.getProfilers().getCurrentQueryProfiler().setCollector(ipc);
collector = ipc;
}

context.searcher().search(query, collector);
DfsKnnResults knnResults = new DfsKnnResults(topScoreDocCollector.topDocs().scoreDocs);
context.dfsResult().knnResults(knnResults);

if (context.getProfilers() != null) {
context.dfsResult().profileResult(context.getProfilers().buildDfsPhaseResults());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.internal.ShardSearchContextId;
import org.elasticsearch.search.internal.ShardSearchRequest;
import org.elasticsearch.search.profile.SearchProfileDfsPhaseResult;

import java.io.IOException;
import java.util.HashMap;
Expand All @@ -33,6 +34,7 @@ public class DfsSearchResult extends SearchPhaseResult {
private Map<String, CollectionStatistics> fieldStatistics = new HashMap<>();
private DfsKnnResults knnResults;
private int maxDoc;
private SearchProfileDfsPhaseResult searchProfileDfsPhaseResult;

public DfsSearchResult(StreamInput in) throws IOException {
super(in);
Expand All @@ -56,6 +58,9 @@ public DfsSearchResult(StreamInput in) throws IOException {
if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
knnResults = in.readOptionalWriteable(DfsKnnResults::new);
}
if (in.getVersion().onOrAfter(Version.V_8_6_0)) {
searchProfileDfsPhaseResult = in.readOptionalWriteable(SearchProfileDfsPhaseResult::new);
}
}

public DfsSearchResult(ShardSearchContextId contextId, SearchShardTarget shardTarget, ShardSearchRequest shardSearchRequest) {
Expand Down Expand Up @@ -89,6 +94,11 @@ public DfsSearchResult knnResults(DfsKnnResults knnResults) {
return this;
}

public DfsSearchResult profileResult(SearchProfileDfsPhaseResult searchProfileDfsPhaseResult) {
this.searchProfileDfsPhaseResult = searchProfileDfsPhaseResult;
return this;
}

public Term[] terms() {
return terms;
}
Expand All @@ -105,6 +115,10 @@ public DfsKnnResults knnResults() {
return knnResults;
}

public SearchProfileDfsPhaseResult searchProfileDfsPhaseResult() {
return searchProfileDfsPhaseResult;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
contextId.writeTo(out);
Expand All @@ -121,6 +135,9 @@ public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
out.writeOptionalWriteable(knnResults);
}
if (out.getVersion().onOrAfter(Version.V_8_6_0)) {
out.writeOptionalWriteable(searchProfileDfsPhaseResult);
}
}

public static void writeFieldStats(StreamOutput out, Map<String, CollectionStatistics> fieldStatistics) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,19 @@ public static FetchProfiler startProfilingFetchPhase() {
return new FetchProfiler();
}

/**
* Build the results for the dfs phase.
*/
public SearchProfileDfsPhaseResult buildDfsPhaseResults() {
QueryProfiler queryProfiler = getCurrentQueryProfiler();
QueryProfileShardResult queryProfileShardResult = new QueryProfileShardResult(
queryProfiler.getTree(),
queryProfiler.getRewriteTime(),
queryProfiler.getCollector()
);
return new SearchProfileDfsPhaseResult(queryProfileShardResult);
}

/**
* Build the results for the query phase.
*/
Expand Down

0 comments on commit 94f05da

Please sign in to comment.