Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: RAG citation and function call #537

Merged
merged 10 commits into from
May 15, 2024
5 changes: 2 additions & 3 deletions camel/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========

from ..loaders.unstructured_io import UnstructuredIO
from .google_maps_function import MAP_FUNCS
from .math_functions import MATH_FUNCS
from .openai_function import (
OpenAIFunction,
get_openai_function_schema,
get_openai_tool_schema,
)
from .retrieval_functions import RETRIEVAL_FUNCS
from .search_functions import SEARCH_FUNCS
from .twitter_function import TWITTER_FUNCS
from .weather_functions import WEATHER_FUNCS
Expand All @@ -33,5 +32,5 @@
'WEATHER_FUNCS',
'MAP_FUNCS',
'TWITTER_FUNCS',
'UnstructuredIO',
'RETRIEVAL_FUNCS',
]
61 changes: 61 additions & 0 deletions camel/functions/retrieval_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from typing import List, Union

from camel.functions import OpenAIFunction
from camel.retrievers import AutoRetriever
from camel.types import StorageType


def information_retrieval(
query: str, content_input_paths: Union[str, List[str]]
) -> str:
r"""Retrieves information from a local vector storage based on the
specified query. This function connects to a local vector storage system
and retrieves relevant information by processing the input query. It is
essential to use this function when the answer to a question requires
external knowledge sources.

Args:
query (str): The question or query for which an answer is required.
content_input_paths (Union[str, List[str]]): Paths to local
files or remote URLs.

Returns:
str: The information retrieved in response to the query, aggregated
and formatted as a string.

Example:
# Retrieve information about CAMEL AI.
information_retrieval(query = "what is CAMEL AI?",
content_input_paths="https://www.camel-ai.org/")
"""
auto_retriever = AutoRetriever(
vector_storage_local_path="camel/temp_storage",
zechengz marked this conversation as resolved.
Show resolved Hide resolved
storage_type=StorageType.QDRANT,
)

retrieved_info = auto_retriever.run_vector_retriever(
query=query, content_input_paths=content_input_paths, top_k=3
)
return retrieved_info


# add the function to OpenAIFunction list
RETRIEVAL_FUNCS: List[OpenAIFunction] = [
OpenAIFunction(func)
for func in [
information_retrieval,
]
]
68 changes: 36 additions & 32 deletions camel/retrievers/auto_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def _initialize_vector_storage(
self,
collection_name: Optional[str] = None,
) -> BaseVectorStorage:
r"""Sets up and returns a vector storage instance with specified parameters.
r"""Sets up and returns a vector storage instance with specified
parameters.

Args:
collection_name (Optional[str]): Name of the collection in the
Expand Down Expand Up @@ -195,7 +196,8 @@ def run_vector_retriever(
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
return_detailed_info: bool = False,
) -> str:
r"""Executes the automatic vector retriever process using vector storage.
r"""Executes the automatic vector retriever process using vector
storage.

Args:
query (str): Query string for information retriever.
Expand Down Expand Up @@ -233,9 +235,7 @@ def run_vector_retriever(

vr = VectorRetriever()

retrieved_infos = ""
retrieved_infos_text = ""

all_retrieved_info = []
for content_input_path in content_input_paths:
# Generate a valid collection name
collection_name = self._collection_name_generator(
Expand Down Expand Up @@ -290,37 +290,41 @@ def run_vector_retriever(
)
# Retrieve info by given query from the vector storage
retrieved_info = vr.query(query, top_k)
# Reorganize the retrieved info with original query
for info in retrieved_info:
retrieved_infos += "\n" + str(info)
retrieved_infos_text += "\n" + str(info['text'])
output = (
"Original Query:"
+ "\n"
+ "{"
+ query
+ "}"
+ "\n"
+ "Retrieved Context:"
+ retrieved_infos
)
output_text = (
"Original Query:"
+ "\n"
+ "{"
+ query
+ "}"
+ "\n"
+ "Retrieved Context:"
+ retrieved_infos_text
)

all_retrieved_info.extend(retrieved_info)
except Exception as e:
raise RuntimeError(
f"Error in auto vector retriever processing: {e!s}"
) from e

# Split records into those with and without a 'similarity_score'
Wendong-Fan marked this conversation as resolved.
Show resolved Hide resolved
# Records with 'similarity_score' lower than 'similarity_threshold'
# will not have a 'similarity_score' in the output content
with_score = [
info for info in all_retrieved_info if 'similarity score' in info
]
without_score = [
info
for info in all_retrieved_info
if 'similarity score' not in info
]
# Sort only the list with scores
with_score_sorted = sorted(
with_score, key=lambda x: x['similarity score'], reverse=True
)
# Merge back the sorted scored items with the non-scored items
all_retrieved_info_sorted = with_score_sorted + without_score
# Select the 'top_k' results
all_retrieved_info = all_retrieved_info_sorted[:top_k]

retrieved_infos = "\n".join(str(info) for info in all_retrieved_info)
retrieved_infos_text = "\n".join(
info['text'] for info in all_retrieved_info if 'text' in info
)

detailed_info = f"Original Query:\n{{ {query} }}\nRetrieved Context:\n{retrieved_infos}"
text_info = f"Original Query:\n{{ {query} }}\nRetrieved Context:\n{retrieved_infos_text}"

if return_detailed_info:
return output
return detailed_info
else:
return output_text
return text_info
2 changes: 1 addition & 1 deletion camel/retrievers/vector_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, Dict, List, Optional

from camel.embeddings import BaseEmbedding, OpenAIEmbedding
from camel.functions import UnstructuredIO
from camel.loaders import UnstructuredIO
from camel.retrievers.base import BaseRetriever
from camel.storages import (
BaseVectorStorage,
Expand Down
6 changes: 4 additions & 2 deletions test/storages/graph_storages/test_neo4j_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,10 @@ def test_neo4j_timeout() -> None:
except Exception as e:
assert (
e.code # type: ignore[attr-defined]
== "Neo.ClientError.Transaction."
"TransactionTimedOutClientConfiguration"
in [
"Neo.ClientError.Transaction.TransactionTimedOutClientConfiguration",
"Neo.ClientError.Transaction.LockClientStopped",
]
)


Expand Down