Skip to content

Commit

Permalink
Add conversational agents, memory
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed May 3, 2023
1 parent 8091ced commit 6875dbc
Show file tree
Hide file tree
Showing 10 changed files with 676 additions and 90 deletions.
17 changes: 17 additions & 0 deletions examples/conversational_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import os

from haystack.agents.conversational import ConversationalAgent
from haystack.nodes import PromptNode

pn = PromptNode("gpt-3.5-turbo", api_key=os.environ.get("OPENAI_API_KEY"), max_length=256)
agent = ConversationalAgent(pn)

while True:
user_input = input("Human (type 'exit' or 'quit' to quit): ")
if user_input.lower() == "exit" or user_input.lower() == "quit":
break
elif user_input.lower() == "memory":
print("\nMemory:\n", agent.memory.load())
else:
assistant_response = agent.run(user_input)
print("\nAssistant:", assistant_response)
106 changes: 106 additions & 0 deletions examples/conversational_agent_with_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os

from haystack.agents import Tool
from haystack.agents.base import ToolsManager
from haystack.agents.conversational import ConversationalAgentWithTools
from haystack.agents.types import Color
from haystack.nodes import PromptNode, WebRetriever, PromptTemplate
from haystack.pipelines import WebQAPipeline

few_shot_prompt = """
In the following conversation, a human user interacts with an AI Agent using the ChatGPT API. The human user poses questions, and the AI Agent goes through several steps to provide well-informed answers.
If the AI Agent knows the answer, the response begins with "Final Answer:" on a new line.
If the AI Agent is uncertain or concerned that the information may be outdated or inaccurate, it must use the available tools to find the most up-to-date information. The AI has access to these tools:
{tool_names_with_descriptions}
AI Agent responses must start with one of the following:
Thought: [AI Agent's reasoning process]
Tool: [{tool_names}] (on a new line) Tool Input: [input for the selected tool WITHOUT quotation marks and on a new line] (These must always be provided together and on separate lines.)
Final Answer: [final answer to the human user's question]
When selecting a tool, the AI Agent must provide both the "Tool:" and "Tool Input:" pair in the same response, but on separate lines. "Observation:" marks the beginning of a tool's result, and the AI Agent trusts these results.
The AI Agent must use the conversation_history tool to infer context when necessary.
If a question is vague or requires context, the AI Agent should explicitly use the conversation_history tool with a clear Tool Input focused on finding the relevant context.
The AI Agent should not ask the human user for additional information, clarification, or context.
If the AI Agent cannot find a specific answer after exhausting available tools and approaches, it answers with Final Answer: inconclusive
Question: {query}
Thought:
{transcript}
"""
search_key = os.environ.get("SERPERDEV_API_KEY")
if not search_key:
raise ValueError("Please set the SERPERDEV_API_KEY environment variable")

openai_key = os.environ.get("OPENAI_API_KEY")
if not search_key:
raise ValueError("Please set the OPENAI_API_KEY environment variable")

prompt_text = """
Synthesize a comprehensive answer from the following most relevant paragraphs and the given question.
Provide a clear and concise answer, no longer than 10-20 words.
\n\n Paragraphs: {documents} \n\n Question: {query} \n\n Answer:
"""

prompt_node = PromptNode(
"gpt-3.5-turbo",
default_prompt_template=PromptTemplate("lfqa", prompt_text=prompt_text),
api_key=openai_key,
max_length=256,
)

web_retriever = WebRetriever(api_key=search_key, top_search_results=3, mode="snippets")
pipeline = WebQAPipeline(retriever=web_retriever, prompt_node=prompt_node)

