Skip to content

Commit

Permalink
Integrate filtering support for ANN (#84734)
Browse files Browse the repository at this point in the history
This PR integrates support for ANN with filtering added in Lucene 9.1. It adds
a new `filter` section to the `_knn_search` endpoint, which accepts a query (in
the Elasticsearch query DSL). The value can either be a single query or a list
of queries, which matches the syntax we use for defining filter clauses in a
`bool` query.

Closes #81788.
  • Loading branch information
jtibshirani committed Mar 10, 2022
1 parent 9ec849b commit 15708d5
Show file tree
Hide file tree
Showing 12 changed files with 400 additions and 28 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/84734.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 84734
summary: Integrate filtering support for ANN
area: Search
type: enhancement
issues:
- 81788
108 changes: 107 additions & 1 deletion docs/reference/search/knn-search.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ GET my-index/_knn_search
"k": 10,
"num_candidates": 100
},
"_source": ["name", "date"]
"_source": ["name", "file_type"]
}
----
// TEST[continued]
Expand Down Expand Up @@ -122,6 +122,14 @@ shard, then merges them to find the top `k` results. Increasing
`num_candidates` tends to improve the accuracy of the final `k` results.
====

`filter`::
(Optional, <<query-dsl,Query DSL object>>) Query to filter the documents that
can match. The kNN search will return the top `k` documents that also match
this filter. The value can be a single query or a list of queries. If `filter`
is not provided, all documents are allowed to match.



include::{es-repo-dir}/search/search.asciidoc[tag=docvalue-fields-def]
include::{es-repo-dir}/search/search.asciidoc[tag=fields-param-def]
include::{es-repo-dir}/search/search.asciidoc[tag=source-filtering-def]
Expand All @@ -141,3 +149,101 @@ the similarity between the query and document vector. See
* The `hits.total` object contains the total number of nearest neighbor
candidates considered, which is `num_candidates * num_shards`. The
`hits.total.relation` will always be `eq`, indicating an exact value.

[[knn-search-api-example]]
==== {api-examples-title}

The following requests create a `dense_vector` field with indexing enabled and
add sample documents:

[source,console]
----
PUT my-index
{
"mappings": {
"properties": {
"image_vector": {
"type": "dense_vector",
"dims": 3,
"index": true,
"similarity": "l2_norm"
},
"name": {
"type": "keyword"
},
"file_type": {
"type": "keyword"
}
}
}
}
PUT my-index/_doc/1?refresh
{
"image_vector" : [0.5, 0.1, 2.6],
"name": "moose family",
"file_type": "jpeg"
}
PUT my-index/_doc/2?refresh
{
"image_vector" : [1.0, 0.8, -0.2],
"name": "alpine lake",
"file_type": "svg"
}
----

The next request performs a kNN search filtered by the `file_type` field:

[source,console]
----
GET my-index/_knn_search
{
"knn": {
"field": "image_vector",
"query_vector": [0.3, 0.1, 1.2],
"k": 5,
"num_candidates": 50
},
"filter": {
"term": {
"file_type": "svg"
}
},
"_source": ["name"]
}
----
// TEST[continued]

[source,console-result]
----
{
"took": 5,
"timed_out": false,
"_shards": {
"total": 1,
"successful": 1,
"skipped": 0,
"failed": 0
},
"hits": {
"total": {
"value": 1,
"relation": "eq"
},
"max_score": 0.2538071,
"hits": [
{
"_index": "my-index",
"_id": "2",
"_score": 0.2538071,
"_source": {
"name": "alpine lake"
}
}
]
}
}
----
// TESTRESPONSE[s/"took": 5/"took": $body.took/]
// TESTRESPONSE[s/,\n \.\.\.//]
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,14 @@ public String getName() {
return getWriteableName();
}

