Skip to content

Commit

Permalink
core[patch]: Deduplicate of callback handlers in merge_configs (#22478)
Browse files Browse the repository at this point in the history
This PR adds deduplication of callback handlers in merge_configs.

Fix for this issue:
#22227

The issue appears when the code is:

1) running python >=3.11
2) invokes a runnable from within a runnable
3) binds the callbacks to the child runnable from the parent runnable
using with_config

In this case, the same callbacks end up appearing twice: (1) the first
time from with_config, (2) the second time with langchain automatically
propagating them on behalf of the user.


Prior to this PR this will emit duplicate events:

```python
@tool
async def get_items(question: str, callbacks: Callbacks):  # <--- Accept callbacks
    """Ask question"""
    template = ChatPromptTemplate.from_messages(
        [
            (
                "human",
                "'{question}"
            )
        ]
    )
    chain = template | chat_model.with_config(
        {
            "callbacks": callbacks,  # <-- Propagate callbacks
        }
    )
    return await chain.ainvoke({"question": question})
```

Prior to this PR this will work work correctly (no duplicate events):

```python
@tool
async def get_items(question: str, callbacks: Callbacks):  # <--- Accept callbacks
    """Ask question"""
    template = ChatPromptTemplate.from_messages(
        [
            (
                "human",
                "'{question}"
            )
        ]
    )
    chain = template | chat_model
    return await chain.ainvoke({"question": question}, {"callbacks": callbacks})
```

This will also work (as long as the user is using python >= 3.11) -- as
langchain will automatically propagate callbacks

```python
@tool
async def get_items(question: str,):  
    """Ask question"""
    template = ChatPromptTemplate.from_messages(
        [
            (
                "human",
                "'{question}"
            )
        ]
    )
    chain = template | chat_model
    return await chain.ainvoke({"question": question})
```
  • Loading branch information
eyurtsev committed Jun 4, 2024
1 parent 64dbc52 commit 9120cf5
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 4 deletions.
22 changes: 18 additions & 4 deletions libs/core/langchain_core/runnables/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,12 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
base["callbacks"] = mngr
else:
# base_callbacks is also a manager
base["callbacks"] = base_callbacks.__class__(

manager = base_callbacks.__class__(
parent_run_id=base_callbacks.parent_run_id
or these_callbacks.parent_run_id,
handlers=base_callbacks.handlers + these_callbacks.handlers,
inheritable_handlers=base_callbacks.inheritable_handlers
+ these_callbacks.inheritable_handlers,
handlers=[],
inheritable_handlers=[],
tags=list(set(base_callbacks.tags + these_callbacks.tags)),
inheritable_tags=list(
set(
Expand All @@ -323,6 +323,20 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
**these_callbacks.metadata,
},
)

handlers = base_callbacks.handlers + these_callbacks.handlers
inheritable_handlers = (
base_callbacks.inheritable_handlers
+ these_callbacks.inheritable_handlers
)

for handler in handlers:
manager.add_handler(handler)

for handler in inheritable_handlers:
manager.add_handler(handler, inherit=True)

base["callbacks"] = manager
else:
base[key] = config[key] or base.get(key) # type: ignore
return base
Expand Down
31 changes: 31 additions & 0 deletions libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1876,3 +1876,34 @@ async def generator(inputs: AsyncIterator[str]) -> AsyncIterator[str]:
"tags": [],
},
]


async def test_with_explicit_config() -> None:
"""Test astream events with explicit callbacks being passed."""
infinite_cycle = cycle([AIMessage(content="hello world", id="ai3")])
model = GenericFakeChatModel(messages=infinite_cycle)

@tool
async def say_hello(query: str, callbacks: Callbacks) -> BaseMessage:
"""Use this tool to look up which items are in the given place."""

@RunnableLambda
def passthrough_to_trigger_issue(x: str) -> str:
"""Add passthrough to trigger issue."""
return x

chain = passthrough_to_trigger_issue | model.with_config(
{"tags": ["hello"], "callbacks": callbacks}
)

return await chain.ainvoke(query)

events = await _collect_events(
say_hello.astream_events("meow", version="v2") # type: ignore
)

assert [
event["data"]["chunk"].content
for event in events
if event["event"] == "on_chat_model_stream"
] == ["hello", " ", "world"]

0 comments on commit 9120cf5

Please sign in to comment.