From 101b4ee545f624d07c2abc1da862e10f2bd727e1 Mon Sep 17 00:00:00 2001 From: Jinni Gu Date: Thu, 20 Nov 2025 22:42:09 -0800 Subject: [PATCH] support output schema --- .../java/com/google/adk/agents/LlmAgent.java | 19 ++- .../com/google/adk/flows/llmflows/Basic.java | 4 +- .../adk/tools/SetModelResponseTool.java | 58 +++++++++ .../com/google/adk/agents/LlmAgentTest.java | 28 ++-- .../adk/tools/SetModelResponseToolTest.java | 123 ++++++++++++++++++ 5 files changed, 208 insertions(+), 24 deletions(-) create mode 100644 core/src/main/java/com/google/adk/tools/SetModelResponseTool.java create mode 100644 core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index b0840ecb..07a2a39d 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -48,6 +48,7 @@ import com.google.adk.models.Model; import com.google.adk.tools.BaseTool; import com.google.adk.tools.BaseToolset; +import com.google.adk.tools.SetModelResponseTool; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CanIgnoreReturnValue; @@ -135,7 +136,17 @@ protected LlmAgent(Builder builder) { this.outputSchema = Optional.ofNullable(builder.outputSchema); this.executor = Optional.ofNullable(builder.executor); this.outputKey = Optional.ofNullable(builder.outputKey); - this.toolsUnion = builder.toolsUnion != null ? builder.toolsUnion : ImmutableList.of(); + ImmutableList userTools = + builder.toolsUnion != null ? builder.toolsUnion : ImmutableList.of(); + if (builder.outputSchema != null && !userTools.isEmpty()) { + this.toolsUnion = + ImmutableList.builder() + .addAll(userTools) + .add(new SetModelResponseTool(builder.outputSchema)) + .build(); + } else { + this.toolsUnion = userTools; + } this.toolsets = extractToolsets(this.toolsUnion); this.codeExecutor = Optional.ofNullable(builder.codeExecutor); @@ -541,12 +552,6 @@ protected void validate() { + ": if outputSchema is set, subAgents must be empty to disable agent" + " transfer."); } - if (this.toolsUnion != null && !this.toolsUnion.isEmpty()) { - throw new IllegalArgumentException( - "Invalid config for agent " - + this.name - + ": if outputSchema is set, tools must be empty."); - } } } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Basic.java b/core/src/main/java/com/google/adk/flows/llmflows/Basic.java index 68cd4e1a..bb11359a 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Basic.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Basic.java @@ -57,7 +57,9 @@ public Single processRequest( .config(agent.generateContentConfig().orElse(GenerateContentConfig.builder().build())) .liveConnectConfig(liveConnectConfigBuilder.build()); - agent.outputSchema().ifPresent(builder::outputSchema); + if (agent.outputSchema().isPresent() && agent.toolsUnion().isEmpty()) { + builder.outputSchema(agent.outputSchema().get()); + } return Single.just( RequestProcessor.RequestProcessingResult.create(builder.build(), ImmutableList.of())); } diff --git a/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java b/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java new file mode 100644 index 00000000..26a6a842 --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/SetModelResponseTool.java @@ -0,0 +1,58 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools; + +import com.google.adk.SchemaUtils; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.Schema; +import io.reactivex.rxjava3.core.Single; +import java.util.Map; +import java.util.Optional; + +/** Tool for setting model response when using output_schema with other tools. */ +public class SetModelResponseTool extends BaseTool { + + private final Schema outputSchema; + + public SetModelResponseTool(Schema outputSchema) { + super( + "set_model_response", + "Set your final response using the required output schema. Use this tool to provide your" + + " final structured answer instead of outputting text directly."); + this.outputSchema = outputSchema; + } + + @Override + public Optional declaration() { + return Optional.of( + FunctionDeclaration.builder() + .name(name()) + .description(description()) + .parameters(outputSchema) + .build()); + } + + @Override + public Single> runAsync(Map args, ToolContext toolContext) { + try { + SchemaUtils.validateMapOnSchema(args, outputSchema, /* isInput= */ false); + } catch (IllegalArgumentException e) { + return Single.error(e); + } + return Single.just(args); + } +} diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index 1efc6abd..359ec9bf 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -140,7 +140,7 @@ public void run_withToolsAndMaxSteps_stopsAfterMaxSteps() { } @Test - public void build_withOutputSchemaAndTools_throwsIllegalArgumentException() { + public void build_withOutputSchemaAndTools_addsSetModelResponseTool() { BaseTool tool = new BaseTool("test_tool", "test_description") { @Override @@ -156,22 +156,18 @@ public Optional declaration() { .required(ImmutableList.of("status")) .build(); - // Expecting an IllegalArgumentException when building the agent - IllegalArgumentException exception = - assertThrows( - IllegalArgumentException.class, - () -> - LlmAgent.builder() // Use the agent builder directly - .name("agent with invalid tool config") - .outputSchema(outputSchema) // Set the output schema - .tools(ImmutableList.of(tool)) // Set tools (this should cause the error) - .build()); // Attempt to build the agent + LlmAgent agent = + LlmAgent.builder() + .name("agent with valid tool config") + .outputSchema(outputSchema) + .tools(ImmutableList.of(tool)) + .build(); - assertThat(exception) - .hasMessageThat() - .contains( - "Invalid config for agent agent with invalid tool config: if outputSchema is set, tools" - + " must be empty"); + List tools = agent.tools(); + assertThat(tools).hasSize(2); + assertThat(tools.stream().anyMatch(t -> t instanceof com.google.adk.tools.SetModelResponseTool)) + .isTrue(); + assertThat(tools.stream().anyMatch(t -> t.name().equals("test_tool"))).isTrue(); } @Test diff --git a/core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java b/core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java new file mode 100644 index 00000000..9a4e6f41 --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/SetModelResponseToolTest.java @@ -0,0 +1,123 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.Schema; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class SetModelResponseToolTest { + + @Test + public void declaration_returnsCorrectFunctionDeclaration() { + Schema outputSchema = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("field1", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("field1")) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(outputSchema); + FunctionDeclaration declaration = tool.declaration().get(); + + assertThat(declaration.name()).hasValue("set_model_response"); + assertThat(declaration.description()).isPresent(); + assertThat(declaration.description().get()).contains("Set your final response"); + assertThat(declaration.parameters()).hasValue(outputSchema); + } + + @Test + public void runAsync_returnsArgs() { + Schema outputSchema = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("field1", Schema.builder().type("STRING").build())) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(outputSchema); + Map args = ImmutableMap.of("field1", "value1"); + + Map result = tool.runAsync(args, null).blockingGet(); + + assertThat(result).isEqualTo(args); + } + + @Test + public void runAsync_validatesArgs() { + Schema outputSchema = + Schema.builder() + .type("OBJECT") + .properties(ImmutableMap.of("field1", Schema.builder().type("STRING").build())) + .required(ImmutableList.of("field1")) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(outputSchema); + Map invalidArgs = ImmutableMap.of("field2", "value2"); + + // Should throw validation error + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, () -> tool.runAsync(invalidArgs, null).blockingGet()); + + assertThat(exception).hasMessageThat().contains("does not match agent output schema"); + } + + @Test + public void runAsync_validatesComplexArgs() { + Schema complexSchema = + Schema.builder() + .type("OBJECT") + .properties( + ImmutableMap.of( + "id", Schema.builder().type("INTEGER").build(), + "tags", + Schema.builder() + .type("ARRAY") + .items(Schema.builder().type("STRING").build()) + .build(), + "metadata", + Schema.builder() + .type("OBJECT") + .properties( + ImmutableMap.of("key", Schema.builder().type("STRING").build())) + .build())) + .required(ImmutableList.of("id", "tags", "metadata")) + .build(); + + SetModelResponseTool tool = new SetModelResponseTool(complexSchema); + Map complexArgs = + ImmutableMap.of( + "id", 123, + "tags", ImmutableList.of("tag1", "tag2"), + "metadata", ImmutableMap.of("key", "value")); + + Map result = tool.runAsync(complexArgs, null).blockingGet(); + + assertThat(result).containsEntry("id", 123); + assertThat(result).containsEntry("tags", ImmutableList.of("tag1", "tag2")); + assertThat(result).containsEntry("metadata", ImmutableMap.of("key", "value")); + } +}