From 6cb5492f02499ca4c3fd196d4e194f71c77aa422 Mon Sep 17 00:00:00 2001
From: seehi <6580@pm.me>
Date: Wed, 27 Mar 2024 19:53:50 +0800
Subject: [PATCH 1/6] upgrade llama-index-vector-stores-chroma and rag test
coverage 100%
---
.../minecraft_env/minecraft_env.py | 2 +-
metagpt/rag/factories/index.py | 2 +-
metagpt/rag/factories/retriever.py | 2 +-
metagpt/rag/vector_stores/__init__.py | 0
metagpt/rag/vector_stores/chroma/__init__.py | 3 -
metagpt/rag/vector_stores/chroma/base.py | 290 ------------------
setup.py | 2 +-
tests/metagpt/rag/engines/test_simple.py | 168 +++++++++-
tests/metagpt/rag/factories/test_embedding.py | 43 +++
tests/metagpt/rag/factories/test_index.py | 89 ++++++
tests/metagpt/rag/factories/test_llm.py | 71 +++++
tests/metagpt/rag/factories/test_ranker.py | 58 ++--
tests/metagpt/rag/factories/test_retriever.py | 80 +++--
tests/metagpt/rag/rankers/test_base_ranker.py | 23 ++
.../metagpt/rag/rankers/test_object_ranker.py | 37 ++-
.../rag/retrievers/test_base_retriever.py | 21 ++
.../rag/retrievers/test_bm25_retriever.py | 12 +-
.../rag/retrievers/test_chroma_retriever.py | 20 ++
.../rag/retrievers/test_es_retriever.py | 20 ++
.../rag/retrievers/test_faiss_retriever.py | 11 +-
.../rag/retrievers/test_hybrid_retriever.py | 28 +-
21 files changed, 600 insertions(+), 382 deletions(-)
delete mode 100644 metagpt/rag/vector_stores/__init__.py
delete mode 100644 metagpt/rag/vector_stores/chroma/__init__.py
delete mode 100644 metagpt/rag/vector_stores/chroma/base.py
create mode 100644 tests/metagpt/rag/factories/test_embedding.py
create mode 100644 tests/metagpt/rag/factories/test_index.py
create mode 100644 tests/metagpt/rag/factories/test_llm.py
create mode 100644 tests/metagpt/rag/rankers/test_base_ranker.py
create mode 100644 tests/metagpt/rag/retrievers/test_base_retriever.py
create mode 100644 tests/metagpt/rag/retrievers/test_chroma_retriever.py
create mode 100644 tests/metagpt/rag/retrievers/test_es_retriever.py
diff --git a/metagpt/environment/minecraft_env/minecraft_env.py b/metagpt/environment/minecraft_env/minecraft_env.py
index 26d4d03a86..6e1800b328 100644
--- a/metagpt/environment/minecraft_env/minecraft_env.py
+++ b/metagpt/environment/minecraft_env/minecraft_env.py
@@ -8,6 +8,7 @@
import time
from typing import Any, Iterable
+from llama_index.vector_stores.chroma import ChromaVectorStore
from pydantic import ConfigDict, Field
from metagpt.config2 import config as CONFIG
@@ -15,7 +16,6 @@
from metagpt.environment.minecraft_env.const import MC_CKPT_DIR
from metagpt.environment.minecraft_env.minecraft_ext_env import MinecraftExtEnv
from metagpt.logs import logger
-from metagpt.rag.vector_stores.chroma import ChromaVectorStore
from metagpt.utils.common import load_mc_skills_code, read_json_file, write_json_file
diff --git a/metagpt/rag/factories/index.py b/metagpt/rag/factories/index.py
index f200fc94f0..a56471359e 100644
--- a/metagpt/rag/factories/index.py
+++ b/metagpt/rag/factories/index.py
@@ -5,6 +5,7 @@
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.vector_stores.types import BasePydanticVectorStore
+from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore
@@ -17,7 +18,6 @@
ElasticsearchKeywordIndexConfig,
FAISSIndexConfig,
)
-from metagpt.rag.vector_stores.chroma import ChromaVectorStore
class RAGIndexFactory(ConfigBasedFactory):
diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py
index a107d95733..65729002ea 100644
--- a/metagpt/rag/factories/retriever.py
+++ b/metagpt/rag/factories/retriever.py
@@ -6,6 +6,7 @@
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.vector_stores.types import BasePydanticVectorStore
+from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore
@@ -25,7 +26,6 @@
FAISSRetrieverConfig,
IndexRetrieverConfig,
)
-from metagpt.rag.vector_stores.chroma import ChromaVectorStore
class RetrieverFactory(ConfigBasedFactory):
diff --git a/metagpt/rag/vector_stores/__init__.py b/metagpt/rag/vector_stores/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/metagpt/rag/vector_stores/chroma/__init__.py b/metagpt/rag/vector_stores/chroma/__init__.py
deleted file mode 100644
index 87ba4d8a76..0000000000
--- a/metagpt/rag/vector_stores/chroma/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from metagpt.rag.vector_stores.chroma.base import ChromaVectorStore
-
-__all__ = ["ChromaVectorStore"]
diff --git a/metagpt/rag/vector_stores/chroma/base.py b/metagpt/rag/vector_stores/chroma/base.py
deleted file mode 100644
index 55e5bd40d2..0000000000
--- a/metagpt/rag/vector_stores/chroma/base.py
+++ /dev/null
@@ -1,290 +0,0 @@
-"""Chroma vector store.
-
-Refs to https://github.com/run-llama/llama_index/blob/v0.10.12/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/llama_index/vector_stores/chroma/base.py.
-The repo requires onnxruntime = "^1.17.0", which is too new for many OS systems, such as CentOS7.
-"""
-
-import math
-from typing import Any, Dict, Generator, List, Optional, cast
-
-import chromadb
-from chromadb.api.models.Collection import Collection
-from llama_index.core.bridge.pydantic import Field, PrivateAttr
-from llama_index.core.schema import BaseNode, MetadataMode, TextNode
-from llama_index.core.utils import truncate_text
-from llama_index.core.vector_stores.types import (
- BasePydanticVectorStore,
- MetadataFilters,
- VectorStoreQuery,
- VectorStoreQueryResult,
-)
-from llama_index.core.vector_stores.utils import (
- legacy_metadata_dict_to_node,
- metadata_dict_to_node,
- node_to_metadata_dict,
-)
-
-from metagpt.logs import logger
-
-
-def _transform_chroma_filter_condition(condition: str) -> str:
- """Translate standard metadata filter op to Chroma specific spec."""
- if condition == "and":
- return "$and"
- elif condition == "or":
- return "$or"
- else:
- raise ValueError(f"Filter condition {condition} not supported")
-
-
-def _transform_chroma_filter_operator(operator: str) -> str:
- """Translate standard metadata filter operator to Chroma specific spec."""
- if operator == "!=":
- return "$ne"
- elif operator == "==":
- return "$eq"
- elif operator == ">":
- return "$gt"
- elif operator == "<":
- return "$lt"
- elif operator == ">=":
- return "$gte"
- elif operator == "<=":
- return "$lte"
- else:
- raise ValueError(f"Filter operator {operator} not supported")
-
-
-def _to_chroma_filter(
- standard_filters: MetadataFilters,
-) -> dict:
- """Translate standard metadata filters to Chroma specific spec."""
- filters = {}
- filters_list = []
- condition = standard_filters.condition or "and"
- condition = _transform_chroma_filter_condition(condition)
- if standard_filters.filters:
- for filter in standard_filters.filters:
- if filter.operator:
- filters_list.append({filter.key: {_transform_chroma_filter_operator(filter.operator): filter.value}})
- else:
- filters_list.append({filter.key: filter.value})
- if len(filters_list) == 1:
- # If there is only one filter, return it directly
- return filters_list[0]
- elif len(filters_list) > 1:
- filters[condition] = filters_list
- return filters
-
-
-import_err_msg = "`chromadb` package not found, please run `pip install chromadb`"
-MAX_CHUNK_SIZE = 41665 # One less than the max chunk size for ChromaDB
-
-
-def chunk_list(lst: List[BaseNode], max_chunk_size: int) -> Generator[List[BaseNode], None, None]:
- """Yield successive max_chunk_size-sized chunks from lst.
- Args:
- lst (List[BaseNode]): list of nodes with embeddings
- max_chunk_size (int): max chunk size
- Yields:
- Generator[List[BaseNode], None, None]: list of nodes with embeddings
- """
- for i in range(0, len(lst), max_chunk_size):
- yield lst[i : i + max_chunk_size]
-
-
-class ChromaVectorStore(BasePydanticVectorStore):
- """Chroma vector store.
- In this vector store, embeddings are stored within a ChromaDB collection.
- During query time, the index uses ChromaDB to query for the top
- k most similar nodes.
- Args:
- chroma_collection (chromadb.api.models.Collection.Collection):
- ChromaDB collection instance
- """
-
- stores_text: bool = True
- flat_metadata: bool = True
- collection_name: Optional[str]
- host: Optional[str]
- port: Optional[str]
- ssl: bool
- headers: Optional[Dict[str, str]]
- persist_dir: Optional[str]
- collection_kwargs: Dict[str, Any] = Field(default_factory=dict)
- _collection: Any = PrivateAttr()
-
- def __init__(
- self,
- chroma_collection: Optional[Any] = None,
- collection_name: Optional[str] = None,
- host: Optional[str] = None,
- port: Optional[str] = None,
- ssl: bool = False,
- headers: Optional[Dict[str, str]] = None,
- persist_dir: Optional[str] = None,
- collection_kwargs: Optional[dict] = None,
- **kwargs: Any,
- ) -> None:
- """Init params."""
- collection_kwargs = collection_kwargs or {}
- if chroma_collection is None:
- client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
- self._collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
- else:
- self._collection = cast(Collection, chroma_collection)
- super().__init__(
- host=host,
- port=port,
- ssl=ssl,
- headers=headers,
- collection_name=collection_name,
- persist_dir=persist_dir,
- collection_kwargs=collection_kwargs or {},
- )
-
- @classmethod
- def from_collection(cls, collection: Any) -> "ChromaVectorStore":
- try:
- from chromadb import Collection
- except ImportError:
- raise ImportError(import_err_msg)
- if not isinstance(collection, Collection):
- raise Exception("argument is not chromadb collection instance")
- return cls(chroma_collection=collection)
-
- @classmethod
- def from_params(
- cls,
- collection_name: str,
- host: Optional[str] = None,
- port: Optional[str] = None,
- ssl: bool = False,
- headers: Optional[Dict[str, str]] = None,
- persist_dir: Optional[str] = None,
- collection_kwargs: dict = {},
- **kwargs: Any,
- ) -> "ChromaVectorStore":
- if persist_dir:
- client = chromadb.PersistentClient(path=persist_dir)
- collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
- elif host and port:
- client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
- collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
- else:
- raise ValueError("Either `persist_dir` or (`host`,`port`) must be specified")
- return cls(
- chroma_collection=collection,
- host=host,
- port=port,
- ssl=ssl,
- headers=headers,
- persist_dir=persist_dir,
- collection_kwargs=collection_kwargs,
- **kwargs,
- )
-
- @classmethod
- def class_name(cls) -> str:
- return "ChromaVectorStore"
-
- def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
- """Add nodes to index.
- Args:
- nodes: List[BaseNode]: list of nodes with embeddings
- """
- if not self._collection:
- raise ValueError("Collection not initialized")
- max_chunk_size = MAX_CHUNK_SIZE
- node_chunks = chunk_list(nodes, max_chunk_size)
- all_ids = []
- for node_chunk in node_chunks:
- embeddings = []
- metadatas = []
- ids = []
- documents = []
- for node in node_chunk:
- embeddings.append(node.get_embedding())
- metadata_dict = node_to_metadata_dict(node, remove_text=True, flat_metadata=self.flat_metadata)
- for key in metadata_dict:
- if metadata_dict[key] is None:
- metadata_dict[key] = ""
- metadatas.append(metadata_dict)
- ids.append(node.node_id)
- documents.append(node.get_content(metadata_mode=MetadataMode.NONE))
- self._collection.add(
- embeddings=embeddings,
- ids=ids,
- metadatas=metadatas,
- documents=documents,
- )
- all_ids.extend(ids)
- return all_ids
-
- def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
- """
- Delete nodes using with ref_doc_id.
- Args:
- ref_doc_id (str): The doc_id of the document to delete.
- """
- self._collection.delete(where={"document_id": ref_doc_id})
-
- @property
- def client(self) -> Any:
- """Return client."""
- return self._collection
-
- def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
- """Query index for top k most similar nodes.
- Args:
- query_embedding (List[float]): query embedding
- similarity_top_k (int): top k most similar nodes
- """
- if query.filters is not None:
- if "where" in kwargs:
- raise ValueError(
- "Cannot specify metadata filters via both query and kwargs. "
- "Use kwargs only for chroma specific items that are "
- "not supported via the generic query interface."
- )
- where = _to_chroma_filter(query.filters)
- else:
- where = kwargs.pop("where", {})
- results = self._collection.query(
- query_embeddings=query.query_embedding,
- n_results=query.similarity_top_k,
- where=where,
- **kwargs,
- )
- logger.debug(f"> Top {len(results['documents'])} nodes:")
- nodes = []
- similarities = []
- ids = []
- for node_id, text, metadata, distance in zip(
- results["ids"][0],
- results["documents"][0],
- results["metadatas"][0],
- results["distances"][0],
- ):
- try:
- node = metadata_dict_to_node(metadata)
- node.set_content(text)
- except Exception:
- # NOTE: deprecated legacy logic for backward compatibility
- metadata, node_info, relationships = legacy_metadata_dict_to_node(metadata)
- node = TextNode(
- text=text,
- id_=node_id,
- metadata=metadata,
- start_char_idx=node_info.get("start", None),
- end_char_idx=node_info.get("end", None),
- relationships=relationships,
- )
- nodes.append(node)
- similarity_score = math.exp(-distance)
- similarities.append(similarity_score)
- logger.debug(
- f"> [Node {node_id}] [Similarity score: {similarity_score}] " f"{truncate_text(str(text), 100)}"
- )
- ids.append(node_id)
- return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
diff --git a/setup.py b/setup.py
index c728872ef2..4fa5499da6 100644
--- a/setup.py
+++ b/setup.py
@@ -37,8 +37,8 @@ def run(self):
"llama-index-retrievers-bm25==0.1.3",
"llama-index-vector-stores-faiss==0.1.1",
"llama-index-vector-stores-elasticsearch==0.1.6",
+ "llama-index-vector-stores-chroma==0.1.6",
"llama-index-postprocessor-colbert-rerank==0.1.1",
- "chromadb==0.4.23",
],
}
diff --git a/tests/metagpt/rag/engines/test_simple.py b/tests/metagpt/rag/engines/test_simple.py
index 5627957c73..9262ccb07e 100644
--- a/tests/metagpt/rag/engines/test_simple.py
+++ b/tests/metagpt/rag/engines/test_simple.py
@@ -1,12 +1,26 @@
+import json
+
import pytest
from llama_index.core import VectorStoreIndex
-from llama_index.core.schema import Document, TextNode
+from llama_index.core.embeddings import MockEmbedding
+from llama_index.core.llms import MockLLM
+from llama_index.core.schema import Document, NodeWithScore, TextNode
from metagpt.rag.engines import SimpleEngine
-from metagpt.rag.retrievers.base import ModifiableRAGRetriever
+from metagpt.rag.retrievers import SimpleHybridRetriever
+from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
+from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode
class TestSimpleEngine:
+ @pytest.fixture
+ def mock_llm(self):
+ return MockLLM()
+
+ @pytest.fixture
+ def mock_embedding(self):
+ return MockEmbedding(embed_dim=1)
+
@pytest.fixture
def mock_simple_directory_reader(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
@@ -54,7 +68,7 @@ def test_from_docs(
retriever_configs = [mocker.MagicMock()]
ranker_configs = [mocker.MagicMock()]
- # Execute
+ # Exec
engine = SimpleEngine.from_docs(
input_dir=input_dir,
input_files=input_files,
@@ -65,7 +79,7 @@ def test_from_docs(
ranker_configs=ranker_configs,
)
- # Assertions
+ # Assert
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_vector_store_index.assert_called_once()
mock_get_retriever.assert_called_once_with(
@@ -75,6 +89,68 @@ def test_from_docs(
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
assert isinstance(engine, SimpleEngine)
+ def test_from_docs_without_file(self):
+ with pytest.raises(ValueError):
+ SimpleEngine.from_docs()
+
+ def test_from_objs(self, mock_llm, mock_embedding):
+ # Mock
+ class MockRAGObject:
+ def rag_key(self):
+ return "key"
+
+ def model_dump_json(self):
+ return "{}"
+
+ objs = [MockRAGObject()]
+
+ # Setup
+ retriever_configs = []
+ ranker_configs = []
+
+ # Exec
+ engine = SimpleEngine.from_objs(
+ objs=objs,
+ llm=mock_llm,
+ embed_model=mock_embedding,
+ retriever_configs=retriever_configs,
+ ranker_configs=ranker_configs,
+ )
+
+ # Assert
+ assert isinstance(engine, SimpleEngine)
+ assert engine.index is not None
+
+ def test_from_objs_with_bm25_config(self):
+ # Setup
+ retriever_configs = [BM25RetrieverConfig()]
+
+ # Exec
+ with pytest.raises(ValueError):
+ SimpleEngine.from_objs(
+ objs=[],
+ llm=MockLLM(),
+ retriever_configs=retriever_configs,
+ ranker_configs=[],
+ )
+
+ def test_from_index(self, mocker, mock_llm, mock_embedding):
+ # Mock
+ mock_index = mocker.MagicMock(spec=VectorStoreIndex)
+ mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index")
+ mock_get_index.return_value = mock_index
+
+ # Exec
+ engine = SimpleEngine.from_index(
+ index_config=mock_index,
+ embed_model=mock_embedding,
+ llm=mock_llm,
+ )
+
+ # Assert
+ assert isinstance(engine, SimpleEngine)
+ assert engine.index is mock_index
+
@pytest.mark.asyncio
async def test_asearch(self, mocker):
# Mock
@@ -86,10 +162,10 @@ async def test_asearch(self, mocker):
engine = SimpleEngine(retriever=mocker.MagicMock())
engine.aquery = mock_aquery
- # Execute
+ # Exec
result = await engine.asearch(test_query)
- # Assertions
+ # Assert
mock_aquery.assert_called_once_with(test_query)
assert result == expected_result
@@ -106,10 +182,10 @@ async def test_aretrieve(self, mocker):
engine = SimpleEngine(retriever=mocker.MagicMock())
test_query = "test query"
- # Execute
+ # Exec
result = await engine.aretrieve(test_query)
- # Assertions
+ # Assert
mock_query_bundle.assert_called_once_with(test_query)
mock_super_aretrieve.assert_called_once_with("query_bundle")
assert result[0].text == "node_with_score"
@@ -134,10 +210,10 @@ def test_add_docs(self, mocker):
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
input_files = ["test_file1", "test_file2"]
- # Execute
+ # Exec
engine.add_docs(input_files=input_files)
- # Assertions
+ # Assert
mock_simple_directory_reader.assert_called_once_with(input_files=input_files)
mock_retriever.add_nodes.assert_called_once_with(["node1", "node2"])
@@ -156,11 +232,79 @@ def model_dump_json(self):
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
- # Execute
+ # Exec
engine.add_objs(objs=objs)
- # Assertions
+ # Assert
assert mock_retriever.add_nodes.call_count == 1
for node in mock_retriever.add_nodes.call_args[0][0]:
assert isinstance(node, TextNode)
assert "is_obj" in node.metadata
+
+ def test_persist_successfully(self, mocker):
+ # Mock
+ mock_retriever = mocker.MagicMock(spec=PersistableRAGRetriever)
+ mock_retriever.persist.return_value = mocker.MagicMock()
+
+ # Setup
+ engine = SimpleEngine(retriever=mock_retriever)
+
+ # Exec
+ engine.persist(persist_dir="")
+
+ def test_ensure_retriever_of_type(self, mocker):
+ # Mock
+ class MyRetriever:
+ def add_nodes(self):
+ ...
+
+ mock_retriever = mocker.MagicMock(spec=SimpleHybridRetriever)
+ mock_retriever.retrievers = [MyRetriever()]
+
+ # Setup
+ engine = SimpleEngine(retriever=mock_retriever)
+
+ # Assert
+ engine._ensure_retriever_of_type(ModifiableRAGRetriever)
+
+ with pytest.raises(TypeError):
+ engine._ensure_retriever_of_type(PersistableRAGRetriever)
+
+ with pytest.raises(TypeError):
+ other_engine = SimpleEngine(retriever=mocker.MagicMock(spec=ModifiableRAGRetriever))
+ other_engine._ensure_retriever_of_type(PersistableRAGRetriever)
+
+ def test_with_obj_metadata(self, mocker):
+ # Mock
+ node = NodeWithScore(
+ node=ObjectNode(
+ text="example",
+ metadata={
+ "is_obj": True,
+ "obj_cls_name": "ExampleObject",
+ "obj_mod_name": "__main__",
+ "obj_json": json.dumps({"key": "test_key", "value": "test_value"}),
+ },
+ )
+ )
+
+ class ExampleObject:
+ def __init__(self, key, value):
+ self.key = key
+ self.value = value
+
+ def __eq__(self, other):
+ return self.key == other.key and self.value == other.value
+
+ mock_import_class = mocker.patch("metagpt.rag.engines.simple.import_class")
+ mock_import_class.return_value = ExampleObject
+
+ # Setup
+ SimpleEngine._try_reconstruct_obj([node])
+
+ # Exec
+ expected_obj = ExampleObject(key="test_key", value="test_value")
+
+ # Assert
+ assert "obj" in node.node.metadata
+ assert node.node.metadata["obj"] == expected_obj
diff --git a/tests/metagpt/rag/factories/test_embedding.py b/tests/metagpt/rag/factories/test_embedding.py
new file mode 100644
index 0000000000..1ded6b4a8d
--- /dev/null
+++ b/tests/metagpt/rag/factories/test_embedding.py
@@ -0,0 +1,43 @@
+import pytest
+
+from metagpt.configs.llm_config import LLMType
+from metagpt.rag.factories.embedding import RAGEmbeddingFactory
+
+
+class TestRAGEmbeddingFactory:
+ @pytest.fixture(autouse=True)
+ def mock_embedding_factory(self):
+ self.embedding_factory = RAGEmbeddingFactory()
+
+ @pytest.fixture
+ def mock_openai_embedding(self, mocker):
+ return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding")
+
+ @pytest.fixture
+ def mock_azure_embedding(self, mocker):
+ return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding")
+
+ def test_get_rag_embedding_openai(self, mock_openai_embedding):
+ # Exec
+ self.embedding_factory.get_rag_embedding(LLMType.OPENAI)
+
+ # Assert
+ mock_openai_embedding.assert_called_once()
+
+ def test_get_rag_embedding_azure(self, mock_azure_embedding):
+ # Exec
+ self.embedding_factory.get_rag_embedding(LLMType.AZURE)
+
+ # Assert
+ mock_azure_embedding.assert_called_once()
+
+ def test_get_rag_embedding_default(self, mocker, mock_openai_embedding):
+ # Mock
+ mock_config = mocker.patch("metagpt.rag.factories.embedding.config")
+ mock_config.llm.api_type = LLMType.OPENAI
+
+ # Exec
+ self.embedding_factory.get_rag_embedding()
+
+ # Assert
+ mock_openai_embedding.assert_called_once()
diff --git a/tests/metagpt/rag/factories/test_index.py b/tests/metagpt/rag/factories/test_index.py
new file mode 100644
index 0000000000..9dc5bfb6be
--- /dev/null
+++ b/tests/metagpt/rag/factories/test_index.py
@@ -0,0 +1,89 @@
+import pytest
+from llama_index.core.embeddings import MockEmbedding
+
+from metagpt.rag.factories.index import RAGIndexFactory
+from metagpt.rag.schema import (
+ BM25IndexConfig,
+ ChromaIndexConfig,
+ ElasticsearchIndexConfig,
+ ElasticsearchStoreConfig,
+ FAISSIndexConfig,
+)
+
+
+class TestRAGIndexFactory:
+ @pytest.fixture(autouse=True)
+ def setup(self):
+ self.index_factory = RAGIndexFactory()
+
+ @pytest.fixture
+ def faiss_config(self):
+ return FAISSIndexConfig(persist_path="")
+
+ @pytest.fixture
+ def chroma_config(self):
+ return ChromaIndexConfig(persist_path="", collection_name="")
+
+ @pytest.fixture
+ def bm25_config(self):
+ return BM25IndexConfig(persist_path="")
+
+ @pytest.fixture
+ def es_config(self, mocker):
+ return ElasticsearchIndexConfig(store_config=ElasticsearchStoreConfig())
+
+ @pytest.fixture
+ def mock_storage_context(self, mocker):
+ return mocker.patch("metagpt.rag.factories.index.StorageContext.from_defaults")
+
+ @pytest.fixture
+ def mock_load_index_from_storage(self, mocker):
+ return mocker.patch("metagpt.rag.factories.index.load_index_from_storage")
+
+ @pytest.fixture
+ def mock_from_vector_store(self, mocker):
+ return mocker.patch("metagpt.rag.factories.index.VectorStoreIndex.from_vector_store")
+
+ @pytest.fixture
+ def mock_embedding(self):
+ return MockEmbedding(embed_dim=1)
+
+ def test_create_faiss_index(
+ self, mocker, faiss_config, mock_storage_context, mock_load_index_from_storage, mock_embedding
+ ):
+ # Mock
+ mock_faiss_store = mocker.patch("metagpt.rag.factories.index.FaissVectorStore.from_persist_dir")
+
+ # Exec
+ self.index_factory.get_index(faiss_config, embed_model=mock_embedding)
+
+ # Assert
+ mock_faiss_store.assert_called_once()
+
+ def test_create_bm25_index(
+ self, mocker, bm25_config, mock_storage_context, mock_load_index_from_storage, mock_embedding
+ ):
+ self.index_factory.get_index(bm25_config, embed_model=mock_embedding)
+
+ def test_create_chroma_index(self, mocker, chroma_config, mock_from_vector_store, mock_embedding):
+ # Mock
+ mock_chroma_db = mocker.patch("metagpt.rag.factories.index.chromadb.PersistentClient")
+ mock_chroma_db.get_or_create_collection.return_value = mocker.MagicMock()
+
+ mock_chroma_store = mocker.patch("metagpt.rag.factories.index.ChromaVectorStore")
+
+ # Exec
+ self.index_factory.get_index(chroma_config, embed_model=mock_embedding)
+
+ # Assert
+ mock_chroma_store.assert_called_once()
+
+ def test_create_es_index(self, mocker, es_config, mock_from_vector_store, mock_embedding):
+ # Mock
+ mock_es_store = mocker.patch("metagpt.rag.factories.index.ElasticsearchStore")
+
+ # Exec
+ self.index_factory.get_index(es_config, embed_model=mock_embedding)
+
+ # Assert
+ mock_es_store.assert_called_once()
diff --git a/tests/metagpt/rag/factories/test_llm.py b/tests/metagpt/rag/factories/test_llm.py
new file mode 100644
index 0000000000..e11b87076c
--- /dev/null
+++ b/tests/metagpt/rag/factories/test_llm.py
@@ -0,0 +1,71 @@
+from typing import Optional, Union
+
+import pytest
+from llama_index.core.llms import LLMMetadata
+
+from metagpt.configs.llm_config import LLMConfig
+from metagpt.const import USE_CONFIG_TIMEOUT
+from metagpt.provider.base_llm import BaseLLM
+from metagpt.rag.factories.llm import RAGLLM, get_rag_llm
+
+
+class MockLLM(BaseLLM):
+ def __init__(self, config: LLMConfig):
+ ...
+
+ async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
+ """_achat_completion implemented by inherited class"""
+
+ async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
+ return "ok"
+
+ def completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
+ return "ok"
+
+ async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
+ """_achat_completion_stream implemented by inherited class"""
+
+ async def aask(
+ self,
+ msg: Union[str, list[dict[str, str]]],
+ system_msgs: Optional[list[str]] = None,
+ format_msgs: Optional[list[dict[str, str]]] = None,
+ images: Optional[Union[str, list[str]]] = None,
+ timeout=USE_CONFIG_TIMEOUT,
+ stream=True,
+ ) -> str:
+ return "ok"
+
+
+class TestRAGLLM:
+ @pytest.fixture
+ def mock_model_infer(self):
+ return MockLLM(config=LLMConfig())
+
+ @pytest.fixture
+ def rag_llm(self, mock_model_infer):
+ return RAGLLM(model_infer=mock_model_infer)
+
+ def test_metadata(self, rag_llm):
+ metadata = rag_llm.metadata
+ assert isinstance(metadata, LLMMetadata)
+ assert metadata.context_window == rag_llm.context_window
+ assert metadata.num_output == rag_llm.num_output
+ assert metadata.model_name == rag_llm.model_name
+
+ @pytest.mark.asyncio
+ async def test_acomplete(self, rag_llm, mock_model_infer):
+ response = await rag_llm.acomplete("question")
+ assert response.text == "ok"
+
+ def test_complete(self, rag_llm, mock_model_infer):
+ response = rag_llm.complete("question")
+ assert response.text == "ok"
+
+ def test_stream_complete(self, rag_llm, mock_model_infer):
+ rag_llm.stream_complete("question")
+
+
+def test_get_rag_llm():
+ result = get_rag_llm(MockLLM(config=LLMConfig()))
+ assert isinstance(result, RAGLLM)
diff --git a/tests/metagpt/rag/factories/test_ranker.py b/tests/metagpt/rag/factories/test_ranker.py
index 563cffa738..3f6b94b47a 100644
--- a/tests/metagpt/rag/factories/test_ranker.py
+++ b/tests/metagpt/rag/factories/test_ranker.py
@@ -1,41 +1,57 @@
import pytest
-from llama_index.core.llms import LLM
+from llama_index.core.llms import MockLLM
from llama_index.core.postprocessor import LLMRerank
from metagpt.rag.factories.ranker import RankerFactory
-from metagpt.rag.schema import LLMRankerConfig
+from metagpt.rag.schema import ColbertRerankConfig, LLMRankerConfig, ObjectRankerConfig
class TestRankerFactory:
- @pytest.fixture
- def ranker_factory(self) -> RankerFactory:
- return RankerFactory()
+ @pytest.fixture(autouse=True)
+ def ranker_factory(self):
+ self.ranker_factory: RankerFactory = RankerFactory()
@pytest.fixture
- def mock_llm(self, mocker):
- return mocker.MagicMock(spec=LLM)
+ def mock_llm(self):
+ return MockLLM()
- def test_get_rankers_with_no_configs(self, ranker_factory: RankerFactory, mock_llm, mocker):
- mocker.patch.object(ranker_factory, "_extract_llm", return_value=mock_llm)
- default_rankers = ranker_factory.get_rankers()
+ def test_get_rankers_with_no_configs(self, mock_llm, mocker):
+ mocker.patch.object(self.ranker_factory, "_extract_llm", return_value=mock_llm)
+ default_rankers = self.ranker_factory.get_rankers()
assert len(default_rankers) == 0
- def test_get_rankers_with_configs(self, ranker_factory: RankerFactory, mock_llm):
+ def test_get_rankers_with_configs(self, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
- rankers = ranker_factory.get_rankers(configs=[mock_config])
+ rankers = self.ranker_factory.get_rankers(configs=[mock_config])
assert len(rankers) == 1
assert isinstance(rankers[0], LLMRerank)
- def test_create_llm_ranker_creates_correct_instance(self, ranker_factory: RankerFactory, mock_llm):
- mock_config = LLMRankerConfig(llm=mock_llm)
- ranker = ranker_factory._create_llm_ranker(mock_config)
- assert isinstance(ranker, LLMRerank)
-
- def test_extract_llm_from_config(self, ranker_factory: RankerFactory, mock_llm):
+ def test_extract_llm_from_config(self, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
- extracted_llm = ranker_factory._extract_llm(config=mock_config)
+ extracted_llm = self.ranker_factory._extract_llm(config=mock_config)
assert extracted_llm == mock_llm
- def test_extract_llm_from_kwargs(self, ranker_factory: RankerFactory, mock_llm):
- extracted_llm = ranker_factory._extract_llm(llm=mock_llm)
+ def test_extract_llm_from_kwargs(self, mock_llm):
+ extracted_llm = self.ranker_factory._extract_llm(llm=mock_llm)
assert extracted_llm == mock_llm
+
+ def test_create_llm_ranker(self, mock_llm):
+ mock_config = LLMRankerConfig(llm=mock_llm)
+ ranker = self.ranker_factory._create_llm_ranker(mock_config)
+ assert isinstance(ranker, LLMRerank)
+
+ def test_create_colbert_ranker(self, mocker, mock_llm):
+ mocker.patch("metagpt.rag.factories.ranker.ColbertRerank", return_value="colbert")
+
+ mock_config = ColbertRerankConfig(llm=mock_llm)
+ ranker = self.ranker_factory._create_colbert_ranker(mock_config)
+
+ assert ranker == "colbert"
+
+ def test_create_object_ranker(self, mocker, mock_llm):
+ mocker.patch("metagpt.rag.factories.ranker.ObjectSortPostprocessor", return_value="object")
+
+ mock_config = ObjectRankerConfig(field_name="fake", llm=mock_llm)
+ ranker = self.ranker_factory._create_object_ranker(mock_config)
+
+ assert ranker == "object"
diff --git a/tests/metagpt/rag/factories/test_retriever.py b/tests/metagpt/rag/factories/test_retriever.py
index ac8926d468..ef1cef7e00 100644
--- a/tests/metagpt/rag/factories/test_retriever.py
+++ b/tests/metagpt/rag/factories/test_retriever.py
@@ -1,18 +1,28 @@
import faiss
import pytest
from llama_index.core import VectorStoreIndex
+from llama_index.vector_stores.chroma import ChromaVectorStore
+from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from metagpt.rag.factories.retriever import RetrieverFactory
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
+from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
+from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
-from metagpt.rag.schema import BM25RetrieverConfig, FAISSRetrieverConfig
+from metagpt.rag.schema import (
+ BM25RetrieverConfig,
+ ChromaRetrieverConfig,
+ ElasticsearchRetrieverConfig,
+ ElasticsearchStoreConfig,
+ FAISSRetrieverConfig,
+)
class TestRetrieverFactory:
- @pytest.fixture
+ @pytest.fixture(autouse=True)
def retriever_factory(self):
- return RetrieverFactory()
+ self.retriever_factory: RetrieverFactory = RetrieverFactory()
@pytest.fixture
def mock_faiss_index(self, mocker):
@@ -25,55 +35,79 @@ def mock_vector_store_index(self, mocker):
mock.docstore.docs.values.return_value = []
return mock
- def test_get_retriever_with_faiss_config(
- self, retriever_factory: RetrieverFactory, mock_faiss_index, mocker, mock_vector_store_index
- ):
+ @pytest.fixture
+ def mock_chroma_vector_store(self, mocker):
+ return mocker.MagicMock(spec=ChromaVectorStore)
+
+ @pytest.fixture
+ def mock_es_vector_store(self, mocker):
+ return mocker.MagicMock(spec=ElasticsearchStore)
+
+ def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_vector_store_index):
mock_config = FAISSRetrieverConfig(dimensions=128)
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
- mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
+ mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
- retriever = retriever_factory.get_retriever(configs=[mock_config])
+ retriever = self.retriever_factory.get_retriever(configs=[mock_config])
assert isinstance(retriever, FAISSRetriever)
- def test_get_retriever_with_bm25_config(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
+ def test_get_retriever_with_bm25_config(self, mocker, mock_vector_store_index):
mock_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
- mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
+ mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
- retriever = retriever_factory.get_retriever(configs=[mock_config])
+ retriever = self.retriever_factory.get_retriever(configs=[mock_config])
assert isinstance(retriever, DynamicBM25Retriever)
- def test_get_retriever_with_multiple_configs_returns_hybrid(
- self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index
- ):
+ def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_vector_store_index):
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
mock_bm25_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
- mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
+ mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
- retriever = retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
+ retriever = self.retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
assert isinstance(retriever, SimpleHybridRetriever)
- def test_create_default_retriever(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
- mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
+ def test_get_retriever_with_chroma_config(self, mocker, mock_vector_store_index, mock_chroma_vector_store):
+ mock_config = ChromaRetrieverConfig(persist_path="/path/to/chroma", collection_name="test_collection")
+ mock_chromadb = mocker.patch("metagpt.rag.factories.retriever.chromadb.PersistentClient")
+ mock_chromadb.get_or_create_collection.return_value = mocker.MagicMock()
+ mocker.patch("metagpt.rag.factories.retriever.ChromaVectorStore", return_value=mock_chroma_vector_store)
+ mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
+
+ retriever = self.retriever_factory.get_retriever(configs=[mock_config])
+
+ assert isinstance(retriever, ChromaRetriever)
+
+ def test_get_retriever_with_es_config(self, mocker, mock_vector_store_index, mock_es_vector_store):
+ mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig())
+ mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store)
+ mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
+
+ retriever = self.retriever_factory.get_retriever(configs=[mock_config])
+
+ assert isinstance(retriever, ElasticsearchRetriever)
+
+ def test_create_default_retriever(self, mocker, mock_vector_store_index):
+ mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
mock_vector_store_index.as_retriever = mocker.MagicMock()
- retriever = retriever_factory.get_retriever()
+ retriever = self.retriever_factory.get_retriever()
mock_vector_store_index.as_retriever.assert_called_once()
assert retriever is mock_vector_store_index.as_retriever.return_value
- def test_extract_index_from_config(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
+ def test_extract_index_from_config(self, mock_vector_store_index):
mock_config = FAISSRetrieverConfig(index=mock_vector_store_index)
- extracted_index = retriever_factory._extract_index(config=mock_config)
+ extracted_index = self.retriever_factory._extract_index(config=mock_config)
assert extracted_index == mock_vector_store_index
- def test_extract_index_from_kwargs(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
- extracted_index = retriever_factory._extract_index(index=mock_vector_store_index)
+ def test_extract_index_from_kwargs(self, mock_vector_store_index):
+ extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index)
assert extracted_index == mock_vector_store_index
diff --git a/tests/metagpt/rag/rankers/test_base_ranker.py b/tests/metagpt/rag/rankers/test_base_ranker.py
new file mode 100644
index 0000000000..9755949f6a
--- /dev/null
+++ b/tests/metagpt/rag/rankers/test_base_ranker.py
@@ -0,0 +1,23 @@
+import pytest
+from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode
+
+from metagpt.rag.rankers.base import RAGRanker
+
+
+class SimpleRAGRanker(RAGRanker):
+ def _postprocess_nodes(self, nodes, query_bundle=None):
+ return [NodeWithScore(node=node.node, score=node.score + 1) for node in nodes]
+
+
+class TestSimpleRAGRanker:
+ @pytest.fixture
+ def ranker(self):
+ return SimpleRAGRanker()
+
+ def test_postprocess_nodes_increases_scores(self, ranker):
+ nodes = [NodeWithScore(node=TextNode(text="a"), score=10), NodeWithScore(node=TextNode(text="b"), score=20)]
+ query_bundle = QueryBundle(query_str="test query")
+
+ processed_nodes = ranker._postprocess_nodes(nodes, query_bundle)
+
+ assert all(node.score == original_node.score + 1 for node, original_node in zip(processed_nodes, nodes))
diff --git a/tests/metagpt/rag/rankers/test_object_ranker.py b/tests/metagpt/rag/rankers/test_object_ranker.py
index 7ea6b7488b..4a9f66a42d 100644
--- a/tests/metagpt/rag/rankers/test_object_ranker.py
+++ b/tests/metagpt/rag/rankers/test_object_ranker.py
@@ -14,7 +14,7 @@ class Record(BaseModel):
class TestObjectSortPostprocessor:
@pytest.fixture
- def nodes_with_scores(self):
+ def mock_nodes_with_scores(self):
nodes = [
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=10).model_dump_json()}), score=10),
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=20).model_dump_json()}), score=20),
@@ -23,38 +23,47 @@ def nodes_with_scores(self):
return nodes
@pytest.fixture
- def query_bundle(self, mocker):
+ def mock_query_bundle(self, mocker):
return mocker.MagicMock(spec=QueryBundle)
- def test_sort_descending(self, nodes_with_scores, query_bundle):
+ def test_sort_descending(self, mock_nodes_with_scores, mock_query_bundle):
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
- sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
+ sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle)
assert [node.score for node in sorted_nodes] == [20, 10, 5]
- def test_sort_ascending(self, nodes_with_scores, query_bundle):
+ def test_sort_ascending(self, mock_nodes_with_scores, mock_query_bundle):
postprocessor = ObjectSortPostprocessor(field_name="score", order="asc")
- sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
+ sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle)
assert [node.score for node in sorted_nodes] == [5, 10, 20]
- def test_top_n_limit(self, nodes_with_scores, query_bundle):
+ def test_top_n_limit(self, mock_nodes_with_scores, mock_query_bundle):
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc", top_n=2)
- sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
+ sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle)
assert len(sorted_nodes) == 2
assert [node.score for node in sorted_nodes] == [20, 10]
- def test_invalid_json_metadata(self, query_bundle):
+ def test_invalid_json_metadata(self, mock_query_bundle):
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": "invalid_json"}), score=10)]
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
with pytest.raises(ValueError):
- postprocessor._postprocess_nodes(nodes, query_bundle)
+ postprocessor._postprocess_nodes(nodes, mock_query_bundle)
- def test_missing_query_bundle(self, nodes_with_scores):
+ def test_missing_query_bundle(self, mock_nodes_with_scores):
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
with pytest.raises(ValueError):
- postprocessor._postprocess_nodes(nodes_with_scores, query_bundle=None)
+ postprocessor._postprocess_nodes(mock_nodes_with_scores, query_bundle=None)
- def test_field_not_found_in_object(self):
+ def test_field_not_found_in_object(self, mock_query_bundle):
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": json.dumps({"not_score": 10})}), score=10)]
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
with pytest.raises(ValueError):
- postprocessor._postprocess_nodes(nodes)
+ postprocessor._postprocess_nodes(nodes, query_bundle=mock_query_bundle)
+
+ def test_not_nodes(self, mock_query_bundle):
+ nodes = []
+ postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
+ result = postprocessor._postprocess_nodes(nodes, mock_query_bundle)
+ assert result == []
+
+ def test_class_name(self):
+ assert ObjectSortPostprocessor.class_name() == "ObjectSortPostprocessor"
diff --git a/tests/metagpt/rag/retrievers/test_base_retriever.py b/tests/metagpt/rag/retrievers/test_base_retriever.py
new file mode 100644
index 0000000000..1065b9731d
--- /dev/null
+++ b/tests/metagpt/rag/retrievers/test_base_retriever.py
@@ -0,0 +1,21 @@
+from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
+
+
+class SubModifiableRAGRetriever(ModifiableRAGRetriever):
+ ...
+
+
+class SubPersistableRAGRetriever(PersistableRAGRetriever):
+ ...
+
+
+class TestModifiableRAGRetriever:
+ def test_subclasshook(self):
+ result = SubModifiableRAGRetriever.__subclasshook__(SubModifiableRAGRetriever)
+ assert result is NotImplemented
+
+
+class TestPersistableRAGRetriever:
+ def test_subclasshook(self):
+ result = SubPersistableRAGRetriever.__subclasshook__(SubPersistableRAGRetriever)
+ assert result is NotImplemented
diff --git a/tests/metagpt/rag/retrievers/test_bm25_retriever.py b/tests/metagpt/rag/retrievers/test_bm25_retriever.py
index 28b37c86b8..5a569f1036 100644
--- a/tests/metagpt/rag/retrievers/test_bm25_retriever.py
+++ b/tests/metagpt/rag/retrievers/test_bm25_retriever.py
@@ -8,30 +8,30 @@
class TestDynamicBM25Retriever:
@pytest.fixture(autouse=True)
def setup(self, mocker):
- # 创建模拟的Document对象
self.doc1 = mocker.MagicMock(spec=Node)
self.doc1.get_content.return_value = "Document content 1"
self.doc2 = mocker.MagicMock(spec=Node)
self.doc2.get_content.return_value = "Document content 2"
self.mock_nodes = [self.doc1, self.doc2]
- # 模拟index
index = mocker.MagicMock(spec=VectorStoreIndex)
+ index.storage_context.persist.return_value = "ok"
- # 模拟nodes和tokenizer参数
mock_nodes = []
mock_tokenizer = mocker.MagicMock()
self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
- # 初始化DynamicBM25Retriever对象,并提供必需的参数
self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer, index=index)
def test_add_docs_updates_nodes_and_corpus(self):
- # Execute
+ # Exec
self.retriever.add_nodes(self.mock_nodes)
- # Assertions
+ # Assert
assert len(self.retriever._nodes) == len(self.mock_nodes)
assert len(self.retriever._corpus) == len(self.mock_nodes)
self.retriever._tokenizer.assert_called()
self.mock_bm25okapi.assert_called()
+
+ def test_persist(self):
+ self.retriever.persist("")
diff --git a/tests/metagpt/rag/retrievers/test_chroma_retriever.py b/tests/metagpt/rag/retrievers/test_chroma_retriever.py
new file mode 100644
index 0000000000..cf07903cf2
--- /dev/null
+++ b/tests/metagpt/rag/retrievers/test_chroma_retriever.py
@@ -0,0 +1,20 @@
+import pytest
+from llama_index.core.schema import Node
+
+from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
+
+
+class TestChromaRetriever:
+ @pytest.fixture(autouse=True)
+ def setup(self, mocker):
+ self.doc1 = mocker.MagicMock(spec=Node)
+ self.doc2 = mocker.MagicMock(spec=Node)
+ self.mock_nodes = [self.doc1, self.doc2]
+
+ self.mock_index = mocker.MagicMock()
+ self.retriever = ChromaRetriever(self.mock_index)
+
+ def test_add_nodes(self):
+ self.retriever.add_nodes(self.mock_nodes)
+
+ self.mock_index.insert_nodes.assert_called()
diff --git a/tests/metagpt/rag/retrievers/test_es_retriever.py b/tests/metagpt/rag/retrievers/test_es_retriever.py
new file mode 100644
index 0000000000..1824bfbd28
--- /dev/null
+++ b/tests/metagpt/rag/retrievers/test_es_retriever.py
@@ -0,0 +1,20 @@
+import pytest
+from llama_index.core.schema import Node
+
+from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
+
+
+class TestElasticsearchRetriever:
+ @pytest.fixture(autouse=True)
+ def setup(self, mocker):
+ self.doc1 = mocker.MagicMock(spec=Node)
+ self.doc2 = mocker.MagicMock(spec=Node)
+ self.mock_nodes = [self.doc1, self.doc2]
+
+ self.mock_index = mocker.MagicMock()
+ self.retriever = ElasticsearchRetriever(self.mock_index)
+
+ def test_add_nodes(self):
+ self.retriever.add_nodes(self.mock_nodes)
+
+ self.mock_index.insert_nodes.assert_called()
diff --git a/tests/metagpt/rag/retrievers/test_faiss_retriever.py b/tests/metagpt/rag/retrievers/test_faiss_retriever.py
index 9113f110cf..8546732157 100644
--- a/tests/metagpt/rag/retrievers/test_faiss_retriever.py
+++ b/tests/metagpt/rag/retrievers/test_faiss_retriever.py
@@ -7,16 +7,19 @@
class TestFAISSRetriever:
@pytest.fixture(autouse=True)
def setup(self, mocker):
- # 创建模拟的Document对象
self.doc1 = mocker.MagicMock(spec=Node)
self.doc2 = mocker.MagicMock(spec=Node)
self.mock_nodes = [self.doc1, self.doc2]
- # 模拟FAISSRetriever的_index属性
self.mock_index = mocker.MagicMock()
self.retriever = FAISSRetriever(self.mock_index)
- def test_add_docs_calls_insert_for_each_document(self, mocker):
+ def test_add_docs_calls_insert_for_each_document(self):
self.retriever.add_nodes(self.mock_nodes)
- assert self.mock_index.insert_nodes.assert_called
+ self.mock_index.insert_nodes.assert_called()
+
+ def test_persist(self):
+ self.retriever.persist("")
+
+ self.mock_index.storage_context.persist.assert_called()
diff --git a/tests/metagpt/rag/retrievers/test_hybrid_retriever.py b/tests/metagpt/rag/retrievers/test_hybrid_retriever.py
index 8cc3087c86..da150d8793 100644
--- a/tests/metagpt/rag/retrievers/test_hybrid_retriever.py
+++ b/tests/metagpt/rag/retrievers/test_hybrid_retriever.py
@@ -1,5 +1,3 @@
-from unittest.mock import AsyncMock
-
import pytest
from llama_index.core.schema import NodeWithScore, TextNode
@@ -7,18 +5,30 @@
class TestSimpleHybridRetriever:
+ @pytest.fixture
+ def mock_retriever(self, mocker):
+ return mocker.MagicMock()
+
+ @pytest.fixture
+ def mock_hybrid_retriever(self, mock_retriever) -> SimpleHybridRetriever:
+ return SimpleHybridRetriever(mock_retriever)
+
+ @pytest.fixture
+ def mock_node(self):
+ return NodeWithScore(node=TextNode(id_="2"), score=0.95)
+
@pytest.mark.asyncio
- async def test_aretrieve(self):
+ async def test_aretrieve(self, mocker):
question = "test query"
# Create mock retrievers
- mock_retriever1 = AsyncMock()
+ mock_retriever1 = mocker.AsyncMock()
mock_retriever1.aretrieve.return_value = [
NodeWithScore(node=TextNode(id_="1"), score=1.0),
NodeWithScore(node=TextNode(id_="2"), score=0.95),
]
- mock_retriever2 = AsyncMock()
+ mock_retriever2 = mocker.AsyncMock()
mock_retriever2.aretrieve.return_value = [
NodeWithScore(node=TextNode(id_="2"), score=0.95),
NodeWithScore(node=TextNode(id_="3"), score=0.8),
@@ -37,3 +47,11 @@ async def test_aretrieve(self):
# Check if the scores are correct (assuming you want the highest score)
node_scores = {node.node.node_id: node.score for node in results}
assert node_scores["2"] == 0.95
+
+ def test_add_nodes(self, mock_hybrid_retriever: SimpleHybridRetriever, mock_node):
+ mock_hybrid_retriever.add_nodes([mock_node])
+ mock_hybrid_retriever.retrievers[0].add_nodes.assert_called_once()
+
+ def test_persist(self, mock_hybrid_retriever: SimpleHybridRetriever):
+ mock_hybrid_retriever.persist("")
+ mock_hybrid_retriever.retrievers[0].persist.assert_called_once()
From 676d5ff55ba6468292fdb99af3617fde819f3433 Mon Sep 17 00:00:00 2001
From: seehi <6580@pm.me>
Date: Thu, 28 Mar 2024 15:47:38 +0800
Subject: [PATCH 2/6] lazy import colbert
---
metagpt/rag/factories/ranker.py | 7 ++++++-
setup.py | 1 -
2 files changed, 6 insertions(+), 2 deletions(-)
diff --git a/metagpt/rag/factories/ranker.py b/metagpt/rag/factories/ranker.py
index 07cb1b929f..476fe8c1a6 100644
--- a/metagpt/rag/factories/ranker.py
+++ b/metagpt/rag/factories/ranker.py
@@ -3,7 +3,6 @@
from llama_index.core.llms import LLM
from llama_index.core.postprocessor import LLMRerank
from llama_index.core.postprocessor.types import BaseNodePostprocessor
-from llama_index.postprocessor.colbert_rerank import ColbertRerank
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor
@@ -38,6 +37,12 @@ def _create_llm_ranker(self, config: LLMRankerConfig, **kwargs) -> LLMRerank:
return LLMRerank(**config.model_dump())
def _create_colbert_ranker(self, config: ColbertRerankConfig, **kwargs) -> LLMRerank:
+ try:
+ from llama_index.postprocessor.colbert_rerank import ColbertRerank
+ except ImportError:
+ raise ImportError(
+ "`llama-index-postprocessor-colbert-rerank` package not found, please run `pip install llama-index-postprocessor-colbert-rerank`"
+ )
return ColbertRerank(**config.model_dump())
def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank:
diff --git a/setup.py b/setup.py
index 4fa5499da6..3eab2b6a09 100644
--- a/setup.py
+++ b/setup.py
@@ -38,7 +38,6 @@ def run(self):
"llama-index-vector-stores-faiss==0.1.1",
"llama-index-vector-stores-elasticsearch==0.1.6",
"llama-index-vector-stores-chroma==0.1.6",
- "llama-index-postprocessor-colbert-rerank==0.1.1",
],
}
From b355f715bd26d338409a6371ea0a6dbdcdab38a8 Mon Sep 17 00:00:00 2001
From: seehi <6580@pm.me>
Date: Thu, 28 Mar 2024 16:06:32 +0800
Subject: [PATCH 3/6] lazy import colbert
---
tests/metagpt/rag/factories/test_ranker.py | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
diff --git a/tests/metagpt/rag/factories/test_ranker.py b/tests/metagpt/rag/factories/test_ranker.py
index 3f6b94b47a..e40f7f8dff 100644
--- a/tests/metagpt/rag/factories/test_ranker.py
+++ b/tests/metagpt/rag/factories/test_ranker.py
@@ -1,3 +1,5 @@
+import contextlib
+
import pytest
from llama_index.core.llms import MockLLM
from llama_index.core.postprocessor import LLMRerank
@@ -41,12 +43,13 @@ def test_create_llm_ranker(self, mock_llm):
assert isinstance(ranker, LLMRerank)
def test_create_colbert_ranker(self, mocker, mock_llm):
- mocker.patch("metagpt.rag.factories.ranker.ColbertRerank", return_value="colbert")
+ with contextlib.suppress(ImportError):
+ mocker.patch("llama_index.postprocessor.colbert_rerank.ColbertRerank", return_value="colbert")
- mock_config = ColbertRerankConfig(llm=mock_llm)
- ranker = self.ranker_factory._create_colbert_ranker(mock_config)
+ mock_config = ColbertRerankConfig(llm=mock_llm)
+ ranker = self.ranker_factory._create_colbert_ranker(mock_config)
- assert ranker == "colbert"
+ assert ranker == "colbert"
def test_create_object_ranker(self, mocker, mock_llm):
mocker.patch("metagpt.rag.factories.ranker.ObjectSortPostprocessor", return_value="object")
From dae7492b92697cd96406d7f13e61a9d22beb6412 Mon Sep 17 00:00:00 2001
From: yzlin