Skip to content

Commit

Permalink
Merge ee3dbc5 into 9b67611
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed May 3, 2023
2 parents 9b67611 + ee3dbc5 commit d125138
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 201 deletions.
13 changes: 4 additions & 9 deletions examples/agent_multihop_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
raise ValueError("Please set the SERPERDEV_API_KEY environment variable")

openai_key = os.environ.get("OPENAI_API_KEY")
if not search_key:
if not openai_key:
raise ValueError("Please set the OPENAI_API_KEY environment variable")


pn = PromptNode(
"text-davinci-003",
"gpt-3.5-turbo",
api_key=openai_key,
max_length=256,
default_prompt_template="question-answering-with-document-scores",
Expand Down Expand Up @@ -84,7 +84,7 @@
"""
few_shot_agent_template = PromptTemplate("few-shot-react", prompt_text=few_shot_prompt)
prompt_node = PromptNode(
"text-davinci-003", api_key=os.environ.get("OPENAI_API_KEY"), max_length=512, stop_words=["Observation:"]
"gpt-3.5-turbo", api_key=os.environ.get("OPENAI_API_KEY"), max_length=512, stop_words=["Observation:"]
)

web_qa_tool = Tool(
Expand All @@ -94,12 +94,7 @@
output_variable="results",
)

agent = Agent(
prompt_node=prompt_node,
prompt_template=few_shot_agent_template,
tools=[web_qa_tool],
final_answer_pattern=r"Final Answer\s*:\s*(.*)",
)
agent = Agent(prompt_node=prompt_node, prompt_template=few_shot_agent_template, tools=[web_qa_tool])

hotpot_questions = [
"What year was the father of the Princes in the Tower born?",
Expand Down
15 changes: 1 addition & 14 deletions haystack/agents/agent_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import re
from typing import Optional, Dict, Tuple, Any
from typing import Optional, Dict, Any

from haystack import Answer
from haystack.errors import AgentError
Expand Down Expand Up @@ -69,19 +69,6 @@ def create_next_step(self, prompt_node_response: Any) -> AgentStep:
transcript=self.transcript,
)

def extract_tool_name_and_tool_input(self, tool_pattern: str) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the tool name and the tool input from the PromptNode response.
:param tool_pattern: The regex pattern to extract the tool name and the tool input from the PromptNode response.
:return: A tuple containing the tool name and the tool input.
"""
tool_match = re.search(tool_pattern, self.prompt_node_response)
if tool_match:
tool_name = tool_match.group(1)
tool_input = tool_match.group(3)
return tool_name.strip('" []\n').strip(), tool_input.strip('" \n')
return None, None

def final_answer(self, query: str) -> Dict[str, Any]:
"""
Formats an answer as a dict containing `query` and `answers` similar to the output of a Pipeline.
Expand Down
208 changes: 110 additions & 98 deletions haystack/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import re
from hashlib import md5
from typing import List, Optional, Union, Dict, Any
from typing import List, Optional, Union, Dict, Any, Tuple

from events import Events

Expand Down Expand Up @@ -111,6 +111,97 @@ def _process_result(self, result: Any) -> str:
return str(result)


class ToolsManager:
"""
The ToolsManager manages tools for an Agent.
"""

def __init__(
self,
tools: Optional[List[Tool]] = None,
tool_pattern: str = r'Tool:\s*(\w+)\s*Tool Input:\s*("?)([^"\n]+)\2\s*',
):
"""
:param tools: A list of tools to add to the ToolManager. Each tool must have a unique name.
:param tool_pattern: A regular expression pattern that matches the text that the Agent generates to invoke
a tool.
"""
self._tools: Dict[str, Tool] = {tool.name: tool for tool in tools} if tools else {}
self.tool_pattern = tool_pattern
self.callback_manager = Events(("on_tool_start", "on_tool_finish", "on_tool_error"))

def add_tool(self, tool: Tool):
"""
Add a tool to the Agent. This also updates the PromptTemplate for the Agent's PromptNode with the tool name.
:param tool: The tool to add to the Agent. Any previously added tool with the same name will be overwritten.
Example:
`agent.add_tool(
Tool(
name="Calculator",
pipeline_or_node=calculator
description="Useful when you need to answer questions about math."
)
)
"""
self.tools[tool.name] = tool

@property
def tools(self):
return self._tools

def get_tool_names(self) -> str:
"""
Returns a string with the names of all registered tools.
"""
return ", ".join(self.tools.keys())

def get_tools(self) -> List[Tool]:
"""
Returns a list of all registered tool instances.
"""
return list(self.tools.values())

def get_tool_names_with_descriptions(self) -> str:
"""
Returns a string with the names and descriptions of all registered tools.
"""
return "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools.values()])

def run_tool(self, llm_response: str, params: Optional[Dict[str, Any]] = None) -> str:
tool_result: str = ""
if self.tools:
tool_name, tool_input = self.extract_tool_name_and_tool_input(llm_response)
if tool_name and tool_input:
tool: Tool = self.tools[tool_name]
try:
self.callback_manager.on_tool_start(tool_input, tool=tool)
tool_result = tool.run(tool_input, params)
self.callback_manager.on_tool_finish(
tool_result,
observation_prefix="Observation: ",
llm_prefix="Thought: ",
color=tool.logging_color,
)
except Exception as e:
self.callback_manager.on_tool_error(e, tool=self.tools[tool_name])
raise e
return tool_result

def extract_tool_name_and_tool_input(self, llm_response: str) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the tool name and the tool input from the PromptNode response.
:param llm_response: The PromptNode response.
:return: A tuple containing the tool name and the tool input.
"""
tool_match = re.search(self.tool_pattern, llm_response)
if tool_match:
tool_name = tool_match.group(1)
tool_input = tool_match.group(3)
return tool_name.strip('" []\n').strip(), tool_input.strip('" \n')
return None, None


class Agent:
"""
An Agent answers queries using the tools you give to it. The tools are pipelines or nodes. The Agent uses a large
Expand All @@ -133,8 +224,8 @@ def __init__(
prompt_template: Union[str, PromptTemplate] = "zero-shot-react",
tools: Optional[List[Tool]] = None,
max_steps: int = 8,
tool_pattern: str = r'Tool:\s*(\w+)\s*Tool Input:\s*("?)([^"\n]+)\2\s*',
final_answer_pattern: str = r"Final Answer\s*:\s*(.*)",
tool_pattern: Optional[str] = None,
):
"""
Creates an Agent instance.
Expand All @@ -144,40 +235,28 @@ def __init__(
:param prompt_template: The name of a PromptTemplate for the PromptNode. It's used for generating thoughts and
choosing tools to answer queries step-by-step. You can use the default `zero-shot-react` template or create a
new template in a similar format.
:param tools: A list of tools the Agent can run. If you don't specify any tools here, you must add them
with `add_tool()` before running the Agent.
:param tools: A list of tools to add to the Agent. Each tool must have a unique name. You can also add tools
with `add_tool()` before running the Agent.
:param max_steps: The number of times the Agent can run a tool +1 to let it infer it knows the final answer.
Set it to at least 2, so that the Agent can run one a tool once and then infer it knows the final answer.
The default is 5.
:param tool_pattern: A regular expression to extract the name of the tool and the corresponding input from the
text the Agent generated.
:param final_answer_pattern: A regular expression to extract the final answer from the text the Agent generated.
"""
self.callback_manager = Events(
(
"on_tool_start",
"on_tool_finish",
"on_tool_error",
"on_agent_start",
"on_agent_step",
"on_agent_finish",
"on_new_token",
)
)
self.max_steps = max_steps
if tool_pattern:
self.tm = ToolsManager(tools, tool_pattern=tool_pattern)
else:
self.tm = ToolsManager(tools)
self.callback_manager = Events(("on_agent_start", "on_agent_step", "on_agent_finish", "on_new_token"))
self.prompt_node = prompt_node
resolved_prompt_template = prompt_node.get_prompt_template(prompt_template)
if not resolved_prompt_template:
raise ValueError(
f"Prompt template '{prompt_template}' not found. Please check the spelling of the template name."
)
self.prompt_template = resolved_prompt_template
self.tools = {tool.name: tool for tool in tools} if tools else {}
self.tool_names = ", ".join(self.tools.keys())
self.tool_names_with_descriptions = "\n".join(
[f"{tool.name}: {tool.description}" for tool in self.tools.values()]
)
self.max_steps = max_steps
self.tool_pattern = tool_pattern
self.final_answer_pattern = final_answer_pattern
# Resolve model name to check if it's a streaming model
if isinstance(self.prompt_node.model_name_or_path, str):
Expand All @@ -195,7 +274,7 @@ def update_hash(self):
See haystack/telemetry.py::send_event
"""
try:
tool_names = " ".join([tool.pipeline_or_node.__class__.__name__ for tool in self.tools.values()])
tool_names = " ".join([tool.pipeline_or_node.__class__.__name__ for tool in self.tm.get_tools()])
self.hash = md5(tool_names.encode()).hexdigest()
except Exception as exc:
logger.debug("Telemetry exception: %s", str(exc))
Expand All @@ -217,7 +296,7 @@ def on_agent_start(**kwargs: Any) -> None:
agent_name = kwargs.pop("name", "react")
print_text(f"\nAgent {agent_name} started with {kwargs}\n")

self.callback_manager.on_tool_finish += on_tool_finish
self.tm.callback_manager.on_tool_finish += on_tool_finish
self.callback_manager.on_agent_start += on_agent_start

if streaming:
Expand All @@ -241,19 +320,15 @@ def add_tool(self, tool: Tool):
)
)
"""
self.tools[tool.name] = tool
self.tool_names = ", ".join(self.tools.keys())
self.tool_names_with_descriptions = "\n".join(
[f"{tool.name}: {tool.description}" for tool in self.tools.values()]
)
self.tm.add_tool(tool)

def has_tool(self, tool_name: str):
def has_tool(self, tool_name: str) -> bool:
"""
Check whether the Agent has a tool with the name you provide.
:param tool_name: The name of the tool for which you want to check whether the Agent has it.
"""
return tool_name in self.tools
return tool_name in self.tm.tools

def run(
self, query: str, max_steps: Optional[int] = None, params: Optional[dict] = None
Expand All @@ -279,11 +354,6 @@ def run(
except Exception as exc:
logger.debug("Telemetry exception: %s", exc)

if not self.tools:
raise AgentError(
"An Agent needs tools to run. Add at least one tool using `add_tool()` or set the parameter `tools` "
"when initializing the Agent."
)
if max_steps is None:
max_steps = self.max_steps
if max_steps < 2:
Expand Down Expand Up @@ -321,72 +391,12 @@ def _step(self, current_step: AgentStep, params: Optional[dict] = None):
self.callback_manager.on_agent_step(next_step)

# run the tool selected by the LLM
observation = self._run_tool(next_step, params) if not next_step.is_last() else None
observation = self.tm.run_tool(next_step.prompt_node_response, params) if not next_step.is_last() else None

# update the next step with the observation
next_step.completed(observation)
return next_step

def run_batch(
self, queries: List[str], max_steps: Optional[int] = None, params: Optional[dict] = None
) -> Dict[str, str]:
"""
Runs the Agent in a batch mode.
:param queries: List of search queries.
:param max_steps: The number of times the Agent can run a tool +1 to infer it knows the final answer.
If you want to set it, make it at least 2 so that the Agent can run a tool once and then infer it knows
the final answer.
:param params: A dictionary of parameters you want to pass to the tools that are pipelines.
To pass a parameter to all nodes in those pipelines, use the format: `{"top_k": 10}`.
To pass a parameter to targeted nodes in those pipelines, use the format:
`{"Retriever": {"top_k": 10}, "Reader": {"top_k": 3}}`.
You can only pass parameters to tools that are pipelines but not nodes.
"""
try:
if not self.hash == self.last_hash:
self.last_hash = self.hash
send_event(event_name="Agent", event_properties={"llm.agent_hash": self.hash})
except Exception as exc:
logger.debug("Telemetry exception: %s", exc)

results: Dict = {"queries": [], "answers": [], "transcripts": []}
for query in queries:
result = self.run(query=query, max_steps=max_steps, params=params)
results["queries"].append(result["query"])
results["answers"].append(result["answers"])
results["transcripts"].append(result["transcript"])

return results

def _run_tool(self, next_step: AgentStep, params: Optional[Dict[str, Any]] = None) -> str:
tool_name, tool_input = next_step.extract_tool_name_and_tool_input(self.tool_pattern)
if tool_name is None or tool_input is None:
raise AgentError(
f"Could not identify the next tool or input for that tool from Agent's output. "
f"Adjust the Agent's param 'tool_pattern' or 'prompt_template'. \n"
f"# 'tool_pattern' to identify next tool: {self.tool_pattern} \n"
f"# Agent Step:\n{next_step}"
)
if not self.has_tool(tool_name):
raise AgentError(
f"The tool {tool_name} wasn't added to the Agent tools: {self.tools.keys()}."
"Add the tool using `add_tool()` or include it in the parameter `tools` when initializing the Agent."
f"Agent Step::\n{next_step}"
)
tool_result: str = ""
tool: Tool = self.tools[tool_name]
try:
self.callback_manager.on_tool_start(tool_input, tool=tool)
tool_result = tool.run(tool_input, params)
self.callback_manager.on_tool_finish(
tool_result, observation_prefix="Observation: ", llm_prefix="Thought: ", color=tool.logging_color
)
except Exception as e:
self.callback_manager.on_tool_error(e, tool=self.tools[tool_name])
raise e
return tool_result

def _get_initial_transcript(self, query: str):
"""
Fills the Agent's PromptTemplate with the query, tool names, and descriptions.
Expand All @@ -395,7 +405,9 @@ def _get_initial_transcript(self, query: str):
"""
return next(
self.prompt_template.fill(
query=query, tool_names=self.tool_names, tool_names_with_descriptions=self.tool_names_with_descriptions
query=query,
tool_names=self.tm.get_tool_names(),
tool_names_with_descriptions=self.tm.get_tool_names_with_descriptions(),
),
"",
)
Loading

0 comments on commit d125138

Please sign in to comment.