From 1a593a996607904eed24b64bc63eecd7708710af Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Wed, 25 Feb 2026 10:57:18 -0800 Subject: [PATCH] feat: remove model restrictions in BuiltInCodeExecutionTool PiperOrigin-RevId: 875242195 --- .../adk/tools/BuiltInCodeExecutionTool.java | 33 +++++- .../com/google/adk/utils/ModelNameUtils.java | 31 ++++++ .../com/google/adk/tools/BaseToolTest.java | 45 +++++++- .../google/adk/utils/ModelNameUtilsTest.java | 100 ++++++++++++++++++ 4 files changed, 203 insertions(+), 6 deletions(-) diff --git a/core/src/main/java/com/google/adk/tools/BuiltInCodeExecutionTool.java b/core/src/main/java/com/google/adk/tools/BuiltInCodeExecutionTool.java index 060b3ffb8..ad97b96a6 100644 --- a/core/src/main/java/com/google/adk/tools/BuiltInCodeExecutionTool.java +++ b/core/src/main/java/com/google/adk/tools/BuiltInCodeExecutionTool.java @@ -16,13 +16,19 @@ package com.google.adk.tools; +import com.google.adk.agents.LlmAgent; +import com.google.adk.models.BaseLlm; import com.google.adk.models.LlmRequest; +import com.google.adk.utils.ModelNameUtils; import com.google.common.collect.ImmutableList; import com.google.genai.types.GenerateContentConfig; import com.google.genai.types.Tool; import com.google.genai.types.ToolCodeExecution; import io.reactivex.rxjava3.core.Completable; import java.util.List; +import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * A built-in code execution tool that is automatically invoked by Gemini 2 models. @@ -32,6 +38,7 @@ */ public final class BuiltInCodeExecutionTool extends BaseTool { public static final BuiltInCodeExecutionTool INSTANCE = new BuiltInCodeExecutionTool(); + private static final Logger LOG = LoggerFactory.getLogger(BuiltInCodeExecutionTool.class); public BuiltInCodeExecutionTool() { super("code_execution", "code_execution"); @@ -41,10 +48,28 @@ public BuiltInCodeExecutionTool() { public Completable processLlmRequest( LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { - String model = llmRequestBuilder.build().model().get(); - if (model.isEmpty() || !model.startsWith("gemini-2")) { - return Completable.error( - new IllegalArgumentException("Code execution tool is not supported for model " + model)); + Optional model = + Optional.ofNullable(toolContext) + .flatMap(tCtx -> Optional.ofNullable(tCtx.invocationContext())) + .flatMap( + iCtx -> { + if (iCtx.agent() instanceof LlmAgent llmAgent) { + return Optional.of(llmAgent); + } else { + return Optional.empty(); + } + }) + .flatMap(llmAgent -> llmAgent.resolvedModel().model()); + + String modelName = llmRequestBuilder.build().model().get(); + if (!ModelNameUtils.isGeminiModel(modelName) + || model.filter(ModelNameUtils::isInstanceOfGemini).isEmpty()) { + // model name is not a gemini model, or the model isn't an instance of Gemini class (eg. + // LangChain case). + LOG.warn( + "Code execution tool is not supported for model: {} ({}).", + modelName, + model.map(Object::getClass).map(Class::toString).orElse("")); } GenerateContentConfig.Builder configBuilder = llmRequestBuilder diff --git a/core/src/main/java/com/google/adk/utils/ModelNameUtils.java b/core/src/main/java/com/google/adk/utils/ModelNameUtils.java index 9995f18b2..c46f6e3a8 100644 --- a/core/src/main/java/com/google/adk/utils/ModelNameUtils.java +++ b/core/src/main/java/com/google/adk/utils/ModelNameUtils.java @@ -16,16 +16,24 @@ package com.google.adk.utils; +import com.google.common.base.Strings; +import java.util.Objects; import java.util.regex.Matcher; import java.util.regex.Pattern; public final class ModelNameUtils { + private static final String GEMINI_PREFIX = "gemini-"; private static final Pattern GEMINI_2_PATTERN = Pattern.compile("^gemini-2\\..*"); + private static final String GEMINI_CLASS = "com.google.adk.models.Gemini"; private static final Pattern PATH_PATTERN = Pattern.compile("^projects/[^/]+/locations/[^/]+/publishers/[^/]+/models/(.+)$"); private static final Pattern APIGEE_PATTERN = Pattern.compile("^apigee/(?:[^/]+/)?(?:[^/]+/)?(.+)$"); + public static boolean isGeminiModel(String modelString) { + return extractModelName(Strings.nullToEmpty(modelString)).startsWith(GEMINI_PREFIX); + } + public static boolean isGemini2Model(String modelString) { if (modelString == null) { return false; @@ -34,6 +42,29 @@ public static boolean isGemini2Model(String modelString) { return GEMINI_2_PATTERN.matcher(modelName).matches(); } + /** + * Checks whether an object is an instance of {@link com.google.adk.models.Gemini}, by searching + * through its class hierarchy for a class whose name equals the hardcoded String name of Gemini + * class. + * + *

This method can be used where the "real" instanceof check is not possible because the Gemini + * type is not known at compile time. + * + * @param o The object to check. + * @return true if object's class is {@link com.google.adk.models.Gemini}, false otherwise. + */ + public static boolean isInstanceOfGemini(Object o) { + if (o == null) { + return false; + } + for (Class clazz = o.getClass(); clazz != null; clazz = clazz.getSuperclass()) { + if (Objects.equals(clazz.getName(), GEMINI_CLASS)) { + return true; + } + } + return false; + } + /** * Extract the actual model name from either simple or path-based format. * diff --git a/core/src/test/java/com/google/adk/tools/BaseToolTest.java b/core/src/test/java/com/google/adk/tools/BaseToolTest.java index e8d3887a4..2a07e7a44 100644 --- a/core/src/test/java/com/google/adk/tools/BaseToolTest.java +++ b/core/src/test/java/com/google/adk/tools/BaseToolTest.java @@ -2,7 +2,11 @@ import static com.google.common.truth.Truth.assertThat; +import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; +import com.google.adk.models.Gemini; import com.google.adk.models.LlmRequest; +import com.google.adk.sessions.InMemorySessionService; import com.google.common.collect.ImmutableList; import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.GenerateContentConfig; @@ -171,13 +175,27 @@ public void processLlmRequestWithUrlContextToolAddsToolToConfig() { Tool.builder().urlContext(UrlContext.builder().build()).build()); } + private static InvocationContext.Builder testInvocationContext() { + InvocationContext.Builder builder = InvocationContext.builder(); + builder.agent(testAgent().build()); + InMemorySessionService inMemorySessionService = new InMemorySessionService(); + builder.sessionService(inMemorySessionService); + builder.session(inMemorySessionService.createSession("test-app", "test-user-id").blockingGet()); + return builder; + } + + private static LlmAgent.Builder testAgent() { + return LlmAgent.builder().name("test-agent"); + } + @Test - public void processLlmRequestWithBuiltInCodeExecutionToolAddsToolToConfig() { + public void + processLlmRequestWithBuiltInCodeExecutionToolAndNonGeminiModelAndNullContextAddsToolToConfig() { BuiltInCodeExecutionTool builtInCodeExecutionTool = new BuiltInCodeExecutionTool(); LlmRequest llmRequest = LlmRequest.builder() .config(GenerateContentConfig.builder().build()) - .model("gemini-2") + .model("text-bison") .build(); LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); Completable unused = @@ -189,6 +207,29 @@ public void processLlmRequestWithBuiltInCodeExecutionToolAddsToolToConfig() { .containsExactly(Tool.builder().codeExecution(ToolCodeExecution.builder().build()).build()); } + @Test + public void processLlmRequestWithBuiltInCodeExecutionToolAndGemini2ModelAddsToolToConfig() { + BuiltInCodeExecutionTool builtInCodeExecutionTool = new BuiltInCodeExecutionTool(); + LlmRequest llmRequest = + LlmRequest.builder() + .config(GenerateContentConfig.builder().build()) + .model("gemini-2") + .build(); + LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); + ToolContext toolContext = + ToolContext.builder( + testInvocationContext() + .agent(testAgent().model(new Gemini("gemini-2", "")).build()) + .build()) + .build(); + Completable unused = builtInCodeExecutionTool.processLlmRequest(llmRequestBuilder, toolContext); + LlmRequest updatedLlmRequest = llmRequestBuilder.build(); + assertThat(updatedLlmRequest.config()).isPresent(); + assertThat(updatedLlmRequest.config().get().tools()).isPresent(); + assertThat(updatedLlmRequest.config().get().tools().get()) + .containsExactly(Tool.builder().codeExecution(ToolCodeExecution.builder().build()).build()); + } + @Test public void processLlmRequestWithGoogleMapsToolAddsToolToConfig() { GoogleMapsTool googleMapsTool = new GoogleMapsTool(); diff --git a/core/src/test/java/com/google/adk/utils/ModelNameUtilsTest.java b/core/src/test/java/com/google/adk/utils/ModelNameUtilsTest.java index 37853c477..20dda7034 100644 --- a/core/src/test/java/com/google/adk/utils/ModelNameUtilsTest.java +++ b/core/src/test/java/com/google/adk/utils/ModelNameUtilsTest.java @@ -2,6 +2,7 @@ import static com.google.common.truth.Truth.assertThat; +import com.google.adk.models.Gemini; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -69,4 +70,103 @@ public void isGemini2Model_withApigeeProviderV1BetaGemini2Model_returnsTrue() { public void isGemini2Model_withNullModel_returnsFalse() { assertThat(ModelNameUtils.isGemini2Model(null)).isFalse(); } + + @Test + public void isGeminiModel_withGeminiModel_returnsTrue() { + assertThat(ModelNameUtils.isGeminiModel("gemini-1.5-flash")).isTrue(); + } + + @Test + public void isGeminiModel_withNonGeminiModel_returnsFalse() { + assertThat(ModelNameUtils.isGeminiModel("text-bison")).isFalse(); + } + + @Test + public void isGeminiModel_withPathBasedGeminiModel_returnsTrue() { + assertThat( + ModelNameUtils.isGeminiModel( + "projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-pro")) + .isTrue(); + } + + @Test + public void isGeminiModel_withPathBasedNonGeminiModel_returnsFalse() { + assertThat( + ModelNameUtils.isGeminiModel( + "projects/test-project/locations/us-central1/publishers/google/models/text-bison")) + .isFalse(); + } + + @Test + public void isGeminiModel_withApigeeGeminiModel_returnsTrue() { + assertThat(ModelNameUtils.isGeminiModel("apigee/gemini-1.5-flash")).isTrue(); + } + + @Test + public void isGeminiModel_withApigeeV1GeminiModel_returnsTrue() { + assertThat(ModelNameUtils.isGeminiModel("apigee/v1/gemini-1.5-flash")).isTrue(); + } + + @Test + public void isGeminiModel_withApigeeProviderGeminiModel_returnsTrue() { + assertThat(ModelNameUtils.isGeminiModel("apigee/gemini/gemini-1.5-flash")).isTrue(); + } + + @Test + public void isGeminiModel_withApigeeProviderVertexGeminiModel_returnsTrue() { + assertThat(ModelNameUtils.isGeminiModel("apigee/vertex_ai/gemini-1.5-flash")).isTrue(); + } + + @Test + public void isGeminiModel_withApigeeProviderV1GeminiModel_returnsTrue() { + assertThat(ModelNameUtils.isGeminiModel("apigee/gemini/v1/gemini-1.5-flash")).isTrue(); + } + + @Test + public void isGeminiModel_withApigeeProviderV1BetaGeminiModel_returnsTrue() { + assertThat(ModelNameUtils.isGeminiModel("apigee/vertex_ai/v1beta/gemini-1.5-flash")).isTrue(); + } + + @Test + public void isGeminiModel_withNullModel_returnsFalse() { + assertThat(ModelNameUtils.isGeminiModel(null)).isFalse(); + } + + @Test + public void isGeminiModel_withEmptyModel_returnsFalse() { + assertThat(ModelNameUtils.isGeminiModel("")).isFalse(); + } + + @Test + public void isInstanceOfGemini_withGeminiInstance_returnsTrue() { + assertThat(ModelNameUtils.isInstanceOfGemini(new Gemini("", ""))).isTrue(); + } + + @Test + public void isInstanceOfGemini_withNonGeminiInstance_returnsFalse() { + assertThat(ModelNameUtils.isInstanceOfGemini(new Object())).isFalse(); + } + + @Test + public void isInstanceOfGemini_withNullInstance_returnsFalse() { + assertThat(ModelNameUtils.isInstanceOfGemini(null)).isFalse(); + } + + private static class GeminiSubclass extends Gemini { + GeminiSubclass() { + super("test-model", "test-api-key"); + } + } + + private static class GeminiSubclassSubclass extends GeminiSubclass {} + + @Test + public void isInstanceOfGemini_withGeminiSubclassInstance_returnsTrue() { + assertThat(ModelNameUtils.isInstanceOfGemini(new GeminiSubclass())).isTrue(); + } + + @Test + public void isInstanceOfGemini_withSubclassOfGeminiSubclassInstance_returnsTrue() { + assertThat(ModelNameUtils.isInstanceOfGemini(new GeminiSubclassSubclass())).isTrue(); + } }