static void writeQueries(StreamOutput out, List<? extends QueryBuilder> queries) throws IOException {
protected static void writeQueries(StreamOutput out, List<? extends QueryBuilder> queries) throws IOException {
out.writeVInt(queries.size());
for (QueryBuilder query : queries) {
out.writeNamedWriteable(query);
}
}

static List<QueryBuilder> readQueries(StreamInput in) throws IOException {
protected static List<QueryBuilder> readQueries(StreamInput in) throws IOException {
int size = in.readVInt();
List<QueryBuilder> queries = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.index.query.WildcardQueryBuilder;
import org.elasticsearch.indices.IndicesRequestCache;
import org.elasticsearch.indices.TermsLookup;
import org.elasticsearch.join.ParentJoinPlugin;
Expand Down Expand Up @@ -889,8 +890,8 @@ public void testKnnSearch() throws Exception {
assertAcked(client().admin().indices().prepareCreate("test").setSettings(indexSettings).setMapping(builder));

for (int i = 0; i < 5; i++) {
client().prepareIndex("test").setSource("field1", "value1", "vector", new float[] { i, i, i }).get();
client().prepareIndex("test").setSource("field2", "value2", "vector", new float[] { i, i, i }).get();
client().prepareIndex("test").setSource("field1", "value1", "other", "valueA", "vector", new float[] { i, i, i }).get();
client().prepareIndex("test").setSource("field2", "value2", "other", "valueB", "vector", new float[] { i, i, i }).get();
}

client().admin().indices().prepareRefresh("test").get();
Expand All @@ -900,6 +901,10 @@ public void testKnnSearch() throws Exception {
float[] queryVector = new float[] { 0.0f, 0.0f, 0.0f };
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 50);

if (randomBoolean()) {
query.addFilterQuery(new WildcardQueryBuilder("other", "value*"));
}

// user1 should only be able to see docs with field1: value1
SearchResponse response = client().filterWithHeader(
Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ public void testQuery() {
.get();
assertHitCount(response, 1);

// user1 has no access to field1, so the query should not match with the document:
// user1 has no access to field2, so the query should not match with the document:
response = client().filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD)))
.prepareSearch("test")
.setQuery(matchQuery("field2", "value2"))
Expand Down Expand Up @@ -399,7 +399,7 @@ public void testKnnSearch() throws IOException {
assertAcked(client().admin().indices().prepareCreate("test").setMapping(builder));

client().prepareIndex("test")
.setSource("field1", "value1", "vector", new float[] { 0.0f, 0.0f, 0.0f })
.setSource("field1", "value1", "field2", "value2", "vector", new float[] { 0.0f, 0.0f, 0.0f })
.setRefreshPolicy(IMMEDIATE)
.get();

Expand Down Expand Up @@ -430,6 +430,26 @@ public void testKnnSearch() throws IOException {
.get();
assertHitCount(response, 1);
assertNull(response.getHits().getAt(0).field("vector"));

// user1 can access field1, so the filtered query should match with the document:
KnnVectorQueryBuilder filterQuery1 = new KnnVectorQueryBuilder("vector", queryVector, 10).addFilterQuery(
QueryBuilders.matchQuery("field1", "value1")
);
response = client().filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD)))
.prepareSearch("test")
.setQuery(filterQuery1)
.get();
assertHitCount(response, 1);

// user1 cannot access field2, so the filtered query should not match with the document:
KnnVectorQueryBuilder filterQuery2 = new KnnVectorQueryBuilder("vector", queryVector, 10).addFilterQuery(
QueryBuilders.matchQuery("field2", "value2")
);
response = client().filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD)))
.prepareSearch("test")
.setQuery(filterQuery2)
.get();
assertHitCount(response, 0);
}

public void testPercolateQueryWithIndexedDocWithFLS() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,44 @@ setup:
- match: {hits.hits.1._id: "3"}
- match: {hits.hits.1.fields.name.0: "rabbit.jpg"}

---
"kNN search with filter":
- do:
knn_search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [-0.5, 90.0, -10, 14.8, -156.0]
k: 2
num_candidates: 3
filter:
term:
name: "rabbit.jpg"

- match: {hits.total.value: 1}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}

- do:
knn_search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [-0.5, 90.0, -10, 14.8, -156.0]
k: 2
num_candidates: 3
filter:
- term:
name: "rabbit.jpg"
- term:
_id: 2

- match: {hits.total.value: 0}

