Skip to content

Commit

Permalink
feat: Add conversational agent (#4931)
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed May 17, 2023
1 parent ca68601 commit 9d52998
Show file tree
Hide file tree
Showing 11 changed files with 398 additions and 102 deletions.
1 change: 1 addition & 0 deletions examples/agent_multihop_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
##
Question: {query}
Thought:
{transcript}
"""
few_shot_agent_template = PromptTemplate("few-shot-react", prompt_text=few_shot_prompt)
prompt_node = PromptNode(
Expand Down
17 changes: 17 additions & 0 deletions examples/conversational_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import os

from haystack.agents.conversational import ConversationalAgent
from haystack.nodes import PromptNode

pn = PromptNode("gpt-3.5-turbo", api_key=os.environ.get("OPENAI_API_KEY"), max_length=256)
agent = ConversationalAgent(pn)

while True:
user_input = input("Human (type 'exit' or 'quit' to quit, 'memory' for agent's memory): ")
if user_input.lower() == "exit" or user_input.lower() == "quit":
break
if user_input.lower() == "memory":
print("\nMemory:\n", agent.memory.load())
else:
assistant_response = agent.run(user_input)
print("\nAssistant:", assistant_response)
100 changes: 52 additions & 48 deletions haystack/agents/agent_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,49 +20,38 @@ def __init__(
self,
current_step: int = 1,
max_steps: int = 10,
final_answer_pattern: str = r"Final Answer\s*:\s*(.*)",
final_answer_pattern: Optional[str] = None,
prompt_node_response: str = "",
transcript: str = "",
):
"""
:param current_step: The current step in the execution of the agent.
:param max_steps: The maximum number of steps the agent can execute.
:param final_answer_pattern: The regex pattern to extract the final answer from the PromptNode response.
:param final_answer_pattern: The regex pattern to extract the final answer from the PromptNode response. If no
pattern is provided, entire prompt node response is considered the final answer.
:param prompt_node_response: The PromptNode response received.
:param transcript: The full Agent execution transcript based on the Agent's initial prompt template and the
text it generated during execution up to this step. The transcript is used to generate the next prompt.
"""
self.current_step = current_step
self.max_steps = max_steps
self.final_answer_pattern = final_answer_pattern
self.final_answer_pattern = final_answer_pattern or r"^([\s\S]+)$"
self.prompt_node_response = prompt_node_response
self.transcript = transcript

def prepare_prompt(self):
"""
Prepares the prompt for the next step.
"""
return self.transcript

def create_next_step(self, prompt_node_response: Any) -> AgentStep:
def create_next_step(self, prompt_node_response: Any, current_step: Optional[int] = None) -> AgentStep:
"""
Creates the next agent step based on the current step and the PromptNode response.
:param prompt_node_response: The PromptNode response received.
:param current_step: The current step in the execution of the agent.
"""
if not isinstance(prompt_node_response, list):
if not isinstance(prompt_node_response, list) or not prompt_node_response:
raise AgentError(
f"Agent output must be a list of str, but {prompt_node_response} received. "
f"Agent output must be a non-empty list of str, but {prompt_node_response} received. "
f"Transcript:\n{self.transcript}"
)

if not prompt_node_response:
raise AgentError(
f"Agent output must be a non empty list of str, but {prompt_node_response} received. "
f"Transcript:\n{self.transcript}"
)

return AgentStep(
current_step=self.current_step + 1,
cls = type(self)
return cls(
current_step=current_step if current_step else self.current_step + 1,
max_steps=self.max_steps,
final_answer_pattern=self.final_answer_pattern,
prompt_node_response=prompt_node_response[0],
Expand All @@ -81,18 +70,18 @@ def final_answer(self, query: str) -> Dict[str, Any]:
"answers": [Answer(answer="", type="generative")],
"transcript": self.transcript,
}
if self.current_step >= self.max_steps:
if self.current_step > self.max_steps:
logger.warning(
"Maximum number of iterations (%s) reached for query (%s). Increase max_steps "
"or no answer can be provided for this query.",
self.max_steps,
query,
)
else:
final_answer = self.extract_final_answer()
final_answer = self.parse_final_answer()
if not final_answer:
logger.warning(
"Final answer pattern (%s) not found in PromptNode response (%s).",
"Final answer parser (%s) could not parse PromptNode response (%s).",
self.final_answer_pattern,
self.prompt_node_response,
)
Expand All @@ -104,35 +93,14 @@ def final_answer(self, query: str) -> Dict[str, Any]:
}
return answer

def extract_final_answer(self) -> Optional[str]:
"""
Parse the final answer from the PromptNode response.
:return: The final answer.
"""
if not self.is_last():
raise AgentError("Cannot extract final answer from non terminal step.")

final_answer_match = re.search(self.final_answer_pattern, self.prompt_node_response)
if final_answer_match:
final_answer = final_answer_match.group(1)
return final_answer.strip('" ')
return None

def is_final_answer_pattern_found(self) -> bool:
"""
Check if the final answer pattern was found in PromptNode response.
:return: True if the final answer pattern was found in PromptNode response, False otherwise.
"""
return bool(re.search(self.final_answer_pattern, self.prompt_node_response))

def is_last(self) -> bool:
"""
Check if this is the last step of the Agent.
:return: True if this is the last step of the Agent, False otherwise.
"""
return self.is_final_answer_pattern_found() or self.current_step >= self.max_steps
return bool(self.parse_final_answer()) or self.current_step > self.max_steps

def completed(self, observation: Optional[str]):
def completed(self, observation: Optional[str]) -> None:
"""
Update the transcript with the observation
:param observation: received observation from the Agent environment.
Expand All @@ -142,3 +110,39 @@ def completed(self, observation: Optional[str]):
if observation
else self.prompt_node_response
)

