-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
676 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import os | ||
|
||
from haystack.agents.conversational import ConversationalAgent | ||
from haystack.nodes import PromptNode | ||
|
||
pn = PromptNode("gpt-3.5-turbo", api_key=os.environ.get("OPENAI_API_KEY"), max_length=256) | ||
agent = ConversationalAgent(pn) | ||
|
||
while True: | ||
user_input = input("Human (type 'exit' or 'quit' to quit): ") | ||
if user_input.lower() == "exit" or user_input.lower() == "quit": | ||
break | ||
elif user_input.lower() == "memory": | ||
print("\nMemory:\n", agent.memory.load()) | ||
else: | ||
assistant_response = agent.run(user_input) | ||
print("\nAssistant:", assistant_response) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import os | ||
|
||
from haystack.agents import Tool | ||
from haystack.agents.base import ToolsManager | ||
from haystack.agents.conversational import ConversationalAgentWithTools | ||
from haystack.agents.types import Color | ||
from haystack.nodes import PromptNode, WebRetriever, PromptTemplate | ||
from haystack.pipelines import WebQAPipeline | ||
|
||
few_shot_prompt = """ | ||
In the following conversation, a human user interacts with an AI Agent using the ChatGPT API. The human user poses questions, and the AI Agent goes through several steps to provide well-informed answers. | ||
If the AI Agent knows the answer, the response begins with "Final Answer:" on a new line. | ||
If the AI Agent is uncertain or concerned that the information may be outdated or inaccurate, it must use the available tools to find the most up-to-date information. The AI has access to these tools: | ||
{tool_names_with_descriptions} | ||
AI Agent responses must start with one of the following: | ||
Thought: [AI Agent's reasoning process] | ||
Tool: [{tool_names}] (on a new line) Tool Input: [input for the selected tool WITHOUT quotation marks and on a new line] (These must always be provided together and on separate lines.) | ||
Final Answer: [final answer to the human user's question] | ||
When selecting a tool, the AI Agent must provide both the "Tool:" and "Tool Input:" pair in the same response, but on separate lines. "Observation:" marks the beginning of a tool's result, and the AI Agent trusts these results. | ||
The AI Agent must use the conversation_history tool to infer context when necessary. | ||
If a question is vague or requires context, the AI Agent should explicitly use the conversation_history tool with a clear Tool Input focused on finding the relevant context. | ||
The AI Agent should not ask the human user for additional information, clarification, or context. | ||
If the AI Agent cannot find a specific answer after exhausting available tools and approaches, it answers with Final Answer: inconclusive | ||
Question: {query} | ||
Thought: | ||
{transcript} | ||
""" | ||
search_key = os.environ.get("SERPERDEV_API_KEY") | ||
if not search_key: | ||
raise ValueError("Please set the SERPERDEV_API_KEY environment variable") | ||
|
||
openai_key = os.environ.get("OPENAI_API_KEY") | ||
if not search_key: | ||
raise ValueError("Please set the OPENAI_API_KEY environment variable") | ||
|
||
prompt_text = """ | ||
Synthesize a comprehensive answer from the following most relevant paragraphs and the given question. | ||
Provide a clear and concise answer, no longer than 10-20 words. | ||
\n\n Paragraphs: {documents} \n\n Question: {query} \n\n Answer: | ||
""" | ||
|
||
prompt_node = PromptNode( | ||
"gpt-3.5-turbo", | ||
default_prompt_template=PromptTemplate("lfqa", prompt_text=prompt_text), | ||
api_key=openai_key, | ||
max_length=256, | ||
) | ||
|
||
web_retriever = WebRetriever(api_key=search_key, top_search_results=3, mode="snippets") | ||
pipeline = WebQAPipeline(retriever=web_retriever, prompt_node=prompt_node) | ||
|
||
few_shot_agent = PromptTemplate("conversational-agent-with-tools", prompt_text=few_shot_prompt) | ||
conversation_history = Tool( | ||
name="conversation_history", | ||
pipeline_or_node=lambda tool_input, **kwargs: agent.memory.load(), | ||
description="useful for when you need to remember what you've already discussed.", | ||
logging_color=Color.MAGENTA, | ||
) | ||
web_qa_tool = Tool( | ||
name="Search", | ||
pipeline_or_node=pipeline, | ||
description="useful for when you need to Google questions.", | ||
output_variable="results", | ||
) | ||
pn = PromptNode( | ||
"gpt-3.5-turbo", | ||
api_key=os.environ.get("OPENAI_API_KEY"), | ||
max_length=256, | ||
stop_words=["Observation:"], | ||
model_kwargs={"temperature": 0.5, "top_p": 0.9}, | ||
) | ||
|
||
agent = ConversationalAgentWithTools( | ||
pn, | ||
max_steps=6, | ||
prompt_template=few_shot_agent, | ||
tools_manager=ToolsManager(tools=[web_qa_tool, conversation_history]), | ||
) | ||
|
||
test = False | ||
if test: | ||
questions = [ | ||
"Why was Jamie Foxx recently hospitalized?", | ||
"Where was he hospitalized?", | ||
"What movie was he filming at the time?", | ||
"Who is Jamie's female co-star in the movie he was filing at that time?", | ||
"Tell me more about her, who is her partner?", | ||
] | ||
for question in questions: | ||
agent.run(question) | ||
else: | ||
while True: | ||
user_input = input("\nHuman (type 'exit' or 'quit' to quit): ") | ||
if user_input.lower() == "exit" or user_input.lower() == "quit": | ||
break | ||
if user_input.lower() == "memory": | ||
print("\nMemory:\n", agent.memory.load()) | ||
else: | ||
assistant_response = agent.run(user_input) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import re | ||
from abc import abstractmethod, ABC | ||
from typing import Any | ||
|
||
|
||
class AgentAnswerParser(ABC): | ||
""" | ||
Abstract base class for parsing agent's answer. | ||
""" | ||
|
||
@abstractmethod | ||
def can_parse(self, parser_input: Any) -> bool: | ||
""" | ||
Check if the parser can parse the input. | ||
:param parser_input: The input to parse. | ||
:return: True if the parser can parse the input, False otherwise. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def parse(self, parser_input: Any) -> str: | ||
""" | ||
Parse the input. | ||
:param parser_input: The input to parse. | ||
:return: The parsed input. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
class RegexAnswerParser(AgentAnswerParser): | ||
""" | ||
Parser that uses a regex to parse the agent's answer. | ||
""" | ||
|
||
def __init__(self, final_answer_pattern: str): | ||
self.pattern = final_answer_pattern | ||
|
||
def can_parse(self, parser_input: Any) -> bool: | ||
if isinstance(parser_input, str): | ||
return bool(re.search(self.pattern, parser_input)) | ||
return False | ||
|
||
def parse(self, parser_input: Any) -> str: | ||
if self.can_parse(parser_input): | ||
final_answer_match = re.search(self.pattern, parser_input) | ||
if final_answer_match: | ||
final_answer = final_answer_match.group(1) | ||
return final_answer.strip('" ') # type: ignore | ||
return "" | ||
|
||
|
||
class BasicAnswerParser(AgentAnswerParser): | ||
""" | ||
Parser that returns the input if it is a non-empty string. | ||
""" | ||
|
||
def can_parse(self, parser_input: Any) -> bool: | ||
return isinstance(parser_input, str) and parser_input | ||
|
||
def parse(self, parser_input: Any) -> str: | ||
if self.can_parse(parser_input): | ||
return parser_input | ||
return "" |
Oops, something went wrong.