Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix generate_reply when sender is None. #1186

Merged
merged 9 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 36 additions & 32 deletions flaml/autogen/agentchat/responsive_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ def __init__(
self._default_auto_reply = default_auto_reply
self._reply_func_list = []
self.reply_at_receive = defaultdict(bool)
self.register_auto_reply(Agent, ResponsiveAgent.generate_oai_reply)
self.register_auto_reply(Agent, ResponsiveAgent.generate_code_execution_reply)
self.register_auto_reply(Agent, ResponsiveAgent.generate_function_call_reply)
self.register_auto_reply(Agent, ResponsiveAgent.check_termination_and_human_reply)
self.register_auto_reply([Agent, None], ResponsiveAgent.generate_oai_reply)
self.register_auto_reply([Agent, None], ResponsiveAgent.generate_code_execution_reply)
self.register_auto_reply([Agent, None], ResponsiveAgent.generate_function_call_reply)
self.register_auto_reply([Agent, None], ResponsiveAgent.check_termination_and_human_reply)

def register_auto_reply(
self,
Expand Down Expand Up @@ -748,17 +748,19 @@ def generate_reply(
str or dict or None: reply. None if no reply is generated.
"""
assert messages is not None or sender is not None, "Either messages or sender must be provided."
if sender is not None:
for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
if exclude and reply_func in exclude:
continue
if asyncio.coroutines.iscoroutinefunction(reply_func):
continue
if self._match_trigger(reply_func_tuple["trigger"], sender):
final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"])
if final:
return reply
if messages is None:
messages = self._oai_messages[sender]

for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
if exclude and reply_func in exclude:
continue
if asyncio.coroutines.iscoroutinefunction(reply_func):
continue
if self._match_trigger(reply_func_tuple["trigger"], sender):
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved
final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"])
if final:
return reply
return self._default_auto_reply

async def a_generate_reply(
Expand Down Expand Up @@ -792,27 +794,29 @@ async def a_generate_reply(
str or dict or None: reply. None if no reply is generated.
"""
assert messages is not None or sender is not None, "Either messages or sender must be provided."
if sender is not None:
for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
if exclude and reply_func in exclude:
continue
if self._match_trigger(reply_func_tuple["trigger"], sender):
if asyncio.coroutines.iscoroutinefunction(reply_func):
final, reply = await reply_func(
self, messages=messages, sender=sender, config=reply_func_tuple["config"]
)
else:
final, reply = reply_func(
self, messages=messages, sender=sender, config=reply_func_tuple["config"]
)
if final:
return reply
if messages is None:
messages = self._oai_messages[sender]

for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
if exclude and reply_func in exclude:
continue
if self._match_trigger(reply_func_tuple["trigger"], sender):
if asyncio.coroutines.iscoroutinefunction(reply_func):
final, reply = await reply_func(
self, messages=messages, sender=sender, config=reply_func_tuple["config"]
)
else:
final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"])
if final:
return reply
return self._default_auto_reply

def _match_trigger(self, trigger, sender):
"""Check if the sender matches the trigger."""
if isinstance(trigger, str):
if trigger is None:
return sender is None
elif isinstance(trigger, str):
return trigger == sender.name
elif isinstance(trigger, type):
return isinstance(sender, trigger)
Expand Down
21 changes: 21 additions & 0 deletions test/autogen/agentchat/test_responsive_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,27 @@ def test_responsive_agent():
assert dummy_agent_1.system_message == "new system message"


def test_generate_reply():
def add_num(num_to_be_added):
given_num = 10
return num_to_be_added + given_num

dummy_agent_2 = ResponsiveAgent(name="user_proxy", human_input_mode="TERMINATE", function_map={"add_num": add_num})
messsages = [{"function_call": {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}, "role": "assistant"}]

# when sender is None, messages is provided
assert (
dummy_agent_2.generate_reply(messages=messsages, sender=None)["content"] == "15"
), "generate_reply not working"
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved

# when sender is provided, messages is None
dummy_agent_1 = ResponsiveAgent(name="dummy_agent_1", human_input_mode="ALWAYS")
dummy_agent_2._oai_messages[dummy_agent_1] = messsages
assert (
dummy_agent_2.generate_reply(messages=None, sender=dummy_agent_1)["content"] == "15"
), "generate_reply not working"
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
test_trigger()
# test_context()
Expand Down
Loading