Skip to content

Commit

Permalink
PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed May 22, 2023
1 parent 168042c commit e631778
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 21 deletions.
21 changes: 13 additions & 8 deletions haystack/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,19 @@ def extract_tool_name_and_tool_input(self, llm_response: str) -> Tuple[Optional[
return None, None


def react_parameter_resolver(query: str, agent: Agent, agent_step: AgentStep, **kwargs) -> Dict[str, Any]:
"""
A parameter resolver for ReAct based agents that returns the query, the tool names, the tool names
with descriptions, and the transcript (internal monologue).
"""
return {
"query": query,
"tool_names": agent.tm.get_tool_names(),
"tool_names_with_descriptions": agent.tm.get_tool_names_with_descriptions(),
"transcript": agent_step.transcript,
}


class Agent:
"""
An Agent answers queries using the tools you give to it. The tools are pipelines or nodes. The Agent uses a large
Expand Down Expand Up @@ -268,14 +281,6 @@ def __init__(
f"Prompt template '{prompt_template}' not found. Please check the spelling of the template name."
)
self.prompt_template = resolved_prompt_template
react_parameter_resolver: Callable[
[str, Agent, AgentStep, Dict[str, Any]], Dict[str, Any]
] = lambda query, agent, agent_step, **kwargs: {
"query": query,
"tool_names": agent.tm.get_tool_names(),
"tool_names_with_descriptions": agent.tm.get_tool_names_with_descriptions(),
"transcript": agent_step.transcript,
}
self.prompt_parameters_resolver = (
prompt_parameters_resolver if prompt_parameters_resolver else react_parameter_resolver
)
Expand Down
9 changes: 2 additions & 7 deletions haystack/agents/conversational.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, Callable, Union

from haystack.agents import Agent
from haystack.agents.base import ToolsManager
from haystack.agents.base import ToolsManager, react_parameter_resolver
from haystack.agents.memory import Memory, ConversationMemory, ConversationSummaryMemory
from haystack.nodes import PromptNode, PromptTemplate

Expand Down Expand Up @@ -130,11 +130,6 @@ def __init__(
memory=memory if memory else ConversationSummaryMemory(prompt_node),
prompt_parameters_resolver=prompt_parameters_resolver
if prompt_parameters_resolver
else lambda query, agent, agent_step, **kwargs: {
"query": query,
"tool_names": agent.tm.get_tool_names(),
"tool_names_with_descriptions": agent.tm.get_tool_names_with_descriptions(),
"transcript": agent_step.transcript,
},
else react_parameter_resolver,
final_answer_pattern=final_answer_pattern,
)
36 changes: 30 additions & 6 deletions test/agents/test_conversational_agent_with_tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from unittest.mock import MagicMock
from unittest.mock import MagicMock, Mock, call

from haystack.agents import Tool
from haystack.agents import Tool, Agent, AgentStep
from haystack.agents.conversational import ConversationalAgentWithTools
from haystack.agents.memory import ConversationSummaryMemory, NoMemory
from haystack.nodes import PromptNode
Expand Down Expand Up @@ -49,9 +49,33 @@ def test_agent_with_memory(prompt_node):

@pytest.mark.unit
def test_run(prompt_node):
"""
Test that the invocation of ConversationalAgentWithTools run method in turn invokes _step of the Agent superclass
Make sure that the agent is starting from the correct step 1, and max_steps is 5
"""
mock_step = Mock(spec=Agent._step)

# Replace the original _step method with the mock
Agent._step = mock_step

# Initialize agent
prompt_node = PromptNode()
agent = ConversationalAgentWithTools(prompt_node)

# Mock the Agent run method
agent.run = MagicMock(return_value="Hello")
assert agent.run("query") == "Hello"
agent.run.assert_called_once_with("query")
# Run agent
agent.run(query="query")

assert mock_step.call_count == 1

# Check the parameters passed to _step method
assert mock_step.call_args[0][0] == "query"
agent_step = mock_step.call_args[0][1]
expected_agent_step = AgentStep(
current_step=1,
max_steps=5,
prompt_node_response="",
final_answer_pattern=r"Final Answer\s*:\s*(.*)",
transcript="",
)
# compare the string representation of the objects
assert str(agent_step) == str(expected_agent_step)

0 comments on commit e631778

Please sign in to comment.