diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 5af8ba14..6ac46878 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -13,6 +13,7 @@ from .client_tool import ClientTool +DEFAULT_MAX_ITER = 10 class Agent: def __init__( @@ -75,21 +76,32 @@ def create_turn( toolgroups: Optional[List[Toolgroup]] = None, documents: Optional[List[Document]] = None, ): - response = self.client.agents.turn.create( - agent_id=self.agent_id, - # use specified session_id or last session created - session_id=session_id or self.session_id[-1], - messages=messages, - stream=True, - documents=documents, - toolgroups=toolgroups, - ) - for chunk in response: - if hasattr(chunk, "error"): - yield chunk - return - elif not self._has_tool_call(chunk): - yield chunk - else: - next_message = self._run_tool(chunk) - yield next_message + stop = False + n_iter = 0 + max_iter = self.agent_config.get('max_infer_iters', DEFAULT_MAX_ITER) + while not stop and n_iter < max_iter: + response = self.client.agents.turn.create( + agent_id=self.agent_id, + # use specified session_id or last session created + session_id=session_id or self.session_id[-1], + messages=messages, + stream=True, + documents=documents, + toolgroups=toolgroups, + ) + # by default, we stop after the first turn + stop = True + for chunk in response: + if hasattr(chunk, "error"): + yield chunk + return + elif not self._has_tool_call(chunk): + yield chunk + else: + next_message = self._run_tool(chunk) + yield next_message + + # continue the turn when there's a tool call + stop = False + messages = [next_message] + n_iter += 1