-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
--------- Co-authored-by: Jerry Liu <jerryjliu98@gmail.com>
- Loading branch information
Showing
3 changed files
with
180 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# LoadAndSearch Tool | ||
|
||
This Tool Spec is intended to wrap other tools, allowing the Agent to perform seperate loading and reading of data. This is very useful for when tools return information larger than or closer to the size of the context window. | ||
|
||
|
||
## Usage | ||
|
||
Here's an example usage of the LoadAndSearchToolSpec. | ||
|
||
```python | ||
from llama_hub.tools.tool_spec.load_and_search.base import LoadAndSearchToolSpec | ||
from llama_index.agent import OpenAIAgent | ||
from llama_hub.tools.tool_spec.wikipedia.base import WikipediaToolSpec | ||
|
||
wiki_spec = WikipediaToolSpec() | ||
|
||
# Get the search_data tool from the wikipedia tool spec | ||
tool = wiki_spec.to_tool_list()[1] | ||
|
||
# Wrap the tool, spliting into a loader and a reader | ||
agent = OpenAIAgent.from_tools( | ||
LoadAndSearchToolSpec.from_defaults( | ||
tool | ||
).to_tool_list(), verbose=True) | ||
|
||
agent.chat('who is ben affleck married to') | ||
``` | ||
|
||
`load`: Calls the wrapped function and loads the data into an index | ||
`read`: Searches the index for the specified query | ||
|
||
This loader is designed to be used as a way to load data as a Tool in a Agent. See [here](https://github.com/emptycrown/llama-hub/tree/main) for examples. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# init |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
"""Ad-hoc data loader tool. | ||
Tool that wraps any data loader, and is able to load data on-demand. | ||
""" | ||
|
||
|
||
from llama_index.tools.types import ToolMetadata | ||
from typing import Any, Optional, Dict, Type, List | ||
from llama_index.indices.base import BaseIndex | ||
from llama_index.indices.vector_store import VectorStoreIndex | ||
from llama_index.tools.utils import create_schema_from_function | ||
from pydantic import BaseModel | ||
from llama_index.tools.function_tool import FunctionTool | ||
from llama_index.tools.tool_spec.base import BaseToolSpec | ||
|
||
|
||
class LoadAndSearchToolSpec(BaseToolSpec): | ||
"""Load and Search Tool | ||
This tool can be used with other tools that load large amounts of | ||
information. Compared to OndemandLoaderTool this returns two tools, | ||
one to retrieve data to an index and another to allow the Agent to search | ||
the retrieved data with a natural language query string. | ||
""" | ||
|
||
loader_prompt = """ | ||
Use this tool to load data from the following function. It must then be read from | ||
the corresponding read_{} function. | ||
{} | ||
""" # noqa: E501 | ||
|
||
# TODO, more general read prompt, not always natural language? | ||
reader_prompt = """ | ||
Once data has been loaded from {} it can then be read using a natural | ||
language query from this function. | ||
You are required to pass the natural language query argument when calling this endpoint | ||
Args: | ||
query (str): The natural language query used to retreieve information from the index | ||
""" # noqa: E501 | ||
|
||
def __init__( | ||
self, | ||
tool: FunctionTool, | ||
index_cls: Type[BaseIndex], | ||
index_kwargs: Dict, | ||
metadata: ToolMetadata, | ||
index: Optional[BaseIndex] = None, | ||
) -> None: | ||
"""Init params.""" | ||
self._index_cls = index_cls | ||
self._index_kwargs = index_kwargs | ||
self._index = index | ||
self._metadata = metadata | ||
self._tool = tool | ||
|
||
if self._metadata.name is None: | ||
raise ValueError("Tool name cannot be None") | ||
self.spec_functions = [ | ||
self._metadata.name, | ||
"read_{}".format(self._metadata.name), | ||
] | ||
self._tool_list = [ | ||
FunctionTool.from_defaults( | ||
fn=self.load, | ||
name=self._metadata.name, | ||
description=self.loader_prompt.format( | ||
self._metadata.name, self._metadata.description | ||
), | ||
fn_schema=self._metadata.fn_schema, | ||
), | ||
FunctionTool.from_defaults( | ||
fn=self.read, | ||
name=str("read_{}".format(self._metadata.name)), | ||
description=self.reader_prompt.format(metadata.name), | ||
fn_schema=create_schema_from_function("ReadData", self.read), | ||
), | ||
] | ||
|
||
@property | ||
def metadata(self) -> ToolMetadata: | ||
return self._metadata | ||
|
||
@classmethod | ||
def from_defaults( | ||
cls, | ||
tool: FunctionTool, | ||
index_cls: Optional[Type[BaseIndex]] = None, | ||
index_kwargs: Optional[Dict] = None, | ||
name: Optional[str] = None, | ||
description: Optional[str] = None, | ||
fn_schema: Optional[Type[BaseModel]] = None, | ||
) -> "LoadAndSearchToolSpec": | ||
"""From defaults.""" | ||
|
||
index_cls = index_cls or VectorStoreIndex | ||
index_kwargs = index_kwargs or {} | ||
if name is None: | ||
name = tool.metadata.name | ||
if description is None: | ||
description = tool.metadata.description | ||
if fn_schema is None: | ||
fn_schema = tool.metadata.fn_schema | ||
metadata = ToolMetadata(name=name, description=description, fn_schema=fn_schema) | ||
return cls( | ||
tool=tool, | ||
index_cls=index_cls, | ||
index_kwargs=index_kwargs, | ||
metadata=metadata, | ||
) | ||
|
||
def to_tool_list( | ||
self, | ||
func_to_metadata_mapping: Optional[Dict[str, ToolMetadata]] = None, | ||
) -> List[FunctionTool]: | ||
return self._tool_list | ||
|
||
def load(self, *args: Any, **kwargs: Any) -> Any: | ||
# Call the wrapped tool and save the result in the index | ||
docs = self._tool(*args, **kwargs).raw_output | ||
if self._index: | ||
for doc in docs: | ||
self._index.insert(doc, **self._index_kwargs) | ||
else: | ||
self._index = self._index_cls.from_documents(docs, **self._index_kwargs) | ||
return ( | ||
"Content loaded! You can now search the information using read_{}".format( | ||
self._metadata.name | ||
) | ||
) | ||
|
||
def read(self, query: str) -> Any: | ||
# Query the index for the result | ||
if not self._index: | ||
err_msg = ( | ||
"Error: No content has been loaded into the index. " | ||
"You must call {} first" | ||
).format(self._metadata.name) | ||
return err_msg | ||
query_engine = self._index.as_query_engine() | ||
response = query_engine.query(query) | ||
return str(response) |