Skip to content

Commit

Permalink
feat: introduce SparseEmbedding (#7382)
Browse files Browse the repository at this point in the history
* introduce SparseEmbedding

* reno

* add to pydoc config
  • Loading branch information
anakin87 committed Mar 19, 2024
1 parent 610ad6f commit dbfd351
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 7 deletions.
2 changes: 1 addition & 1 deletion docs/pydoc/config/data_classess_api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../haystack/dataclasses]
modules:
["answer", "byte_stream", "chat_message", "document", "streaming_chunk"]
["answer", "byte_stream", "chat_message", "document", "streaming_chunk", "sparse_embedding"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter
Expand Down
2 changes: 2 additions & 0 deletions haystack/dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from haystack.dataclasses.byte_stream import ByteStream
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
from haystack.dataclasses.document import Document
from haystack.dataclasses.sparse_embedding import SparseEmbedding
from haystack.dataclasses.streaming_chunk import StreamingChunk

__all__ = [
Expand All @@ -13,4 +14,5 @@
"ChatMessage",
"ChatRole",
"StreamingChunk",
"SparseEmbedding",
]
16 changes: 14 additions & 2 deletions haystack/dataclasses/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from haystack import logging
from haystack.dataclasses.byte_stream import ByteStream
from haystack.dataclasses.sparse_embedding import SparseEmbedding

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -57,7 +58,8 @@ class Document(metaclass=_BackwardCompatible):
:param blob: Binary data associated with the document, if the document has any binary data associated with it.
:param meta: Additional custom metadata for the document. Must be JSON-serializable.
:param score: Score of the document. Used for ranking, usually assigned by retrievers.
:param embedding: Vector representation of the document.
:param embedding: dense vector representation of the document.
:param sparse_embedding: sparse vector representation of the document.
"""

id: str = field(default="")
Expand All @@ -67,6 +69,7 @@ class Document(metaclass=_BackwardCompatible):
meta: Dict[str, Any] = field(default_factory=dict)
score: Optional[float] = field(default=None)
embedding: Optional[List[float]] = field(default=None)
sparse_embedding: Optional[SparseEmbedding] = field(default=None)

def __repr__(self):
fields = []
Expand All @@ -84,6 +87,8 @@ def __repr__(self):
fields.append(f"score: {self.score}")
if self.embedding is not None:
fields.append(f"embedding: vector of size {len(self.embedding)}")
if self.sparse_embedding is not None:
fields.append(f"sparse_embedding: vector with {len(self.sparse_embedding.indices)} non-zero elements")
fields_str = ", ".join(fields)
return f"{self.__class__.__name__}(id={self.id}, {fields_str})"

Expand Down Expand Up @@ -114,7 +119,8 @@ def _create_id(self):
mime_type = self.blob.mime_type if self.blob is not None else None
meta = self.meta or {}
embedding = self.embedding if self.embedding is not None else None
data = f"{text}{dataframe}{blob}{mime_type}{meta}{embedding}"
sparse_embedding = self.sparse_embedding.to_dict() if self.sparse_embedding is not None else ""
data = f"{text}{dataframe}{blob}{mime_type}{meta}{embedding}{sparse_embedding}"
return hashlib.sha256(data.encode("utf-8")).hexdigest()

def to_dict(self, flatten=True) -> Dict[str, Any]:
Expand All @@ -132,6 +138,9 @@ def to_dict(self, flatten=True) -> Dict[str, Any]:
if (blob := data.get("blob")) is not None:
data["blob"] = {"data": list(blob["data"]), "mime_type": blob["mime_type"]}

if (sparse_embedding := data.get("sparse_embedding")) is not None:
data["sparse_embedding"] = sparse_embedding.to_dict()

if flatten:
meta = data.pop("meta")
return {**data, **meta}
Expand All @@ -149,6 +158,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "Document":
data["dataframe"] = read_json(io.StringIO(dataframe))
if blob := data.get("blob"):
data["blob"] = ByteStream(data=bytes(blob["data"]), mime_type=blob["mime_type"])
if sparse_embedding := data.get("sparse_embedding"):
data["sparse_embedding"] = SparseEmbedding.from_dict(sparse_embedding)

# Store metadata for a moment while we try un-flattening allegedly flatten metadata.
# We don't expect both a `meta=` keyword and flatten metadata keys so we'll raise a
# ValueError later if this is the case.
Expand Down
26 changes: 26 additions & 0 deletions haystack/dataclasses/sparse_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import List


class SparseEmbedding:
"""
Class representing a sparse embedding.
"""

def __init__(self, indices: List[int], values: List[float]):
"""
:param indices: List of indices of non-zero elements in the embedding.
:param values: List of values of non-zero elements in the embedding.
:raises ValueError: If the indices and values lists are not of the same length.
"""
if len(indices) != len(values):
raise ValueError("Length of indices and values must be the same.")
self.indices = indices
self.values = values

def to_dict(self):
return {"indices": self.indices, "values": self.values}

@classmethod
def from_dict(cls, sparse_embedding_dict):
return cls(indices=sparse_embedding_dict["indices"], values=sparse_embedding_dict["values"])
7 changes: 7 additions & 0 deletions releasenotes/notes/sparse-embedding-fd55b670437492be.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
features:
- |
Introduce a new `SparseEmbedding` class which can be used to store a sparse
vector representation of a Document.
It will be instrumental to support Sparse Embedding Retrieval with
the subsequent introduction of Sparse Embedders and Sparse Embedding Retrievers.
19 changes: 18 additions & 1 deletion test/dataclasses/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from haystack import Document
from haystack.dataclasses.byte_stream import ByteStream
from haystack.dataclasses.sparse_embedding import SparseEmbedding


@pytest.mark.parametrize(
Expand Down Expand Up @@ -37,6 +38,7 @@ def test_init():
assert doc.meta == {}
assert doc.score == None
assert doc.embedding == None
assert doc.sparse_embedding == None


def test_init_with_wrong_parameters():
Expand All @@ -46,15 +48,17 @@ def test_init_with_wrong_parameters():

def test_init_with_parameters():
blob_data = b"some bytes"
sparse_embedding = SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3])
doc = Document(
content="test text",
dataframe=pd.DataFrame([0]),
blob=ByteStream(data=blob_data, mime_type="text/markdown"),
meta={"text": "test text"},
score=0.812,
embedding=[0.1, 0.2, 0.3],
sparse_embedding=sparse_embedding,
)
assert doc.id == "ec92455f3f4576d40031163c89b1b4210b34ea1426ee0ff68ebed86cb7ba13f8"
assert doc.id == "967b7bd4a21861ad9e863f638cefcbdd6bf6306bebdd30aa3fedf0c26bc636ed"
assert doc.content == "test text"
assert doc.dataframe is not None
assert doc.dataframe.equals(pd.DataFrame([0]))
Expand All @@ -63,6 +67,7 @@ def test_init_with_parameters():
assert doc.meta == {"text": "test text"}
assert doc.score == 0.812
assert doc.embedding == [0.1, 0.2, 0.3]
assert doc.sparse_embedding == sparse_embedding


def test_init_with_legacy_fields():
Expand All @@ -76,6 +81,7 @@ def test_init_with_legacy_fields():
assert doc.meta == {}
assert doc.score == 0.812
assert doc.embedding == [0.1, 0.2, 0.3]
assert doc.sparse_embedding == None


def test_init_with_legacy_field():
Expand All @@ -93,6 +99,7 @@ def test_init_with_legacy_field():
assert doc.meta == {"date": "10-10-2023", "type": "article"}
assert doc.score == 0.812
assert doc.embedding == [0.1, 0.2, 0.3]
assert doc.sparse_embedding == None


def test_basic_equality_type_mismatch():
Expand Down Expand Up @@ -121,6 +128,7 @@ def test_to_dict():
"blob": None,
"score": None,
"embedding": None,
"sparse_embedding": None,
}


Expand All @@ -134,6 +142,7 @@ def test_to_dict_without_flattening():
"meta": {},
"score": None,
"embedding": None,
"sparse_embedding": None,
}


Expand All @@ -145,6 +154,7 @@ def test_to_dict_with_custom_parameters():
meta={"some": "values", "test": 10},
score=0.99,
embedding=[10.0, 10.0],
sparse_embedding=SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3]),
)

assert doc.to_dict() == {
Expand All @@ -156,6 +166,7 @@ def test_to_dict_with_custom_parameters():
"test": 10,
"score": 0.99,
"embedding": [10.0, 10.0],
"sparse_embedding": {"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]},
}


Expand All @@ -167,6 +178,7 @@ def test_to_dict_with_custom_parameters_without_flattening():
meta={"some": "values", "test": 10},
score=0.99,
embedding=[10.0, 10.0],
sparse_embedding=SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3]),
)

assert doc.to_dict(flatten=False) == {
Expand All @@ -177,6 +189,7 @@ def test_to_dict_with_custom_parameters_without_flattening():
"meta": {"some": "values", "test": 10},
"score": 0.99,
"embedding": [10, 10],
"sparse_embedding": {"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]},
}


Expand All @@ -194,6 +207,7 @@ def from_from_dict_with_parameters():
"meta": {"text": "test text"},
"score": 0.812,
"embedding": [0.1, 0.2, 0.3],
"sparse_embedding": {"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]},
}
) == Document(
content="test text",
Expand All @@ -202,6 +216,7 @@ def from_from_dict_with_parameters():
meta={"text": "test text"},
score=0.812,
embedding=[0.1, 0.2, 0.3],
sparse_embedding=SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3]),
)


Expand Down Expand Up @@ -249,6 +264,7 @@ def test_from_dict_with_flat_meta():
"blob": {"data": list(blob_data), "mime_type": "text/markdown"},
"score": 0.812,
"embedding": [0.1, 0.2, 0.3],
"sparse_embedding": {"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]},
"date": "10-10-2023",
"type": "article",
}
Expand All @@ -258,6 +274,7 @@ def test_from_dict_with_flat_meta():
blob=ByteStream(blob_data, mime_type="text/markdown"),
score=0.812,
embedding=[0.1, 0.2, 0.3],
sparse_embedding=SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3]),
meta={"date": "10-10-2023", "type": "article"},
)

Expand Down
23 changes: 23 additions & 0 deletions test/dataclasses/test_sparse_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

from haystack.dataclasses.sparse_embedding import SparseEmbedding


class TestSparseEmbedding:
def test_init(self):
se = SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3])
assert se.indices == [0, 2, 4]
assert se.values == [0.1, 0.2, 0.3]

def test_init_with_wrong_parameters(self):
with pytest.raises(ValueError):
SparseEmbedding(indices=[0, 2], values=[0.1, 0.2, 0.3, 0.4])

def test_to_dict(self):
se = SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3])
assert se.to_dict() == {"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]}

def test_from_dict(self):
se = SparseEmbedding.from_dict({"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]})
assert se.indices == [0, 2, 4]
assert se.values == [0.1, 0.2, 0.3]
6 changes: 3 additions & 3 deletions test/tracing/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ class TestTypeCoercion:
(NonSerializableClass(), "NonSerializableClass"),
(
Document(id="1", content="text"),
'{"id": "1", "content": "text", "dataframe": null, "blob": null, "score": null, "embedding": null}',
'{"id": "1", "content": "text", "dataframe": null, "blob": null, "score": null, "embedding": null, "sparse_embedding": null}',
),
(
[Document(id="1", content="text")],
'[{"id": "1", "content": "text", "dataframe": null, "blob": null, "score": null, "embedding": null}]',
'[{"id": "1", "content": "text", "dataframe": null, "blob": null, "score": null, "embedding": null, "sparse_embedding": null}]',
),
(
{"key": Document(id="1", content="text")},
'{"key": {"id": "1", "content": "text", "dataframe": null, "blob": null, "score": null, "embedding": null}}',
'{"key": {"id": "1", "content": "text", "dataframe": null, "blob": null, "score": null, "embedding": null, "sparse_embedding": null}}',
),
],
)
Expand Down

0 comments on commit dbfd351

Please sign in to comment.