def __repr__(self) -> str:
"""
Return a string representation of the AgentStep object.
:return: A string that represents the AgentStep object.
"""
return (
f"AgentStep(current_step={self.current_step}, max_steps={self.max_steps}, "
f"prompt_node_response={self.prompt_node_response}, final_answer_pattern={self.final_answer_pattern}, "
f"transcript={self.transcript})"
)

def parse_final_answer(self) -> Optional[str]:
"""
Parse the final answer from the response of the prompt node.
This function searches the prompt node's response for a match with the
pre-defined final answer pattern. If a match is found, it's returned as the
final answer after removing leading/trailing quotes and whitespaces.
If no match is found, it returns None.
:return: The final answer as a string if a match is found, otherwise None.
"""
# Search for a match with the final answer pattern in the prompt node response
final_answer_match = re.search(self.final_answer_pattern, self.prompt_node_response)

if final_answer_match:
# If a match is found, get the first group (i.e., the content inside the parentheses of the regex pattern)
final_answer = final_answer_match.group(1)

# Remove leading/trailing quotes and whitespaces, then return the final answer
return final_answer.strip('" ') # type: ignore
else:
# If no match is found, return None
return None
100 changes: 60 additions & 40 deletions haystack/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@

import logging
import re
from collections.abc import Iterable, Callable
from hashlib import md5
from typing import List, Optional, Union, Dict, Any, Tuple

from events import Events