---
"Test nonexistent field":
- do:
Expand All @@ -81,7 +119,6 @@ setup:
- do:
catch: bad_request
search:
rest_total_hits_as_int: true
index: test-index
body:
query:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.common.Strings;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.StoredFieldsContext;
Expand Down Expand Up @@ -37,11 +39,18 @@ class KnnSearchRequestBuilder {
static final String ROUTING_PARAM = "routing";

static final ParseField KNN_SECTION_FIELD = new ParseField("knn");
static final ParseField FILTER_FIELD = new ParseField("filter");
private static final ObjectParser<KnnSearchRequestBuilder, Void> PARSER;

static {
PARSER = new ObjectParser<>("knn-search");
PARSER.declareField(KnnSearchRequestBuilder::knnSearch, KnnSearch::parse, KNN_SECTION_FIELD, ObjectParser.ValueType.OBJECT);
PARSER.declareFieldArray(
KnnSearchRequestBuilder::filter,
(p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p),
FILTER_FIELD,
ObjectParser.ValueType.OBJECT_ARRAY
);
PARSER.declareField(
(p, request, c) -> request.fetchSource(FetchSourceContext.fromXContent(p)),
SearchSourceBuilder._SOURCE_FIELD,
Expand Down Expand Up @@ -86,6 +95,7 @@ static KnnSearchRequestBuilder parseRestRequest(RestRequest restRequest) throws
private final String[] indices;
private String routing;
private KnnSearch knnSearch;
private List<QueryBuilder> filters;

private FetchSourceContext fetchSource;
private List<FieldAndFormat> fields;
Expand All @@ -103,6 +113,10 @@ private void knnSearch(KnnSearch knnSearch) {
this.knnSearch = knnSearch;
}

private void filter(List<QueryBuilder> filter) {
this.filters = filter;
}

/**
* A comma separated list of routing values to control the shards the search will be executed on.
*/
Expand Down Expand Up @@ -152,17 +166,22 @@ public void build(SearchRequestBuilder builder) {
if (knnSearch == null) {
throw new IllegalArgumentException("missing required [" + KNN_SECTION_FIELD.getPreferredName() + "] section in search body");
}
knnSearch.build(sourceBuilder);

KnnVectorQueryBuilder queryBuilder = knnSearch.buildQuery();
if (filters != null) {
queryBuilder.addFilterQueries(this.filters);
}

sourceBuilder.query(queryBuilder);
sourceBuilder.size(knnSearch.k);

sourceBuilder.fetchSource(fetchSource);
sourceBuilder.storedFields(storedFields);

if (fields != null) {
for (FieldAndFormat field : fields) {
sourceBuilder.fetchField(field);
}
}

if (docValueFields != null) {
for (FieldAndFormat field : docValueFields) {
sourceBuilder.docValueField(field.field, field.format);
Expand Down Expand Up @@ -221,7 +240,7 @@ public static KnnSearch parse(XContentParser parser) throws IOException {
this.numCands = numCands;
}

void build(SearchSourceBuilder builder) {
public KnnVectorQueryBuilder buildQuery() {
// We perform validation here instead of the constructor because it makes the errors
// much clearer. Otherwise, the error message is deeply nested under parsing exceptions.
if (k < 1) {
Expand All @@ -236,8 +255,7 @@ void build(SearchSourceBuilder builder) {
throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
}

builder.query(new KnnVectorQueryBuilder(field, queryVector, numCands));
builder.size(k);
return new KnnVectorQueryBuilder(field, queryVector, numCands);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ public Query termQuery(Object value, SearchExecutionContext context) {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] doesn't support term queries");
}

public KnnVectorQuery createKnnQuery(float[] queryVector, int numCands) {
public KnnVectorQuery createKnnQuery(float[] queryVector, int numCands, Query filter) {
if (isIndexed() == false) {
throw new IllegalArgumentException(
"to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]"
Expand All @@ -321,7 +321,7 @@ public KnnVectorQuery createKnnQuery(float[] queryVector, int numCands) {
}
checkVectorMagnitude(queryVector, squaredMagnitude);
}
return new KnnVectorQuery(name(), queryVector, numCands);
return new KnnVectorQuery(name(), queryVector, numCands, filter);
}

private void checkVectorMagnitude(float[] vector, float squaredMagnitude) {
Expand Down

0 comments on commit 15708d5

Please sign in to comment.