few_shot_agent = PromptTemplate("conversational-agent-with-tools", prompt_text=few_shot_prompt)
conversation_history = Tool(
name="conversation_history",
pipeline_or_node=lambda tool_input, **kwargs: agent.memory.load(),
description="useful for when you need to remember what you've already discussed.",
logging_color=Color.MAGENTA,
)
web_qa_tool = Tool(
name="Search",
pipeline_or_node=pipeline,
description="useful for when you need to Google questions.",
output_variable="results",
)
pn = PromptNode(
"gpt-3.5-turbo",
api_key=os.environ.get("OPENAI_API_KEY"),
max_length=256,
stop_words=["Observation:"],
model_kwargs={"temperature": 0.5, "top_p": 0.9},
)

agent = ConversationalAgentWithTools(
pn,
max_steps=6,
prompt_template=few_shot_agent,
tools_manager=ToolsManager(tools=[web_qa_tool, conversation_history]),
)

test = False
if test:
questions = [
"Why was Jamie Foxx recently hospitalized?",
"Where was he hospitalized?",
"What movie was he filming at the time?",
"Who is Jamie's female co-star in the movie he was filing at that time?",
"Tell me more about her, who is her partner?",
]
for question in questions:
agent.run(question)
else:
while True:
user_input = input("\nHuman (type 'exit' or 'quit' to quit): ")
if user_input.lower() == "exit" or user_input.lower() == "quit":
break
if user_input.lower() == "memory":
print("\nMemory:\n", agent.memory.load())
else:
assistant_response = agent.run(user_input)
73 changes: 23 additions & 50 deletions haystack/agents/agent_step.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

import logging
import re
from typing import Optional, Dict, Any

from haystack import Answer
from haystack.agents.answer_parser import AgentAnswerParser
from haystack.errors import AgentError

logger = logging.getLogger(__name__)
Expand All @@ -20,51 +20,39 @@ def __init__(
self,
current_step: int = 1,
max_steps: int = 10,
final_answer_pattern: str = r"Final Answer\s*:\s*(.*)",
final_answer_parser: AgentAnswerParser = None,
prompt_node_response: str = "",
transcript: str = "",
):
"""
:param current_step: The current step in the execution of the agent.
:param max_steps: The maximum number of steps the agent can execute.
:param final_answer_pattern: The regex pattern to extract the final answer from the PromptNode response.
:param final_answer_parser: AgentAnswerParser to extract the final answer from the PromptNode response.
:param prompt_node_response: The PromptNode response received.
:param transcript: The full Agent execution transcript based on the Agent's initial prompt template and the
text it generated during execution up to this step. The transcript is used to generate the next prompt.
"""
self.current_step = current_step
self.max_steps = max_steps
self.final_answer_pattern = final_answer_pattern
self.final_answer_parser = final_answer_parser
self.prompt_node_response = prompt_node_response
self.transcript = transcript

def prepare_prompt(self):
"""
Prepares the prompt for the next step.
"""
return self.transcript

def create_next_step(self, prompt_node_response: Any) -> AgentStep:
def create_next_step(self, prompt_node_response: Any, current_step: Optional[int] = None) -> AgentStep:
"""
Creates the next agent step based on the current step and the PromptNode response.
:param prompt_node_response: The PromptNode response received.
:param current_step: The current step in the execution of the agent.
"""
if not isinstance(prompt_node_response, list):
if not isinstance(prompt_node_response, list) or not prompt_node_response:
raise AgentError(
f"Agent output must be a list of str, but {prompt_node_response} received. "
f"Agent output must be a non-empty list of str, but {prompt_node_response} received. "
f"Transcript:\n{self.transcript}"
)

if not prompt_node_response:
raise AgentError(
f"Agent output must be a non empty list of str, but {prompt_node_response} received. "
f"Transcript:\n{self.transcript}"
)

