Skip to content

Commit

Permalink
rename QdrantSparseRetriever to QdrantSparseEmbeddingRetriever
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Apr 23, 2024
1 parent 70df967 commit 7e29719
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0

from .retriever import QdrantEmbeddingRetriever, QdrantSparseRetriever
from .retriever import QdrantEmbeddingRetriever, QdrantSparseEmbeddingRetriever

__all__ = ("QdrantEmbeddingRetriever", "QdrantSparseRetriever")
__all__ = ("QdrantEmbeddingRetriever", "QdrantSparseEmbeddingRetriever")
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,13 @@ def run(


@component
class QdrantSparseRetriever:
class QdrantSparseEmbeddingRetriever:
"""
A component for retrieving documents from an QdrantDocumentStore using sparse vectors.
Usage example:
```python
from haystack_integrations.components.retrievers.qdrant import QdrantSparseRetriever
from haystack_integrations.components.retrievers.qdrant import QdrantSparseEmbeddingRetriever
from haystack_integrations.document_stores.qdrant import QdrantDocumentStore
from haystack.dataclasses.sparse_embedding import SparseEmbedding
Expand All @@ -140,7 +140,7 @@ class QdrantSparseRetriever:
return_embedding=True,
wait_result_from_api=True,
)
retriever = QdrantSparseRetriever(document_store=document_store)
retriever = QdrantSparseEmbeddingRetriever(document_store=document_store)
sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33])
retriever.run(query_sparse_embedding=sparse_embedding)
```
Expand All @@ -155,7 +155,7 @@ def __init__(
return_embedding: bool = False,
):
"""
Create a QdrantSparseRetriever component.
Create a QdrantSparseEmbeddingRetriever component.
:param document_store: An instance of QdrantDocumentStore.
:param filters: A dictionary with filters to narrow down the search space. Default is None.
Expand Down
16 changes: 8 additions & 8 deletions integrations/qdrant/tests/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)
from haystack_integrations.components.retrievers.qdrant import (
QdrantEmbeddingRetriever,
QdrantSparseRetriever,
QdrantSparseEmbeddingRetriever,
)
from haystack_integrations.document_stores.qdrant import QdrantDocumentStore

Expand Down Expand Up @@ -135,21 +135,21 @@ def test_run_with_sparse_activated(self, filterable_docs: List[Document]):
assert document.embedding is None


class TestQdrantSparseRetriever(FilterableDocsFixtureMixin):
class TestQdrantSparseEmbeddingRetriever(FilterableDocsFixtureMixin):
def test_init_default(self):
document_store = QdrantDocumentStore(location=":memory:", index="test")
retriever = QdrantSparseRetriever(document_store=document_store)
retriever = QdrantSparseEmbeddingRetriever(document_store=document_store)
assert retriever._document_store == document_store
assert retriever._filters is None
assert retriever._top_k == 10
assert retriever._return_embedding is False

def test_to_dict(self):
document_store = QdrantDocumentStore(location=":memory:", index="test")
retriever = QdrantSparseRetriever(document_store=document_store)
retriever = QdrantSparseEmbeddingRetriever(document_store=document_store)
res = retriever.to_dict()
assert res == {
"type": "haystack_integrations.components.retrievers.qdrant.retriever.QdrantSparseRetriever",
"type": "haystack_integrations.components.retrievers.qdrant.retriever.QdrantSparseEmbeddingRetriever",
"init_parameters": {
"document_store": {
"type": "haystack_integrations.document_stores.qdrant.document_store.QdrantDocumentStore",
Expand Down Expand Up @@ -202,7 +202,7 @@ def test_to_dict(self):

def test_from_dict(self):
data = {
"type": "haystack_integrations.components.retrievers.qdrant.retriever.QdrantSparseRetriever",
"type": "haystack_integrations.components.retrievers.qdrant.retriever.QdrantSparseEmbeddingRetriever",
"init_parameters": {
"document_store": {
"init_parameters": {"location": ":memory:", "index": "test"},
Expand All @@ -214,7 +214,7 @@ def test_from_dict(self):
"return_embedding": True,
},
}
retriever = QdrantSparseRetriever.from_dict(data)
retriever = QdrantSparseEmbeddingRetriever.from_dict(data)
assert isinstance(retriever._document_store, QdrantDocumentStore)
assert retriever._document_store.index == "test"
assert retriever._filters is None
Expand All @@ -241,7 +241,7 @@ def test_run(self, filterable_docs: List[Document]):
doc.sparse_embedding = SparseEmbedding.from_dict(self._generate_mocked_sparse_embedding(1)[0])

document_store.write_documents(filterable_docs)
retriever = QdrantSparseRetriever(document_store=document_store)
retriever = QdrantSparseEmbeddingRetriever(document_store=document_store)
sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33])

results: List[Document] = retriever.run(query_sparse_embedding=sparse_embedding)["documents"]
Expand Down

0 comments on commit 7e29719

Please sign in to comment.