diff --git a/flaml/autogen/agentchat/contrib/math_user_proxy_agent.py b/flaml/autogen/agentchat/contrib/math_user_proxy_agent.py index e0a017adb4..3455847051 100644 --- a/flaml/autogen/agentchat/contrib/math_user_proxy_agent.py +++ b/flaml/autogen/agentchat/contrib/math_user_proxy_agent.py @@ -165,7 +165,7 @@ def __init__( default_auto_reply=default_auto_reply, **kwargs, ) - self.register_auto_reply(Agent, MathUserProxyAgent._generate_math_reply, 1) + self.register_auto_reply([Agent, None], MathUserProxyAgent._generate_math_reply, 1) # fixed var self._max_invalid_q_per_step = max_invalid_q_per_step diff --git a/flaml/autogen/agentchat/responsive_agent.py b/flaml/autogen/agentchat/responsive_agent.py index 1daa35ec9e..00fb4ad303 100644 --- a/flaml/autogen/agentchat/responsive_agent.py +++ b/flaml/autogen/agentchat/responsive_agent.py @@ -119,10 +119,10 @@ def __init__( self._default_auto_reply = default_auto_reply self._reply_func_list = [] self.reply_at_receive = defaultdict(bool) - self.register_auto_reply(Agent, ResponsiveAgent.generate_oai_reply) - self.register_auto_reply(Agent, ResponsiveAgent.generate_code_execution_reply) - self.register_auto_reply(Agent, ResponsiveAgent.generate_function_call_reply) - self.register_auto_reply(Agent, ResponsiveAgent.check_termination_and_human_reply) + self.register_auto_reply([Agent, None], ResponsiveAgent.generate_oai_reply) + self.register_auto_reply([Agent, None], ResponsiveAgent.generate_code_execution_reply) + self.register_auto_reply([Agent, None], ResponsiveAgent.generate_function_call_reply) + self.register_auto_reply([Agent, None], ResponsiveAgent.check_termination_and_human_reply) def register_auto_reply( self, @@ -145,6 +145,8 @@ def register_auto_reply( - If an agent instance is provided, the reply function will be called when the sender is the agent instance. - If a callable is provided, the reply function will be called when the callable returns True. - If a list is provided, the reply function will be called when any of the triggers in the list is activated. + - If None is provided, the reply function will be called only when the sender is None. + Note: Be sure to register `None` as a trigger if you would like to trigger an auto-reply function with non-empty messages and `sender=None`. reply_func (Callable): the reply function. The function takes a recipient agent, a list of messages, a sender agent and a config as input and returns a reply message. ```python @@ -726,6 +728,7 @@ def generate_reply( """Reply based on the conversation history and the sender. Either messages or sender must be provided. + Register a reply_func with `None` as one trigger for it to be activated when `messages` is non-empty and `sender` is `None`. Use registered auto reply functions to generate replies. By default, the following functions are checked in order: 1. check_termination_and_human_reply @@ -748,17 +751,19 @@ def generate_reply( str or dict or None: reply. None if no reply is generated. """ assert messages is not None or sender is not None, "Either messages or sender must be provided." - if sender is not None: - for reply_func_tuple in self._reply_func_list: - reply_func = reply_func_tuple["reply_func"] - if exclude and reply_func in exclude: - continue - if asyncio.coroutines.iscoroutinefunction(reply_func): - continue - if self._match_trigger(reply_func_tuple["trigger"], sender): - final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"]) - if final: - return reply + if messages is None: + messages = self._oai_messages[sender] + + for reply_func_tuple in self._reply_func_list: + reply_func = reply_func_tuple["reply_func"] + if exclude and reply_func in exclude: + continue + if asyncio.coroutines.iscoroutinefunction(reply_func): + continue + if self._match_trigger(reply_func_tuple["trigger"], sender): + final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"]) + if final: + return reply return self._default_auto_reply async def a_generate_reply( @@ -770,6 +775,7 @@ async def a_generate_reply( """(async) Reply based on the conversation history and the sender. Either messages or sender must be provided. + Register a reply_func with `None` as one trigger for it to be activated when `messages` is non-empty and `sender` is `None`. Use registered auto reply functions to generate replies. By default, the following functions are checked in order: 1. check_termination_and_human_reply @@ -792,27 +798,29 @@ async def a_generate_reply( str or dict or None: reply. None if no reply is generated. """ assert messages is not None or sender is not None, "Either messages or sender must be provided." - if sender is not None: - for reply_func_tuple in self._reply_func_list: - reply_func = reply_func_tuple["reply_func"] - if exclude and reply_func in exclude: - continue - if self._match_trigger(reply_func_tuple["trigger"], sender): - if asyncio.coroutines.iscoroutinefunction(reply_func): - final, reply = await reply_func( - self, messages=messages, sender=sender, config=reply_func_tuple["config"] - ) - else: - final, reply = reply_func( - self, messages=messages, sender=sender, config=reply_func_tuple["config"] - ) - if final: - return reply + if messages is None: + messages = self._oai_messages[sender] + + for reply_func_tuple in self._reply_func_list: + reply_func = reply_func_tuple["reply_func"] + if exclude and reply_func in exclude: + continue + if self._match_trigger(reply_func_tuple["trigger"], sender): + if asyncio.coroutines.iscoroutinefunction(reply_func): + final, reply = await reply_func( + self, messages=messages, sender=sender, config=reply_func_tuple["config"] + ) + else: + final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"]) + if final: + return reply return self._default_auto_reply def _match_trigger(self, trigger, sender): """Check if the sender matches the trigger.""" - if isinstance(trigger, str): + if trigger is None: + return sender is None + elif isinstance(trigger, str): return trigger == sender.name elif isinstance(trigger, type): return isinstance(sender, trigger) diff --git a/test/autogen/agentchat/test_responsive_agent.py b/test/autogen/agentchat/test_responsive_agent.py index 8d169c7778..4d44f02ce5 100644 --- a/test/autogen/agentchat/test_responsive_agent.py +++ b/test/autogen/agentchat/test_responsive_agent.py @@ -154,6 +154,27 @@ def test_responsive_agent(): assert dummy_agent_1.system_message == "new system message" +def test_generate_reply(): + def add_num(num_to_be_added): + given_num = 10 + return num_to_be_added + given_num + + dummy_agent_2 = ResponsiveAgent(name="user_proxy", human_input_mode="TERMINATE", function_map={"add_num": add_num}) + messsages = [{"function_call": {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}, "role": "assistant"}] + + # when sender is None, messages is provided + assert ( + dummy_agent_2.generate_reply(messages=messsages, sender=None)["content"] == "15" + ), "generate_reply not working when sender is None" + + # when sender is provided, messages is None + dummy_agent_1 = ResponsiveAgent(name="dummy_agent_1", human_input_mode="ALWAYS") + dummy_agent_2._oai_messages[dummy_agent_1] = messsages + assert ( + dummy_agent_2.generate_reply(messages=None, sender=dummy_agent_1)["content"] == "15" + ), "generate_reply not working when messages is None" + + if __name__ == "__main__": test_trigger() # test_context()