Skip to content

Commit

Permalink
Add a knn method to elasticsearch_dsl.search.Search (#1691)
Browse files Browse the repository at this point in the history
* Add a `knn` method to `elasticsearch_dsl.search.Search`

* add knn's boost option
  • Loading branch information
miguelgrinberg committed Jan 3, 2024
1 parent f0c5045 commit baed085
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 1 deletion.
27 changes: 27 additions & 0 deletions docs/search_dsl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ The ``Search`` object represents the entire search request:

* aggregations

* k-nearest neighbor searches

* sort

* pagination
Expand Down Expand Up @@ -352,6 +354,31 @@ As opposed to other methods on the ``Search`` objects, defining aggregations is
done in-place (does not return a copy).


K-Nearest Neighbor Searches
~~~~~~~~~~~~~~~~~~~~~~~~~~~

To issue a kNN search, use the ``.knn()`` method:

.. code:: python
s = Search()
vector = get_embedding("search text")
s = s.knn(
field="embedding",
k=5,
num_candidates=10,
query_vector=vector
)
The ``field``, ``k`` and ``num_candidates`` arguments can be given as
positional or keyword arguments and are required. In addition to these,
``query_vector`` or ``query_vector_builder`` must be given as well.

The ``.knn()`` method can be invoked multiple times to include multiple kNN
searches in the request.


Sorting
~~~~~~~

Expand Down
72 changes: 71 additions & 1 deletion elasticsearch_dsl/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .aggs import A, AggBase
from .connections import get_connection
from .exceptions import IllegalOperation
from .query import Bool, Q
from .query import Bool, Q, Query
from .response import Hit, Response
from .utils import AttrDict, DslBase, recursive_to_dict

Expand Down Expand Up @@ -319,6 +319,7 @@ def __init__(self, **kwargs):
self.aggs = AggsProxy(self)
self._sort = []
self._collapse = {}
self._knn = []
self._source = None
self._highlight = {}
self._highlight_opts = {}
Expand Down Expand Up @@ -406,6 +407,7 @@ def _clone(self):
s = super()._clone()

s._response_class = self._response_class
s._knn = [knn.copy() for knn in self._knn]
s._collapse = self._collapse.copy()
s._sort = self._sort[:]
s._source = copy.copy(self._source) if self._source is not None else None
Expand Down Expand Up @@ -445,6 +447,10 @@ def update_from_dict(self, d):
self.aggs._params = {
"aggs": {name: A(value) for (name, value) in aggs.items()}
}
if "knn" in d:
self._knn = d.pop("knn")
if isinstance(self._knn, dict):
self._knn = [self._knn]
if "collapse" in d:
self._collapse = d.pop("collapse")
if "sort" in d:
Expand Down Expand Up @@ -494,6 +500,64 @@ def script_fields(self, **kwargs):
s._script_fields.update(kwargs)
return s

def knn(
self,
field,
k,
num_candidates,
query_vector=None,
query_vector_builder=None,
boost=None,
filter=None,
similarity=None,
):
"""
Add a k-nearest neighbor (kNN) search.
:arg field: the name of the vector field to search against
:arg k: number of nearest neighbors to return as top hits
:arg num_candidates: number of nearest neighbor candidates to consider per shard
:arg query_vector: the vector to search for
:arg query_vector_builder: A dictionary indicating how to build a query vector
:arg boost: A floating-point boost factor for kNN scores
:arg filter: query to filter the documents that can match
:arg similarity: the minimum similarity required for a document to be considered a match, as a float value
Example::
s = Search()
s = s.knn(field='embedding', k=5, num_candidates=10, query_vector=vector,
filter=Q('term', category='blog')))
"""
s = self._clone()
s._knn.append(
{
"field": field,
"k": k,
"num_candidates": num_candidates,
}
)
if query_vector is None and query_vector_builder is None:
raise ValueError("one of query_vector and query_vector_builder is required")
if query_vector is not None and query_vector_builder is not None:
raise ValueError(
"only one of query_vector and query_vector_builder must be given"
)
if query_vector is not None:
s._knn[-1]["query_vector"] = query_vector
if query_vector_builder is not None:
s._knn[-1]["query_vector_builder"] = query_vector_builder
if boost is not None:
s._knn[-1]["boost"] = boost
if filter is not None:
if isinstance(filter, Query):
s._knn[-1]["filter"] = filter.to_dict()
else:
s._knn[-1]["filter"] = filter
if similarity is not None:
s._knn[-1]["similarity"] = similarity
return s

def source(self, fields=None, **kwargs):
"""
Selectively control how the _source field is returned.
Expand Down Expand Up @@ -677,6 +741,12 @@ def to_dict(self, count=False, **kwargs):
if self.query:
d["query"] = self.query.to_dict()

if self._knn:
if len(self._knn) == 1:
d["knn"] = self._knn[0]
else:
d["knn"] = self._knn

# count request doesn't care for sorting and other things
if not count:
if self.post_filter:
Expand Down
54 changes: 54 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,60 @@ class MyDocument(Document):
assert s._doc_type_map == {}


def test_knn():
s = search.Search()

with raises(TypeError):
s.knn()
with raises(TypeError):
s.knn("field")
with raises(TypeError):
s.knn("field", 5)
with raises(ValueError):
s.knn("field", 5, 100)
with raises(ValueError):
s.knn("field", 5, 100, query_vector=[1, 2, 3], query_vector_builder={})

s = s.knn("field", 5, 100, query_vector=[1, 2, 3])
assert {
"knn": {
"field": "field",
"k": 5,
"num_candidates": 100,
"query_vector": [1, 2, 3],
}
} == s.to_dict()

s = s.knn(
k=4,
num_candidates=40,
boost=0.8,
field="name",
query_vector_builder={
"text_embedding": {"model_id": "foo", "model_text": "search text"}
},
)
assert {
"knn": [
{
"field": "field",
"k": 5,
"num_candidates": 100,
"query_vector": [1, 2, 3],
},
{
"field": "name",
"k": 4,
"num_candidates": 40,
"query_vector_builder": {
"text_embedding": {"model_id": "foo", "model_text": "search text"}
},
"boost": 0.8,
},
]
} == s.to_dict()


def test_sort():
s = search.Search()
s = s.sort("fielda", "-fieldb")
Expand Down

0 comments on commit baed085

Please sign in to comment.