-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
document_store.py
338 lines (294 loc) · 15.2 KB
/
document_store.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
import re
from typing import Any, Dict, Iterable, List, Literal, Optional
import numpy as np
from haystack_bm25 import rank_bm25
from tqdm.auto import tqdm
from haystack import default_from_dict, default_to_dict, logging
from haystack.dataclasses import Document
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
from haystack.document_stores.types import DuplicatePolicy
from haystack.utils import expit
from haystack.utils.filters import convert, document_matches_filter
logger = logging.getLogger(__name__)
# document scores are essentially unbounded and will be scaled to values between 0 and 1 if scale_score is set to
# True (default). Scaling uses the expit function (inverse of the logit function) after applying a scaling factor
# (e.g., BM25_SCALING_FACTOR for the bm25_retrieval method).
# Larger scaling factor decreases scaled scores. For example, an input of 10 is scaled to 0.99 with BM25_SCALING_FACTOR=2
# but to 0.78 with BM25_SCALING_FACTOR=8 (default). The defaults were chosen empirically. Increase the default if most
# unscaled scores are larger than expected (>30) and otherwise would incorrectly all be mapped to scores ~1.
BM25_SCALING_FACTOR = 8
DOT_PRODUCT_SCALING_FACTOR = 100
class InMemoryDocumentStore:
"""
Stores data in-memory. It's ephemeral and cannot be saved to disk.
"""
def __init__(
self,
bm25_tokenization_regex: str = r"(?u)\b\w\w+\b",
bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25L",
bm25_parameters: Optional[Dict] = None,
embedding_similarity_function: Literal["dot_product", "cosine"] = "dot_product",
):
"""
Initializes the DocumentStore.
:param bm25_tokenization_regex: The regular expression used to tokenize the text for BM25 retrieval.
:param bm25_algorithm: The BM25 algorithm to use. One of "BM25Okapi", "BM25L", or "BM25Plus".
:param bm25_parameters: Parameters for BM25 implementation in a dictionary format.
For example: {'k1':1.5, 'b':0.75, 'epsilon':0.25}
You can learn more about these parameters by visiting https://github.com/dorianbrown/rank_bm25.
By default, no parameters are set.
:param embedding_similarity_function: The similarity function used to compare Documents embeddings.
One of "dot_product" (default) or "cosine".
To choose the most appropriate function, look for information about your embedding model.
"""
self.storage: Dict[str, Document] = {}
self._bm25_tokenization_regex = bm25_tokenization_regex
self.tokenizer = re.compile(bm25_tokenization_regex).findall
algorithm_class = getattr(rank_bm25, bm25_algorithm)
if algorithm_class is None:
raise ValueError(f"BM25 algorithm '{bm25_algorithm}' not found.")
self.bm25_algorithm = algorithm_class
self.bm25_parameters = bm25_parameters or {}
self.embedding_similarity_function = embedding_similarity_function
def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
bm25_tokenization_regex=self._bm25_tokenization_regex,
bm25_algorithm=self.bm25_algorithm.__name__,
bm25_parameters=self.bm25_parameters,
embedding_similarity_function=self.embedding_similarity_function,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "InMemoryDocumentStore":
"""
Deserializes the component from a dictionary.
:param data:
The dictionary to deserialize from.
:returns:
The deserialized component.
"""
return default_from_dict(cls, data)
def count_documents(self) -> int:
"""
Returns the number of how many documents are present in the DocumentStore.
"""
return len(self.storage.keys())
def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
"""
Returns the documents that match the filters provided.
For a detailed specification of the filters, refer to the DocumentStore.filter_documents() protocol documentation.
:param filters: The filters to apply to the document list.
:returns: A list of Documents that match the given filters.
"""
if filters:
if "operator" not in filters and "conditions" not in filters:
filters = convert(filters)
return [doc for doc in self.storage.values() if document_matches_filter(filters=filters, document=doc)]
return list(self.storage.values())
def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int:
"""
Refer to the DocumentStore.write_documents() protocol documentation.
If `policy` is set to `DuplicatePolicy.NONE` defaults to `DuplicatePolicy.FAIL`.
"""
if (
not isinstance(documents, Iterable)
or isinstance(documents, str)
or any(not isinstance(doc, Document) for doc in documents)
):
raise ValueError("Please provide a list of Documents.")
if policy == DuplicatePolicy.NONE:
policy = DuplicatePolicy.FAIL
written_documents = len(documents)
for document in documents:
if policy != DuplicatePolicy.OVERWRITE and document.id in self.storage.keys():
if policy == DuplicatePolicy.FAIL:
raise DuplicateDocumentError(f"ID '{document.id}' already exists.")
if policy == DuplicatePolicy.SKIP:
logger.warning("ID '{document_id}' already exists", document_id=document.id)
written_documents -= 1
continue
self.storage[document.id] = document
return written_documents
def delete_documents(self, document_ids: List[str]) -> None:
"""
Deletes all documents with matching document_ids from the DocumentStore.
:param document_ids: The object_ids to delete.
"""
for doc_id in document_ids:
if doc_id not in self.storage.keys():
continue
del self.storage[doc_id]
def bm25_retrieval(
self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, scale_score: bool = False
) -> List[Document]:
"""
Retrieves documents that are most relevant to the query using BM25 algorithm.
:param query: The query string.
:param filters: A dictionary with filters to narrow down the search space.
:param top_k: The number of top documents to retrieve. Default is 10.
:param scale_score: Whether to scale the scores of the retrieved documents. Default is False.
:returns: A list of the top_k documents most relevant to the query.
"""
if not query:
raise ValueError("Query should be a non-empty string")
content_type_filter = {
"operator": "OR",
"conditions": [
{"field": "content", "operator": "!=", "value": None},
{"field": "dataframe", "operator": "!=", "value": None},
],
}
if filters:
if "operator" not in filters:
filters = convert(filters)
filters = {"operator": "AND", "conditions": [content_type_filter, filters]}
else:
filters = content_type_filter
all_documents = self.filter_documents(filters=filters)
# Lowercase all documents
lower_case_documents = []
for doc in all_documents:
if doc.content is None and doc.dataframe is None:
logger.info(
"Document '{document_id}' has no text or dataframe content. Skipping it.", document_id=doc.id
)
else:
if doc.content is not None:
lower_case_documents.append(doc.content.lower())
if doc.dataframe is not None:
logger.warning(
"Document '{document_id}' has both text and dataframe content. "
"Using text content and skipping dataframe content.",
document_id=doc.id,
)
continue
if doc.dataframe is not None:
str_content = doc.dataframe.astype(str)
csv_content = str_content.to_csv(index=False)
lower_case_documents.append(csv_content.lower())
# Tokenize the entire content of the DocumentStore
tokenized_corpus = [
self.tokenizer(doc) for doc in tqdm(lower_case_documents, unit=" docs", desc="Ranking by BM25...")
]
if len(tokenized_corpus) == 0:
logger.info("No documents found for BM25 retrieval. Returning empty list.")
return []
# initialize BM25
bm25_scorer = self.bm25_algorithm(tokenized_corpus, **self.bm25_parameters)
# tokenize query
tokenized_query = self.tokenizer(query.lower())
# get scores for the query against the corpus
docs_scores = bm25_scorer.get_scores(tokenized_query)
if scale_score:
docs_scores = [expit(float(score / BM25_SCALING_FACTOR)) for score in docs_scores]
# get the last top_k indexes and reverse them
top_docs_positions = np.argsort(docs_scores)[-top_k:][::-1]
# BM25Okapi can return meaningful negative values, so they should not be filtered out when scale_score is False.
# It's the only algorithm supported by rank_bm25 at the time of writing (2024) that can return negative scores.
# see https://github.com/deepset-ai/haystack/pull/6889 for more context.
negatives_are_valid = self.bm25_algorithm is rank_bm25.BM25Okapi and not scale_score
# Create documents with the BM25 score to return them
return_documents = []
for i in top_docs_positions:
doc = all_documents[i]
score = docs_scores[i]
if not negatives_are_valid and score <= 0.0:
continue
doc_fields = doc.to_dict()
doc_fields["score"] = score
return_document = Document.from_dict(doc_fields)
return_documents.append(return_document)
return return_documents
def embedding_retrieval(
self,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
scale_score: bool = False,
return_embedding: bool = False,
) -> List[Document]:
"""
Retrieves documents that are most similar to the query embedding using a vector similarity metric.
:param query_embedding: Embedding of the query.
:param filters: A dictionary with filters to narrow down the search space.
:param top_k: The number of top documents to retrieve. Default is 10.
:param scale_score: Whether to scale the scores of the retrieved Documents. Default is False.
:param return_embedding: Whether to return the embedding of the retrieved Documents. Default is False.
:returns: A list of the top_k documents most relevant to the query.
"""
if len(query_embedding) == 0 or not isinstance(query_embedding[0], float):
raise ValueError("query_embedding should be a non-empty list of floats.")
filters = filters or {}
all_documents = self.filter_documents(filters=filters)
documents_with_embeddings = [doc for doc in all_documents if doc.embedding is not None]
if len(documents_with_embeddings) == 0:
logger.warning(
"No Documents found with embeddings. Returning empty list. "
"To generate embeddings, use a DocumentEmbedder."
)
return []
elif len(documents_with_embeddings) < len(all_documents):
logger.info(
"Skipping some Documents that don't have an embedding. "
"To generate embeddings, use a DocumentEmbedder."
)
scores = self._compute_query_embedding_similarity_scores(
embedding=query_embedding, documents=documents_with_embeddings, scale_score=scale_score
)
# create Documents with the similarity score for the top k results
top_documents = []
for doc, score in sorted(zip(documents_with_embeddings, scores), key=lambda x: x[1], reverse=True)[:top_k]:
doc_fields = doc.to_dict()
doc_fields["score"] = score
if return_embedding is False:
doc_fields["embedding"] = None
top_documents.append(Document.from_dict(doc_fields))
return top_documents
def _compute_query_embedding_similarity_scores(
self, embedding: List[float], documents: List[Document], scale_score: bool = False
) -> List[float]:
"""
Computes the similarity scores between the query embedding and the embeddings of the documents.
:param embedding: Embedding of the query.
:param documents: A list of Documents.
:param scale_score: Whether to scale the scores of the Documents. Default is False.
:returns: A list of scores.
"""
query_embedding = np.array(embedding)
if query_embedding.ndim == 1:
query_embedding = np.expand_dims(a=query_embedding, axis=0)
try:
document_embeddings = np.array([doc.embedding for doc in documents])
except ValueError as e:
if "inhomogeneous shape" in str(e):
raise DocumentStoreError(
"The embedding size of all Documents should be the same. "
"Please make sure that the Documents have been embedded with the same model."
) from e
raise e
if document_embeddings.ndim == 1:
document_embeddings = np.expand_dims(a=document_embeddings, axis=0)
if self.embedding_similarity_function == "cosine":
# cosine similarity is a normed dot product
query_embedding /= np.linalg.norm(x=query_embedding, axis=1, keepdims=True)
document_embeddings /= np.linalg.norm(x=document_embeddings, axis=1, keepdims=True)
try:
scores = np.dot(a=query_embedding, b=document_embeddings.T)[0].tolist()
except ValueError as e:
if "shapes" in str(e) and "not aligned" in str(e):
raise DocumentStoreError(
"The embedding size of the query should be the same as the embedding size of the Documents. "
"Please make sure that the query has been embedded with the same model as the Documents."
) from e
raise e
if scale_score:
if self.embedding_similarity_function == "dot_product":
scores = [expit(float(score / DOT_PRODUCT_SCALING_FACTOR)) for score in scores]
elif self.embedding_similarity_function == "cosine":
scores = [(score + 1) / 2 for score in scores]
return scores