diff --git a/flaml/autogen/agentchat/responsive_agent.py b/flaml/autogen/agentchat/responsive_agent.py index 0884b88913..0305068fe3 100644 --- a/flaml/autogen/agentchat/responsive_agent.py +++ b/flaml/autogen/agentchat/responsive_agent.py @@ -253,7 +253,9 @@ def _message_to_dict(message: Union[Dict, str]): else: return message - def _append_oai_message(self, message: Union[Dict, str], role, conversation_id: Agent) -> bool: + def _append_oai_message( + self, message: Union[Dict, str], conversation_id: Agent, role: Optional[str] = None + ) -> bool: """Append a message to the ChatCompletion conversation. If the message received is a string, it will be put in the "content" field of the new dictionary. @@ -262,7 +264,7 @@ def _append_oai_message(self, message: Union[Dict, str], role, conversation_id: Args: message (dict or str): message to be appended to the ChatCompletion conversation. - role (str): role of the message, can be "assistant" or "function". + role (str): role of the message. It can be "assistant", "user" or "function". It will be used to overwrite the "role" in the message only when original role is not "function". conversation_id (Agent): id of the conversation, should be the recipient or sender. Returns: @@ -270,11 +272,17 @@ def _append_oai_message(self, message: Union[Dict, str], role, conversation_id: """ message = self._message_to_dict(message) # create oai message to be appended to the oai conversation that can be passed to oai directly. - oai_message = {k: message[k] for k in ("content", "function_call", "name", "context") if k in message} - if "content" not in oai_message and "function_call" not in oai_message: + oai_message = {k: message[k] for k in ("content", "function_call", "name", "role", "context") if k in message} + # 'content' is a required field + if "content" not in oai_message: + return False + if role is not None: + # cannot overwrite the role if the original role is "function" + oai_message["role"] = role if oai_message.get("role", "") != "function" else "function" + elif oai_message.get("role", "") not in ("assistant", "user", "function", "system"): + # role is None and oai_message["role"] is not valid. return False - oai_message["role"] = "function" if message.get("role") == "function" else role self._oai_messages[conversation_id].append(oai_message) return True @@ -319,7 +327,7 @@ def send( """ # When the agent composes and sends the message, the role of the message is "assistant" # unless it's "function". - valid = self._append_oai_message(message, "assistant", recipient) + valid = self._append_oai_message(message, recipient, role="assistant") if valid: recipient.receive(message, self, request_reply, silent) else: @@ -368,7 +376,7 @@ async def a_send( """ # When the agent composes and sends the message, the role of the message is "assistant" # unless it's "function". - valid = self._append_oai_message(message, "assistant", recipient) + valid = self._append_oai_message(message, recipient, role="assistant") if valid: await recipient.a_receive(message, self, request_reply, silent) else: @@ -409,7 +417,7 @@ def _print_received_message(self, message: Union[Dict, str], sender: Agent): def _process_received_message(self, message, sender, silent): message = self._message_to_dict(message) # When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.) - valid = self._append_oai_message(message, "user", sender) + valid = self._append_oai_message(message, sender, role="user") if not valid: raise ValueError( "Received message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided."