-
Notifications
You must be signed in to change notification settings - Fork 62
/
document_store.py
497 lines (439 loc) · 21.9 KB
/
document_store.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import base64
import datetime
import json
import logging
from dataclasses import asdict
from typing import Any, Dict, List, Optional
from haystack.core.serialization import default_from_dict, default_to_dict
from haystack.dataclasses.document import Document
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
from haystack.document_stores.types.policy import DuplicatePolicy
import weaviate
from weaviate.collections.classes.data import DataObject
from weaviate.config import AdditionalConfig
from weaviate.embedded import EmbeddedOptions
from weaviate.util import generate_uuid5
from ._filters import convert_filters
from .auth import AuthCredentials
logger = logging.getLogger(__name__)
# This is the default collection properties for Weaviate.
# It's a list of properties that will be created on the collection.
# These are extremely similar to the Document dataclass, but with a few differences:
# - `id` is renamed to `_original_id` as the `id` field is reserved by Weaviate.
# - `blob` is split into `blob_data` and `blob_mime_type` as it's more efficient to store them separately.
# Blob meta is missing as it's not usually serialized when saving a Document as we rely on the Document own meta.
#
# Also the Document `meta` fields are omitted as we can't make assumptions on the structure of the meta field.
# We recommend the user to create a proper collection with the correct meta properties for their use case.
# We mostly rely on these defaults for testing purposes using Weaviate automatic schema generation, but that's not
# recommended for production use.
DOCUMENT_COLLECTION_PROPERTIES = [
{"name": "_original_id", "dataType": ["text"]},
{"name": "content", "dataType": ["text"]},
{"name": "dataframe", "dataType": ["text"]},
{"name": "blob_data", "dataType": ["blob"]},
{"name": "blob_mime_type", "dataType": ["text"]},
{"name": "score", "dataType": ["number"]},
]
# This is the default limit used when querying documents with WeaviateDocumentStore.
#
# We picked this as QUERY_MAXIMUM_RESULTS defaults to 10000, trying to get that many
# documents at once will fail, even if the query is paginated.
# This value will ensure we get the most documents possible without hitting that limit, it would
# still fail if the user lowers the QUERY_MAXIMUM_RESULTS environment variable for their Weaviate instance.
#
# See WeaviateDocumentStore._query_with_filters() for more information.
DEFAULT_QUERY_LIMIT = 9999
class WeaviateDocumentStore:
"""
WeaviateDocumentStore is a Document Store for Weaviate.
It can be used with Weaviate Cloud Services or self-hosted instances.
Usage example with Weaviate Cloud Services:
```python
import os
from haystack_integrations.document_stores.weaviate.auth import AuthApiKey
from haystack_integrations.document_stores.weaviate.document_store import WeaviateDocumentStore
os.environ["WEAVIATE_API_KEY"] = "MY_API_KEY
document_store = WeaviateDocumentStore(
url="rAnD0mD1g1t5.something.weaviate.cloud",
auth_client_secret=AuthApiKey(),
)
```
Usage example with self-hosted Weaviate:
```python
from haystack_integrations.document_stores.weaviate.document_store import WeaviateDocumentStore
document_store = WeaviateDocumentStore(url="http://localhost:8080")
```
"""
def __init__(
self,
*,
url: Optional[str] = None,
collection_settings: Optional[Dict[str, Any]] = None,
auth_client_secret: Optional[AuthCredentials] = None,
additional_headers: Optional[Dict] = None,
embedded_options: Optional[EmbeddedOptions] = None,
additional_config: Optional[AdditionalConfig] = None,
grpc_port: int = 50051,
grpc_secure: bool = False,
):
"""
Create a new instance of WeaviateDocumentStore and connects to the Weaviate instance.
:param url:
The URL to the weaviate instance.
:param collection_settings:
The collection settings to use. If `None`, it will use a collection named `default` with the following
properties:
- _original_id: text
- content: text
- dataframe: text
- blob_data: blob
- blob_mime_type: text
- score: number
The Document `meta` fields are omitted in the default collection settings as we can't make assumptions
on the structure of the meta field.
We heavily recommend to create a custom collection with the correct meta properties
for your use case.
Another option is relying on the automatic schema generation, but that's not recommended for
production use.
See the official `Weaviate documentation<https://weaviate.io/developers/weaviate/manage-data/collections>`_
for more information on collections and their properties.
:param auth_client_secret:
Authentication credentials. Can be one of the following types depending on the authentication mode:
- `AuthBearerToken` to use existing access and (optionally, but recommended) refresh tokens
- `AuthClientPassword` to use username and password for oidc Resource Owner Password flow
- `AuthClientCredentials` to use a client secret for oidc client credential flow
- `AuthApiKey` to use an API key
:param additional_headers:
Additional headers to include in the requests. Can be used to set OpenAI/HuggingFace keys.
OpenAI/HuggingFace key looks like this:
```
{"X-OpenAI-Api-Key": "<THE-KEY>"}, {"X-HuggingFace-Api-Key": "<THE-KEY>"}
```
:param embedded_options:
If set, create an embedded Weaviate cluster inside the client. For a full list of options see
`weaviate.embedded.EmbeddedOptions`.
:param additional_config:
Additional and advanced configuration options for weaviate.
:param grpc_port:
The port to use for the gRPC connection.
:param grpc_secure:
Whether to use a secure channel for the underlying gRPC API.
"""
# proxies, timeout_config, trust_env are part of additional_config now
# startup_period has been removed
self._client = weaviate.WeaviateClient(
connection_params=(
weaviate.connect.base.ConnectionParams.from_url(url=url, grpc_port=grpc_port, grpc_secure=grpc_secure)
if url
else None
),
auth_client_secret=auth_client_secret.resolve_value() if auth_client_secret else None,
additional_config=additional_config,
additional_headers=additional_headers,
embedded_options=embedded_options,
skip_init_checks=False,
)
self._client.connect()
# Test connection, it will raise an exception if it fails.
self._client.collections._get_all(simple=True)
if collection_settings is None:
collection_settings = {
"class": "Default",
"invertedIndexConfig": {"indexNullState": True},
"properties": DOCUMENT_COLLECTION_PROPERTIES,
}
else:
# Set the class if not set
collection_settings["class"] = collection_settings.get("class", "default").capitalize()
# Set the properties if they're not set
collection_settings["properties"] = collection_settings.get("properties", DOCUMENT_COLLECTION_PROPERTIES)
if not self._client.collections.exists(collection_settings["class"]):
self._client.collections.create_from_dict(collection_settings)
self._url = url
self._collection_settings = collection_settings
self._auth_client_secret = auth_client_secret
self._additional_headers = additional_headers
self._embedded_options = embedded_options
self._additional_config = additional_config
self._collection = self._client.collections.get(collection_settings["class"])
def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
embedded_options = asdict(self._embedded_options) if self._embedded_options else None
additional_config = (
json.loads(self._additional_config.model_dump_json(by_alias=True)) if self._additional_config else None
)
return default_to_dict(
self,
url=self._url,
collection_settings=self._collection_settings,
auth_client_secret=self._auth_client_secret.to_dict() if self._auth_client_secret else None,
additional_headers=self._additional_headers,
embedded_options=embedded_options,
additional_config=additional_config,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "WeaviateDocumentStore":
"""
Deserializes the component from a dictionary.
:param data:
The dictionary to deserialize from.
:returns:
The deserialized component.
"""
if (auth_client_secret := data["init_parameters"].get("auth_client_secret")) is not None:
data["init_parameters"]["auth_client_secret"] = AuthCredentials.from_dict(auth_client_secret)
if (embedded_options := data["init_parameters"].get("embedded_options")) is not None:
data["init_parameters"]["embedded_options"] = EmbeddedOptions(**embedded_options)
if (additional_config := data["init_parameters"].get("additional_config")) is not None:
data["init_parameters"]["additional_config"] = AdditionalConfig(**additional_config)
return default_from_dict(
cls,
data,
)
def count_documents(self) -> int:
"""
Returns the number of documents present in the DocumentStore.
"""
total = self._collection.aggregate.over_all(total_count=True).total_count
return total if total else 0
def _to_data_object(self, document: Document) -> Dict[str, Any]:
"""
Converts a Document to a Weaviate data object ready to be saved.
"""
data = document.to_dict()
# Weaviate forces a UUID as an id.
# We don't know if the id of our Document is a UUID or not, so we save it on a different field
# and let Weaviate a UUID that we're going to ignore completely.
data["_original_id"] = data.pop("id")
if (blob := data.pop("blob")) is not None:
# Weaviate wants the blob data as a base64 encoded string
# See the official docs for more information:
# https://weaviate.io/developers/weaviate/config-refs/datatypes#datatype-blob
data["blob_data"] = base64.b64encode(bytes(blob.pop("data"))).decode()
data["blob_mime_type"] = blob.pop("mime_type")
# The embedding vector is stored separately from the rest of the data
del data["embedding"]
if "sparse_embedding" in data:
sparse_embedding = data.pop("sparse_embedding", None)
if sparse_embedding:
logger.warning(
"Document %s has the `sparse_embedding` field set,"
"but storing sparse embeddings in Weaviate is not currently supported."
"The `sparse_embedding` field will be ignored.",
data["_original_id"],
)
return data
def _to_document(self, data: DataObject[Dict[str, Any], None]) -> Document:
"""
Converts a data object read from Weaviate into a Document.
"""
document_data = data.properties
document_data["id"] = document_data.pop("_original_id")
if isinstance(data.vector, List):
document_data["embedding"] = data.vector
elif isinstance(data.vector, Dict):
document_data["embedding"] = data.vector.get("default")
else:
document_data["embedding"] = None
if (blob_data := document_data.get("blob_data")) is not None:
document_data["blob"] = {
"data": base64.b64decode(blob_data),
"mime_type": document_data.get("blob_mime_type"),
}
# We always delete these fields as they're not part of the Document dataclass
document_data.pop("blob_data", None)
document_data.pop("blob_mime_type", None)
for key, value in document_data.items():
if isinstance(value, datetime.datetime):
document_data[key] = value.strftime("%Y-%m-%dT%H:%M:%SZ")
if weaviate_meta := getattr(data, "metadata", None):
# Depending on the type of retrieval we get score from different fields.
# score is returned when using BM25 retrieval.
# certainty is returned when using embedding retrieval.
if weaviate_meta.score is not None:
document_data["score"] = weaviate_meta.score
elif weaviate_meta.certainty is not None:
document_data["score"] = weaviate_meta.certainty
return Document.from_dict(document_data)
def _query(self) -> List[Dict[str, Any]]:
properties = [p.name for p in self._collection.config.get().properties]
try:
result = self._collection.iterator(include_vector=True, return_properties=properties)
except weaviate.exceptions.WeaviateQueryError as e:
msg = f"Failed to query documents in Weaviate. Error: {e.message}"
raise DocumentStoreError(msg) from e
return result
def _query_with_filters(self, filters: Dict[str, Any]) -> List[Dict[str, Any]]:
properties = [p.name for p in self._collection.config.get().properties]
# When querying with filters we need to paginate using limit and offset as using
# a cursor with after is not possible. See the official docs:
# https://weaviate.io/developers/weaviate/api/graphql/additional-operators#cursor-with-after
#
# Nonetheless there's also another issue, paginating with limit and offset is not efficient
# and it's still restricted by the QUERY_MAXIMUM_RESULTS environment variable.
# If the sum of limit and offest is greater than QUERY_MAXIMUM_RESULTS an error is raised.
# See the official docs for more:
# https://weaviate.io/developers/weaviate/api/graphql/additional-operators#performance-considerations
offset = 0
partial_result = None
result = []
# Keep querying until we get all documents matching the filters
while partial_result is None or len(partial_result.objects) == DEFAULT_QUERY_LIMIT:
try:
partial_result = self._collection.query.fetch_objects(
filters=convert_filters(filters),
include_vector=True,
limit=DEFAULT_QUERY_LIMIT,
offset=offset,
return_properties=properties,
)
except weaviate.exceptions.WeaviateQueryError as e:
msg = f"Failed to query documents in Weaviate. Error: {e.message}"
raise DocumentStoreError(msg) from e
result.extend(partial_result.objects)
offset += DEFAULT_QUERY_LIMIT
return result
def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
"""
Returns the documents that match the filters provided.
For a detailed specification of the filters, refer to the
DocumentStore.filter_documents() protocol documentation.
:param filters: The filters to apply to the document list.
:returns: A list of Documents that match the given filters.
"""
result = []
if filters:
result = self._query_with_filters(filters)
else:
result = self._query()
return [self._to_document(doc) for doc in result]
def _batch_write(self, documents: List[Document]) -> int:
"""
Writes document to Weaviate in batches.
Documents with the same id will be overwritten.
Raises in case of errors.
"""
with self._client.batch.dynamic() as batch:
for doc in documents:
if not isinstance(doc, Document):
msg = f"Expected a Document, got '{type(doc)}' instead."
raise ValueError(msg)
batch.add_object(
properties=self._to_data_object(doc),
collection=self._collection.name,
uuid=generate_uuid5(doc.id),
vector=doc.embedding,
)
if failed_objects := self._client.batch.failed_objects:
# We fallback to use the UUID if the _original_id is not present, this is just to be
mapped_objects = {}
for obj in failed_objects:
properties = obj.object_.properties or {}
# We get the object uuid just in case the _original_id is not present.
# That's extremely unlikely to happen but let's stay on the safe side.
id_ = properties.get("_original_id", obj.object_.uuid)
mapped_objects[id_] = obj.message
msg = "\n".join(
[
f"Failed to write object with id '{id_}'. Error: '{message}'"
for id_, message in mapped_objects.items()
]
)
raise DocumentStoreError(msg)
# If the document already exists we get no status message back from Weaviate.
# So we assume that all Documents were written.
return len(documents)
def _write(self, documents: List[Document], policy: DuplicatePolicy) -> int:
"""
Writes documents to Weaviate using the specified policy.
This doesn't uses the batch API, so it's slower than _batch_write.
If policy is set to SKIP it will skip any document that already exists.
If policy is set to FAIL it will raise an exception if any of the documents already exists.
"""
written = 0
duplicate_errors_ids = []
for doc in documents:
if not isinstance(doc, Document):
msg = f"Expected a Document, got '{type(doc)}' instead."
raise ValueError(msg)
if policy == DuplicatePolicy.SKIP and self._collection.data.exists(uuid=generate_uuid5(doc.id)):
# This Document already exists, we skip it
continue
try:
self._collection.data.insert(
uuid=generate_uuid5(doc.id),
properties=self._to_data_object(doc),
vector=doc.embedding,
)
written += 1
except weaviate.exceptions.UnexpectedStatusCodeError:
if policy == DuplicatePolicy.FAIL:
duplicate_errors_ids.append(doc.id)
if duplicate_errors_ids:
msg = f"IDs '{', '.join(duplicate_errors_ids)}' already exist in the document store."
raise DuplicateDocumentError(msg)
return written
def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int:
"""
Writes documents to Weaviate using the specified policy.
We recommend using a OVERWRITE policy as it's faster than other policies for Weaviate since it uses
the batch API.
We can't use the batch API for other policies as it doesn't return any information whether the document
already exists or not. That prevents us from returning errors when using the FAIL policy or skipping a
Document when using the SKIP policy.
"""
if policy in [DuplicatePolicy.NONE, DuplicatePolicy.OVERWRITE]:
return self._batch_write(documents)
return self._write(documents, policy)
def delete_documents(self, document_ids: List[str]) -> None:
"""
Deletes all documents with matching document_ids from the DocumentStore.
:param document_ids: The object_ids to delete.
"""
weaviate_ids = [generate_uuid5(doc_id) for doc_id in document_ids]
self._collection.data.delete_many(where=weaviate.classes.query.Filter.by_id().contains_any(weaviate_ids))
def _bm25_retrieval(
self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None
) -> List[Document]:
properties = [p.name for p in self._collection.config.get().properties]
result = self._collection.query.bm25(
query=query,
filters=convert_filters(filters) if filters else None,
limit=top_k,
include_vector=True,
query_properties=["content"],
return_properties=properties,
return_metadata=["score"],
)
return [self._to_document(doc) for doc in result.objects]
def _embedding_retrieval(
self,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
distance: Optional[float] = None,
certainty: Optional[float] = None,
) -> List[Document]:
if distance is not None and certainty is not None:
msg = "Can't use 'distance' and 'certainty' parameters together"
raise ValueError(msg)
properties = [p.name for p in self._collection.config.get().properties]
result = self._collection.query.near_vector(
near_vector=query_embedding,
distance=distance,
certainty=certainty,
include_vector=True,
filters=convert_filters(filters) if filters else None,
limit=top_k,
return_properties=properties,
return_metadata=["certainty"],
)
return [self._to_document(doc) for doc in result.objects]