From 558ab7d1f72279dabcfdbf87f59574965f37478b Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 21 Dec 2023 12:31:56 -0800 Subject: [PATCH] Add option to make messages placeholder optional --- libs/core/langchain_core/prompts/chat.py | 14 +++++++++++--- libs/core/tests/unit_tests/prompts/test_chat.py | 9 +++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 2b24bf3d3345b..e2978d383bff9 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -87,13 +87,17 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): variable_name: str """Name of variable to use as messages.""" + optional: bool = False + @classmethod def get_lc_namespace(cls) -> List[str]: """Get the namespace of the langchain object.""" return ["langchain", "prompts", "chat"] - def __init__(self, variable_name: str, **kwargs: Any): - return super().__init__(variable_name=variable_name, **kwargs) + def __init__(self, variable_name: str, *, optional: bool = False, **kwargs: Any): + return super().__init__( + variable_name=variable_name, optional=optional, **kwargs + ) def format_messages(self, **kwargs: Any) -> List[BaseMessage]: """Format messages from kwargs. @@ -104,7 +108,11 @@ def format_messages(self, **kwargs: Any) -> List[BaseMessage]: Returns: List of BaseMessage. """ - value = kwargs[self.variable_name] + value = ( + kwargs.get(self.variable_name, []) + if self.optional + else kwargs[self.variable_name] + ) if not isinstance(value, list): raise ValueError( f"variable {self.variable_name} should be a list of base messages, " diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 3719573a6c674..2765d030d5686 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -19,6 +19,7 @@ ChatMessagePromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate, + MessagesPlaceholder, SystemMessagePromptTemplate, _convert_to_message, ) @@ -360,3 +361,11 @@ def test_chat_message_partial() -> None: ] assert res == expected assert template2.format(input="hello") == get_buffer_string(expected) + + +def test_messages_placeholder() -> None: + prompt = MessagesPlaceholder("history") + with pytest.raises(KeyError): + prompt.format_messages() + prompt = MessagesPlaceholder("history", optional=True) + assert prompt.format_messages() == []