Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional output_parser param in create_react_agent #18320

Merged
merged 4 commits into from
Feb 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions libs/langchain/langchain/agents/react/agent.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
from __future__ import annotations

from typing import Sequence
from typing import Optional, Sequence

from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.tools import BaseTool

from langchain.agents import AgentOutputParser
from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.output_parsers import ReActSingleInputOutputParser
from langchain.tools.render import render_text_description


def create_react_agent(
llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: BasePromptTemplate
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt: BasePromptTemplate,
output_parser: Optional[AgentOutputParser] = None,
) -> Runnable:
"""Create an agent that uses ReAct prompting.

Args:
llm: LLM to use as the agent.
tools: Tools this agent has access to.
prompt: The prompt to use. See Prompt section below for more.
output_parser: AgentOutputParser for parse the LLM output.

Returns:
A Runnable sequence representing an agent. It takes as input all the same input
Expand Down Expand Up @@ -101,12 +106,13 @@ def create_react_agent(
tool_names=", ".join([t.name for t in tools]),
)
llm_with_stop = llm.bind(stop=["\nObservation"])
output_parser = output_parser or ReActSingleInputOutputParser()
agent = (
RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),
)
| prompt
| llm_with_stop
| ReActSingleInputOutputParser()
| output_parser
)
return agent
Loading