|
37 | 37 | except (ImportError, AttributeError): |
38 | 38 | Event = Any |
39 | 39 |
|
| 40 | + try: |
| 41 | + from google.adk.apps import App |
| 42 | + |
| 43 | + App = App |
| 44 | + except (ImportError, AttributeError): |
| 45 | + App = Any |
| 46 | + |
40 | 47 | try: |
41 | 48 | from google.adk.agents import BaseAgent |
42 | 49 |
|
@@ -449,7 +456,8 @@ class AdkApp: |
449 | 456 | def __init__( |
450 | 457 | self, |
451 | 458 | *, |
452 | | - agent: "BaseAgent", |
| 459 | + app: "App" = None, |
| 460 | + agent: "BaseAgent" = None, |
453 | 461 | app_name: Optional[str] = None, |
454 | 462 | plugins: Optional[List["BasePlugin"]] = None, |
455 | 463 | enable_tracing: Optional[bool] = None, |
@@ -505,10 +513,26 @@ def __init__( |
505 | 513 | ) |
506 | 514 | raise ValueError(msg) |
507 | 515 |
|
| 516 | + if not agent and not app: |
| 517 | + raise ValueError("One of `agent` or `app` must be provided.") |
| 518 | + if app: |
| 519 | + if app_name: |
| 520 | + raise ValueError( |
| 521 | + "When app is provided, app_name should not be provided." |
| 522 | + ) |
| 523 | + if agent: |
| 524 | + raise ValueError("When app is provided, agent should not be provided.") |
| 525 | + if plugins: |
| 526 | + raise ValueError( |
| 527 | + "When app is provided, plugins should not be provided and" |
| 528 | + " should be provided in the app instead." |
| 529 | + ) |
| 530 | + |
508 | 531 | self._tmpl_attrs: Dict[str, Any] = { |
509 | 532 | "project": initializer.global_config.project, |
510 | 533 | "location": initializer.global_config.location, |
511 | 534 | "agent": agent, |
| 535 | + "app": app, |
512 | 536 | "app_name": app_name, |
513 | 537 | "plugins": plugins, |
514 | 538 | "enable_tracing": enable_tracing, |
@@ -624,10 +648,23 @@ def clone(self): |
624 | 648 | import copy |
625 | 649 |
|
626 | 650 | return self.__class__( |
627 | | - agent=copy.deepcopy(self._tmpl_attrs.get("agent")), |
| 651 | + app=copy.deepcopy(self._tmpl_attrs.get("app")), |
628 | 652 | enable_tracing=self._tmpl_attrs.get("enable_tracing"), |
629 | | - app_name=self._tmpl_attrs.get("app_name"), |
630 | | - plugins=self._tmpl_attrs.get("plugins"), |
| 653 | + agent=( |
| 654 | + None |
| 655 | + if self._tmpl_attrs.get("app") |
| 656 | + else copy.deepcopy(self._tmpl_attrs.get("agent")) |
| 657 | + ), |
| 658 | + app_name=( |
| 659 | + None |
| 660 | + if self._tmpl_attrs.get("app") |
| 661 | + else self._tmpl_attrs.get("app_name") |
| 662 | + ), |
| 663 | + plugins=( |
| 664 | + None |
| 665 | + if self._tmpl_attrs.get("app") |
| 666 | + else copy.deepcopy(self._tmpl_attrs.get("plugins")) |
| 667 | + ), |
631 | 668 | session_service_builder=self._tmpl_attrs.get("session_service_builder"), |
632 | 669 | artifact_service_builder=self._tmpl_attrs.get("artifact_service_builder"), |
633 | 670 | memory_service_builder=self._tmpl_attrs.get("memory_service_builder"), |
@@ -774,20 +811,38 @@ def tracing_enabled() -> bool: |
774 | 811 | self._tmpl_attrs["memory_service"] = InMemoryMemoryService() |
775 | 812 |
|
776 | 813 | self._tmpl_attrs["runner"] = Runner( |
777 | | - agent=self._tmpl_attrs.get("agent"), |
778 | | - plugins=self._tmpl_attrs.get("plugins"), |
| 814 | + app=self._tmpl_attrs.get("app"), |
| 815 | + agent=( |
| 816 | + None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("agent") |
| 817 | + ), |
| 818 | + app_name=( |
| 819 | + None |
| 820 | + if self._tmpl_attrs.get("app") |
| 821 | + else self._tmpl_attrs.get("app_name") |
| 822 | + ), |
| 823 | + plugins=( |
| 824 | + None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("plugins") |
| 825 | + ), |
779 | 826 | session_service=self._tmpl_attrs.get("session_service"), |
780 | 827 | artifact_service=self._tmpl_attrs.get("artifact_service"), |
781 | 828 | memory_service=self._tmpl_attrs.get("memory_service"), |
782 | | - app_name=self._tmpl_attrs.get("app_name"), |
783 | 829 | ) |
784 | 830 | self._tmpl_attrs["in_memory_session_service"] = InMemorySessionService() |
785 | 831 | self._tmpl_attrs["in_memory_artifact_service"] = InMemoryArtifactService() |
786 | 832 | self._tmpl_attrs["in_memory_memory_service"] = InMemoryMemoryService() |
787 | 833 | self._tmpl_attrs["in_memory_runner"] = Runner( |
788 | | - app_name=self._tmpl_attrs.get("app_name"), |
789 | | - agent=self._tmpl_attrs.get("agent"), |
790 | | - plugins=self._tmpl_attrs.get("plugins"), |
| 834 | + app=self._tmpl_attrs.get("app"), |
| 835 | + app_name=( |
| 836 | + None |
| 837 | + if self._tmpl_attrs.get("app") |
| 838 | + else self._tmpl_attrs.get("app_name") |
| 839 | + ), |
| 840 | + agent=( |
| 841 | + None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("agent") |
| 842 | + ), |
| 843 | + plugins=( |
| 844 | + None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("plugins") |
| 845 | + ), |
791 | 846 | session_service=self._tmpl_attrs.get("in_memory_session_service"), |
792 | 847 | artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"), |
793 | 848 | memory_service=self._tmpl_attrs.get("in_memory_memory_service"), |
@@ -968,12 +1023,13 @@ async def streaming_agent_run_with_events(self, request_json: str): |
968 | 1023 | self.set_up() |
969 | 1024 | session_service = self._tmpl_attrs.get("in_memory_session_service") |
970 | 1025 | artifact_service = self._tmpl_attrs.get("in_memory_artifact_service") |
| 1026 | + app = self._tmpl_attrs.get("app") |
971 | 1027 | # Try to get the session, if it doesn't exist, create a new one. |
972 | 1028 | session = None |
973 | 1029 | if request.session_id: |
974 | 1030 | try: |
975 | 1031 | session = await session_service.get_session( |
976 | | - app_name=self._tmpl_attrs.get("app_name"), |
| 1032 | + app_name=app.name if app else self._tmpl_attrs.get("app_name"), |
977 | 1033 | user_id=request.user_id, |
978 | 1034 | session_id=request.session_id, |
979 | 1035 | ) |
@@ -1006,8 +1062,9 @@ async def streaming_agent_run_with_events(self, request_json: str): |
1006 | 1062 | yield converted_event |
1007 | 1063 | finally: |
1008 | 1064 | if session and not request.session_id: |
| 1065 | + app = self._tmpl_attrs.get("app") |
1009 | 1066 | await session_service.delete_session( |
1010 | | - app_name=self._tmpl_attrs.get("app_name"), |
| 1067 | + app_name=app.name if app else self._tmpl_attrs.get("app_name"), |
1011 | 1068 | user_id=request.user_id, |
1012 | 1069 | session_id=session.id, |
1013 | 1070 | ) |
@@ -1039,8 +1096,9 @@ async def async_get_session( |
1039 | 1096 | """ |
1040 | 1097 | if not self._tmpl_attrs.get("session_service"): |
1041 | 1098 | self.set_up() |
| 1099 | + app = self._tmpl_attrs.get("app") |
1042 | 1100 | session = await self._tmpl_attrs.get("session_service").get_session( |
1043 | | - app_name=self._tmpl_attrs.get("app_name"), |
| 1101 | + app_name=app.name if app else self._tmpl_attrs.get("app_name"), |
1044 | 1102 | user_id=user_id, |
1045 | 1103 | session_id=session_id, |
1046 | 1104 | **kwargs, |
@@ -1116,8 +1174,9 @@ async def async_list_sessions(self, *, user_id: str, **kwargs): |
1116 | 1174 | """ |
1117 | 1175 | if not self._tmpl_attrs.get("session_service"): |
1118 | 1176 | self.set_up() |
| 1177 | + app = self._tmpl_attrs.get("app") |
1119 | 1178 | return await self._tmpl_attrs.get("session_service").list_sessions( |
1120 | | - app_name=self._tmpl_attrs.get("app_name"), |
| 1179 | + app_name=app.name if app else self._tmpl_attrs.get("app_name"), |
1121 | 1180 | user_id=user_id, |
1122 | 1181 | **kwargs, |
1123 | 1182 | ) |
@@ -1188,8 +1247,9 @@ async def async_create_session( |
1188 | 1247 | """ |
1189 | 1248 | if not self._tmpl_attrs.get("session_service"): |
1190 | 1249 | self.set_up() |
| 1250 | + app = self._tmpl_attrs.get("app") |
1191 | 1251 | session = await self._tmpl_attrs.get("session_service").create_session( |
1192 | | - app_name=self._tmpl_attrs.get("app_name"), |
| 1252 | + app_name=app.name if app else self._tmpl_attrs.get("app_name"), |
1193 | 1253 | user_id=user_id, |
1194 | 1254 | session_id=session_id, |
1195 | 1255 | state=state, |
@@ -1269,8 +1329,9 @@ async def async_delete_session( |
1269 | 1329 | """ |
1270 | 1330 | if not self._tmpl_attrs.get("session_service"): |
1271 | 1331 | self.set_up() |
| 1332 | + app = self._tmpl_attrs.get("app") |
1272 | 1333 | await self._tmpl_attrs.get("session_service").delete_session( |
1273 | | - app_name=self._tmpl_attrs.get("app_name"), |
| 1334 | + app_name=app.name if app else self._tmpl_attrs.get("app_name"), |
1274 | 1335 | user_id=user_id, |
1275 | 1336 | session_id=session_id, |
1276 | 1337 | **kwargs, |
@@ -1359,8 +1420,9 @@ async def async_search_memory(self, *, user_id: str, query: str): |
1359 | 1420 | """ |
1360 | 1421 | if not self._tmpl_attrs.get("memory_service"): |
1361 | 1422 | self.set_up() |
| 1423 | + app = self._tmpl_attrs.get("app") |
1362 | 1424 | return await self._tmpl_attrs.get("memory_service").search_memory( |
1363 | | - app_name=self._tmpl_attrs.get("app_name"), |
| 1425 | + app_name=app.name if app else self._tmpl_attrs.get("app_name"), |
1364 | 1426 | user_id=user_id, |
1365 | 1427 | query=query, |
1366 | 1428 | ) |
|
0 commit comments