Skip to content

Multiple Agents run in parallel with different inputs. How to implement? #2124

@namjsuh

Description

@namjsuh

** Please make sure you read the contribution guide and file the issues in the right place. **
Contribution guide.

Is your feature request related to a problem? Please describe.
I have a list of subtasks, and each subtask is dealt by the sub_task agent. These agents are basically same agents but they are instantiated separtely. I want each sub_task agent deals with each subtask and gives output in parallel fashion.
I have no idea on how to isolate the ctx session state variables. I constantly have the race condition problems, and hope you can correct me if I make any mistakes.. Below is my code:

Describe the solution you'd like
`# Define the build_battery_db_subtask_agent function
def build_battery_db_subtask_agent(
db_tool: LongRunningFunctionTool,
ff_tool: LongRunningFunctionTool,
chemical_property_agent: LlmAgent,
pdf_agent: LlmAgent,
idx: int = 0
) -> LlmAgent:
"""Create the sub-task answering agent."""
return LlmAgent(
model=LiteLlm(
model=-,
stream=False,
reasoning_effort="low",
stream_options={"include_usage": True}
),
name="Battery_Agent",
description=(
"An agent which retrieves battery-related knowledge from the proprietary vector database."
"Based on the retrieved information, output the scientifically well-supported answers for "
"the given sub-task."
),
instruction=SUBTASK_EXPERT_INSTRUCTION,
include_contents="none",
input_schema=SubtaskInput,
tools=[db_tool, ff_tool, AgentTool(agent=chemical_property_agent), AgentTool(agent=pdf_agent)],
output_key=f"sub_task_answer_{idx}",

Define a debugging version to see what's happening

def create_subtask_worker(name: str, idx: int, task_string: str, battery_db_subtask: BaseAgent) -> BaseAgent:
"""Create a subtask worker with extensive debugging"""

class SubtaskWorkerImpl(BaseAgent):
    @override
    async def _run_async_impl(self, ctx: InvocationContext) -> AsyncGenerator[Event, None]:
        answer_key = f"sub_task_answer_{idx}"
        
        logger.info(f"=== [{name}] STARTING TASK {idx} ===")
        logger.info(f"[{name}] Task string: {task_string}")
        logger.info(f"[{name}] Answer key: {answer_key}")
        
        # Log current session state
        logger.info(f"[{name}] Session keys before: {list(ctx.session.state.keys())}")
        
        # Create isolated input
        isolated_input = SubtaskInput(
            rewritten_query=ctx.session.state.get("rewritten_query", ""),
            task_list=ctx.session.state.get("task_list", []),
            sub_task_answers=ctx.session.state.get("sub_task_answers", []),
            current_sub_task_string=task_string,
            extra_outputs=ctx.session.state.get("extra_outputs", {})
        )
        
        logger.info(f"[{name}] Created isolated input with task: {isolated_input.current_sub_task_string}")
        
        # Set session state with locking mechanism
        ctx.session.state[f"worker_{idx}_active"] = True
        ctx.session.state["current_sub_task_string"] = task_string
        ctx.session.state["current_worker"] = name
        
        logger.info(f"[{name}] Set session state - current_sub_task_string: {ctx.session.state['current_sub_task_string']}")
        logger.info(f"[{name}] Set session state - current_worker: {ctx.session.state['current_worker']}")
        
        # Run the agent with timeout simulation
        logger.info(f"[{name}] Starting battery_db_subtask.run_async...")
        start_time = time.time()
        
        event_count = 0
        async for event in battery_db_subtask.run_async(ctx):
            event_count += 1
            elapsed = time.time() - start_time
            logger.info(f"[{name}] Event {event_count} after {elapsed:.1f}s: {event.type if hasattr(event, 'type') else str(event)}")
            
            # Safety timeout
            if elapsed > 120:  # 2 minute timeout per worker
                logger.error(f"[{name}] TIMEOUT after {elapsed:.1f} seconds!")
                break
                
            yield event
        
        total_time = time.time() - start_time
        logger.info(f"[{name}] Finished battery_db_subtask after {total_time:.1f}s with {event_count} events")

        # Extract answer
        raw = ctx.session.state.get(answer_key, "<no answer>")
        logger.info(f"[{name}] Raw answer from {answer_key}: {str(raw)[:100]}...")
        
        parsed = safe_json_parse(raw)
        main_answer = (
            parsed.get("sub_task_answer", parsed)
            if isinstance(parsed, dict) else parsed
        )
        
        logger.info(f"[{name}] Parsed answer: {str(main_answer)[:100]}...")

        # Store result
        result = {"task": task_string, "sub_task_answer": main_answer}
        ctx.session.state["sub_task_answers"].append(result)
        
        logger.info(f"[{name}] Added result to sub_task_answers. Total answers: {len(ctx.session.state['sub_task_answers'])}")
        
        # Cleanup
        ctx.session.state.pop(f"worker_{idx}_active", None)
        if ctx.session.state.get("current_worker") == name:
            ctx.session.state.pop("current_sub_task_string", None)
            ctx.session.state.pop("current_worker", None)
        
        logger.info(f"=== [{name}] COMPLETED TASK {idx} in {total_time:.1f}s ===")

return SubtaskWorkerImpl(name=name)

Define the main async function

async def call_agent_async(query: str,
sub_task_list: list[str],
app_name: str,
user_id: str,
session_id: str):

logger.info(f"Starting parallel execution with {len(sub_task_list)} tasks")

# Create shared tools
GetBatteryDBTool = LongRunningFunctionTool(func=GetBatteryDB)
FindFriendTool = LongRunningFunctionTool(func=find_friend_wrapper)
GetMolPropTool = LongRunningFunctionTool(func=GetMoleculeProperty)
ChemPropAgent = build_chemical_property_agent(GetMolPropTool)
PDFAgent = build_pdf_agent(md_tool, fig_tool)

# Create workers using the factory function
workers = []
for i, task in enumerate(sub_task_list):
    battery_db_subtask = build_battery_db_subtask_agent(
        GetBatteryDBTool, FindFriendTool,
        ChemPropAgent, PDFAgent,
        idx=i
    )
    
    worker = create_subtask_worker(
        name=f"Worker_{i}",
        idx=i,
        task_string=task,
        battery_db_subtask=battery_db_subtask
    )
    workers.append(worker)

# Create the parallel pipeline
pipeline = ParallelAgent(
    name="ParallelBatteryPipeline",
    sub_agents=workers,
    description="Each worker receives an isolated input to prevent race conditions."
)

# Create session and runner
sess_service = InMemorySessionService()
await sess_service.create_session(
    app_name=app_name, 
    user_id=user_id, 
    session_id=session_id,
    state={
        "rewritten_query": query,
        "task_list": sub_task_list,
        "sub_task_answers": [],
        "extra_outputs": {},
    }
)

runner = Runner(
    agent=pipeline,
    app_name=app_name,
    session_service=sess_service
)

# Execute the pipeline
content = types.Content(role="user", parts=[types.Part(text=query)])
async for _ in runner.run_async(
    user_id=user_id, 
    session_id=session_id, 
    new_message=content
):
    pass

# Get the final results
final = await sess_service.get_session(
    app_name=app_name, 
    user_id=user_id, 
    session_id=session_id
)
answers = final.state.get("sub_task_answers", [])

# Clean up
await sess_service.delete_session(
    app_name=app_name, 
    user_id=user_id, 
    session_id=session_id
)

logger.info(f"Parallel execution completed. Got {len(answers)} answers.")
return answers

`

Please let me know what are the problems in my current code and how to fix them..
Your help is appreciated in advance!!

Metadata

Metadata

Assignees

No one assigned

    Labels

    question[Component] This issue is asking a question or clarification

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions