Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions core/src/main/java/com/google/adk/agents/RunConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,26 @@ public enum StreamingMode {
/**
* Execution mode when the model requests multiple tools.
*
* <p>NONE: defaults to SEQUENTIAL.
* <p>NONE: defaults to PARALLEL.
*
* <p>SEQUENTIAL: tools execute in request order on the caller thread.
* <p>SEQUENTIAL: tools execute strictly in request order on the caller thread; each tool must
* complete (including any asynchronous work) before the next one is subscribed to.
*
* <p>PARALLEL: tools execute concurrently on worker threads. Tool implementations must be
* thread-safe.
* <p>PARALLEL: tools are subscribed to eagerly on the caller thread (i.e. all are kicked off
* up-front), but no worker threads are introduced. Tools that are truly asynchronous (e.g. they
* return a {@code Single} backed by I/O or another scheduler) will run concurrently; tools that
* block the subscribing thread (e.g. {@code Single.fromCallable} that performs blocking work)
* will still execute sequentially. This preserves the historical default behavior.
*
* <p>PARALLEL_SUBSCRIBE: like {@code PARALLEL}, but every tool is additionally subscribed on a
* worker thread, so blocking tools also run concurrently. Tool implementations must be
* thread-safe. The worker is the agent's executor when set, otherwise the RxJava IO scheduler.
*/
public enum ToolExecutionMode {
NONE,
SEQUENTIAL,
PARALLEL
PARALLEL,
PARALLEL_SUBSCRIBE
}

public abstract @Nullable SpeechConfig speechConfig();
Expand Down
35 changes: 26 additions & 9 deletions core/src/main/java/com/google/adk/flows/llmflows/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -236,23 +236,40 @@ public static Maybe<Event> handleFunctionCallsLive(
}

/**
* Sequential by default; only {@link ToolExecutionMode#PARALLEL} with multiple calls dispatches
* tools on workers (using {@code concatMapEager} to preserve input order).
* Builds the tool-execution {@link Observable} for the configured {@link ToolExecutionMode}.
*
* <ul>
* <li>{@link ToolExecutionMode#SEQUENTIAL} (or a single call, where parallelism is moot) uses
* {@code concatMapMaybe}: each tool is subscribed only after the previous one completes.
* <li>{@link ToolExecutionMode#PARALLEL} (the default) uses {@code concatMapEager}: all tools
* are subscribed eagerly on the caller thread. Async tools therefore run concurrently, but
* tools that block the subscribing thread still execute sequentially. This matches the
* historical behavior of the default mode.
* <li>{@link ToolExecutionMode#PARALLEL_SUBSCRIBE} uses {@code concatMapEager} and additionally
* subscribes each tool on a worker scheduler, so blocking tools also run concurrently.
* {@code concatMapEager} preserves input order required by {@link
* #mergeParallelFunctionResponseEvents}.
* </ul>
*/
private static Observable<Event> buildToolExecutionObservable(
InvocationContext invocationContext,
List<FunctionCall> validFunctionCalls,
Function<FunctionCall, Maybe<Event>> functionCallMapper) {
boolean parallel =
invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.PARALLEL
&& validFunctionCalls.size() > 1;
if (!parallel) {
ToolExecutionMode mode = invocationContext.runConfig().toolExecutionMode();
boolean sequential = mode == ToolExecutionMode.SEQUENTIAL || validFunctionCalls.size() <= 1;
if (sequential) {
return Observable.fromIterable(validFunctionCalls).concatMapMaybe(functionCallMapper);
}
Scheduler scheduler = resolveToolExecutionScheduler(invocationContext);
if (mode == ToolExecutionMode.PARALLEL_SUBSCRIBE) {
Scheduler scheduler = resolveToolExecutionScheduler(invocationContext);
return Observable.fromIterable(validFunctionCalls)
.concatMapEager(
call -> functionCallMapper.apply(call).toObservable().subscribeOn(scheduler));
}
// PARALLEL (and NONE, which defaults to PARALLEL): eager subscribe on the caller thread,
// without offloading to a worker. Async tools run concurrently; blocking tools still block.
return Observable.fromIterable(validFunctionCalls)
.concatMapEager(
call -> functionCallMapper.apply(call).toObservable().subscribeOn(scheduler));
.concatMapEager(call -> functionCallMapper.apply(call).toObservable());
}

/** Agent executor if set, otherwise the IO scheduler. */
Expand Down
70 changes: 57 additions & 13 deletions core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,10 @@ public void getAskUserConfirmationFunctionCalls_eventWithConfirmationFunctionCal
assertThat(result).containsExactly(confirmationCall1, confirmationCall2);
}

// Default ToolExecutionMode.NONE must execute tools sequentially.
// Default ToolExecutionMode.NONE behaves like PARALLEL: blocking tools still execute serially
// on the caller thread (no worker scheduler is used), preserving the historical default.
@Test
public void handleFunctionCalls_defaultMode_blockingTools_runSequentially() {
public void handleFunctionCalls_defaultMode_blockingTools_runSerially() {
long sleepMillis = 300L;
int toolCount = 2;
InvocationContext invocationContext =
Expand Down Expand Up @@ -435,29 +436,69 @@ public void handleFunctionCalls_defaultMode_blockingTools_runSequentially() {
assertThat(durationMillis).isAtLeast((long) toolCount * sleepMillis);
}

// PARALLEL mode does NOT introduce worker threads; blocking tools still run serially on the
// caller thread. PARALLEL_SUBSCRIBE is the mode that runs blocking tools concurrently.
@Test
public void handleFunctionCalls_parallel_blockingTools_runConcurrently_twoTools() {
runParallelBlockingToolsTest(/* toolCount= */ 2);
public void handleFunctionCalls_parallel_blockingTools_runSerially() {
long sleepMillis = 300L;
int toolCount = 2;
InvocationContext invocationContext =
createInvocationContext(
createRootAgent(),
RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL).build());

Map<String, BaseTool> tools = new LinkedHashMap<>();
List<Part> callParts = new ArrayList<>();
for (int i = 1; i <= toolCount; i++) {
String toolName = "slow_tool_" + i;
tools.put(toolName, new SleepingTool(toolName, sleepMillis));
callParts.add(
Part.builder()
.functionCall(
FunctionCall.builder()
.id("call_" + i)
.name(toolName)
.args(ImmutableMap.of())
.build())
.build());
}
Event event =
createEvent("event").toBuilder()
.content(Content.fromParts(callParts.toArray(new Part[0])))
.build();

long start = System.currentTimeMillis();
Event functionResponseEvent =
Functions.handleFunctionCalls(invocationContext, event, tools).blockingGet();
long durationMillis = System.currentTimeMillis() - start;

assertThat(functionResponseEvent).isNotNull();
assertThat(durationMillis).isAtLeast((long) toolCount * sleepMillis);
}

@Test
public void handleFunctionCalls_parallel_blockingTools_runConcurrently_threeTools() {
runParallelBlockingToolsTest(/* toolCount= */ 3);
public void handleFunctionCalls_parallelSubscribe_blockingTools_runConcurrently_twoTools() {
runParallelSubscribeBlockingToolsTest(/* toolCount= */ 2);
}

@Test
public void handleFunctionCalls_parallel_blockingTools_runConcurrently_fiveTools() {
runParallelBlockingToolsTest(/* toolCount= */ 5);
public void handleFunctionCalls_parallelSubscribe_blockingTools_runConcurrently_threeTools() {
runParallelSubscribeBlockingToolsTest(/* toolCount= */ 3);
}

@Test
public void handleFunctionCalls_parallelSubscribe_blockingTools_runConcurrently_fiveTools() {
runParallelSubscribeBlockingToolsTest(/* toolCount= */ 5);
}

/** Single-tool case bypasses the parallel scheduler path; must still return the correct event. */
@Test
public void handleFunctionCalls_parallel_blockingTool_singleTool() {
public void handleFunctionCalls_parallelSubscribe_blockingTool_singleTool() {
long sleepMillis = 200L;
InvocationContext invocationContext =
createInvocationContext(
createRootAgent(),
RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL).build());
RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL_SUBSCRIBE).build());
SleepingTool tool = new SleepingTool("slow_tool_1", sleepMillis);
Event event =
createEvent("event").toBuilder()
Expand Down Expand Up @@ -491,13 +532,16 @@ public void handleFunctionCalls_parallel_blockingTool_singleTool() {
.build());
}

/** Asserts that {@code toolCount} blocking tools in PARALLEL mode run faster than sequential. */
private static void runParallelBlockingToolsTest(int toolCount) {
/**
* Asserts that {@code toolCount} blocking tools in PARALLEL_SUBSCRIBE mode run faster than
* sequential, since each tool is subscribed on a worker thread.
*/
private static void runParallelSubscribeBlockingToolsTest(int toolCount) {
long sleepMillis = 500L;
InvocationContext invocationContext =
createInvocationContext(
createRootAgent(),
RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL).build());
RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL_SUBSCRIBE).build());

Map<String, BaseTool> tools = new LinkedHashMap<>();
List<Part> callParts = new ArrayList<>();
Expand Down
Loading