diff --git a/langchain/agents/initialize.py b/langchain/agents/initialize.py index cb26fb630a3426..8b4ff608f70ccb 100644 --- a/langchain/agents/initialize.py +++ b/langchain/agents/initialize.py @@ -51,7 +51,7 @@ def initialize_agent( f"Got unknown agent type: {agent}. " f"Valid types are: {AGENT_TO_CLASS.keys()}." ) - tags_.append(agent.value) + tags_.append(agent.value if isinstance(agent, AgentType) else agent) agent_cls = AGENT_TO_CLASS[agent] agent_kwargs = agent_kwargs or {} agent_obj = agent_cls.from_llm_and_tools( diff --git a/tests/unit_tests/agents/test_initialize.py b/tests/unit_tests/agents/test_initialize.py new file mode 100644 index 00000000000000..04d3de9e20ac5f --- /dev/null +++ b/tests/unit_tests/agents/test_initialize.py @@ -0,0 +1,23 @@ +"""Test the initialize module.""" + +from langchain.agents.agent_types import AgentType +from langchain.agents.initialize import initialize_agent +from langchain.tools.base import tool +from tests.unit_tests.llms.fake_llm import FakeLLM + + +@tool +def my_tool(query: str) -> str: + """A fake tool.""" + return "fake tool" + + +def test_initialize_agent_with_str_agent_type() -> None: + """Test initialize_agent with a string.""" + fake_llm = FakeLLM() + agent_executor = initialize_agent( + [my_tool], fake_llm, "zero-shot-react-description" # type: ignore + ) + assert agent_executor.agent._agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION + assert isinstance(agent_executor.tags, list) + assert "zero-shot-react-description" in agent_executor.tags