Skip to content

Commit

Permalink
feat: Milvus storage connector (#1198) (#1400)
Browse files Browse the repository at this point in the history
Signed-off-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: ChengZi <chen.zhang@zilliz.com>
  • Loading branch information
sarahwooders and zc277584121 committed May 26, 2024
1 parent f357719 commit 8d7ac83
Show file tree
Hide file tree
Showing 9 changed files with 447 additions and 62 deletions.
1 change: 0 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ jobs:
- name: Build and run container
run: bash db/run_postgres.sh


- name: "Setup Python, Poetry and Dependencies"
uses: packetcoders/action-setup-cache-python-poetry@main
with:
Expand Down
20 changes: 20 additions & 0 deletions docs/storage.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,23 @@ memgpt configure
```

and selecting `lancedb` for archival storage, and database URI (e.g. `./.lancedb`"), Empty archival uri is also handled and default uri is set at `./.lancedb`. For more checkout [lancedb docs](https://lancedb.github.io/lancedb/)


## Milvus

To enable the Milvus backend, make sure to install the required dependencies with:

```sh
pip install 'pymemgpt[milvus]'
```
You can configure Milvus connection via command `memgpt configure`.

```sh
...
? Select storage backend for archival data: milvus
? Enter the Milvus connection URI (Default: ~/.memgpt/milvus.db): ~/.memgpt/milvus.db
```
You just set the URI to the local file path, e.g. `~/.memgpt/milvus.db`, which will automatically invoke the local Milvus service instance through Milvus Lite.

If you have large scale of data such as more than a million docs, we recommend setting up a more performant Milvus server on [docker or kubenetes](https://milvus.io/docs/quickstart.md).
And in this case, your URI should be the server URI, e.g. `http://localhost:19530`.
198 changes: 198 additions & 0 deletions memgpt/agent_store/milvus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import uuid
from copy import deepcopy
from typing import Dict, Iterator, List, Optional, cast

from pymilvus import DataType, MilvusClient
from pymilvus.client.constants import ConsistencyLevel

from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.config import MemGPTConfig
from memgpt.constants import MAX_EMBEDDING_DIM
from memgpt.data_types import Passage, Record, RecordType
from memgpt.utils import datetime_to_timestamp, printd, timestamp_to_datetime


class MilvusStorageConnector(StorageConnector):
"""Storage via Milvus"""

def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None):
super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)

assert table_type in [TableType.ARCHIVAL_MEMORY, TableType.PASSAGES], "Milvus only supports archival memory"
if config.archival_storage_uri:
self.client = MilvusClient(uri=config.archival_storage_uri)
self._create_collection()
else:
raise ValueError("Please set `archival_storage_uri` in the config file when using Milvus.")

# need to be converted to strings
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]

def _create_collection(self):
schema = MilvusClient.create_schema(
auto_id=False,
enable_dynamic_field=True,
)
schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=65_535)
schema.add_field(field_name="text", datatype=DataType.VARCHAR, is_primary=False, max_length=65_535)
schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=MAX_EMBEDDING_DIM)
index_params = self.client.prepare_index_params()
index_params.add_index(field_name="id")
index_params.add_index(field_name="embedding", index_type="AUTOINDEX", metric_type="IP")
self.client.create_collection(
collection_name=self.table_name, schema=schema, index_params=index_params, consistency_level=ConsistencyLevel.Strong
)

def get_milvus_filter(self, filters: Optional[Dict] = {}) -> str:
filter_conditions = {**self.filters, **filters} if filters is not None else self.filters
if not filter_conditions:
return ""
conditions = []
for key, value in filter_conditions.items():
if key in self.uuid_fields or isinstance(key, str):
condition = f'({key} == "{value}")'
else:
condition = f"({key} == {value})"
conditions.append(condition)
filter_expr = " and ".join(conditions)
if len(conditions) == 1:
filter_expr = filter_expr[1:-1]
return filter_expr

def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000) -> Iterator[List[RecordType]]:
if not self.client.has_collection(collection_name=self.table_name):
yield []
filter_expr = self.get_milvus_filter(filters)
offset = 0
while True:
# Retrieve a chunk of records with the given page_size
query_res = self.client.query(
collection_name=self.table_name,
filter=filter_expr,
offset=offset,
limit=page_size,
)
if not query_res:
break
# Yield a list of Record objects converted from the chunk
yield self._list_to_records(query_res)

# Increment the offset to get the next chunk in the next iteration
offset += page_size

def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[RecordType]:
if not self.client.has_collection(collection_name=self.table_name):
return []
filter_expr = self.get_milvus_filter(filters)
query_res = self.client.query(
collection_name=self.table_name,
filter=filter_expr,
limit=limit,
)
return self._list_to_records(query_res)

def get(self, id: uuid.UUID) -> Optional[RecordType]:
res = self.client.get(collection_name=self.table_name, ids=str(id))
return self._list_to_records(res)[0] if res else None

