-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
Signed-off-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: ChengZi <chen.zhang@zilliz.com>
- Loading branch information
1 parent
f357719
commit 8d7ac83
Showing
9 changed files
with
447 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
import uuid | ||
from copy import deepcopy | ||
from typing import Dict, Iterator, List, Optional, cast | ||
|
||
from pymilvus import DataType, MilvusClient | ||
from pymilvus.client.constants import ConsistencyLevel | ||
|
||
from memgpt.agent_store.storage import StorageConnector, TableType | ||
from memgpt.config import MemGPTConfig | ||
from memgpt.constants import MAX_EMBEDDING_DIM | ||
from memgpt.data_types import Passage, Record, RecordType | ||
from memgpt.utils import datetime_to_timestamp, printd, timestamp_to_datetime | ||
|
||
|
||
class MilvusStorageConnector(StorageConnector): | ||
"""Storage via Milvus""" | ||
|
||
def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None): | ||
super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id) | ||
|
||
assert table_type in [TableType.ARCHIVAL_MEMORY, TableType.PASSAGES], "Milvus only supports archival memory" | ||
if config.archival_storage_uri: | ||
self.client = MilvusClient(uri=config.archival_storage_uri) | ||
self._create_collection() | ||
else: | ||
raise ValueError("Please set `archival_storage_uri` in the config file when using Milvus.") | ||
|
||
# need to be converted to strings | ||
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"] | ||
|
||
def _create_collection(self): | ||
schema = MilvusClient.create_schema( | ||
auto_id=False, | ||
enable_dynamic_field=True, | ||
) | ||
schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=65_535) | ||
schema.add_field(field_name="text", datatype=DataType.VARCHAR, is_primary=False, max_length=65_535) | ||
schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=MAX_EMBEDDING_DIM) | ||
index_params = self.client.prepare_index_params() | ||
index_params.add_index(field_name="id") | ||
index_params.add_index(field_name="embedding", index_type="AUTOINDEX", metric_type="IP") | ||
self.client.create_collection( | ||
collection_name=self.table_name, schema=schema, index_params=index_params, consistency_level=ConsistencyLevel.Strong | ||
) | ||
|
||
def get_milvus_filter(self, filters: Optional[Dict] = {}) -> str: | ||
filter_conditions = {**self.filters, **filters} if filters is not None else self.filters | ||
if not filter_conditions: | ||
return "" | ||
conditions = [] | ||
for key, value in filter_conditions.items(): | ||
if key in self.uuid_fields or isinstance(key, str): | ||
condition = f'({key} == "{value}")' | ||
else: | ||
condition = f"({key} == {value})" | ||
conditions.append(condition) | ||
filter_expr = " and ".join(conditions) | ||
if len(conditions) == 1: | ||
filter_expr = filter_expr[1:-1] | ||
return filter_expr | ||
|
||
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000) -> Iterator[List[RecordType]]: | ||
if not self.client.has_collection(collection_name=self.table_name): | ||
yield [] | ||
filter_expr = self.get_milvus_filter(filters) | ||
offset = 0 | ||
while True: | ||
# Retrieve a chunk of records with the given page_size | ||
query_res = self.client.query( | ||
collection_name=self.table_name, | ||
filter=filter_expr, | ||
offset=offset, | ||
limit=page_size, | ||
) | ||
if not query_res: | ||
break | ||
# Yield a list of Record objects converted from the chunk | ||
yield self._list_to_records(query_res) | ||
|
||
# Increment the offset to get the next chunk in the next iteration | ||
offset += page_size | ||
|
||
def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[RecordType]: | ||
if not self.client.has_collection(collection_name=self.table_name): | ||
return [] | ||
filter_expr = self.get_milvus_filter(filters) | ||
query_res = self.client.query( | ||
collection_name=self.table_name, | ||
filter=filter_expr, | ||
limit=limit, | ||
) | ||
return self._list_to_records(query_res) | ||
|
||
def get(self, id: uuid.UUID) -> Optional[RecordType]: | ||
res = self.client.get(collection_name=self.table_name, ids=str(id)) | ||
return self._list_to_records(res)[0] if res else None | ||
|
||
def size(self, filters: Optional[Dict] = {}) -> int: | ||
if not self.client.has_collection(collection_name=self.table_name): | ||
return 0 | ||
filter_expr = self.get_milvus_filter(filters) | ||
count_expr = "count(*)" | ||
query_res = self.client.query( | ||
collection_name=self.table_name, | ||
filter=filter_expr, | ||
output_fields=[count_expr], | ||
) | ||
doc_num = query_res[0][count_expr] | ||
return doc_num | ||
|
||
def insert(self, record: RecordType): | ||
self.insert_many([record]) | ||
|
||
def insert_many(self, records: List[RecordType], show_progress=False): | ||
if not records: | ||
return | ||
|
||
# Milvus lite currently does not support upsert, so we delete and insert instead | ||
# self.client.upsert(collection_name=self.table_name, data=self._records_to_list(records)) | ||
ids = [str(record.id) for record in records] | ||
self.client.delete(collection_name=self.table_name, ids=ids) | ||
data = self._records_to_list(records) | ||
self.client.insert(collection_name=self.table_name, data=data) | ||
|
||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]: | ||
if not self.client.has_collection(self.table_name): | ||
return [] | ||
search_res = self.client.search( | ||
collection_name=self.table_name, data=[query_vec], filter=self.get_milvus_filter(filters), limit=top_k, output_fields=["*"] | ||
)[0] | ||
entity_res = [res["entity"] for res in search_res] | ||
return self._list_to_records(entity_res) | ||
|
||
def delete_table(self): | ||
self.client.drop_collection(collection_name=self.table_name) | ||
|
||
def delete(self, filters: Optional[Dict] = {}): | ||
if not self.client.has_collection(collection_name=self.table_name): | ||
return | ||
filter_expr = self.get_milvus_filter(filters) | ||
self.client.delete(collection_name=self.table_name, filter=filter_expr) | ||
|
||
def save(self): | ||
# save to persistence file (nothing needs to be done) | ||
printd("Saving milvus") | ||
|
||
def _records_to_list(self, records: List[Record]) -> List[Dict]: | ||
if records == []: | ||
return [] | ||
assert all(isinstance(r, Passage) for r in records) | ||
record_list = [] | ||
records = list(set(records)) | ||
for record in records: | ||
record_vars = deepcopy(vars(record)) | ||
_id = record_vars.pop("id") | ||
text = record_vars.pop("text", "") | ||
embedding = record_vars.pop("embedding") | ||
record_metadata = record_vars.pop("metadata_", None) or {} | ||
if "created_at" in record_vars: | ||
record_vars["created_at"] = datetime_to_timestamp(record_vars["created_at"]) | ||
record_dict = {key: value for key, value in record_vars.items() if value is not None} | ||
record_dict = { | ||
**record_dict, | ||
**record_metadata, | ||
"id": str(_id), | ||
"text": text, | ||
"embedding": embedding, | ||
} | ||
for key, value in record_dict.items(): | ||
if key in self.uuid_fields: | ||
record_dict[key] = str(value) | ||
record_list.append(record_dict) | ||
return record_list | ||
|
||
def _list_to_records(self, query_res: List[Dict]) -> List[RecordType]: | ||
records = [] | ||
for res_dict in query_res: | ||
_id = res_dict.pop("id") | ||
embedding = res_dict.pop("embedding") | ||
text = res_dict.pop("text") | ||
metadata = deepcopy(res_dict) | ||
for key, value in metadata.items(): | ||
if key in self.uuid_fields: | ||
metadata[key] = uuid.UUID(value) | ||
elif key == "created_at": | ||
metadata[key] = timestamp_to_datetime(value) | ||
records.append( | ||
cast( | ||
RecordType, | ||
self.type( | ||
text=text, | ||
embedding=embedding, | ||
id=uuid.UUID(_id), | ||
**metadata, | ||
), | ||
) | ||
) | ||
return records |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.