return AgentStep(
current_step=self.current_step + 1,
cls = type(self)
return cls(
current_step=current_step if current_step else self.current_step + 1,
max_steps=self.max_steps,
final_answer_pattern=self.final_answer_pattern,
final_answer_parser=self.final_answer_parser,
prompt_node_response=prompt_node_response[0],
transcript=self.transcript,
)
Expand All @@ -81,19 +69,19 @@ def final_answer(self, query: str) -> Dict[str, Any]:
"answers": [Answer(answer="", type="generative")],
"transcript": self.transcript,
}
if self.current_step >= self.max_steps:
if self.current_step > self.max_steps:
logger.warning(
"Maximum number of iterations (%s) reached for query (%s). Increase max_steps "
"or no answer can be provided for this query.",
self.max_steps,
query,
)
else:
final_answer = self.extract_final_answer()
final_answer = self.final_answer_parser.parse(self.prompt_node_response)
if not final_answer:
logger.warning(
"Final answer pattern (%s) not found in PromptNode response (%s).",
self.final_answer_pattern,
"Final answer parser (%s) could not parse PromptNode response (%s).",
self.final_answer_parser,
self.prompt_node_response,
)
else:
Expand All @@ -104,33 +92,12 @@ def final_answer(self, query: str) -> Dict[str, Any]:
}
return answer

def extract_final_answer(self) -> Optional[str]:
"""
Parse the final answer from the PromptNode response.
:return: The final answer.
"""
if not self.is_last():
raise AgentError("Cannot extract final answer from non terminal step.")

final_answer_match = re.search(self.final_answer_pattern, self.prompt_node_response)
if final_answer_match:
final_answer = final_answer_match.group(1)
return final_answer.strip('" ')
return None

def is_final_answer_pattern_found(self) -> bool:
"""
Check if the final answer pattern was found in PromptNode response.
:return: True if the final answer pattern was found in PromptNode response, False otherwise.
"""
return bool(re.search(self.final_answer_pattern, self.prompt_node_response))

def is_last(self) -> bool:
"""
Check if this is the last step of the Agent.
:return: True if this is the last step of the Agent, False otherwise.
"""
return self.is_final_answer_pattern_found() or self.current_step >= self.max_steps
return self.final_answer_parser.can_parse(self.prompt_node_response) or self.current_step > self.max_steps

def completed(self, observation: Optional[str]):
"""
Expand All @@ -142,3 +109,9 @@ def completed(self, observation: Optional[str]):
if observation
else self.prompt_node_response
)

def __repr__(self):
return (
f"AgentStep(current_step={self.current_step}, max_steps={self.max_steps}, "
f"prompt_node_response={self.prompt_node_response}, transcript={self.transcript})"
)
63 changes: 63 additions & 0 deletions haystack/agents/answer_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import re
from abc import abstractmethod, ABC
from typing import Any


class AgentAnswerParser(ABC):
"""
Abstract base class for parsing agent's answer.
"""

@abstractmethod
def can_parse(self, parser_input: Any) -> bool:
"""
Check if the parser can parse the input.
:param parser_input: The input to parse.
:return: True if the parser can parse the input, False otherwise.
"""
raise NotImplementedError

@abstractmethod
def parse(self, parser_input: Any) -> str:
"""
Parse the input.
:param parser_input: The input to parse.
:return: The parsed input.
"""
raise NotImplementedError


class RegexAnswerParser(AgentAnswerParser):
"""
Parser that uses a regex to parse the agent's answer.
"""

def __init__(self, final_answer_pattern: str):
self.pattern = final_answer_pattern

def can_parse(self, parser_input: Any) -> bool:
if isinstance(parser_input, str):
return bool(re.search(self.pattern, parser_input))
return False

def parse(self, parser_input: Any) -> str:
if self.can_parse(parser_input):
final_answer_match = re.search(self.pattern, parser_input)
if final_answer_match:
final_answer = final_answer_match.group(1)
return final_answer.strip('" ') # type: ignore
return ""


class BasicAnswerParser(AgentAnswerParser):
"""
Parser that returns the input if it is a non-empty string.
"""

def can_parse(self, parser_input: Any) -> bool:
return isinstance(parser_input, str) and parser_input

def parse(self, parser_input: Any) -> str:
if self.can_parse(parser_input):
return parser_input
return ""
Loading

0 comments on commit 6875dbc

Please sign in to comment.