Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[patch]: fix no current event loop for sql history in async mode #22933

Merged
merged 1 commit into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions libs/community/langchain_community/chat_message_histories/sql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import contextlib
import json
import logging
Expand Down Expand Up @@ -252,17 +251,11 @@ async def aadd_message(self, message: BaseMessage) -> None:
await session.commit()

def add_messages(self, messages: Sequence[BaseMessage]) -> None:
# The method RunnableWithMessageHistory._exit_history() call
# add_message method by mistake and not aadd_message.
# See https://github.com/langchain-ai/langchain/issues/22021
if self.async_mode:
loop = asyncio.get_event_loop()
loop.run_until_complete(self.aadd_messages(messages))
else:
with self._make_sync_session() as session:
for message in messages:
session.add(self.converter.to_sql_model(message, self.session_id))
session.commit()
# Add all messages in one transaction
with self._make_sync_session() as session:
for message in messages:
session.add(self.converter.to_sql_model(message, self.session_id))
session.commit()

async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
# Add all messages in one transaction
Expand Down
37 changes: 35 additions & 2 deletions libs/core/langchain_core/runnables/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.load.load import load
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import RunnableBranch
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.utils import (
Expand Down Expand Up @@ -306,8 +307,17 @@ def get_session_history(
history_chain = RunnablePassthrough.assign(
**{messages_key: history_chain}
).with_config(run_name="insert_history")
bound = (
history_chain | runnable.with_listeners(on_end=self._exit_history)
bound: Runnable = (
history_chain
| RunnableBranch(
(
RunnableLambda(
self._is_not_async, afunc=self._is_async
).with_config(run_name="RunnableWithMessageHistoryInAsyncMode"),
runnable.with_alisteners(on_end=self._aexit_history),
),
runnable.with_listeners(on_end=self._exit_history),
)
).with_config(run_name="RunnableWithMessageHistory")

if history_factory_config:
Expand Down Expand Up @@ -367,6 +377,12 @@ def get_input_schema(
else:
return super_schema

def _is_not_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool:
return False

async def _is_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool:
return True

def _get_input_messages(
self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
) -> List[BaseMessage]:
Expand Down Expand Up @@ -483,6 +499,23 @@ def _exit_history(self, run: Run, config: RunnableConfig) -> None:
output_messages = self._get_output_messages(output_val)
hist.add_messages(input_messages + output_messages)

async def _aexit_history(self, run: Run, config: RunnableConfig) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mackong would you mind unit testing code to cover the async path?

Copy link
Contributor Author

@mackong mackong Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eyurtsev unit testing added, but there is a unit will be failed caused by a unrelated issue. see https://github.com/langchain-ai/langchain/actions/runs/9594957636/job/26458682186?pr=22933#step:6:164

Now AsyncRootListenersTracer's schema format is original, so on_chat_model_start will fallback to on_llm_start, then type of Run's input will be str not BaseMessage, then it will be ignored by ChatMessageHistory's add_message.

if not isinstance(message, BaseMessage):
raise ValueError
self.messages.append(message)

I have create a PR #23214 which fix the issue, please review #23214 first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eyurtsev now added unit tests passed

hist: BaseChatMessageHistory = config["configurable"]["message_history"]

# Get the input messages
inputs = load(run.inputs)
input_messages = self._get_input_messages(inputs)
# If historic messages were prepended to the input messages, remove them to
# avoid adding duplicate messages to history.
if not self.history_messages_key:
historic_messages = config["configurable"]["message_history"].messages
input_messages = input_messages[len(historic_messages) :]

# Get the output messages
output_val = load(run.outputs)
output_messages = self._get_output_messages(output_val)
await hist.aadd_messages(input_messages + output_messages)

def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
config = super()._merge_configs(*configs)
expected_keys = [field_spec.id for field_spec in self.history_factory_config]
Expand Down
Loading
Loading