Skip to content

Commit

Permalink
Added inner_hits option to kNN search (#1777)
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Apr 15, 2024
1 parent b9c8343 commit 9ade575
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 0 deletions.
4 changes: 4 additions & 0 deletions elasticsearch_dsl/search_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ def knn(
boost=None,
filter=None,
similarity=None,
inner_hits=None,
):
"""
Add a k-nearest neighbor (kNN) search.
Expand All @@ -526,6 +527,7 @@ def knn(
:arg boost: A floating-point boost factor for kNN scores
:arg filter: query to filter the documents that can match
:arg similarity: the minimum similarity required for a document to be considered a match, as a float value
:arg inner_hits: retrieve hits from nested field
Example::
Expand Down Expand Up @@ -560,6 +562,8 @@ def knn(
s._knn[-1]["filter"] = filter
if similarity is not None:
s._knn[-1]["similarity"] = similarity
if inner_hits is not None:
s._knn[-1]["inner_hits"] = inner_hits
return s

def rank(self, rrf=None):
Expand Down
2 changes: 2 additions & 0 deletions tests/_async/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def test_knn():
query_vector_builder={
"text_embedding": {"model_id": "foo", "model_text": "search text"}
},
inner_hits={"size": 1},
)
assert {
"knn": [
Expand All @@ -283,6 +284,7 @@ def test_knn():
"text_embedding": {"model_id": "foo", "model_text": "search text"}
},
"boost": 0.8,
"inner_hits": {"size": 1},
},
]
} == s.to_dict()
Expand Down
2 changes: 2 additions & 0 deletions tests/_sync/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def test_knn():
query_vector_builder={
"text_embedding": {"model_id": "foo", "model_text": "search text"}
},
inner_hits={"size": 1},
)
assert {
"knn": [
Expand All @@ -283,6 +284,7 @@ def test_knn():
"text_embedding": {"model_id": "foo", "model_text": "search text"}
},
"boost": 0.8,
"inner_hits": {"size": 1},
},
]
} == s.to_dict()
Expand Down

0 comments on commit 9ade575

Please sign in to comment.