Skip to content

Commit

Permalink
core(minor): Add bulk add messages to BaseChatMessageHistory interface (
Browse files Browse the repository at this point in the history
#15709)

* Add bulk add_messages method to the interface.
* Update documentation for add_ai_message and add_human_message to
denote them as being marked for deprecation. We should stop using them
as they create more incorrect (inefficient) ways of doing things
  • Loading branch information
eyurtsev committed Jan 31, 2024
1 parent af8c5c1 commit 2e5949b
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 7 deletions.
60 changes: 53 additions & 7 deletions libs/core/langchain_core/chat_history.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import List, Union
from typing import List, Sequence, Union

from langchain_core.messages import (
AIMessage,
Expand All @@ -14,9 +14,18 @@
class BaseChatMessageHistory(ABC):
"""Abstract base class for storing chat message history.
See `ChatMessageHistory` for default implementation.
Implementations should over-ride the add_messages method to handle bulk addition
of messages.
The default implementation of add_message will correctly call add_messages, so
it is not necessary to implement both methods.
When used for updating history, users should favor usage of `add_messages`
over `add_message` or other variants like `add_user_message` and `add_ai_message`
to avoid unnecessary round-trips to the underlying persistence layer.
Example: Shows a default implementation.
Example:
.. code-block:: python
class FileChatMessageHistory(BaseChatMessageHistory):
Expand All @@ -29,8 +38,13 @@ def messages(self):
messages = json.loads(f.read())
return messages_from_dict(messages)
def add_message(self, message: BaseMessage) -> None:
messages = self.messages.append(_message_to_dict(message))
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
all_messages = list(self.messages) # Existing messages
all_messages.extend(messages) # Add new messages
serialized = [message_to_dict(message) for message in all_messages]
# Can be further optimized by only writing new messages
# using append mode.
with open(os.path.join(storage_path, session_id), 'w') as f:
json.dump(f, messages)
Expand All @@ -45,6 +59,12 @@ def clear(self):
def add_user_message(self, message: Union[HumanMessage, str]) -> None:
"""Convenience method for adding a human message string to the store.
Please note that this is a convenience method. Code should favor the
bulk add_messages interface instead to save on round-trips to the underlying
persistence layer.
This method may be deprecated in a future release.
Args:
message: The human message to add
"""
Expand All @@ -56,6 +76,12 @@ def add_user_message(self, message: Union[HumanMessage, str]) -> None:
def add_ai_message(self, message: Union[AIMessage, str]) -> None:
"""Convenience method for adding an AI message string to the store.
Please note that this is a convenience method. Code should favor the bulk
add_messages interface instead to save on round-trips to the underlying
persistence layer.
This method may be deprecated in a future release.
Args:
message: The AI message to add.
"""
Expand All @@ -64,18 +90,38 @@ def add_ai_message(self, message: Union[AIMessage, str]) -> None:
else:
self.add_message(AIMessage(content=message))

@abstractmethod
def add_message(self, message: BaseMessage) -> None:
"""Add a Message object to the store.
Args:
message: A BaseMessage object to store.
"""
raise NotImplementedError()
if type(self).add_messages != BaseChatMessageHistory.add_messages:
# This means that the sub-class has implemented an efficient add_messages
# method, so we should usage of add_message to that.
self.add_messages([message])
else:
raise NotImplementedError(
"add_message is not implemented for this class. "
"Please implement add_message or add_messages."
)

def add_messages(self, messages: Sequence[BaseMessage]) -> None:
"""Add a list of messages.
Implementations should over-ride this method to handle bulk addition of messages
in an efficient manner to avoid unnecessary round-trips to the underlying store.
Args:
messages: A list of BaseMessage objects to store.
"""
for message in messages:
self.add_message(message)

@abstractmethod
def clear(self) -> None:
"""Remove all messages from the store"""

def __str__(self) -> str:
"""Return a string representation of the chat history."""
return get_buffer_string(self.messages)
Empty file.
68 changes: 68 additions & 0 deletions libs/core/tests/unit_tests/chat_history/test_chat_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import List, Sequence

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, HumanMessage


def test_add_message_implementation_only() -> None:
"""Test implementation of add_message only."""

class SampleChatHistory(BaseChatMessageHistory):
def __init__(self, *, store: List[BaseMessage]) -> None:
self.store = store

def add_message(self, message: BaseMessage) -> None:
"""Add a message to the store."""
self.store.append(message)

def clear(self) -> None:
"""Clear the store."""
raise NotImplementedError()

store: List[BaseMessage] = []
chat_history = SampleChatHistory(store=store)
chat_history.add_message(HumanMessage(content="Hello"))
assert len(store) == 1
assert store[0] == HumanMessage(content="Hello")
chat_history.add_message(HumanMessage(content="World"))
assert len(store) == 2
assert store[1] == HumanMessage(content="World")

chat_history.add_messages(
[HumanMessage(content="Hello"), HumanMessage(content="World")]
)
assert len(store) == 4
assert store[2] == HumanMessage(content="Hello")
assert store[3] == HumanMessage(content="World")


def test_bulk_message_implementation_only() -> None:
"""Test that SampleChatHistory works as expected."""
store: List[BaseMessage] = []

class BulkAddHistory(BaseChatMessageHistory):
def __init__(self, *, store: List[BaseMessage]) -> None:
self.store = store

def add_messages(self, message: Sequence[BaseMessage]) -> None:
"""Add a message to the store."""
self.store.extend(message)

def clear(self) -> None:
"""Clear the store."""
raise NotImplementedError()

chat_history = BulkAddHistory(store=store)
chat_history.add_message(HumanMessage(content="Hello"))
assert len(store) == 1
assert store[0] == HumanMessage(content="Hello")
chat_history.add_message(HumanMessage(content="World"))
assert len(store) == 2
assert store[1] == HumanMessage(content="World")

chat_history.add_messages(
[HumanMessage(content="Hello"), HumanMessage(content="World")]
)
assert len(store) == 4
assert store[2] == HumanMessage(content="Hello")
assert store[3] == HumanMessage(content="World")

0 comments on commit 2e5949b

Please sign in to comment.