From 6cb5492f02499ca4c3fd196d4e194f71c77aa422 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Wed, 27 Mar 2024 19:53:50 +0800 Subject: [PATCH 1/6] upgrade llama-index-vector-stores-chroma and rag test coverage 100% --- .../minecraft_env/minecraft_env.py | 2 +- metagpt/rag/factories/index.py | 2 +- metagpt/rag/factories/retriever.py | 2 +- metagpt/rag/vector_stores/__init__.py | 0 metagpt/rag/vector_stores/chroma/__init__.py | 3 - metagpt/rag/vector_stores/chroma/base.py | 290 ------------------ setup.py | 2 +- tests/metagpt/rag/engines/test_simple.py | 168 +++++++++- tests/metagpt/rag/factories/test_embedding.py | 43 +++ tests/metagpt/rag/factories/test_index.py | 89 ++++++ tests/metagpt/rag/factories/test_llm.py | 71 +++++ tests/metagpt/rag/factories/test_ranker.py | 58 ++-- tests/metagpt/rag/factories/test_retriever.py | 80 +++-- tests/metagpt/rag/rankers/test_base_ranker.py | 23 ++ .../metagpt/rag/rankers/test_object_ranker.py | 37 ++- .../rag/retrievers/test_base_retriever.py | 21 ++ .../rag/retrievers/test_bm25_retriever.py | 12 +- .../rag/retrievers/test_chroma_retriever.py | 20 ++ .../rag/retrievers/test_es_retriever.py | 20 ++ .../rag/retrievers/test_faiss_retriever.py | 11 +- .../rag/retrievers/test_hybrid_retriever.py | 28 +- 21 files changed, 600 insertions(+), 382 deletions(-) delete mode 100644 metagpt/rag/vector_stores/__init__.py delete mode 100644 metagpt/rag/vector_stores/chroma/__init__.py delete mode 100644 metagpt/rag/vector_stores/chroma/base.py create mode 100644 tests/metagpt/rag/factories/test_embedding.py create mode 100644 tests/metagpt/rag/factories/test_index.py create mode 100644 tests/metagpt/rag/factories/test_llm.py create mode 100644 tests/metagpt/rag/rankers/test_base_ranker.py create mode 100644 tests/metagpt/rag/retrievers/test_base_retriever.py create mode 100644 tests/metagpt/rag/retrievers/test_chroma_retriever.py create mode 100644 tests/metagpt/rag/retrievers/test_es_retriever.py diff --git a/metagpt/environment/minecraft_env/minecraft_env.py b/metagpt/environment/minecraft_env/minecraft_env.py index 26d4d03a86..6e1800b328 100644 --- a/metagpt/environment/minecraft_env/minecraft_env.py +++ b/metagpt/environment/minecraft_env/minecraft_env.py @@ -8,6 +8,7 @@ import time from typing import Any, Iterable +from llama_index.vector_stores.chroma import ChromaVectorStore from pydantic import ConfigDict, Field from metagpt.config2 import config as CONFIG @@ -15,7 +16,6 @@ from metagpt.environment.minecraft_env.const import MC_CKPT_DIR from metagpt.environment.minecraft_env.minecraft_ext_env import MinecraftExtEnv from metagpt.logs import logger -from metagpt.rag.vector_stores.chroma import ChromaVectorStore from metagpt.utils.common import load_mc_skills_code, read_json_file, write_json_file diff --git a/metagpt/rag/factories/index.py b/metagpt/rag/factories/index.py index f200fc94f0..a56471359e 100644 --- a/metagpt/rag/factories/index.py +++ b/metagpt/rag/factories/index.py @@ -5,6 +5,7 @@ from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex from llama_index.core.vector_stores.types import BasePydanticVectorStore +from llama_index.vector_stores.chroma import ChromaVectorStore from llama_index.vector_stores.elasticsearch import ElasticsearchStore from llama_index.vector_stores.faiss import FaissVectorStore @@ -17,7 +18,6 @@ ElasticsearchKeywordIndexConfig, FAISSIndexConfig, ) -from metagpt.rag.vector_stores.chroma import ChromaVectorStore class RAGIndexFactory(ConfigBasedFactory): diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index a107d95733..65729002ea 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -6,6 +6,7 @@ import faiss from llama_index.core import StorageContext, VectorStoreIndex from llama_index.core.vector_stores.types import BasePydanticVectorStore +from llama_index.vector_stores.chroma import ChromaVectorStore from llama_index.vector_stores.elasticsearch import ElasticsearchStore from llama_index.vector_stores.faiss import FaissVectorStore @@ -25,7 +26,6 @@ FAISSRetrieverConfig, IndexRetrieverConfig, ) -from metagpt.rag.vector_stores.chroma import ChromaVectorStore class RetrieverFactory(ConfigBasedFactory): diff --git a/metagpt/rag/vector_stores/__init__.py b/metagpt/rag/vector_stores/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/metagpt/rag/vector_stores/chroma/__init__.py b/metagpt/rag/vector_stores/chroma/__init__.py deleted file mode 100644 index 87ba4d8a76..0000000000 --- a/metagpt/rag/vector_stores/chroma/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from metagpt.rag.vector_stores.chroma.base import ChromaVectorStore - -__all__ = ["ChromaVectorStore"] diff --git a/metagpt/rag/vector_stores/chroma/base.py b/metagpt/rag/vector_stores/chroma/base.py deleted file mode 100644 index 55e5bd40d2..0000000000 --- a/metagpt/rag/vector_stores/chroma/base.py +++ /dev/null @@ -1,290 +0,0 @@ -"""Chroma vector store. - -Refs to https://github.com/run-llama/llama_index/blob/v0.10.12/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/llama_index/vector_stores/chroma/base.py. -The repo requires onnxruntime = "^1.17.0", which is too new for many OS systems, such as CentOS7. -""" - -import math -from typing import Any, Dict, Generator, List, Optional, cast - -import chromadb -from chromadb.api.models.Collection import Collection -from llama_index.core.bridge.pydantic import Field, PrivateAttr -from llama_index.core.schema import BaseNode, MetadataMode, TextNode -from llama_index.core.utils import truncate_text -from llama_index.core.vector_stores.types import ( - BasePydanticVectorStore, - MetadataFilters, - VectorStoreQuery, - VectorStoreQueryResult, -) -from llama_index.core.vector_stores.utils import ( - legacy_metadata_dict_to_node, - metadata_dict_to_node, - node_to_metadata_dict, -) - -from metagpt.logs import logger - - -def _transform_chroma_filter_condition(condition: str) -> str: - """Translate standard metadata filter op to Chroma specific spec.""" - if condition == "and": - return "$and" - elif condition == "or": - return "$or" - else: - raise ValueError(f"Filter condition {condition} not supported") - - -def _transform_chroma_filter_operator(operator: str) -> str: - """Translate standard metadata filter operator to Chroma specific spec.""" - if operator == "!=": - return "$ne" - elif operator == "==": - return "$eq" - elif operator == ">": - return "$gt" - elif operator == "<": - return "$lt" - elif operator == ">=": - return "$gte" - elif operator == "<=": - return "$lte" - else: - raise ValueError(f"Filter operator {operator} not supported") - - -def _to_chroma_filter( - standard_filters: MetadataFilters, -) -> dict: - """Translate standard metadata filters to Chroma specific spec.""" - filters = {} - filters_list = [] - condition = standard_filters.condition or "and" - condition = _transform_chroma_filter_condition(condition) - if standard_filters.filters: - for filter in standard_filters.filters: - if filter.operator: - filters_list.append({filter.key: {_transform_chroma_filter_operator(filter.operator): filter.value}}) - else: - filters_list.append({filter.key: filter.value}) - if len(filters_list) == 1: - # If there is only one filter, return it directly - return filters_list[0] - elif len(filters_list) > 1: - filters[condition] = filters_list - return filters - - -import_err_msg = "`chromadb` package not found, please run `pip install chromadb`" -MAX_CHUNK_SIZE = 41665 # One less than the max chunk size for ChromaDB - - -def chunk_list(lst: List[BaseNode], max_chunk_size: int) -> Generator[List[BaseNode], None, None]: - """Yield successive max_chunk_size-sized chunks from lst. - Args: - lst (List[BaseNode]): list of nodes with embeddings - max_chunk_size (int): max chunk size - Yields: - Generator[List[BaseNode], None, None]: list of nodes with embeddings - """ - for i in range(0, len(lst), max_chunk_size): - yield lst[i : i + max_chunk_size] - - -class ChromaVectorStore(BasePydanticVectorStore): - """Chroma vector store. - In this vector store, embeddings are stored within a ChromaDB collection. - During query time, the index uses ChromaDB to query for the top - k most similar nodes. - Args: - chroma_collection (chromadb.api.models.Collection.Collection): - ChromaDB collection instance - """ - - stores_text: bool = True - flat_metadata: bool = True - collection_name: Optional[str] - host: Optional[str] - port: Optional[str] - ssl: bool - headers: Optional[Dict[str, str]] - persist_dir: Optional[str] - collection_kwargs: Dict[str, Any] = Field(default_factory=dict) - _collection: Any = PrivateAttr() - - def __init__( - self, - chroma_collection: Optional[Any] = None, - collection_name: Optional[str] = None, - host: Optional[str] = None, - port: Optional[str] = None, - ssl: bool = False, - headers: Optional[Dict[str, str]] = None, - persist_dir: Optional[str] = None, - collection_kwargs: Optional[dict] = None, - **kwargs: Any, - ) -> None: - """Init params.""" - collection_kwargs = collection_kwargs or {} - if chroma_collection is None: - client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers) - self._collection = client.get_or_create_collection(name=collection_name, **collection_kwargs) - else: - self._collection = cast(Collection, chroma_collection) - super().__init__( - host=host, - port=port, - ssl=ssl, - headers=headers, - collection_name=collection_name, - persist_dir=persist_dir, - collection_kwargs=collection_kwargs or {}, - ) - - @classmethod - def from_collection(cls, collection: Any) -> "ChromaVectorStore": - try: - from chromadb import Collection - except ImportError: - raise ImportError(import_err_msg) - if not isinstance(collection, Collection): - raise Exception("argument is not chromadb collection instance") - return cls(chroma_collection=collection) - - @classmethod - def from_params( - cls, - collection_name: str, - host: Optional[str] = None, - port: Optional[str] = None, - ssl: bool = False, - headers: Optional[Dict[str, str]] = None, - persist_dir: Optional[str] = None, - collection_kwargs: dict = {}, - **kwargs: Any, - ) -> "ChromaVectorStore": - if persist_dir: - client = chromadb.PersistentClient(path=persist_dir) - collection = client.get_or_create_collection(name=collection_name, **collection_kwargs) - elif host and port: - client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers) - collection = client.get_or_create_collection(name=collection_name, **collection_kwargs) - else: - raise ValueError("Either `persist_dir` or (`host`,`port`) must be specified") - return cls( - chroma_collection=collection, - host=host, - port=port, - ssl=ssl, - headers=headers, - persist_dir=persist_dir, - collection_kwargs=collection_kwargs, - **kwargs, - ) - - @classmethod - def class_name(cls) -> str: - return "ChromaVectorStore" - - def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]: - """Add nodes to index. - Args: - nodes: List[BaseNode]: list of nodes with embeddings - """ - if not self._collection: - raise ValueError("Collection not initialized") - max_chunk_size = MAX_CHUNK_SIZE - node_chunks = chunk_list(nodes, max_chunk_size) - all_ids = [] - for node_chunk in node_chunks: - embeddings = [] - metadatas = [] - ids = [] - documents = [] - for node in node_chunk: - embeddings.append(node.get_embedding()) - metadata_dict = node_to_metadata_dict(node, remove_text=True, flat_metadata=self.flat_metadata) - for key in metadata_dict: - if metadata_dict[key] is None: - metadata_dict[key] = "" - metadatas.append(metadata_dict) - ids.append(node.node_id) - documents.append(node.get_content(metadata_mode=MetadataMode.NONE)) - self._collection.add( - embeddings=embeddings, - ids=ids, - metadatas=metadatas, - documents=documents, - ) - all_ids.extend(ids) - return all_ids - - def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: - """ - Delete nodes using with ref_doc_id. - Args: - ref_doc_id (str): The doc_id of the document to delete. - """ - self._collection.delete(where={"document_id": ref_doc_id}) - - @property - def client(self) -> Any: - """Return client.""" - return self._collection - - def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: - """Query index for top k most similar nodes. - Args: - query_embedding (List[float]): query embedding - similarity_top_k (int): top k most similar nodes - """ - if query.filters is not None: - if "where" in kwargs: - raise ValueError( - "Cannot specify metadata filters via both query and kwargs. " - "Use kwargs only for chroma specific items that are " - "not supported via the generic query interface." - ) - where = _to_chroma_filter(query.filters) - else: - where = kwargs.pop("where", {}) - results = self._collection.query( - query_embeddings=query.query_embedding, - n_results=query.similarity_top_k, - where=where, - **kwargs, - ) - logger.debug(f"> Top {len(results['documents'])} nodes:") - nodes = [] - similarities = [] - ids = [] - for node_id, text, metadata, distance in zip( - results["ids"][0], - results["documents"][0], - results["metadatas"][0], - results["distances"][0], - ): - try: - node = metadata_dict_to_node(metadata) - node.set_content(text) - except Exception: - # NOTE: deprecated legacy logic for backward compatibility - metadata, node_info, relationships = legacy_metadata_dict_to_node(metadata) - node = TextNode( - text=text, - id_=node_id, - metadata=metadata, - start_char_idx=node_info.get("start", None), - end_char_idx=node_info.get("end", None), - relationships=relationships, - ) - nodes.append(node) - similarity_score = math.exp(-distance) - similarities.append(similarity_score) - logger.debug( - f"> [Node {node_id}] [Similarity score: {similarity_score}] " f"{truncate_text(str(text), 100)}" - ) - ids.append(node_id) - return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids) diff --git a/setup.py b/setup.py index c728872ef2..4fa5499da6 100644 --- a/setup.py +++ b/setup.py @@ -37,8 +37,8 @@ def run(self): "llama-index-retrievers-bm25==0.1.3", "llama-index-vector-stores-faiss==0.1.1", "llama-index-vector-stores-elasticsearch==0.1.6", + "llama-index-vector-stores-chroma==0.1.6", "llama-index-postprocessor-colbert-rerank==0.1.1", - "chromadb==0.4.23", ], } diff --git a/tests/metagpt/rag/engines/test_simple.py b/tests/metagpt/rag/engines/test_simple.py index 5627957c73..9262ccb07e 100644 --- a/tests/metagpt/rag/engines/test_simple.py +++ b/tests/metagpt/rag/engines/test_simple.py @@ -1,12 +1,26 @@ +import json + import pytest from llama_index.core import VectorStoreIndex -from llama_index.core.schema import Document, TextNode +from llama_index.core.embeddings import MockEmbedding +from llama_index.core.llms import MockLLM +from llama_index.core.schema import Document, NodeWithScore, TextNode from metagpt.rag.engines import SimpleEngine -from metagpt.rag.retrievers.base import ModifiableRAGRetriever +from metagpt.rag.retrievers import SimpleHybridRetriever +from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever +from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode class TestSimpleEngine: + @pytest.fixture + def mock_llm(self): + return MockLLM() + + @pytest.fixture + def mock_embedding(self): + return MockEmbedding(embed_dim=1) + @pytest.fixture def mock_simple_directory_reader(self, mocker): return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader") @@ -54,7 +68,7 @@ def test_from_docs( retriever_configs = [mocker.MagicMock()] ranker_configs = [mocker.MagicMock()] - # Execute + # Exec engine = SimpleEngine.from_docs( input_dir=input_dir, input_files=input_files, @@ -65,7 +79,7 @@ def test_from_docs( ranker_configs=ranker_configs, ) - # Assertions + # Assert mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files) mock_vector_store_index.assert_called_once() mock_get_retriever.assert_called_once_with( @@ -75,6 +89,68 @@ def test_from_docs( mock_get_response_synthesizer.assert_called_once_with(llm=llm) assert isinstance(engine, SimpleEngine) + def test_from_docs_without_file(self): + with pytest.raises(ValueError): + SimpleEngine.from_docs() + + def test_from_objs(self, mock_llm, mock_embedding): + # Mock + class MockRAGObject: + def rag_key(self): + return "key" + + def model_dump_json(self): + return "{}" + + objs = [MockRAGObject()] + + # Setup + retriever_configs = [] + ranker_configs = [] + + # Exec + engine = SimpleEngine.from_objs( + objs=objs, + llm=mock_llm, + embed_model=mock_embedding, + retriever_configs=retriever_configs, + ranker_configs=ranker_configs, + ) + + # Assert + assert isinstance(engine, SimpleEngine) + assert engine.index is not None + + def test_from_objs_with_bm25_config(self): + # Setup + retriever_configs = [BM25RetrieverConfig()] + + # Exec + with pytest.raises(ValueError): + SimpleEngine.from_objs( + objs=[], + llm=MockLLM(), + retriever_configs=retriever_configs, + ranker_configs=[], + ) + + def test_from_index(self, mocker, mock_llm, mock_embedding): + # Mock + mock_index = mocker.MagicMock(spec=VectorStoreIndex) + mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index") + mock_get_index.return_value = mock_index + + # Exec + engine = SimpleEngine.from_index( + index_config=mock_index, + embed_model=mock_embedding, + llm=mock_llm, + ) + + # Assert + assert isinstance(engine, SimpleEngine) + assert engine.index is mock_index + @pytest.mark.asyncio async def test_asearch(self, mocker): # Mock @@ -86,10 +162,10 @@ async def test_asearch(self, mocker): engine = SimpleEngine(retriever=mocker.MagicMock()) engine.aquery = mock_aquery - # Execute + # Exec result = await engine.asearch(test_query) - # Assertions + # Assert mock_aquery.assert_called_once_with(test_query) assert result == expected_result @@ -106,10 +182,10 @@ async def test_aretrieve(self, mocker): engine = SimpleEngine(retriever=mocker.MagicMock()) test_query = "test query" - # Execute + # Exec result = await engine.aretrieve(test_query) - # Assertions + # Assert mock_query_bundle.assert_called_once_with(test_query) mock_super_aretrieve.assert_called_once_with("query_bundle") assert result[0].text == "node_with_score" @@ -134,10 +210,10 @@ def test_add_docs(self, mocker): engine = SimpleEngine(retriever=mock_retriever, index=mock_index) input_files = ["test_file1", "test_file2"] - # Execute + # Exec engine.add_docs(input_files=input_files) - # Assertions + # Assert mock_simple_directory_reader.assert_called_once_with(input_files=input_files) mock_retriever.add_nodes.assert_called_once_with(["node1", "node2"]) @@ -156,11 +232,79 @@ def model_dump_json(self): objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)] engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock()) - # Execute + # Exec engine.add_objs(objs=objs) - # Assertions + # Assert assert mock_retriever.add_nodes.call_count == 1 for node in mock_retriever.add_nodes.call_args[0][0]: assert isinstance(node, TextNode) assert "is_obj" in node.metadata + + def test_persist_successfully(self, mocker): + # Mock + mock_retriever = mocker.MagicMock(spec=PersistableRAGRetriever) + mock_retriever.persist.return_value = mocker.MagicMock() + + # Setup + engine = SimpleEngine(retriever=mock_retriever) + + # Exec + engine.persist(persist_dir="") + + def test_ensure_retriever_of_type(self, mocker): + # Mock + class MyRetriever: + def add_nodes(self): + ... + + mock_retriever = mocker.MagicMock(spec=SimpleHybridRetriever) + mock_retriever.retrievers = [MyRetriever()] + + # Setup + engine = SimpleEngine(retriever=mock_retriever) + + # Assert + engine._ensure_retriever_of_type(ModifiableRAGRetriever) + + with pytest.raises(TypeError): + engine._ensure_retriever_of_type(PersistableRAGRetriever) + + with pytest.raises(TypeError): + other_engine = SimpleEngine(retriever=mocker.MagicMock(spec=ModifiableRAGRetriever)) + other_engine._ensure_retriever_of_type(PersistableRAGRetriever) + + def test_with_obj_metadata(self, mocker): + # Mock + node = NodeWithScore( + node=ObjectNode( + text="example", + metadata={ + "is_obj": True, + "obj_cls_name": "ExampleObject", + "obj_mod_name": "__main__", + "obj_json": json.dumps({"key": "test_key", "value": "test_value"}), + }, + ) + ) + + class ExampleObject: + def __init__(self, key, value): + self.key = key + self.value = value + + def __eq__(self, other): + return self.key == other.key and self.value == other.value + + mock_import_class = mocker.patch("metagpt.rag.engines.simple.import_class") + mock_import_class.return_value = ExampleObject + + # Setup + SimpleEngine._try_reconstruct_obj([node]) + + # Exec + expected_obj = ExampleObject(key="test_key", value="test_value") + + # Assert + assert "obj" in node.node.metadata + assert node.node.metadata["obj"] == expected_obj diff --git a/tests/metagpt/rag/factories/test_embedding.py b/tests/metagpt/rag/factories/test_embedding.py new file mode 100644 index 0000000000..1ded6b4a8d --- /dev/null +++ b/tests/metagpt/rag/factories/test_embedding.py @@ -0,0 +1,43 @@ +import pytest + +from metagpt.configs.llm_config import LLMType +from metagpt.rag.factories.embedding import RAGEmbeddingFactory + + +class TestRAGEmbeddingFactory: + @pytest.fixture(autouse=True) + def mock_embedding_factory(self): + self.embedding_factory = RAGEmbeddingFactory() + + @pytest.fixture + def mock_openai_embedding(self, mocker): + return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding") + + @pytest.fixture + def mock_azure_embedding(self, mocker): + return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding") + + def test_get_rag_embedding_openai(self, mock_openai_embedding): + # Exec + self.embedding_factory.get_rag_embedding(LLMType.OPENAI) + + # Assert + mock_openai_embedding.assert_called_once() + + def test_get_rag_embedding_azure(self, mock_azure_embedding): + # Exec + self.embedding_factory.get_rag_embedding(LLMType.AZURE) + + # Assert + mock_azure_embedding.assert_called_once() + + def test_get_rag_embedding_default(self, mocker, mock_openai_embedding): + # Mock + mock_config = mocker.patch("metagpt.rag.factories.embedding.config") + mock_config.llm.api_type = LLMType.OPENAI + + # Exec + self.embedding_factory.get_rag_embedding() + + # Assert + mock_openai_embedding.assert_called_once() diff --git a/tests/metagpt/rag/factories/test_index.py b/tests/metagpt/rag/factories/test_index.py new file mode 100644 index 0000000000..9dc5bfb6be --- /dev/null +++ b/tests/metagpt/rag/factories/test_index.py @@ -0,0 +1,89 @@ +import pytest +from llama_index.core.embeddings import MockEmbedding + +from metagpt.rag.factories.index import RAGIndexFactory +from metagpt.rag.schema import ( + BM25IndexConfig, + ChromaIndexConfig, + ElasticsearchIndexConfig, + ElasticsearchStoreConfig, + FAISSIndexConfig, +) + + +class TestRAGIndexFactory: + @pytest.fixture(autouse=True) + def setup(self): + self.index_factory = RAGIndexFactory() + + @pytest.fixture + def faiss_config(self): + return FAISSIndexConfig(persist_path="") + + @pytest.fixture + def chroma_config(self): + return ChromaIndexConfig(persist_path="", collection_name="") + + @pytest.fixture + def bm25_config(self): + return BM25IndexConfig(persist_path="") + + @pytest.fixture + def es_config(self, mocker): + return ElasticsearchIndexConfig(store_config=ElasticsearchStoreConfig()) + + @pytest.fixture + def mock_storage_context(self, mocker): + return mocker.patch("metagpt.rag.factories.index.StorageContext.from_defaults") + + @pytest.fixture + def mock_load_index_from_storage(self, mocker): + return mocker.patch("metagpt.rag.factories.index.load_index_from_storage") + + @pytest.fixture + def mock_from_vector_store(self, mocker): + return mocker.patch("metagpt.rag.factories.index.VectorStoreIndex.from_vector_store") + + @pytest.fixture + def mock_embedding(self): + return MockEmbedding(embed_dim=1) + + def test_create_faiss_index( + self, mocker, faiss_config, mock_storage_context, mock_load_index_from_storage, mock_embedding + ): + # Mock + mock_faiss_store = mocker.patch("metagpt.rag.factories.index.FaissVectorStore.from_persist_dir") + + # Exec + self.index_factory.get_index(faiss_config, embed_model=mock_embedding) + + # Assert + mock_faiss_store.assert_called_once() + + def test_create_bm25_index( + self, mocker, bm25_config, mock_storage_context, mock_load_index_from_storage, mock_embedding + ): + self.index_factory.get_index(bm25_config, embed_model=mock_embedding) + + def test_create_chroma_index(self, mocker, chroma_config, mock_from_vector_store, mock_embedding): + # Mock + mock_chroma_db = mocker.patch("metagpt.rag.factories.index.chromadb.PersistentClient") + mock_chroma_db.get_or_create_collection.return_value = mocker.MagicMock() + + mock_chroma_store = mocker.patch("metagpt.rag.factories.index.ChromaVectorStore") + + # Exec + self.index_factory.get_index(chroma_config, embed_model=mock_embedding) + + # Assert + mock_chroma_store.assert_called_once() + + def test_create_es_index(self, mocker, es_config, mock_from_vector_store, mock_embedding): + # Mock + mock_es_store = mocker.patch("metagpt.rag.factories.index.ElasticsearchStore") + + # Exec + self.index_factory.get_index(es_config, embed_model=mock_embedding) + + # Assert + mock_es_store.assert_called_once() diff --git a/tests/metagpt/rag/factories/test_llm.py b/tests/metagpt/rag/factories/test_llm.py new file mode 100644 index 0000000000..e11b87076c --- /dev/null +++ b/tests/metagpt/rag/factories/test_llm.py @@ -0,0 +1,71 @@ +from typing import Optional, Union + +import pytest +from llama_index.core.llms import LLMMetadata + +from metagpt.configs.llm_config import LLMConfig +from metagpt.const import USE_CONFIG_TIMEOUT +from metagpt.provider.base_llm import BaseLLM +from metagpt.rag.factories.llm import RAGLLM, get_rag_llm + + +class MockLLM(BaseLLM): + def __init__(self, config: LLMConfig): + ... + + async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): + """_achat_completion implemented by inherited class""" + + async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): + return "ok" + + def completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): + return "ok" + + async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: + """_achat_completion_stream implemented by inherited class""" + + async def aask( + self, + msg: Union[str, list[dict[str, str]]], + system_msgs: Optional[list[str]] = None, + format_msgs: Optional[list[dict[str, str]]] = None, + images: Optional[Union[str, list[str]]] = None, + timeout=USE_CONFIG_TIMEOUT, + stream=True, + ) -> str: + return "ok" + + +class TestRAGLLM: + @pytest.fixture + def mock_model_infer(self): + return MockLLM(config=LLMConfig()) + + @pytest.fixture + def rag_llm(self, mock_model_infer): + return RAGLLM(model_infer=mock_model_infer) + + def test_metadata(self, rag_llm): + metadata = rag_llm.metadata + assert isinstance(metadata, LLMMetadata) + assert metadata.context_window == rag_llm.context_window + assert metadata.num_output == rag_llm.num_output + assert metadata.model_name == rag_llm.model_name + + @pytest.mark.asyncio + async def test_acomplete(self, rag_llm, mock_model_infer): + response = await rag_llm.acomplete("question") + assert response.text == "ok" + + def test_complete(self, rag_llm, mock_model_infer): + response = rag_llm.complete("question") + assert response.text == "ok" + + def test_stream_complete(self, rag_llm, mock_model_infer): + rag_llm.stream_complete("question") + + +def test_get_rag_llm(): + result = get_rag_llm(MockLLM(config=LLMConfig())) + assert isinstance(result, RAGLLM) diff --git a/tests/metagpt/rag/factories/test_ranker.py b/tests/metagpt/rag/factories/test_ranker.py index 563cffa738..3f6b94b47a 100644 --- a/tests/metagpt/rag/factories/test_ranker.py +++ b/tests/metagpt/rag/factories/test_ranker.py @@ -1,41 +1,57 @@ import pytest -from llama_index.core.llms import LLM +from llama_index.core.llms import MockLLM from llama_index.core.postprocessor import LLMRerank from metagpt.rag.factories.ranker import RankerFactory -from metagpt.rag.schema import LLMRankerConfig +from metagpt.rag.schema import ColbertRerankConfig, LLMRankerConfig, ObjectRankerConfig class TestRankerFactory: - @pytest.fixture - def ranker_factory(self) -> RankerFactory: - return RankerFactory() + @pytest.fixture(autouse=True) + def ranker_factory(self): + self.ranker_factory: RankerFactory = RankerFactory() @pytest.fixture - def mock_llm(self, mocker): - return mocker.MagicMock(spec=LLM) + def mock_llm(self): + return MockLLM() - def test_get_rankers_with_no_configs(self, ranker_factory: RankerFactory, mock_llm, mocker): - mocker.patch.object(ranker_factory, "_extract_llm", return_value=mock_llm) - default_rankers = ranker_factory.get_rankers() + def test_get_rankers_with_no_configs(self, mock_llm, mocker): + mocker.patch.object(self.ranker_factory, "_extract_llm", return_value=mock_llm) + default_rankers = self.ranker_factory.get_rankers() assert len(default_rankers) == 0 - def test_get_rankers_with_configs(self, ranker_factory: RankerFactory, mock_llm): + def test_get_rankers_with_configs(self, mock_llm): mock_config = LLMRankerConfig(llm=mock_llm) - rankers = ranker_factory.get_rankers(configs=[mock_config]) + rankers = self.ranker_factory.get_rankers(configs=[mock_config]) assert len(rankers) == 1 assert isinstance(rankers[0], LLMRerank) - def test_create_llm_ranker_creates_correct_instance(self, ranker_factory: RankerFactory, mock_llm): - mock_config = LLMRankerConfig(llm=mock_llm) - ranker = ranker_factory._create_llm_ranker(mock_config) - assert isinstance(ranker, LLMRerank) - - def test_extract_llm_from_config(self, ranker_factory: RankerFactory, mock_llm): + def test_extract_llm_from_config(self, mock_llm): mock_config = LLMRankerConfig(llm=mock_llm) - extracted_llm = ranker_factory._extract_llm(config=mock_config) + extracted_llm = self.ranker_factory._extract_llm(config=mock_config) assert extracted_llm == mock_llm - def test_extract_llm_from_kwargs(self, ranker_factory: RankerFactory, mock_llm): - extracted_llm = ranker_factory._extract_llm(llm=mock_llm) + def test_extract_llm_from_kwargs(self, mock_llm): + extracted_llm = self.ranker_factory._extract_llm(llm=mock_llm) assert extracted_llm == mock_llm + + def test_create_llm_ranker(self, mock_llm): + mock_config = LLMRankerConfig(llm=mock_llm) + ranker = self.ranker_factory._create_llm_ranker(mock_config) + assert isinstance(ranker, LLMRerank) + + def test_create_colbert_ranker(self, mocker, mock_llm): + mocker.patch("metagpt.rag.factories.ranker.ColbertRerank", return_value="colbert") + + mock_config = ColbertRerankConfig(llm=mock_llm) + ranker = self.ranker_factory._create_colbert_ranker(mock_config) + + assert ranker == "colbert" + + def test_create_object_ranker(self, mocker, mock_llm): + mocker.patch("metagpt.rag.factories.ranker.ObjectSortPostprocessor", return_value="object") + + mock_config = ObjectRankerConfig(field_name="fake", llm=mock_llm) + ranker = self.ranker_factory._create_object_ranker(mock_config) + + assert ranker == "object" diff --git a/tests/metagpt/rag/factories/test_retriever.py b/tests/metagpt/rag/factories/test_retriever.py index ac8926d468..ef1cef7e00 100644 --- a/tests/metagpt/rag/factories/test_retriever.py +++ b/tests/metagpt/rag/factories/test_retriever.py @@ -1,18 +1,28 @@ import faiss import pytest from llama_index.core import VectorStoreIndex +from llama_index.vector_stores.chroma import ChromaVectorStore +from llama_index.vector_stores.elasticsearch import ElasticsearchStore from metagpt.rag.factories.retriever import RetrieverFactory from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever +from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever +from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever -from metagpt.rag.schema import BM25RetrieverConfig, FAISSRetrieverConfig +from metagpt.rag.schema import ( + BM25RetrieverConfig, + ChromaRetrieverConfig, + ElasticsearchRetrieverConfig, + ElasticsearchStoreConfig, + FAISSRetrieverConfig, +) class TestRetrieverFactory: - @pytest.fixture + @pytest.fixture(autouse=True) def retriever_factory(self): - return RetrieverFactory() + self.retriever_factory: RetrieverFactory = RetrieverFactory() @pytest.fixture def mock_faiss_index(self, mocker): @@ -25,55 +35,79 @@ def mock_vector_store_index(self, mocker): mock.docstore.docs.values.return_value = [] return mock - def test_get_retriever_with_faiss_config( - self, retriever_factory: RetrieverFactory, mock_faiss_index, mocker, mock_vector_store_index - ): + @pytest.fixture + def mock_chroma_vector_store(self, mocker): + return mocker.MagicMock(spec=ChromaVectorStore) + + @pytest.fixture + def mock_es_vector_store(self, mocker): + return mocker.MagicMock(spec=ElasticsearchStore) + + def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_vector_store_index): mock_config = FAISSRetrieverConfig(dimensions=128) mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index) - mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index) + mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) - retriever = retriever_factory.get_retriever(configs=[mock_config]) + retriever = self.retriever_factory.get_retriever(configs=[mock_config]) assert isinstance(retriever, FAISSRetriever) - def test_get_retriever_with_bm25_config(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index): + def test_get_retriever_with_bm25_config(self, mocker, mock_vector_store_index): mock_config = BM25RetrieverConfig() mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None) - mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index) + mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) - retriever = retriever_factory.get_retriever(configs=[mock_config]) + retriever = self.retriever_factory.get_retriever(configs=[mock_config]) assert isinstance(retriever, DynamicBM25Retriever) - def test_get_retriever_with_multiple_configs_returns_hybrid( - self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index - ): + def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_vector_store_index): mock_faiss_config = FAISSRetrieverConfig(dimensions=128) mock_bm25_config = BM25RetrieverConfig() mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None) - mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index) + mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) - retriever = retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config]) + retriever = self.retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config]) assert isinstance(retriever, SimpleHybridRetriever) - def test_create_default_retriever(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index): - mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index) + def test_get_retriever_with_chroma_config(self, mocker, mock_vector_store_index, mock_chroma_vector_store): + mock_config = ChromaRetrieverConfig(persist_path="/path/to/chroma", collection_name="test_collection") + mock_chromadb = mocker.patch("metagpt.rag.factories.retriever.chromadb.PersistentClient") + mock_chromadb.get_or_create_collection.return_value = mocker.MagicMock() + mocker.patch("metagpt.rag.factories.retriever.ChromaVectorStore", return_value=mock_chroma_vector_store) + mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) + + retriever = self.retriever_factory.get_retriever(configs=[mock_config]) + + assert isinstance(retriever, ChromaRetriever) + + def test_get_retriever_with_es_config(self, mocker, mock_vector_store_index, mock_es_vector_store): + mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig()) + mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store) + mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) + + retriever = self.retriever_factory.get_retriever(configs=[mock_config]) + + assert isinstance(retriever, ElasticsearchRetriever) + + def test_create_default_retriever(self, mocker, mock_vector_store_index): + mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index) mock_vector_store_index.as_retriever = mocker.MagicMock() - retriever = retriever_factory.get_retriever() + retriever = self.retriever_factory.get_retriever() mock_vector_store_index.as_retriever.assert_called_once() assert retriever is mock_vector_store_index.as_retriever.return_value - def test_extract_index_from_config(self, retriever_factory: RetrieverFactory, mock_vector_store_index): + def test_extract_index_from_config(self, mock_vector_store_index): mock_config = FAISSRetrieverConfig(index=mock_vector_store_index) - extracted_index = retriever_factory._extract_index(config=mock_config) + extracted_index = self.retriever_factory._extract_index(config=mock_config) assert extracted_index == mock_vector_store_index - def test_extract_index_from_kwargs(self, retriever_factory: RetrieverFactory, mock_vector_store_index): - extracted_index = retriever_factory._extract_index(index=mock_vector_store_index) + def test_extract_index_from_kwargs(self, mock_vector_store_index): + extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index) assert extracted_index == mock_vector_store_index diff --git a/tests/metagpt/rag/rankers/test_base_ranker.py b/tests/metagpt/rag/rankers/test_base_ranker.py new file mode 100644 index 0000000000..9755949f6a --- /dev/null +++ b/tests/metagpt/rag/rankers/test_base_ranker.py @@ -0,0 +1,23 @@ +import pytest +from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode + +from metagpt.rag.rankers.base import RAGRanker + + +class SimpleRAGRanker(RAGRanker): + def _postprocess_nodes(self, nodes, query_bundle=None): + return [NodeWithScore(node=node.node, score=node.score + 1) for node in nodes] + + +class TestSimpleRAGRanker: + @pytest.fixture + def ranker(self): + return SimpleRAGRanker() + + def test_postprocess_nodes_increases_scores(self, ranker): + nodes = [NodeWithScore(node=TextNode(text="a"), score=10), NodeWithScore(node=TextNode(text="b"), score=20)] + query_bundle = QueryBundle(query_str="test query") + + processed_nodes = ranker._postprocess_nodes(nodes, query_bundle) + + assert all(node.score == original_node.score + 1 for node, original_node in zip(processed_nodes, nodes)) diff --git a/tests/metagpt/rag/rankers/test_object_ranker.py b/tests/metagpt/rag/rankers/test_object_ranker.py index 7ea6b7488b..4a9f66a42d 100644 --- a/tests/metagpt/rag/rankers/test_object_ranker.py +++ b/tests/metagpt/rag/rankers/test_object_ranker.py @@ -14,7 +14,7 @@ class Record(BaseModel): class TestObjectSortPostprocessor: @pytest.fixture - def nodes_with_scores(self): + def mock_nodes_with_scores(self): nodes = [ NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=10).model_dump_json()}), score=10), NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=20).model_dump_json()}), score=20), @@ -23,38 +23,47 @@ def nodes_with_scores(self): return nodes @pytest.fixture - def query_bundle(self, mocker): + def mock_query_bundle(self, mocker): return mocker.MagicMock(spec=QueryBundle) - def test_sort_descending(self, nodes_with_scores, query_bundle): + def test_sort_descending(self, mock_nodes_with_scores, mock_query_bundle): postprocessor = ObjectSortPostprocessor(field_name="score", order="desc") - sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle) + sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle) assert [node.score for node in sorted_nodes] == [20, 10, 5] - def test_sort_ascending(self, nodes_with_scores, query_bundle): + def test_sort_ascending(self, mock_nodes_with_scores, mock_query_bundle): postprocessor = ObjectSortPostprocessor(field_name="score", order="asc") - sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle) + sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle) assert [node.score for node in sorted_nodes] == [5, 10, 20] - def test_top_n_limit(self, nodes_with_scores, query_bundle): + def test_top_n_limit(self, mock_nodes_with_scores, mock_query_bundle): postprocessor = ObjectSortPostprocessor(field_name="score", order="desc", top_n=2) - sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle) + sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle) assert len(sorted_nodes) == 2 assert [node.score for node in sorted_nodes] == [20, 10] - def test_invalid_json_metadata(self, query_bundle): + def test_invalid_json_metadata(self, mock_query_bundle): nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": "invalid_json"}), score=10)] postprocessor = ObjectSortPostprocessor(field_name="score", order="desc") with pytest.raises(ValueError): - postprocessor._postprocess_nodes(nodes, query_bundle) + postprocessor._postprocess_nodes(nodes, mock_query_bundle) - def test_missing_query_bundle(self, nodes_with_scores): + def test_missing_query_bundle(self, mock_nodes_with_scores): postprocessor = ObjectSortPostprocessor(field_name="score", order="desc") with pytest.raises(ValueError): - postprocessor._postprocess_nodes(nodes_with_scores, query_bundle=None) + postprocessor._postprocess_nodes(mock_nodes_with_scores, query_bundle=None) - def test_field_not_found_in_object(self): + def test_field_not_found_in_object(self, mock_query_bundle): nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": json.dumps({"not_score": 10})}), score=10)] postprocessor = ObjectSortPostprocessor(field_name="score", order="desc") with pytest.raises(ValueError): - postprocessor._postprocess_nodes(nodes) + postprocessor._postprocess_nodes(nodes, query_bundle=mock_query_bundle) + + def test_not_nodes(self, mock_query_bundle): + nodes = [] + postprocessor = ObjectSortPostprocessor(field_name="score", order="desc") + result = postprocessor._postprocess_nodes(nodes, mock_query_bundle) + assert result == [] + + def test_class_name(self): + assert ObjectSortPostprocessor.class_name() == "ObjectSortPostprocessor" diff --git a/tests/metagpt/rag/retrievers/test_base_retriever.py b/tests/metagpt/rag/retrievers/test_base_retriever.py new file mode 100644 index 0000000000..1065b9731d --- /dev/null +++ b/tests/metagpt/rag/retrievers/test_base_retriever.py @@ -0,0 +1,21 @@ +from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever + + +class SubModifiableRAGRetriever(ModifiableRAGRetriever): + ... + + +class SubPersistableRAGRetriever(PersistableRAGRetriever): + ... + + +class TestModifiableRAGRetriever: + def test_subclasshook(self): + result = SubModifiableRAGRetriever.__subclasshook__(SubModifiableRAGRetriever) + assert result is NotImplemented + + +class TestPersistableRAGRetriever: + def test_subclasshook(self): + result = SubPersistableRAGRetriever.__subclasshook__(SubPersistableRAGRetriever) + assert result is NotImplemented diff --git a/tests/metagpt/rag/retrievers/test_bm25_retriever.py b/tests/metagpt/rag/retrievers/test_bm25_retriever.py index 28b37c86b8..5a569f1036 100644 --- a/tests/metagpt/rag/retrievers/test_bm25_retriever.py +++ b/tests/metagpt/rag/retrievers/test_bm25_retriever.py @@ -8,30 +8,30 @@ class TestDynamicBM25Retriever: @pytest.fixture(autouse=True) def setup(self, mocker): - # 创建模拟的Document对象 self.doc1 = mocker.MagicMock(spec=Node) self.doc1.get_content.return_value = "Document content 1" self.doc2 = mocker.MagicMock(spec=Node) self.doc2.get_content.return_value = "Document content 2" self.mock_nodes = [self.doc1, self.doc2] - # 模拟index index = mocker.MagicMock(spec=VectorStoreIndex) + index.storage_context.persist.return_value = "ok" - # 模拟nodes和tokenizer参数 mock_nodes = [] mock_tokenizer = mocker.MagicMock() self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None) - # 初始化DynamicBM25Retriever对象,并提供必需的参数 self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer, index=index) def test_add_docs_updates_nodes_and_corpus(self): - # Execute + # Exec self.retriever.add_nodes(self.mock_nodes) - # Assertions + # Assert assert len(self.retriever._nodes) == len(self.mock_nodes) assert len(self.retriever._corpus) == len(self.mock_nodes) self.retriever._tokenizer.assert_called() self.mock_bm25okapi.assert_called() + + def test_persist(self): + self.retriever.persist("") diff --git a/tests/metagpt/rag/retrievers/test_chroma_retriever.py b/tests/metagpt/rag/retrievers/test_chroma_retriever.py new file mode 100644 index 0000000000..cf07903cf2 --- /dev/null +++ b/tests/metagpt/rag/retrievers/test_chroma_retriever.py @@ -0,0 +1,20 @@ +import pytest +from llama_index.core.schema import Node + +from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever + + +class TestChromaRetriever: + @pytest.fixture(autouse=True) + def setup(self, mocker): + self.doc1 = mocker.MagicMock(spec=Node) + self.doc2 = mocker.MagicMock(spec=Node) + self.mock_nodes = [self.doc1, self.doc2] + + self.mock_index = mocker.MagicMock() + self.retriever = ChromaRetriever(self.mock_index) + + def test_add_nodes(self): + self.retriever.add_nodes(self.mock_nodes) + + self.mock_index.insert_nodes.assert_called() diff --git a/tests/metagpt/rag/retrievers/test_es_retriever.py b/tests/metagpt/rag/retrievers/test_es_retriever.py new file mode 100644 index 0000000000..1824bfbd28 --- /dev/null +++ b/tests/metagpt/rag/retrievers/test_es_retriever.py @@ -0,0 +1,20 @@ +import pytest +from llama_index.core.schema import Node + +from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever + + +class TestElasticsearchRetriever: + @pytest.fixture(autouse=True) + def setup(self, mocker): + self.doc1 = mocker.MagicMock(spec=Node) + self.doc2 = mocker.MagicMock(spec=Node) + self.mock_nodes = [self.doc1, self.doc2] + + self.mock_index = mocker.MagicMock() + self.retriever = ElasticsearchRetriever(self.mock_index) + + def test_add_nodes(self): + self.retriever.add_nodes(self.mock_nodes) + + self.mock_index.insert_nodes.assert_called() diff --git a/tests/metagpt/rag/retrievers/test_faiss_retriever.py b/tests/metagpt/rag/retrievers/test_faiss_retriever.py index 9113f110cf..8546732157 100644 --- a/tests/metagpt/rag/retrievers/test_faiss_retriever.py +++ b/tests/metagpt/rag/retrievers/test_faiss_retriever.py @@ -7,16 +7,19 @@ class TestFAISSRetriever: @pytest.fixture(autouse=True) def setup(self, mocker): - # 创建模拟的Document对象 self.doc1 = mocker.MagicMock(spec=Node) self.doc2 = mocker.MagicMock(spec=Node) self.mock_nodes = [self.doc1, self.doc2] - # 模拟FAISSRetriever的_index属性 self.mock_index = mocker.MagicMock() self.retriever = FAISSRetriever(self.mock_index) - def test_add_docs_calls_insert_for_each_document(self, mocker): + def test_add_docs_calls_insert_for_each_document(self): self.retriever.add_nodes(self.mock_nodes) - assert self.mock_index.insert_nodes.assert_called + self.mock_index.insert_nodes.assert_called() + + def test_persist(self): + self.retriever.persist("") + + self.mock_index.storage_context.persist.assert_called() diff --git a/tests/metagpt/rag/retrievers/test_hybrid_retriever.py b/tests/metagpt/rag/retrievers/test_hybrid_retriever.py index 8cc3087c86..da150d8793 100644 --- a/tests/metagpt/rag/retrievers/test_hybrid_retriever.py +++ b/tests/metagpt/rag/retrievers/test_hybrid_retriever.py @@ -1,5 +1,3 @@ -from unittest.mock import AsyncMock - import pytest from llama_index.core.schema import NodeWithScore, TextNode @@ -7,18 +5,30 @@ class TestSimpleHybridRetriever: + @pytest.fixture + def mock_retriever(self, mocker): + return mocker.MagicMock() + + @pytest.fixture + def mock_hybrid_retriever(self, mock_retriever) -> SimpleHybridRetriever: + return SimpleHybridRetriever(mock_retriever) + + @pytest.fixture + def mock_node(self): + return NodeWithScore(node=TextNode(id_="2"), score=0.95) + @pytest.mark.asyncio - async def test_aretrieve(self): + async def test_aretrieve(self, mocker): question = "test query" # Create mock retrievers - mock_retriever1 = AsyncMock() + mock_retriever1 = mocker.AsyncMock() mock_retriever1.aretrieve.return_value = [ NodeWithScore(node=TextNode(id_="1"), score=1.0), NodeWithScore(node=TextNode(id_="2"), score=0.95), ] - mock_retriever2 = AsyncMock() + mock_retriever2 = mocker.AsyncMock() mock_retriever2.aretrieve.return_value = [ NodeWithScore(node=TextNode(id_="2"), score=0.95), NodeWithScore(node=TextNode(id_="3"), score=0.8), @@ -37,3 +47,11 @@ async def test_aretrieve(self): # Check if the scores are correct (assuming you want the highest score) node_scores = {node.node.node_id: node.score for node in results} assert node_scores["2"] == 0.95 + + def test_add_nodes(self, mock_hybrid_retriever: SimpleHybridRetriever, mock_node): + mock_hybrid_retriever.add_nodes([mock_node]) + mock_hybrid_retriever.retrievers[0].add_nodes.assert_called_once() + + def test_persist(self, mock_hybrid_retriever: SimpleHybridRetriever): + mock_hybrid_retriever.persist("") + mock_hybrid_retriever.retrievers[0].persist.assert_called_once() From 676d5ff55ba6468292fdb99af3617fde819f3433 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 28 Mar 2024 15:47:38 +0800 Subject: [PATCH 2/6] lazy import colbert --- metagpt/rag/factories/ranker.py | 7 ++++++- setup.py | 1 - 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/metagpt/rag/factories/ranker.py b/metagpt/rag/factories/ranker.py index 07cb1b929f..476fe8c1a6 100644 --- a/metagpt/rag/factories/ranker.py +++ b/metagpt/rag/factories/ranker.py @@ -3,7 +3,6 @@ from llama_index.core.llms import LLM from llama_index.core.postprocessor import LLMRerank from llama_index.core.postprocessor.types import BaseNodePostprocessor -from llama_index.postprocessor.colbert_rerank import ColbertRerank from metagpt.rag.factories.base import ConfigBasedFactory from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor @@ -38,6 +37,12 @@ def _create_llm_ranker(self, config: LLMRankerConfig, **kwargs) -> LLMRerank: return LLMRerank(**config.model_dump()) def _create_colbert_ranker(self, config: ColbertRerankConfig, **kwargs) -> LLMRerank: + try: + from llama_index.postprocessor.colbert_rerank import ColbertRerank + except ImportError: + raise ImportError( + "`llama-index-postprocessor-colbert-rerank` package not found, please run `pip install llama-index-postprocessor-colbert-rerank`" + ) return ColbertRerank(**config.model_dump()) def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank: diff --git a/setup.py b/setup.py index 4fa5499da6..3eab2b6a09 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,6 @@ def run(self): "llama-index-vector-stores-faiss==0.1.1", "llama-index-vector-stores-elasticsearch==0.1.6", "llama-index-vector-stores-chroma==0.1.6", - "llama-index-postprocessor-colbert-rerank==0.1.1", ], } From b355f715bd26d338409a6371ea0a6dbdcdab38a8 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 28 Mar 2024 16:06:32 +0800 Subject: [PATCH 3/6] lazy import colbert --- tests/metagpt/rag/factories/test_ranker.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/metagpt/rag/factories/test_ranker.py b/tests/metagpt/rag/factories/test_ranker.py index 3f6b94b47a..e40f7f8dff 100644 --- a/tests/metagpt/rag/factories/test_ranker.py +++ b/tests/metagpt/rag/factories/test_ranker.py @@ -1,3 +1,5 @@ +import contextlib + import pytest from llama_index.core.llms import MockLLM from llama_index.core.postprocessor import LLMRerank @@ -41,12 +43,13 @@ def test_create_llm_ranker(self, mock_llm): assert isinstance(ranker, LLMRerank) def test_create_colbert_ranker(self, mocker, mock_llm): - mocker.patch("metagpt.rag.factories.ranker.ColbertRerank", return_value="colbert") + with contextlib.suppress(ImportError): + mocker.patch("llama_index.postprocessor.colbert_rerank.ColbertRerank", return_value="colbert") - mock_config = ColbertRerankConfig(llm=mock_llm) - ranker = self.ranker_factory._create_colbert_ranker(mock_config) + mock_config = ColbertRerankConfig(llm=mock_llm) + ranker = self.ranker_factory._create_colbert_ranker(mock_config) - assert ranker == "colbert" + assert ranker == "colbert" def test_create_object_ranker(self, mocker, mock_llm): mocker.patch("metagpt.rag.factories.ranker.ObjectSortPostprocessor", return_value="object") From dae7492b92697cd96406d7f13e61a9d22beb6412 Mon Sep 17 00:00:00 2001 From: yzlin Date: Fri, 29 Mar 2024 10:15:46 +0800 Subject: [PATCH 4/6] update paddleocr version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 3eab2b6a09..6d12514017 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ def run(self): "selenium": ["selenium>4", "webdriver_manager", "beautifulsoup4"], "search-google": ["google-api-python-client==2.94.0"], "search-ddg": ["duckduckgo-search~=4.1.1"], - "ocr": ["paddlepaddle==2.4.2", "paddleocr>=2.0.1", "tabulate==0.9.0"], + "ocr": ["paddlepaddle==2.4.2", "paddleocr~=2.7.3", "tabulate==0.9.0"], "rag": [ "llama-index-core==0.10.15", "llama-index-embeddings-azure-openai==0.1.6", From 5fab800f09b22cb0cc8bd924efd11e9deb4f154f Mon Sep 17 00:00:00 2001 From: yzlin Date: Fri, 29 Mar 2024 14:19:56 +0800 Subject: [PATCH 5/6] v0.8.0 release --- README.md | 2 ++ setup.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index edb2066a33..9f129105c1 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,8 @@