from haystack import Pipeline, BaseComponent, Answer, Document
from haystack.agents.memory import Memory, NoMemory
from haystack.telemetry import send_event
from haystack.agents.agent_step import AgentStep
from haystack.agents.types import Color, AgentTokenStreamingHandler
from haystack.agents.utils import print_text, STREAMING_CAPABLE_MODELS
from haystack.errors import AgentError
from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate
from haystack.pipelines import (
BaseStandardPipeline,
Expand Down Expand Up @@ -221,8 +222,10 @@ class Agent:
def __init__(
self,
prompt_node: PromptNode,
prompt_template: Union[str, PromptTemplate] = "zero-shot-react",
prompt_template: Optional[Union[str, PromptTemplate]] = None,
tools_manager: Optional[ToolsManager] = None,
memory: Optional[Memory] = None,
prompt_parameters_resolver: Optional[Callable] = None,
max_steps: int = 8,
final_answer_pattern: str = r"Final Answer\s*:\s*(.*)",
):
Expand All @@ -235,24 +238,40 @@ def __init__(
choosing tools to answer queries step-by-step. You can use the default `zero-shot-react` template or create a
new template in a similar format.
with `add_tool()` before running the Agent.
:param tools: A list of tools to add to the Agent. Each tool must have a unique name. You can also add tools
with `add_tool()` before running the Agent.
:param tools_manager: A ToolsManager instance that the Agent uses to run tools. Each tool must have a unique name.
You can also add tools with `add_tool()` before running the Agent.
:param memory: A Memory instance that the Agent uses to store information between iterations.
:param prompt_parameters_resolver: A callable that takes query, agent, and agent_step as parameters and returns
a dictionary of parameters to pass to the prompt_template. The default is a callable that returns a dictionary
of keys and values needed for the React agent prompt template.
:param max_steps: The number of times the Agent can run a tool +1 to let it infer it knows the final answer.
Set it to at least 2, so that the Agent can run one a tool once and then infer it knows the final answer.
The default is 5.
text the Agent generated.
The default is 8.
:param final_answer_pattern: A regular expression to extract the final answer from the text the Agent generated.
"""
self.max_steps = max_steps
self.tm = tools_manager or ToolsManager()
self.memory = memory or NoMemory()
self.callback_manager = Events(("on_agent_start", "on_agent_step", "on_agent_finish", "on_new_token"))
self.prompt_node = prompt_node
prompt_template = prompt_template or "zero-shot-react"
resolved_prompt_template = prompt_node.get_prompt_template(prompt_template)
if not resolved_prompt_template:
raise ValueError(
f"Prompt template '{prompt_template}' not found. Please check the spelling of the template name."
)
self.prompt_template = resolved_prompt_template
react_parameter_resolver: Callable[
[str, Agent, AgentStep, Dict[str, Any]], Dict[str, Any]
] = lambda query, agent, agent_step, **kwargs: {
"query": query,
"tool_names": agent.tm.get_tool_names(),
"tool_names_with_descriptions": agent.tm.get_tool_names_with_descriptions(),
"transcript": agent_step.transcript,
}
self.prompt_parameters_resolver = (
prompt_parameters_resolver if prompt_parameters_resolver else react_parameter_resolver
)
self.final_answer_pattern = final_answer_pattern
# Resolve model name to check if it's a streaming model
if isinstance(self.prompt_node.model_name_or_path, str):
Expand Down Expand Up @@ -350,37 +369,18 @@ def run(
except Exception as exc:
logger.debug("Telemetry exception: %s", exc)

if max_steps is None:
max_steps = self.max_steps
if max_steps < 2:
raise AgentError(
f"max_steps must be at least 2 to let the Agent use a tool once and then infer it knows the final "
f"answer. It was set to {max_steps}."
)
self.callback_manager.on_agent_start(name=self.prompt_template.name, query=query, params=params)
agent_step = self._create_first_step(query, max_steps)
agent_step = self.create_agent_step(max_steps)
try:
while not agent_step.is_last():
agent_step = self._step(agent_step, params)
agent_step = self._step(query, agent_step, params)
finally:
self.callback_manager.on_agent_finish(agent_step)
return agent_step.final_answer(query=query)

def _create_first_step(self, query: str, max_steps: int = 10):
transcript = self._get_initial_transcript(query=query)
return AgentStep(
current_step=1,
max_steps=max_steps,
final_answer_pattern=self.final_answer_pattern,
prompt_node_response="", # no LLM response for the first step
transcript=transcript,
)

def _step(self, current_step: AgentStep, params: Optional[dict] = None):
def _step(self, query: str, current_step: AgentStep, params: Optional[dict] = None):
# plan next step using the LLM
prompt_node_response = self.prompt_node(
current_step.prepare_prompt(), stream_handler=AgentTokenStreamingHandler(self.callback_manager)
)
prompt_node_response = self._plan(query, current_step)

# from the LLM response, create the next step
next_step = current_step.create_next_step(prompt_node_response)
Expand All @@ -389,21 +389,41 @@ def _step(self, current_step: AgentStep, params: Optional[dict] = None):
# run the tool selected by the LLM
observation = self.tm.run_tool(next_step.prompt_node_response, params) if not next_step.is_last() else None

# save the input, output and observation to memory (if memory is enabled)
memory_data = self.prepare_data_for_memory(input=query, output=prompt_node_response, observation=observation)
self.memory.save(data=memory_data)

# update the next step with the observation
next_step.completed(observation)
return next_step

def _get_initial_transcript(self, query: str):
def _plan(self, query, current_step):
# first resolve prompt template params
template_params = self.prompt_parameters_resolver(query=query, agent=self, agent_step=current_step)

# if prompt node has no default prompt template, use agent's prompt template
if self.prompt_node.default_prompt_template is None:
prepared_prompt = next(self.prompt_template.fill(**template_params))
prompt_node_response = self.prompt_node(
prepared_prompt, stream_handler=AgentTokenStreamingHandler(self.callback_manager)
)
# otherwise, if prompt node has default prompt template, use it
else:
prompt_node_response = self.prompt_node(
stream_handler=AgentTokenStreamingHandler(self.callback_manager), **template_params
)
return prompt_node_response

def create_agent_step(self, max_steps: Optional[int] = None) -> AgentStep:
"""
Fills the Agent's PromptTemplate with the query, tool names, and descriptions.
Create an AgentStep object. Override this method to customize the AgentStep class used by the Agent.
"""
return AgentStep(max_steps=max_steps or self.max_steps, final_answer_pattern=self.final_answer_pattern)

:param query: The search query.
def prepare_data_for_memory(self, **kwargs) -> dict:
"""
return next(
self.prompt_template.fill(
query=query,
tool_names=self.tm.get_tool_names(),
tool_names_with_descriptions=self.tm.get_tool_names_with_descriptions(),
),
"",
)
Prepare data for saving to the Agent's memory. Override this method to customize the data saved to the memory.
"""
return {
k: v if isinstance(v, str) else next(iter(v)) for k, v in kwargs.items() if isinstance(v, (str, Iterable))
}
Loading

0 comments on commit 9d52998

Please sign in to comment.