From b5115d31ead151a6d33bd1a00cd308cebd863ee6 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 20 Feb 2024 17:14:13 -0800 Subject: [PATCH] Assign message id in ChatOpenAI --- .../langchain_openai/chat_models/base.py | 33 ++++++++++++------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index d2fc4845cff8d3..74849ea3eeead9 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -92,8 +92,9 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: The LangChain message. """ role = _dict.get("role") + id_ = _dict.get("id") if role == "user": - return HumanMessage(content=_dict.get("content", "")) + return HumanMessage(content=_dict.get("content", ""), id=id_) elif role == "assistant": # Fix for azure # Also OpenAI returns None for tool invocations @@ -103,11 +104,13 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: additional_kwargs["function_call"] = dict(function_call) if tool_calls := _dict.get("tool_calls"): additional_kwargs["tool_calls"] = tool_calls - return AIMessage(content=content, additional_kwargs=additional_kwargs) + return AIMessage(content=content, additional_kwargs=additional_kwargs, id=id_) elif role == "system": - return SystemMessage(content=_dict.get("content", "")) + return SystemMessage(content=_dict.get("content", ""), id=id_) elif role == "function": - return FunctionMessage(content=_dict.get("content", ""), name=_dict.get("name")) + return FunctionMessage( + content=_dict.get("content", ""), name=_dict.get("name"), id=id_ + ) elif role == "tool": additional_kwargs = {} if "name" in _dict: @@ -116,9 +119,10 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: content=_dict.get("content", ""), tool_call_id=_dict.get("tool_call_id"), additional_kwargs=additional_kwargs, + id=id_, ) else: - return ChatMessage(content=_dict.get("content", ""), role=role) + return ChatMessage(content=_dict.get("content", ""), role=role, id=id_) def _convert_message_to_dict(message: BaseMessage) -> dict: @@ -171,6 +175,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: def _convert_delta_to_message_chunk( _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] ) -> BaseMessageChunk: + id_ = _dict.get("id") role = cast(str, _dict.get("role")) content = cast(str, _dict.get("content") or "") additional_kwargs: Dict = {} @@ -183,19 +188,23 @@ def _convert_delta_to_message_chunk( additional_kwargs["tool_calls"] = _dict["tool_calls"] if role == "user" or default_class == HumanMessageChunk: - return HumanMessageChunk(content=content) + return HumanMessageChunk(content=content, id=id_) elif role == "assistant" or default_class == AIMessageChunk: - return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) + return AIMessageChunk( + content=content, additional_kwargs=additional_kwargs, id=id_ + ) elif role == "system" or default_class == SystemMessageChunk: - return SystemMessageChunk(content=content) + return SystemMessageChunk(content=content, id=id_) elif role == "function" or default_class == FunctionMessageChunk: - return FunctionMessageChunk(content=content, name=_dict["name"]) + return FunctionMessageChunk(content=content, name=_dict["name"], id=id_) elif role == "tool" or default_class == ToolMessageChunk: - return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"]) + return ToolMessageChunk( + content=content, tool_call_id=_dict["tool_call_id"], id=id_ + ) elif role or default_class == ChatMessageChunk: - return ChatMessageChunk(content=content, role=role) + return ChatMessageChunk(content=content, role=role, id=id_) else: - return default_class(content=content) # type: ignore + return default_class(content=content, id=id_) # type: ignore class _FunctionCall(TypedDict):