## News +🚀 Mar. 29, 2024: [v0.8.0](https://github.com/geekan/MetaGPT/releases/tag/v0.8.0) released. Now you can use Data Interpreter via pypi package import. Meanwhile, we integrated RAG module and supported multiple new LLMs. + 🚀 Mar. 14, 2024: Our **Data Interpreter** paper is on [arxiv](https://arxiv.org/abs/2402.18679). Check the [example](https://docs.deepwisdom.ai/main/en/DataInterpreter/) and [code](https://github.com/geekan/MetaGPT/tree/main/examples/di)! 🚀 Feb. 08, 2024: [v0.7.0](https://github.com/geekan/MetaGPT/releases/tag/v0.7.0) released, supporting assigning different LLMs to different Roles. We also introduced [Data Interpreter](https://github.com/geekan/MetaGPT/blob/main/examples/di/README.md), a powerful agent capable of solving a wide range of real-world problems. diff --git a/setup.py b/setup.py index 6d12514017..4ea0b366eb 100644 --- a/setup.py +++ b/setup.py @@ -67,7 +67,7 @@ def run(self): setup( name="metagpt", - version="0.7.6", + version="0.8.0", description="The Multi-Agent Framework", long_description=long_description, long_description_content_type="text/markdown", From ae2a89c9dd244745ebe38a741c68444e211671f0 Mon Sep 17 00:00:00 2001 From: geekan Date: Sat, 30 Mar 2024 15:15:03 +0800 Subject: [PATCH 6/6] add contributor form --- README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index edb2066a33..df1cf9e776 100644 --- a/README.md +++ b/README.md @@ -145,10 +145,13 @@ https://github.com/geekan/MetaGPT/assets/34952977/34345016-5d13-489d-b9f9-b82ace ## Support -### Discard Join US -📢 Join Our [Discord Channel](https://discord.gg/ZRHeExS6xv)! +### Discord Join US -Looking forward to seeing you there! 🎉 +📢 Join Our [Discord Channel](https://discord.gg/ZRHeExS6xv)! Looking forward to seeing you there! 🎉 + +### Contributor form + +📝 [Fill out the form](https://airtable.com/appInfdG0eJ9J4NNL/pagK3Fh1sGclBvVkV/form) to become a contributor. We are looking forward to your participation! ### Contact Information