def size(self, filters: Optional[Dict] = {}) -> int:
if not self.client.has_collection(collection_name=self.table_name):
return 0
filter_expr = self.get_milvus_filter(filters)
count_expr = "count(*)"
query_res = self.client.query(
collection_name=self.table_name,
filter=filter_expr,
output_fields=[count_expr],
)
doc_num = query_res[0][count_expr]
return doc_num

def insert(self, record: RecordType):
self.insert_many([record])

def insert_many(self, records: List[RecordType], show_progress=False):
if not records:
return

# Milvus lite currently does not support upsert, so we delete and insert instead
# self.client.upsert(collection_name=self.table_name, data=self._records_to_list(records))
ids = [str(record.id) for record in records]
self.client.delete(collection_name=self.table_name, ids=ids)
data = self._records_to_list(records)
self.client.insert(collection_name=self.table_name, data=data)

def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
if not self.client.has_collection(self.table_name):
return []
search_res = self.client.search(
collection_name=self.table_name, data=[query_vec], filter=self.get_milvus_filter(filters), limit=top_k, output_fields=["*"]
)[0]
entity_res = [res["entity"] for res in search_res]
return self._list_to_records(entity_res)

def delete_table(self):
self.client.drop_collection(collection_name=self.table_name)

def delete(self, filters: Optional[Dict] = {}):
if not self.client.has_collection(collection_name=self.table_name):
return
filter_expr = self.get_milvus_filter(filters)
self.client.delete(collection_name=self.table_name, filter=filter_expr)

def save(self):
# save to persistence file (nothing needs to be done)
printd("Saving milvus")

def _records_to_list(self, records: List[Record]) -> List[Dict]:
if records == []:
return []
assert all(isinstance(r, Passage) for r in records)
record_list = []
records = list(set(records))
for record in records:
record_vars = deepcopy(vars(record))
_id = record_vars.pop("id")
text = record_vars.pop("text", "")
embedding = record_vars.pop("embedding")
record_metadata = record_vars.pop("metadata_", None) or {}
if "created_at" in record_vars:
record_vars["created_at"] = datetime_to_timestamp(record_vars["created_at"])
record_dict = {key: value for key, value in record_vars.items() if value is not None}
record_dict = {
**record_dict,
**record_metadata,
"id": str(_id),
"text": text,
"embedding": embedding,
}
for key, value in record_dict.items():
if key in self.uuid_fields:
record_dict[key] = str(value)
record_list.append(record_dict)
return record_list

def _list_to_records(self, query_res: List[Dict]) -> List[RecordType]:
records = []
for res_dict in query_res:
_id = res_dict.pop("id")
embedding = res_dict.pop("embedding")
text = res_dict.pop("text")
metadata = deepcopy(res_dict)
for key, value in metadata.items():
if key in self.uuid_fields:
metadata[key] = uuid.UUID(value)
elif key == "created_at":
metadata[key] = timestamp_to_datetime(value)
records.append(
cast(
RecordType,
self.type(
text=text,
embedding=embedding,
id=uuid.UUID(_id),
**metadata,
),
)
)
return records
3 changes: 3 additions & 0 deletions memgpt/agent_store/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,10 @@ def get_storage_connector(
from memgpt.agent_store.db import SQLLiteStorageConnector

return SQLLiteStorageConnector(table_type, config, user_id, agent_id)
elif storage_type == "milvus":
from memgpt.agent_store.milvus import MilvusStorageConnector

return MilvusStorageConnector(table_type, config, user_id, agent_id)
else:
raise NotImplementedError(f"Storage type {storage_type} not implemented")

Expand Down
9 changes: 8 additions & 1 deletion memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden

def configure_archival_storage(config: MemGPTConfig, credentials: MemGPTCredentials):
# Configure archival storage backend
archival_storage_options = ["postgres", "chroma"]
archival_storage_options = ["postgres", "chroma", "milvus"]
archival_storage_type = questionary.select(
"Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type
).ask()
Expand Down Expand Up @@ -914,6 +914,13 @@ def configure_archival_storage(config: MemGPTConfig, credentials: MemGPTCredenti
if chroma_type == "persistent":
archival_storage_path = os.path.join(MEMGPT_DIR, "chroma")

if archival_storage_type == "milvus":
default_milvus_uri = archival_storage_path = os.path.join(MEMGPT_DIR, "milvus.db")
archival_storage_uri = questionary.text(
f"Enter the Milvus connection URI (Default: {default_milvus_uri}):", default=default_milvus_uri
).ask()
if archival_storage_uri is None:
raise KeyboardInterrupt
return archival_storage_type, archival_storage_uri, archival_storage_path

# TODO: allow configuring embedding model
Expand Down
Loading

0 comments on commit 8d7ac83

Please sign in to comment.