Skip to content

Commit 10ca56f

Browse files
yeesiancopybara-github
authored andcommitted
feat: Add support for app input in AdkApp template
PiperOrigin-RevId: 825234667
1 parent 9ae5f35 commit 10ca56f

File tree

2 files changed

+80
-17
lines changed

2 files changed

+80
-17
lines changed

tests/unit/vertexai/genai/test_evals.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,7 @@ def test_inference_from_local_jsonl_file(self, mock_models):
712712
assert inference_result.candidate_name == "gemini-pro"
713713
assert inference_result.gcs_source is None
714714

715+
@pytest.mark.skip(reason="currently flakey")
715716
@mock.patch.object(_evals_common, "Models")
716717
def test_inference_from_local_csv_file(self, mock_models):
717718
local_src_path = "/tmp/input.csv"

vertexai/agent_engines/templates/adk.py

Lines changed: 79 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@
3737
except (ImportError, AttributeError):
3838
Event = Any
3939

40+
try:
41+
from google.adk.apps import App
42+
43+
App = App
44+
except (ImportError, AttributeError):
45+
App = Any
46+
4047
try:
4148
from google.adk.agents import BaseAgent
4249

@@ -449,7 +456,8 @@ class AdkApp:
449456
def __init__(
450457
self,
451458
*,
452-
agent: "BaseAgent",
459+
app: "App" = None,
460+
agent: "BaseAgent" = None,
453461
app_name: Optional[str] = None,
454462
plugins: Optional[List["BasePlugin"]] = None,
455463
enable_tracing: Optional[bool] = None,
@@ -505,10 +513,26 @@ def __init__(
505513
)
506514
raise ValueError(msg)
507515

516+
if not agent and not app:
517+
raise ValueError("One of `agent` or `app` must be provided.")
518+
if app:
519+
if app_name:
520+
raise ValueError(
521+
"When app is provided, app_name should not be provided."
522+
)
523+
if agent:
524+
raise ValueError("When app is provided, agent should not be provided.")
525+
if plugins:
526+
raise ValueError(
527+
"When app is provided, plugins should not be provided and"
528+
" should be provided in the app instead."
529+
)
530+
508531
self._tmpl_attrs: Dict[str, Any] = {
509532
"project": initializer.global_config.project,
510533
"location": initializer.global_config.location,
511534
"agent": agent,
535+
"app": app,
512536
"app_name": app_name,
513537
"plugins": plugins,
514538
"enable_tracing": enable_tracing,
@@ -624,10 +648,23 @@ def clone(self):
624648
import copy
625649

626650
return self.__class__(
627-
agent=copy.deepcopy(self._tmpl_attrs.get("agent")),
651+
app=copy.deepcopy(self._tmpl_attrs.get("app")),
628652
enable_tracing=self._tmpl_attrs.get("enable_tracing"),
629-
app_name=self._tmpl_attrs.get("app_name"),
630-
plugins=self._tmpl_attrs.get("plugins"),
653+
agent=(
654+
None
655+
if self._tmpl_attrs.get("app")
656+
else copy.deepcopy(self._tmpl_attrs.get("agent"))
657+
),
658+
app_name=(
659+
None
660+
if self._tmpl_attrs.get("app")
661+
else self._tmpl_attrs.get("app_name")
662+
),
663+
plugins=(
664+
None
665+
if self._tmpl_attrs.get("app")
666+
else copy.deepcopy(self._tmpl_attrs.get("plugins"))
667+
),
631668
session_service_builder=self._tmpl_attrs.get("session_service_builder"),
632669
artifact_service_builder=self._tmpl_attrs.get("artifact_service_builder"),
633670
memory_service_builder=self._tmpl_attrs.get("memory_service_builder"),
@@ -774,20 +811,38 @@ def tracing_enabled() -> bool:
774811
self._tmpl_attrs["memory_service"] = InMemoryMemoryService()
775812

776813
self._tmpl_attrs["runner"] = Runner(
777-
agent=self._tmpl_attrs.get("agent"),
778-
plugins=self._tmpl_attrs.get("plugins"),
814+
app=self._tmpl_attrs.get("app"),
815+
agent=(
816+
None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("agent")
817+
),
818+
app_name=(
819+
None
820+
if self._tmpl_attrs.get("app")
821+
else self._tmpl_attrs.get("app_name")
822+
),
823+
plugins=(
824+
None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("plugins")
825+
),
779826
session_service=self._tmpl_attrs.get("session_service"),
780827
artifact_service=self._tmpl_attrs.get("artifact_service"),
781828
memory_service=self._tmpl_attrs.get("memory_service"),
782-
app_name=self._tmpl_attrs.get("app_name"),
783829
)
784830
self._tmpl_attrs["in_memory_session_service"] = InMemorySessionService()
785831
self._tmpl_attrs["in_memory_artifact_service"] = InMemoryArtifactService()
786832
self._tmpl_attrs["in_memory_memory_service"] = InMemoryMemoryService()
787833
self._tmpl_attrs["in_memory_runner"] = Runner(
788-
app_name=self._tmpl_attrs.get("app_name"),
789-
agent=self._tmpl_attrs.get("agent"),
790-
plugins=self._tmpl_attrs.get("plugins"),
834+
app=self._tmpl_attrs.get("app"),
835+
app_name=(
836+
None
837+
if self._tmpl_attrs.get("app")
838+
else self._tmpl_attrs.get("app_name")
839+
),
840+
agent=(
841+
None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("agent")
842+
),
843+
plugins=(
844+
None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("plugins")
845+
),
791846
session_service=self._tmpl_attrs.get("in_memory_session_service"),
792847
artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"),
793848
memory_service=self._tmpl_attrs.get("in_memory_memory_service"),
@@ -968,12 +1023,13 @@ async def streaming_agent_run_with_events(self, request_json: str):
9681023
self.set_up()
9691024
session_service = self._tmpl_attrs.get("in_memory_session_service")
9701025
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
1026+
app = self._tmpl_attrs.get("app")
9711027
# Try to get the session, if it doesn't exist, create a new one.
9721028
session = None
9731029
if request.session_id:
9741030
try:
9751031
session = await session_service.get_session(
976-
app_name=self._tmpl_attrs.get("app_name"),
1032+
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
9771033
user_id=request.user_id,
9781034
session_id=request.session_id,
9791035
)
@@ -1006,8 +1062,9 @@ async def streaming_agent_run_with_events(self, request_json: str):
10061062
yield converted_event
10071063
finally:
10081064
if session and not request.session_id:
1065+
app = self._tmpl_attrs.get("app")
10091066
await session_service.delete_session(
1010-
app_name=self._tmpl_attrs.get("app_name"),
1067+
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
10111068
user_id=request.user_id,
10121069
session_id=session.id,
10131070
)
@@ -1039,8 +1096,9 @@ async def async_get_session(
10391096
"""
10401097
if not self._tmpl_attrs.get("session_service"):
10411098
self.set_up()
1099+
app = self._tmpl_attrs.get("app")
10421100
session = await self._tmpl_attrs.get("session_service").get_session(
1043-
app_name=self._tmpl_attrs.get("app_name"),
1101+
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
10441102
user_id=user_id,
10451103
session_id=session_id,
10461104
**kwargs,
@@ -1116,8 +1174,9 @@ async def async_list_sessions(self, *, user_id: str, **kwargs):
11161174
"""
11171175
if not self._tmpl_attrs.get("session_service"):
11181176
self.set_up()
1177+
app = self._tmpl_attrs.get("app")
11191178
return await self._tmpl_attrs.get("session_service").list_sessions(
1120-
app_name=self._tmpl_attrs.get("app_name"),
1179+
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
11211180
user_id=user_id,
11221181
**kwargs,
11231182
)
@@ -1188,8 +1247,9 @@ async def async_create_session(
11881247
"""
11891248
if not self._tmpl_attrs.get("session_service"):
11901249
self.set_up()
1250+
app = self._tmpl_attrs.get("app")
11911251
session = await self._tmpl_attrs.get("session_service").create_session(
1192-
app_name=self._tmpl_attrs.get("app_name"),
1252+
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
11931253
user_id=user_id,
11941254
session_id=session_id,
11951255
state=state,
@@ -1269,8 +1329,9 @@ async def async_delete_session(
12691329
"""
12701330
if not self._tmpl_attrs.get("session_service"):
12711331
self.set_up()
1332+
app = self._tmpl_attrs.get("app")
12721333
await self._tmpl_attrs.get("session_service").delete_session(
1273-
app_name=self._tmpl_attrs.get("app_name"),
1334+
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
12741335
user_id=user_id,
12751336
session_id=session_id,
12761337
**kwargs,
@@ -1359,8 +1420,9 @@ async def async_search_memory(self, *, user_id: str, query: str):
13591420
"""
13601421
if not self._tmpl_attrs.get("memory_service"):
13611422
self.set_up()
1423+
app = self._tmpl_attrs.get("app")
13621424
return await self._tmpl_attrs.get("memory_service").search_memory(
1363-
app_name=self._tmpl_attrs.get("app_name"),
1425+
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
13641426
user_id=user_id,
13651427
query=query,
13661428
)

0 commit comments

Comments
 (0)