/
fastembed_document_embedder.py
175 lines (153 loc) · 6.95 KB
/
fastembed_document_embedder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from typing import Any, Dict, List, Optional
from haystack import Document, component, default_to_dict
from .embedding_backend.fastembed_backend import _FastembedEmbeddingBackendFactory
@component
class FastembedDocumentEmbedder:
"""
FastembedDocumentEmbedder computes Document embeddings using Fastembed embedding models.
The embedding of each Document is stored in the `embedding` field of the Document.
Usage example:
```python
# To use this component, install the "fastembed-haystack" package.
# pip install fastembed-haystack
from haystack_integrations.components.embedders.fastembed import FastembedDocumentEmbedder
from haystack.dataclasses import Document
doc_embedder = FastembedDocumentEmbedder(
model="BAAI/bge-small-en-v1.5",
batch_size=256,
)
doc_embedder.warm_up()
# Text taken from PubMed QA Dataset (https://huggingface.co/datasets/pubmed_qa)
document_list = [
Document(
content=("Oxidative stress generated within inflammatory joints can produce autoimmune phenomena and joint "
"destruction. Radical species with oxidative activity, including reactive nitrogen species, "
"represent mediators of inflammation and cartilage damage."),
meta={
"pubid": "25,445,628",
"long_answer": "yes",
},
),
Document(
content=("Plasma levels of pancreatic polypeptide (PP) rise upon food intake. Although other pancreatic "
"islet hormones, such as insulin and glucagon, have been extensively investigated, PP secretion "
"and actions are still poorly understood."),
meta={
"pubid": "25,445,712",
"long_answer": "yes",
},
),
]
result = doc_embedder.run(document_list)
print(f"Document Text: {result['documents'][0].content}")
print(f"Document Embedding: {result['documents'][0].embedding}")
print(f"Embedding Dimension: {len(result['documents'][0].embedding)}")
```
"""
def __init__(
self,
model: str = "BAAI/bge-small-en-v1.5",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
prefix: str = "",
suffix: str = "",
batch_size: int = 256,
progress_bar: bool = True,
parallel: Optional[int] = None,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
):
"""
Create an FastembedDocumentEmbedder component.
:param model: Local path or name of the model in Hugging Face's model hub,
such as `BAAI/bge-small-en-v1.5`.
:param cache_dir: The path to the cache directory.
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
:param threads: The number of threads single onnxruntime session can use. Defaults to None.
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
:param batch_size: Number of strings to encode at once.
:param progress_bar: If true, displays progress bar during embedding.
:param parallel:
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content.
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
"""
self.model_name = model
self.cache_dir = cache_dir
self.threads = threads
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
self.progress_bar = progress_bar
self.parallel = parallel
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
model=self.model_name,
cache_dir=self.cache_dir,
threads=self.threads,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
parallel=self.parallel,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
)
def warm_up(self):
"""
Initializes the component.
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(
model_name=self.model_name, cache_dir=self.cache_dir, threads=self.threads
)
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
texts_to_embed = []
for doc in documents:
meta_values_to_embed = [
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
]
text_to_embed = (
self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix
)
texts_to_embed.append(text_to_embed)
return texts_to_embed
@component.output_types(documents=List[Document])
def run(self, documents: List[Document]):
"""
Embeds a list of Documents.
:param documents: List of Documents to embed.
:returns: A dictionary with the following keys:
- `documents`: List of Documents with each Document's `embedding` field set to the computed embeddings.
"""
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
msg = (
"FastembedDocumentEmbedder expects a list of Documents as input. "
"In case you want to embed a list of strings, please use the FastembedTextEmbedder."
)
raise TypeError(msg)
if not hasattr(self, "embedding_backend"):
msg = "The embedding model has not been loaded. Please call warm_up() before running."
raise RuntimeError(msg)
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
embeddings = self.embedding_backend.embed(
texts_to_embed,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
parallel=self.parallel,
)
for doc, emb in zip(documents, embeddings):
doc.embedding = emb
return {"documents": documents}