Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix _append_oai_message #1195

Closed
wants to merge 12 commits into from
26 changes: 18 additions & 8 deletions flaml/autogen/agentchat/responsive_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,9 @@ def _message_to_dict(message: Union[Dict, str]):
else:
return message

def _append_oai_message(self, message: Union[Dict, str], role, conversation_id: Agent) -> bool:
def _append_oai_message(
self, message: Union[Dict, str], conversation_id: Agent, role: Optional[str] = None
) -> bool:
"""Append a message to the ChatCompletion conversation.

If the message received is a string, it will be put in the "content" field of the new dictionary.
Expand All @@ -262,19 +264,27 @@ def _append_oai_message(self, message: Union[Dict, str], role, conversation_id:

Args:
message (dict or str): message to be appended to the ChatCompletion conversation.
role (str): role of the message, can be "assistant" or "function".
role (str): role of the message, can be "assistant", "user" or "function", this will overwrite the `role` in the message only when original role is not "function".
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved
conversation_id (Agent): id of the conversation, should be the recipient or sender.

Returns:
bool: whether the message is appended to the ChatCompletion conversation.
"""
message = self._message_to_dict(message)
# create oai message to be appended to the oai conversation that can be passed to oai directly.
oai_message = {k: message[k] for k in ("content", "function_call", "name", "context") if k in message}
if "content" not in oai_message and "function_call" not in oai_message:
oai_message = {k: message[k] for k in ("content", "function_call", "name", "role", "context") if k in message}
# 'content' is a required field
if "content" not in oai_message:
return False
if role is not None:
if oai_message["role"] is "function" and role is not "function":
print(f"Warning: Attempt to overwrite role 'function' with {role}. Rejected.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested? This will insert a warning in the conversation will may confuse readers. I don't know why we want to print this message.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, I will remove the warning add add a comment.

return False
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved
oai_message["role"] = role
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved
elif oai_message.get("role", "") not in ("assistant", "user", "function", "system"):
# role is None and oai_message["role"] is not valid.
return False

oai_message["role"] = "function" if message.get("role") == "function" else role
self._oai_messages[conversation_id].append(oai_message)
return True

Expand Down Expand Up @@ -319,7 +329,7 @@ def send(
"""
# When the agent composes and sends the message, the role of the message is "assistant"
# unless it's "function".
valid = self._append_oai_message(message, "assistant", recipient)
valid = self._append_oai_message(message, recipient, role="assistant")
if valid:
recipient.receive(message, self, request_reply, silent)
else:
Expand Down Expand Up @@ -368,7 +378,7 @@ async def a_send(
"""
# When the agent composes and sends the message, the role of the message is "assistant"
# unless it's "function".
valid = self._append_oai_message(message, "assistant", recipient)
valid = self._append_oai_message(message, recipient, role="assistant")
if valid:
await recipient.a_receive(message, self, request_reply, silent)
else:
Expand Down Expand Up @@ -409,7 +419,7 @@ def _print_received_message(self, message: Union[Dict, str], sender: Agent):
def _process_received_message(self, message, sender, silent):
message = self._message_to_dict(message)
# When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.)
valid = self._append_oai_message(message, "user", sender)
valid = self._append_oai_message(message, sender, role="user")
if not valid:
raise ValueError(
"Received message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided."
Expand Down
Loading