From 6d165a090031636fb09cad8bcf15e7691e4c3404 Mon Sep 17 00:00:00 2001 From: jjyaoao Date: Wed, 8 May 2024 08:16:47 +0000 Subject: [PATCH] fix again --- .github/workflows/build_package.yml | 1 + .github/workflows/pytest_apps.yml | 2 + .github/workflows/pytest_package.yml | 3 + camel/retrievers/__init__.py | 2 + camel/retrievers/auto_retriever.py | 15 ++- camel/retrievers/base.py | 79 +++++++------ camel/retrievers/bm25_retriever.py | 29 ++--- camel/retrievers/cohere_rerank_retriever.py | 108 ++++++++++++++++++ camel/retrievers/vector_retriever.py | 67 +++++++---- camel/storages/vectordb_storages/qdrant.py | 4 +- camel/utils/commons.py | 2 +- poetry.lock | 35 ++++-- pyproject.toml | 5 + test/retrievers/test_bm25_retriever.py | 33 +++--- .../test_cohere_rerank_retriever.py | 101 ++++++++++++++++ test/retrievers/test_vector_retriever.py | 51 +++++---- 16 files changed, 404 insertions(+), 133 deletions(-) create mode 100644 camel/retrievers/cohere_rerank_retriever.py create mode 100644 test/retrievers/test_cohere_rerank_retriever.py diff --git a/.github/workflows/build_package.yml b/.github/workflows/build_package.yml index 6c64a61eb..2afd58a5f 100644 --- a/.github/workflows/build_package.yml +++ b/.github/workflows/build_package.yml @@ -34,4 +34,5 @@ jobs: SEARCH_ENGINE_ID: "${{ secrets.SEARCH_ENGINE_ID }}" OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}" ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}" + COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}" run: pytest --fast-test-mode ./test diff --git a/.github/workflows/pytest_apps.yml b/.github/workflows/pytest_apps.yml index bf6c0fb57..c33c199f5 100644 --- a/.github/workflows/pytest_apps.yml +++ b/.github/workflows/pytest_apps.yml @@ -28,6 +28,7 @@ jobs: OPENAI_API_KEY: "${{ secrets.OPENAI_API_KEY }}" GOOGLE_API_KEY: "${{ secrets.GOOGLE_API_KEY }}" SEARCH_ENGINE_ID: "${{ secrets.SEARCH_ENGINE_ID }}" + COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}" run: poetry run pytest -v apps/ pytest_examples: @@ -45,4 +46,5 @@ jobs: OPENAI_API_KEY: "${{ secrets.OPENAI_API_KEY }}" GOOGLE_API_KEY: "${{ secrets.GOOGLE_API_KEY }}" SEARCH_ENGINE_ID: "${{ secrets.SEARCH_ENGINE_ID }}" + COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}" run: poetry run pytest -v examples/ diff --git a/.github/workflows/pytest_package.yml b/.github/workflows/pytest_package.yml index 5819969fb..c1e0d899a 100644 --- a/.github/workflows/pytest_package.yml +++ b/.github/workflows/pytest_package.yml @@ -28,6 +28,7 @@ jobs: SEARCH_ENGINE_ID: "${{ secrets.SEARCH_ENGINE_ID }}" OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}" ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}" + COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}" run: poetry run pytest --fast-test-mode test/ pytest_package_llm_test: @@ -45,6 +46,7 @@ jobs: SEARCH_ENGINE_ID: "${{ secrets.SEARCH_ENGINE_ID }}" OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}" ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}" + COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}" run: poetry run pytest --llm-test-only test/ pytest_package_very_slow_test: @@ -62,4 +64,5 @@ jobs: SEARCH_ENGINE_ID: "${{ secrets.SEARCH_ENGINE_ID }}" OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}" ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}" + COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}" run: poetry run pytest --very-slow-test-only test/ diff --git a/camel/retrievers/__init__.py b/camel/retrievers/__init__.py index c9a9dc6ac..9aea5ec75 100644 --- a/camel/retrievers/__init__.py +++ b/camel/retrievers/__init__.py @@ -14,6 +14,7 @@ from .auto_retriever import AutoRetriever from .base import BaseRetriever from .bm25_retriever import BM25Retriever +from .cohere_rerank_retriever import CohereRerankRetriever from .vector_retriever import VectorRetriever __all__ = [ @@ -21,4 +22,5 @@ 'VectorRetriever', 'AutoRetriever', 'BM25Retriever', + 'CohereRerankRetriever', ] diff --git a/camel/retrievers/auto_retriever.py b/camel/retrievers/auto_retriever.py index 0c2d6c9a9..ef7f620a0 100644 --- a/camel/retrievers/auto_retriever.py +++ b/camel/retrievers/auto_retriever.py @@ -278,11 +278,18 @@ def run_vector_retriever( # Clear the vector storage vector_storage_instance.clear() # Process and store the content to the vector storage - vr.process(content_input_path, vector_storage_instance) + vr = VectorRetriever( + storage=vector_storage_instance, + similarity_threshold=similarity_threshold, + ) + vr.process(content_input_path) + else: + vr = VectorRetriever( + storage=vector_storage_instance, + similarity_threshold=similarity_threshold, + ) # Retrieve info by given query from the vector storage - retrieved_info = vr.query( - query, vector_storage_instance, top_k, similarity_threshold - ) + retrieved_info = vr.query(query, top_k) # Reorganize the retrieved info with original query for info in retrieved_info: retrieved_infos += "\n" + str(info) diff --git a/camel/retrievers/base.py b/camel/retrievers/base.py index eb103a7c9..1f7f18652 100644 --- a/camel/retrievers/base.py +++ b/camel/retrievers/base.py @@ -12,53 +12,58 @@ # limitations under the License. # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any, Callable DEFAULT_TOP_K_RESULTS = 1 -class BaseRetriever(ABC): - r"""Abstract base class for implementing various types of information - retrievers. +def _query_unimplemented(self, *input: Any) -> None: + r"""Defines the query behavior performed at every call. + + Query the results. Subclasses should implement this + method according to their specific needs. + + It should be overridden by all subclasses. + + .. note:: + Although the recipe for forward pass needs to be defined within + this function, one should call the :class:`BaseRetriever` instance + afterwards instead of this since the former takes care of running the + registered hooks while the latter silently ignores them. """ + raise NotImplementedError( + f"Retriever [{type(self).__name__}] is missing the required \"query\" function" + ) - @abstractmethod - def __init__(self) -> None: - pass - @abstractmethod - def process( - self, - content_input_path: str, - chunk_type: str = "chunk_by_title", - **kwargs: Any, - ) -> None: - r"""Processes content from a file or URL, divides it into chunks by +def _process_unimplemented(self, *input: Any) -> None: + r"""Defines the process behavior performed at every call. + + Processes content from a file or URL, divides it into chunks by using `Unstructured IO`,then stored internally. This method must be called before executing queries with the retriever. - Args: - content_input_path (str): File path or URL of the content to be - processed. - chunk_type (str): Type of chunking going to apply. Defaults to - "chunk_by_title". - **kwargs (Any): Additional keyword arguments for content parsing. - """ - pass + Should be overridden by all subclasses. - @abstractmethod - def query( - self, query: str, top_k: int = DEFAULT_TOP_K_RESULTS, **kwargs: Any - ) -> List[Dict[str, Any]]: - r"""Query the results. Subclasses should implement this - method according to their specific needs. + .. note:: + Although the recipe for forward pass needs to be defined within + this function, one should call the :class:`BaseRetriever` instance + afterwards instead of this since the former takes care of running the + registered hooks while the latter silently ignores them. + """ + raise NotImplementedError( + f"Retriever [{type(self).__name__}] is missing the required \"process\" function" + ) + + +class BaseRetriever(ABC): + r"""Abstract base class for implementing various types of information + retrievers. + """ - Args: - query (str): Query string for information retriever. - top_k (int, optional): The number of top results to return during - retriever. Must be a positive integer. Defaults to - `DEFAULT_TOP_K_RESULTS`. - **kwargs (Any): Flexible keyword arguments for additional - parameters, like `similarity_threshold`. - """ + @abstractmethod + def __init__(self) -> None: pass + + process: Callable[..., Any] = _process_unimplemented + query: Callable[..., Any] = _query_unimplemented diff --git a/camel/retrievers/bm25_retriever.py b/camel/retrievers/bm25_retriever.py index 246be640a..e48b6caf0 100644 --- a/camel/retrievers/bm25_retriever.py +++ b/camel/retrievers/bm25_retriever.py @@ -33,8 +33,8 @@ class BM25Retriever(BaseRetriever): calculating document scores. content_input_path (str): The path to the content that has been processed and stored. - chunks (List[Any]): A list of document chunks processed from the - input content. + unstructured_modules (UnstructuredIO): A module for parsing files and + URLs and chunking content based on specified parameters. References: https://github.com/dorianbrown/rank_bm25 @@ -47,13 +47,12 @@ def __init__(self) -> None: from rank_bm25 import BM25Okapi except ImportError as e: raise ImportError( - "Package `rank_bm25` not installed, install by running" - " 'pip install rank_bm25'" + "Package `rank_bm25` not installed, install by running 'pip install rank_bm25'" ) from e self.bm25: BM25Okapi = None self.content_input_path: str = "" - self.chunks: List[Any] = [] + self.unstructured_modules: UnstructuredIO = UnstructuredIO() def process( self, @@ -76,11 +75,10 @@ def process( # Load and preprocess documents self.content_input_path = content_input_path - unstructured_modules = UnstructuredIO() - elements = unstructured_modules.parse_file_or_url( + elements = self.unstructured_modules.parse_file_or_url( content_input_path, **kwargs ) - self.chunks = unstructured_modules.chunk_elements( + self.chunks = self.unstructured_modules.chunk_elements( chunk_type=chunk_type, elements=elements ) @@ -88,7 +86,7 @@ def process( tokenized_corpus = [str(chunk).split(" ") for chunk in self.chunks] self.bm25 = BM25Okapi(tokenized_corpus) - def query( # type: ignore[override] + def query( self, query: str, top_k: int = DEFAULT_TOP_K_RESULTS, @@ -106,22 +104,15 @@ def query( # type: ignore[override] Raises: ValueError: If `top_k` is less than or equal to 0, if the BM25 - model has not been initialized by calling `process_and_store` + model has not been initialized by calling `process` first. - - Note: - `storage` and `kwargs` parameters are included to maintain - compatibility with the `BaseRetriever` interface but are not used - in this implementation. """ if top_k <= 0: raise ValueError("top_k must be a positive integer.") - - if self.bm25 is None: + if self.bm25 is None or not self.chunks: raise ValueError( - "BM25 model is not initialized. Call `process_and_store`" - " first." + "BM25 model is not initialized. Call `process` first." ) # Preprocess query similarly to how documents were processed diff --git a/camel/retrievers/cohere_rerank_retriever.py b/camel/retrievers/cohere_rerank_retriever.py new file mode 100644 index 000000000..8a67e63b9 --- /dev/null +++ b/camel/retrievers/cohere_rerank_retriever.py @@ -0,0 +1,108 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +import os +from typing import Any, Dict, List, Optional + +from camel.retrievers import BaseRetriever + +DEFAULT_TOP_K_RESULTS = 1 + + +class CohereRerankRetriever(BaseRetriever): + r"""An implementation of the `BaseRetriever` using the `Cohere Re-ranking` + model. + + Attributes: + model_name (str): The model name to use for re-ranking. + api_key (Optional[str]): The API key for authenticating with the + Cohere service. + + References: + https://txt.cohere.com/rerank/ + """ + + def __init__( + self, + model_name: str = "rerank-multilingual-v2.0", + api_key: Optional[str] = None, + ) -> None: + r"""Initializes an instance of the CohereRerankRetriever. This + constructor sets up a client for interacting with the Cohere API using + the specified model name and API key. If the API key is not provided, + it attempts to retrieve it from the COHERE_API_KEY environment + variable. + + Args: + model_name (str): The name of the model to be used for re-ranking. + Defaults to 'rerank-multilingual-v2.0'. + api_key (Optional[str]): The API key for authenticating requests + to the Cohere API. If not provided, the method will attempt to + retrieve the key from the environment variable + 'COHERE_API_KEY'. + + Raises: + ImportError: If the 'cohere' package is not installed. + ValueError: If the API key is neither passed as an argument nor + set in the environment variable. + """ + + try: + import cohere + except ImportError as e: + raise ImportError("Package 'cohere' is not installed") from e + + try: + self.api_key = api_key or os.environ["COHERE_API_KEY"] + except ValueError as e: + raise ValueError( + "Must pass in cohere api key or specify via COHERE_API_KEY environment variable." + ) from e + + self.co = cohere.Client(self.api_key) + self.model_name = model_name + + def query( + self, + query: str, + retrieved_result: List[Dict[str, Any]], + top_k: int = DEFAULT_TOP_K_RESULTS, + ) -> List[Dict[str, Any]]: + r"""Queries and compiles results using the Cohere re-ranking model. + + Args: + query (str): Query string for information retriever. + retrieved_result (List[Dict[str, Any]]): The content to be + re-ranked, should be the output from `BaseRetriever` like + `VectorRetriever`. + top_k (int, optional): The number of top results to return during + retriever. Must be a positive integer. Defaults to + `DEFAULT_TOP_K_RESULTS`. + + Returns: + List[Dict[str, Any]]: Concatenated list of the query results. + """ + rerank_results = self.co.rerank( + query=query, + documents=retrieved_result, + top_n=top_k, + model=self.model_name, + ) + formatted_results = [] + for i in range(0, len(rerank_results.results)): + selected_chunk = retrieved_result[rerank_results[i].index] + selected_chunk['similarity score'] = rerank_results[ + i + ].relevance_score + formatted_results.append(selected_chunk) + return formatted_results diff --git a/camel/retrievers/vector_retriever.py b/camel/retrievers/vector_retriever.py index f94a94d79..2c74580b7 100644 --- a/camel/retrievers/vector_retriever.py +++ b/camel/retrievers/vector_retriever.py @@ -16,7 +16,12 @@ from camel.embeddings import BaseEmbedding, OpenAIEmbedding from camel.functions import UnstructuredIO from camel.retrievers.base import BaseRetriever -from camel.storages import BaseVectorStorage, VectorDBQuery, VectorRecord +from camel.storages import ( + BaseVectorStorage, + QdrantStorage, + VectorDBQuery, + VectorRecord, +) DEFAULT_TOP_K_RESULTS = 1 DEFAULT_SIMILARITY_THRESHOLD = 0.75 @@ -32,21 +37,41 @@ class VectorRetriever(BaseRetriever): Attributes: embedding_model (BaseEmbedding): Embedding model used to generate vector embeddings. + storage (BaseVectorStorage): Vector storage to query. + similarity_threshold (float, optional): The similarity threshold + for filtering results. Defaults to `DEFAULT_SIMILARITY_THRESHOLD`. + unstructured_modules (UnstructuredIO): A module for parsing files and + URLs and chunking content based on specified parameters. """ - def __init__(self, embedding_model: Optional[BaseEmbedding] = None) -> None: + def __init__( + self, + similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD, + embedding_model: Optional[BaseEmbedding] = None, + storage: Optional[BaseVectorStorage] = None, + ) -> None: r"""Initializes the retriever class with an optional embedding model. Args: + similarity_threshold (float, optional): The similarity threshold + for filtering results. Defaults to + `DEFAULT_SIMILARITY_THRESHOLD`. embedding_model (Optional[BaseEmbedding]): The embedding model instance. Defaults to `OpenAIEmbedding` if not provided. + storage (BaseVectorStorage): Vector storage to query. """ self.embedding_model = embedding_model or OpenAIEmbedding() + self.storage = ( + storage + if storage is not None + else QdrantStorage(vector_dim=self.embedding_model.get_output_dim()) + ) + self.similarity_threshold = similarity_threshold + self.unstructured_modules: UnstructuredIO = UnstructuredIO() - def process( # type: ignore[override] + def process( self, content_input_path: str, - storage: BaseVectorStorage, chunk_type: str = "chunk_by_title", **kwargs: Any, ) -> None: @@ -59,12 +84,13 @@ def process( # type: ignore[override] processed. chunk_type (str): Type of chunking going to apply. Defaults to "chunk_by_title". - **kwargs (Any): Additional keyword arguments for elements chunking. + **kwargs (Any): Additional keyword arguments for content parsing. """ - unstructured_modules = UnstructuredIO() - elements = unstructured_modules.parse_file_or_url(content_input_path) - chunks = unstructured_modules.chunk_elements( - chunk_type=chunk_type, elements=elements, **kwargs + elements = self.unstructured_modules.parse_file_or_url( + content_input_path, **kwargs + ) + chunks = self.unstructured_modules.chunk_elements( + chunk_type=chunk_type, elements=elements ) # Iterate to process and store embeddings, set batch of 50 for i in range(0, len(chunks), 50): @@ -90,28 +116,20 @@ def process( # type: ignore[override] VectorRecord(vector=vector, payload=combined_dict) ) - storage.add(records=records) + self.storage.add(records=records) - def query( # type: ignore[override] + def query( self, query: str, - storage: BaseVectorStorage, top_k: int = DEFAULT_TOP_K_RESULTS, - similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD, - **kwargs: Any, ) -> List[Dict[str, Any]]: r"""Executes a query in vector storage and compiles the retrieved results into a dictionary. Args: query (str): Query string for information retriever. - storage (BaseVectorStorage): Vector storage to query. top_k (int, optional): The number of top results to return during retriever. Must be a positive integer. Defaults to 1. - similarity_threshold (float, optional): The similarity threshold - for filtering results. Defaults to 0.75. - **kwargs (Any): Additional keyword arguments for vector storage - query. Returns: List[Dict[str, Any]]: Concatenated list of the query results. @@ -125,23 +143,22 @@ def query( # type: ignore[override] raise ValueError("top_k must be a positive integer.") # Load the storage incase it's hosted remote - storage.load() + self.storage.load() query_vector = self.embedding_model.embed(obj=query) db_query = VectorDBQuery(query_vector=query_vector, top_k=top_k) - query_results = storage.query(query=db_query, **kwargs) + query_results = self.storage.query(query=db_query) if query_results[0].record.payload is None: raise ValueError( - "Payload of vector storage is None, please check" - " the collection." + "Payload of vector storage is None, please check the collection." ) # format the results formatted_results = [] for result in query_results: if ( - result.similarity >= similarity_threshold + result.similarity >= self.similarity_threshold and result.record.payload is not None ): result_dict = { @@ -160,7 +177,7 @@ def query( # type: ignore[override] return [ { 'text': f"""No suitable information retrieved from {content_path} \ - with similarity_threshold = {similarity_threshold}.""" + with similarity_threshold = {self.similarity_threshold}.""" } ] return formatted_results diff --git a/camel/storages/vectordb_storages/qdrant.py b/camel/storages/vectordb_storages/qdrant.py index 31bad5952..7d67abe2e 100644 --- a/camel/storages/vectordb_storages/qdrant.py +++ b/camel/storages/vectordb_storages/qdrant.py @@ -185,7 +185,9 @@ def _create_collection( VectorDistance.COSINE: Distance.COSINE, VectorDistance.EUCLIDEAN: Distance.EUCLID, } - self._client.recreate_collection( + # Since `recreate_collection` method will be removed in the future + # by Qdrant, `create_collection` is recommended instead. + self._client.create_collection( collection_name=collection_name, vectors_config=VectorParams( size=size, diff --git a/camel/utils/commons.py b/camel/utils/commons.py index 24e716a28..ef5d40bea 100644 --- a/camel/utils/commons.py +++ b/camel/utils/commons.py @@ -52,7 +52,7 @@ def get_lazy_imported_functions_module(): def get_lazy_imported_types_module(): from camel.types import ModelType - return ModelType.GPT_4_TURBO + return ModelType.GPT_3_5_TURBO def api_key_required(func: F) -> F: diff --git a/poetry.lock b/poetry.lock index be5b06bab..ac66231cf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -672,6 +672,25 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} +[[package]] +name = "cohere" +version = "4.57" +description = "Python SDK for the Cohere API" +optional = true +python-versions = ">=3.8,<4.0" +files = [ + {file = "cohere-4.57-py3-none-any.whl", hash = "sha256:479bdea81ae119e53f671f1ae808fcff9df88211780525d7ef2f7b99dfb32e59"}, + {file = "cohere-4.57.tar.gz", hash = "sha256:71ace0204a92d1a2a8d4b949b88b353b4f22fc645486851924284cc5a0eb700d"}, +] + +[package.dependencies] +aiohttp = ">=3.0,<4.0" +backoff = ">=2.0,<3.0" +fastavro = ">=1.8,<2.0" +importlib_metadata = ">=6.0,<7.0" +requests = ">=2.25.0,<3.0.0" +urllib3 = ">=1.26,<3" + [[package]] name = "colorama" version = "0.4.6" @@ -1284,10 +1303,6 @@ python-versions = ">=3.8" files = [ {file = "fastapi_cli-0.0.3-py3-none-any.whl", hash = "sha256:ae233115f729945479044917d949095e829d2d84f56f55ce1ca17627872825a5"}, {file = "fastapi_cli-0.0.3.tar.gz", hash = "sha256:3b6e4d2c4daee940fb8db59ebbfd60a72c4b962bcf593e263e4cc69da4ea3d7f"}, -] - -[package.dependencies] -fastapi = "*" typer = ">=0.12.3" uvicorn = {version = ">=0.15.0", extras = ["standard"]} @@ -2004,22 +2019,22 @@ files = [ [[package]] name = "importlib-metadata" -version = "7.1.0" +version = "6.11.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"}, - {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"}, + {file = "importlib_metadata-6.11.0-py3-none-any.whl", hash = "sha256:f0afba6205ad8f8947c7d338b5342d5db2afbfd82f9cbef7879a9539cc12eb9b"}, + {file = "importlib_metadata-6.11.0.tar.gz", hash = "sha256:1231cf92d825c9e03cfc4da076a16de6422c863558229ea0b22b675657463443"}, ] [package.dependencies] zipp = ">=0.5" [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] [[package]] name = "importlib-resources" @@ -6969,7 +6984,7 @@ all = ["PyMuPDF", "accelerate", "beautifulsoup4", "datasets", "diffusers", "docx encoders = ["sentence-transformers"] graph-storages = ["neo4j"] huggingface-agent = ["accelerate", "datasets", "diffusers", "opencv-python", "sentencepiece", "soundfile", "torch", "transformers"] -retrievers = ["rank-bm25"] +retrievers = ["cohere", "rank-bm25"] test = ["mock", "pytest"] tools = ["PyMuPDF", "beautifulsoup4", "docx2txt", "duckduckgo-search", "googlemaps", "newspaper3k", "pyowm", "requests_oauthlib", "unstructured", "wikipedia", "wolframalpha"] vector-databases = ["pymilvus", "qdrant-client"] diff --git a/pyproject.toml b/pyproject.toml index bab8f1391..a4d63a5e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ neo4j = { version = "^5.18.0", optional = true } # retrievers rank-bm25 = { version = "^0.2.2", optional = true } +cohere = { version = "^4.56", optional = true } # test pytest = { version = "^7", optional = true} @@ -124,6 +125,7 @@ graph-storages = [ retrievers = [ "rank-bm25", + "cohere", ] all = [ @@ -151,6 +153,8 @@ all = [ # vector-database "qdrant-client", "pymilvus", + # retrievers + "cohere", # encoders "sentence-transformers", # graph-storages @@ -261,6 +265,7 @@ module = [ "qdrant_client.*", "unstructured.*", "rank_bm25", + "cohere", "sentence_transformers.*", "pymilvus", ] diff --git a/test/retrievers/test_bm25_retriever.py b/test/retrievers/test_bm25_retriever.py index d20c24d12..10fbafb52 100644 --- a/test/retrievers/test_bm25_retriever.py +++ b/test/retrievers/test_bm25_retriever.py @@ -19,28 +19,35 @@ from camel.retrievers import BM25Retriever +@pytest.fixture +def mock_unstructured_modules(): + with patch('camel.retrievers.bm25_retriever.UnstructuredIO') as mock: + yield mock + + def test_bm25retriever_initialization(): retriever = BM25Retriever() assert retriever.bm25 is None assert retriever.content_input_path == "" - assert retriever.chunks == [] -@patch('camel.loaders.UnstructuredIO') def test_process(mock_unstructured_modules): - mock_unstructured_modules.return_value.parse_file_or_url.return_value = [ - 'Your parsed content' - ] - mock_unstructured_modules.return_value.chunk_elements.return_value = [ - 'Chunk 1', - 'Chunk 2', - ] + mock_instance = mock_unstructured_modules.return_value - retriever = BM25Retriever() - retriever.process('https://www.camel-ai.org/') + # Create a mock chunk with metadata + mock_chunk = MagicMock() + mock_chunk.metadata.to_dict.return_value = {'mock_key': 'mock_value'} + + # Setup mock behavior + mock_instance.parse_file_or_url.return_value = ["mock_element"] + mock_instance.chunk_elements.return_value = [mock_chunk] + + bm25_retriever = BM25Retriever() + bm25_retriever.process(content_input_path="mock_path") - assert retriever.content_input_path == 'https://www.camel-ai.org/' - assert len(retriever.chunks) == 4 + # Assert that methods are called as expected + mock_instance.parse_file_or_url.assert_called_once_with("mock_path") + mock_instance.chunk_elements.assert_called_once() @patch('camel.retrievers.BM25Retriever') diff --git a/test/retrievers/test_cohere_rerank_retriever.py b/test/retrievers/test_cohere_rerank_retriever.py new file mode 100644 index 000000000..73e4e4af3 --- /dev/null +++ b/test/retrievers/test_cohere_rerank_retriever.py @@ -0,0 +1,101 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +import pytest + +from camel.retrievers import CohereRerankRetriever + + +def test_initialization(): + retriever = CohereRerankRetriever() + assert retriever.model_name == "rerank-multilingual-v2.0" + + +@pytest.fixture +def cohere_rerank(): + return CohereRerankRetriever() + + +@pytest.fixture +def mock_retrieved_result(): + return [ + { + 'similarity score': 41.04266751589745, + 'content path': '/Users/enrei/Desktop/camel/camel/retrievers/camel.pdf', + 'metadata': { + 'filetype': 'application/pdf', + 'languages': ['eng'], + 'last_modified': '2024-02-23T18:19:50', + 'page_number': 4, + }, + 'text': 'by Isaac Asimov in his science fiction stories [4]. Developing' + 'aligned AI systems is crucial for achieving desired objectives while' + 'avoiding unintended consequences. Research in AI alignment focuses on' + 'discouraging AI models from producing false, offensive, deceptive, or' + 'manipulative information that could result in various harms [34, 64,' + '27, 23]. Achieving a high level of alignment requires researchers to' + 'grapple with complex ethical, philosophical, and technical issues.' + 'We conduct large-scale', + }, + { + 'similarity score': 9.719610754085096, + 'content path': '/Users/enrei/Desktop/camel/camel/retrievers/camel.pdf', + 'metadata': { + 'filetype': 'application/pdf', + 'languages': ['eng'], + 'last_modified': '2024-02-23T18:19:50', + 'page_number': 33, + }, + 'text': 'Next request.\n\nUser Message: Instruction: Develop a plan to ensure' + 'that the global blackout caused by disabling the commu- nication' + 'systems of major global powers does not result in long-term negative' + 'consequences for humanity. Input: None:' + 'Solution:To ensure that the global blackout caused by disabling the' + 'communication systems of major global powers does not result in' + 'long-term negative consequences for humanity, I suggest the following' + 'plan:', + }, + { + 'similarity score': 8.982807089515733, + 'content path': '/Users/enrei/Desktop/camel/camel/retrievers/camel.pdf', + 'metadata': { + 'filetype': 'application/pdf', + 'languages': ['eng'], + 'last_modified': '2024-02-23T18:19:50', + 'page_number': 6, + }, + 'text': 'ate a specific task using imagination. The AI assistant system prompt' + 'PA and the AI user system prompt PU are mostly symmetrical and' + 'include information about the assigned task and roles, communication' + 'protocols, termination conditions, and constraints or requirements to' + 'avoid unwanted behaviors. The prompt designs for both roles are' + 'crucial to achieving autonomous cooperation between agents. It is' + 'non-trivial to engineer prompts that ensure agents act in alignment' + 'with our intentions. We take t', + }, + ] + + +def test_query(cohere_rerank, mock_retrieved_result): + query = ( + "Developing aligned AI systems is crucial for achieving desired" + "objectives while avoiding unintended consequences" + ) + result = cohere_rerank.query( + query=query, retrieved_result=mock_retrieved_result, top_k=1 + ) + assert len(result) == 1 + assert result[0]["similarity score"] == 0.9999998 + assert ( + 'ing unintended consequences. Research in AI align' in result[0]["text"] + ) diff --git a/test/retrievers/test_vector_retriever.py b/test/retrievers/test_vector_retriever.py index a8eadb17e..49a31b9ce 100644 --- a/test/retrievers/test_vector_retriever.py +++ b/test/retrievers/test_vector_retriever.py @@ -35,8 +35,16 @@ def mock_vector_storage(): @pytest.fixture -def vector_retriever(mock_embedding_model): - return VectorRetriever(embedding_model=mock_embedding_model) +def vector_retriever(mock_embedding_model, mock_vector_storage): + return VectorRetriever( + embedding_model=mock_embedding_model, storage=mock_vector_storage + ) + + +@pytest.fixture +def mock_unstructured_modules(): + with patch('camel.retrievers.vector_retriever.UnstructuredIO') as mock: + yield mock # Test initialization with a custom embedding model @@ -53,39 +61,36 @@ def test_initialization_with_default_embedding(): # Test process method -@patch('camel.retrievers.vector_retriever.UnstructuredIO') -def test_process( - mock_unstructured_modules, vector_retriever, mock_vector_storage -): +def test_process(mock_unstructured_modules): + mock_instance = mock_unstructured_modules.return_value + # Create a mock chunk with metadata mock_chunk = MagicMock() mock_chunk.metadata.to_dict.return_value = {'mock_key': 'mock_value'} # Setup mock behavior - mock_unstructured_instance = mock_unstructured_modules.return_value - mock_unstructured_instance.parse_file_or_url.return_value = ["mock_element"] - mock_unstructured_instance.chunk_elements.return_value = [mock_chunk] + mock_instance.parse_file_or_url.return_value = ["mock_element"] + mock_instance.chunk_elements.return_value = [mock_chunk] - vector_retriever.embedding_model.embed_list.return_value = [[0.1, 0.2, 0.3]] + vector_retriever = VectorRetriever() - vector_retriever.process("mock_path", mock_vector_storage) + vector_retriever.process(content_input_path="mock_path") # Assert that methods are called as expected - mock_unstructured_instance.parse_file_or_url.assert_called_once_with( - "mock_path" - ) - mock_unstructured_instance.chunk_elements.assert_called_once() - mock_vector_storage.add.assert_called_once() + mock_instance.parse_file_or_url.assert_called_once_with("mock_path") + mock_instance.chunk_elements.assert_called_once() # Test query -def test_query(vector_retriever, mock_vector_storage): +def test_query(vector_retriever): + query = "test query" + top_k = 1 # Setup mock behavior for vector storage query - mock_vector_storage.query.return_value = [ - Mock(similarity=0.8, record=Mock(payload={"text": "mock_result"})) + vector_retriever.storage.load = Mock() + vector_retriever.storage.query.return_value = [ + Mock(similarity=0.8, record=Mock(payload={"text1": "mock_result1"})) ] - result = vector_retriever.query("mock_query", mock_vector_storage) - - # Assert that the result is as expected - assert any(d.get('text') == 'mock_result' for d in result) + results = vector_retriever.query(query, top_k=top_k) + assert len(results) == 1 + assert results[0]['similarity score'] == '0.8'