-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
milvus.py
521 lines (447 loc) · 25.2 KB
/
milvus.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
import logging
from typing import Any, Dict, Generator, List, Optional, Union
import numpy
import numpy as np
from milvus import IndexType, MetricType, Milvus, Status
from scipy.special import expit
from tqdm import tqdm
from haystack import Document
from haystack.document_store.sql import SQLDocumentStore
from haystack.retriever.base import BaseRetriever
from haystack.utils import get_batches_from_generator
logger = logging.getLogger(__name__)
class MilvusDocumentStore(SQLDocumentStore):
"""
Milvus (https://milvus.io/) is a highly reliable, scalable Document Store specialized on storing and processing vectors.
Therefore, it is particularly suited for Haystack users that work with dense retrieval methods (like DPR).
In contrast to FAISS, Milvus ...
- runs as a separate service (e.g. a Docker container) and can scale easily in a distributed environment
- allows dynamic data management (i.e. you can insert/delete vectors without recreating the whole index)
- encapsulates multiple ANN libraries (FAISS, ANNOY ...)
This class uses Milvus for all vector related storage, processing and querying.
The meta-data (e.g. for filtering) and the document text are however stored in a separate SQL Database as Milvus
does not allow these data types (yet).
Usage:
1. Start a Milvus server (see https://milvus.io/docs/v0.10.5/install_milvus.md)
2. Init a MilvusDocumentStore in Haystack
"""
def __init__(
self,
sql_url: str = "sqlite:///",
milvus_url: str = "tcp://localhost:19530",
connection_pool: str = "SingletonThread",
index: str = "document",
vector_dim: int = 768,
index_file_size: int = 1024,
similarity: str = "dot_product",
index_type: IndexType = IndexType.FLAT,
index_param: Optional[Dict[str, Any]] = None,
search_param: Optional[Dict[str, Any]] = None,
update_existing_documents: bool = False,
return_embedding: bool = False,
embedding_field: str = "embedding",
progress_bar: bool = True,
**kwargs,
):
"""
:param sql_url: SQL connection URL for storing document texts and metadata. It defaults to a local, file based SQLite DB. For large scale
deployment, Postgres is recommended. If using MySQL then same server can also be used for
Milvus metadata. For more details see https://milvus.io/docs/v0.10.5/data_manage.md.
:param milvus_url: Milvus server connection URL for storing and processing vectors.
Protocol, host and port will automatically be inferred from the URL.
See https://milvus.io/docs/v0.10.5/install_milvus.md for instructions to start a Milvus instance.
:param connection_pool: Connection pool type to connect with Milvus server. Default: "SingletonThread".
:param index: Index name for text, embedding and metadata (in Milvus terms, this is the "collection name").
:param vector_dim: The embedding vector size. Default: 768.
:param index_file_size: Specifies the size of each segment file that is stored by Milvus and its default value is 1024 MB.
When the size of newly inserted vectors reaches the specified volume, Milvus packs these vectors into a new segment.
Milvus creates one index file for each segment. When conducting a vector search, Milvus searches all index files one by one.
As a rule of thumb, we would see a 30% ~ 50% increase in the search performance after changing the value of index_file_size from 1024 to 2048.
Note that an overly large index_file_size value may cause failure to load a segment into the memory or graphics memory.
(From https://milvus.io/docs/v0.10.5/performance_faq.md#How-can-I-get-the-best-performance-from-Milvus-through-setting-index_file_size)
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default and recommended for DPR embeddings.
'cosine' is recommended for Sentence Transformers, but is not directly supported by Milvus.
However, you can normalize your embeddings and use `dot_product` to get the same results.
See https://milvus.io/docs/v0.10.5/metric.md?Inner-product-(IP)#floating.
:param index_type: Type of approximate nearest neighbour (ANN) index used. The choice here determines your tradeoff between speed and accuracy.
Some popular options:
- FLAT (default): Exact method, slow
- IVF_FLAT, inverted file based heuristic, fast
- HSNW: Graph based, fast
- ANNOY: Tree based, fast
See: https://milvus.io/docs/v0.10.5/index.md
:param index_param: Configuration parameters for the chose index_type needed at indexing time.
For example: {"nlist": 16384} as the number of cluster units to create for index_type IVF_FLAT.
See https://milvus.io/docs/v0.10.5/index.md
:param search_param: Configuration parameters for the chose index_type needed at query time
For example: {"nprobe": 10} as the number of cluster units to query for index_type IVF_FLAT.
See https://milvus.io/docs/v0.10.5/index.md
:param update_existing_documents: Whether to update any existing documents with the same ID when adding
documents. When set as True, any document with an existing ID gets updated.
If set to False, an error is raised if the document ID of the document being
added already exists.
:param return_embedding: To return document embedding.
:param embedding_field: Name of field containing an embedding vector.
:param progress_bar: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.
"""
self.milvus_server = Milvus(uri=milvus_url, pool=connection_pool)
self.vector_dim = vector_dim
self.index_file_size = index_file_size
if similarity == "dot_product":
self.metric_type = MetricType.L2
else:
raise ValueError("The Milvus document store can currently only support dot_product similarity. "
"Please set similarity=\"dot_product\"")
self.index_type = index_type
self.index_param = index_param or {"nlist": 16384}
self.search_param = search_param or {"nprobe": 10}
self.index = index
self._create_collection_and_index_if_not_exist(self.index)
self.return_embedding = return_embedding
self.embedding_field = embedding_field
self.progress_bar = progress_bar
super().__init__(
url=sql_url,
update_existing_documents=update_existing_documents,
index=index
)
def __del__(self):
return self.milvus_server.close()
def _create_collection_and_index_if_not_exist(
self,
index: Optional[str] = None,
index_param: Optional[Dict[str, Any]] = None
):
index = index or self.index
index_param = index_param or self.index_param
status, ok = self.milvus_server.has_collection(collection_name=index)
if not ok:
collection_param = {
'collection_name': index,
'dimension': self.vector_dim,
'index_file_size': self.index_file_size,
'metric_type': self.metric_type
}
status = self.milvus_server.create_collection(collection_param)
if status.code != Status.SUCCESS:
raise RuntimeError(f'Collection creation on Milvus server failed: {status}')
status = self.milvus_server.create_index(index, self.index_type, index_param)
if status.code != Status.SUCCESS:
raise RuntimeError(f'Index creation on Milvus server failed: {status}')
def _create_document_field_map(self) -> Dict:
return {
self.index: self.embedding_field,
}
def write_documents(
self, documents: Union[List[dict], List[Document]], index: Optional[str] = None, batch_size: int = 10_000
):
"""
Add new documents to the DocumentStore.
:param documents: List of `Dicts` or List of `Documents`. If they already contain the embeddings, we'll index
them right away in Milvus. If not, you can later call update_embeddings() to create & index them.
:param index: (SQL) index name for storing the docs and metadata
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
:return:
"""
index = index or self.index
self._create_collection_and_index_if_not_exist(index)
field_map = self._create_document_field_map()
if len(documents) == 0:
logger.warning("Calling DocumentStore.write_documents() with empty list")
return
document_objects = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents]
add_vectors = False if document_objects[0].embedding is None else True
batched_documents = get_batches_from_generator(document_objects, batch_size)
with tqdm(total=len(document_objects), disable=not self.progress_bar) as progress_bar:
for document_batch in batched_documents:
vector_ids = []
if add_vectors:
doc_ids = []
embeddings = []
for doc in document_batch:
doc_ids.append(doc.id)
if isinstance(doc.embedding, np.ndarray):
embeddings.append(doc.embedding.tolist())
elif isinstance(doc.embedding, list):
embeddings.append(doc.embedding)
else:
raise AttributeError(f'Format of supplied document embedding {type(doc.embedding)} is not '
f'supported. Please use list or numpy.ndarray')
if self.update_existing_documents:
existing_docs = super().get_documents_by_id(ids=doc_ids, index=index)
self._delete_vector_ids_from_milvus(documents=existing_docs, index=index)
status, vector_ids = self.milvus_server.insert(collection_name=index, records=embeddings)
if status.code != Status.SUCCESS:
raise RuntimeError(f'Vector embedding insertion failed: {status}')
docs_to_write_in_sql = []
for idx, doc in enumerate(document_batch):
meta = doc.meta
if add_vectors:
meta["vector_id"] = vector_ids[idx]
docs_to_write_in_sql.append(doc)
super().write_documents(docs_to_write_in_sql, index=index)
progress_bar.update(batch_size)
progress_bar.close()
self.milvus_server.flush([index])
if self.update_existing_documents:
self.milvus_server.compact(collection_name=index)
def update_embeddings(
self,
retriever: BaseRetriever,
index: Optional[str] = None,
batch_size: int = 10_000,
update_existing_embeddings: bool = True,
filters: Optional[Dict[str, List[str]]] = None,
):
"""
Updates the embeddings in the the document store using the encoding model specified in the retriever.
This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config).
:param retriever: Retriever to use to get embeddings for text
:param index: (SQL) index name for storing the docs and metadata
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
:param update_existing_embeddings: Whether to update existing embeddings of the documents. If set to False,
only documents without embeddings are processed. This mode can be used for
incremental updating of embeddings, wherein, only newly indexed documents
get processed.
:param filters: Optional filters to narrow down the documents for which embeddings are to be updated.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:return: None
"""
index = index or self.index
self._create_collection_and_index_if_not_exist(index)
document_count = self.get_document_count(index=index)
if document_count == 0:
logger.warning("Calling DocumentStore.update_embeddings() on an empty index")
return
logger.info(f"Updating embeddings for {document_count} docs...")
result = self._query(
index=index,
vector_ids=None,
batch_size=batch_size,
filters=filters,
only_documents_without_embedding=not update_existing_embeddings
)
batched_documents = get_batches_from_generator(result, batch_size)
with tqdm(total=document_count, disable=not self.progress_bar) as progress_bar:
for document_batch in batched_documents:
self._delete_vector_ids_from_milvus(documents=document_batch, index=index)
embeddings = retriever.embed_passages(document_batch) # type: ignore
embeddings_list = [embedding.tolist() for embedding in embeddings]
assert len(document_batch) == len(embeddings_list)
status, vector_ids = self.milvus_server.insert(collection_name=index, records=embeddings_list)
if status.code != Status.SUCCESS:
raise RuntimeError(f'Vector embedding insertion failed: {status}')
vector_id_map = {}
for vector_id, doc in zip(vector_ids, document_batch):
vector_id_map[doc.id] = vector_id
self.update_vector_ids(vector_id_map, index=index)
progress_bar.update(batch_size)
progress_bar.close()
self.milvus_server.flush([index])
self.milvus_server.compact(collection_name=index)
def query_by_embedding(self,
query_emb: np.ndarray,
filters: Optional[dict] = None,
top_k: int = 10,
index: Optional[str] = None,
return_embedding: Optional[bool] = None) -> List[Document]:
"""
Find the document that is most similar to the provided `query_emb` by using a vector similarity metric.
:param query_emb: Embedding of the query (e.g. gathered from DPR)
:param filters: Optional filters to narrow down the search space.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param top_k: How many documents to return
:param index: (SQL) index name for storing the docs and metadata
:param return_embedding: To return document embedding
:return:
"""
if filters:
raise Exception("Query filters are not implemented for the MilvusDocumentStore.")
index = index or self.index
status, ok = self.milvus_server.has_collection(collection_name=index)
if status.code != Status.SUCCESS:
raise RuntimeError(f'Milvus has collection check failed: {status}')
if not ok:
raise Exception("No index exists. Use 'update_embeddings()` to create an index.")
if return_embedding is None:
return_embedding = self.return_embedding
index = index or self.index
query_emb = query_emb.reshape(1, -1).astype(np.float32)
status, search_result = self.milvus_server.search(
collection_name=index,
query_records=query_emb,
top_k=top_k,
params=self.search_param
)
if status.code != Status.SUCCESS:
raise RuntimeError(f'Vector embedding search failed: {status}')
vector_ids_for_query = []
scores_for_vector_ids: Dict[str, float] = {}
for vector_id_list, distance_list in zip(search_result.id_array, search_result.distance_array):
for vector_id, distance in zip(vector_id_list, distance_list):
vector_ids_for_query.append(str(vector_id))
scores_for_vector_ids[str(vector_id)] = distance
documents = self.get_documents_by_vector_ids(vector_ids_for_query, index=index)
if return_embedding:
self._populate_embeddings_to_docs(index=index, docs=documents)
for doc in documents:
doc.score = scores_for_vector_ids[doc.meta["vector_id"]]
doc.probability = float(expit(np.asarray(doc.score / 100)))
return documents
def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
"""
Delete all documents (from SQL AND Milvus).
:param index: (SQL) index name for storing the docs and metadata
:param filters: Optional filters to narrow down the search space.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:return: None
"""
index = index or self.index
super().delete_all_documents(index=index, filters=filters)
status, ok = self.milvus_server.has_collection(collection_name=index)
if status.code != Status.SUCCESS:
raise RuntimeError(f'Milvus has collection check failed: {status}')
if ok:
status = self.milvus_server.drop_collection(collection_name=index)
if status.code != Status.SUCCESS:
raise RuntimeError(f'Milvus drop collection failed: {status}')
self.milvus_server.flush([index])
self.milvus_server.compact(collection_name=index)
def get_all_documents_generator(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
) -> Generator[Document, None, None]:
"""
Get all documents from the document store. Under-the-hood, documents are fetched in batches from the
document store and yielded as individual documents. This method can be used to iteratively process
a large number of documents without having to load all documents in memory.
:param index: Name of the index to get the documents from. If None, the
DocumentStore's default index (self.index) will be used.
:param filters: Optional filters to narrow down the documents to return.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param return_embedding: Whether to return the document embeddings.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
"""
index = index or self.index
documents = super().get_all_documents_generator(
index=index, filters=filters, batch_size=batch_size
)
if return_embedding is None:
return_embedding = self.return_embedding
for doc in documents:
if return_embedding:
self._populate_embeddings_to_docs(index=index, docs=[doc])
yield doc
def get_all_documents(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
) -> List[Document]:
"""
Get documents from the document store (optionally using filter criteria).
:param index: Name of the index to get the documents from. If None, the
DocumentStore's default index (self.index) will be used.
:param filters: Optional filters to narrow down the documents to return.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param return_embedding: Whether to return the document embeddings.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
"""
index = index or self.index
result = self.get_all_documents_generator(
index=index, filters=filters, return_embedding=return_embedding, batch_size=batch_size
)
documents = list(result)
return documents
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
"""
Fetch a document by specifying its text id string
:param id: ID of the document
:param index: Name of the index to get the documents from. If None, the
DocumentStore's default index (self.index) will be used.
"""
documents = self.get_documents_by_id([id], index)
document = documents[0] if documents else None
return document
def get_documents_by_id(
self, ids: List[str], index: Optional[str] = None, batch_size: int = 10_000
) -> List[Document]:
"""
Fetch multiple documents by specifying their IDs (strings)
:param ids: List of IDs of the documents
:param index: Name of the index to get the documents from. If None, the
DocumentStore's default index (self.index) will be used.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
"""
index = index or self.index
documents = super().get_documents_by_id(ids=ids, index=index)
if self.return_embedding:
self._populate_embeddings_to_docs(index=index, docs=documents)
return documents
def _populate_embeddings_to_docs(self, docs: List[Document], index: Optional[str] = None):
index = index or self.index
docs_with_vector_ids = []
for doc in docs:
if doc.meta and doc.meta.get("vector_id") is not None:
docs_with_vector_ids.append(doc)
if len(docs_with_vector_ids) == 0:
return
ids = [int(doc.meta.get("vector_id")) for doc in docs_with_vector_ids] # type: ignore
status, vector_embeddings = self.milvus_server.get_entity_by_id(
collection_name=index,
ids=ids
)
if status.code != Status.SUCCESS:
raise RuntimeError(f'Getting vector embedding by id failed: {status}')
for embedding, doc in zip(vector_embeddings, docs_with_vector_ids):
doc.embedding = numpy.array(embedding, dtype="float32")
def _delete_vector_ids_from_milvus(self, documents: List[Document], index: Optional[str] = None):
index = index or self.index
existing_vector_ids = []
for doc in documents:
if "vector_id" in doc.meta:
existing_vector_ids.append(int(doc.meta["vector_id"]))
if len(existing_vector_ids) > 0:
status = self.milvus_server.delete_entity_by_id(
collection_name=index,
id_array=existing_vector_ids
)
if status.code != Status.SUCCESS:
raise RuntimeError("E existing vector ids deletion failed: {status}")
def get_all_vectors(self, index: Optional[str] = None) -> List[np.ndarray]:
"""
Helper function to dump all vectors stored in Milvus server.
:param index: Name of the index to get the documents from. If None, the
DocumentStore's default index (self.index) will be used.
:return: List[np.array]: List of vectors.
"""
index = index or self.index
status, collection_info = self.milvus_server.get_collection_stats(collection_name=index)
if not status.OK():
logger.info(f"Failed fetch stats from store ...")
return list()
logger.debug(f"collection_info = {collection_info}")
ids = list()
partition_list = collection_info["partitions"]
for partition in partition_list:
segment_list = partition["segments"]
for segment in segment_list:
segment_name = segment["name"]
status, id_list = self.milvus_server.list_id_in_segment(
collection_name=index,
segment_name=segment_name)
logger.debug(f"{status}: segment {segment_name} has {len(id_list)} vectors ...")
ids.extend(id_list)
if len(ids) == 0:
logger.info(f"No documents in the store ...")
return list()
status, vectors = self.milvus_server.get_entity_by_id(collection_name=index, ids=ids)
if not status.OK():
logger.info(f"Failed fetch document for ids {ids} from store ...")
return list()
return vectors