Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/google/adk/cli/adk_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from fastapi import Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from fastapi.responses import Response
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.websockets import WebSocket
Expand Down Expand Up @@ -1033,6 +1034,24 @@ async def get_trace_dict(event_id: str) -> Any:

if web_assets_dir:

@app.get("/dev/build_graph_image/{app_name}")
async def get_app_graph_image(
app_name: str, dark_mode: bool = False
) -> Response:
agent_or_app = self.agent_loader.load_agent(app_name)
root_agent = self._get_root_agent(agent_or_app)

graph_image = await agent_graph.get_agent_graph(
root_agent, [], image=True, dark_mode=dark_mode
)

if isinstance(graph_image, bytes):
return Response(content=graph_image, media_type="image/png")

raise HTTPException(
status_code=500, detail="Failed to render app graph image"
)

@app.get("/dev/build_graph/{app_name}")
async def get_app_info(app_name: str) -> Any:
runner = await self.get_runner_async(app_name)
Expand Down
25 changes: 22 additions & 3 deletions src/google/adk/cli/trigger_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,22 @@ class TriggerResponse(BaseModel):
)


def _make_trigger_user_id(
raw_value: Optional[str],
*,
default: str,
) -> str:
"""Normalize trigger metadata into a session-safe user_id."""
if not raw_value:
return default

normalized = raw_value.strip().strip("/")
if not normalized:
return default

return normalized.replace("/", "--")


# ---------------------------------------------------------------------------
# Trigger Router
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -411,7 +427,9 @@ def register(self, app: FastAPI) -> None:
async def trigger_pubsub(
app_name: str, req: PubSubTriggerRequest, request: Request
) -> TriggerResponse:
user_id = req.subscription or "pubsub-caller"
user_id = _make_trigger_user_id(
req.subscription, default="pubsub-caller"
)

decoded_data = None
data_payload = None
Expand Down Expand Up @@ -477,8 +495,9 @@ async def trigger_eventarc(
app_name: str, req: EventarcTriggerRequest, request: Request
) -> TriggerResponse:

user_id = (
req.source or request.headers.get("ce-source") or "eventarc-caller"
user_id = _make_trigger_user_id(
req.source or request.headers.get("ce-source"),
default="eventarc-caller",
)

logger.info(
Expand Down
47 changes: 47 additions & 0 deletions tests/unittests/cli/test_fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1769,6 +1769,53 @@ def list_agents(self):
assert "dotSrc" in response.json()


def test_build_graph_image_returns_png_bytes():
"""Ensure legacy graph-image endpoint still returns a PNG for the dev UI."""
from google.adk.cli.adk_web_server import AdkWebServer

root_agent = DummyAgent(name="dummy_agent")
app_agent = App(name="test_app", root_agent=root_agent)

class Loader:

def load_agent(self, app_name):
return app_agent

def list_agents(self):
return [app_agent.name]

adk_web_server = AdkWebServer(
agent_loader=Loader(),
session_service=AsyncMock(),
memory_service=MagicMock(),
artifact_service=MagicMock(),
credential_service=MagicMock(),
eval_sets_manager=MagicMock(),
eval_set_results_manager=MagicMock(),
agents_dir=".",
)

fast_api_app = adk_web_server.get_fast_api_app(
setup_observer=lambda _observer, _server: None,
tear_down_observer=lambda _observer, _server: None,
)

client = TestClient(fast_api_app)

with patch(
"google.adk.cli.agent_graph.get_agent_graph",
new=AsyncMock(return_value=b"png-bytes"),
) as mock_get_agent_graph:
response = client.get("/dev/build_graph_image/test_app?dark_mode=true")

assert response.status_code == 200
assert response.content == b"png-bytes"
assert response.headers["content-type"] == "image/png"
mock_get_agent_graph.assert_awaited_once_with(
root_agent, [], image=True, dark_mode=True
)


def test_a2a_agent_discovery(test_app_with_a2a):
"""Test that A2A agents are properly discovered and configured."""
# This test mainly verifies that the A2A setup doesn't break the app
Expand Down
37 changes: 37 additions & 0 deletions tests/unittests/cli/test_trigger_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,24 @@ def test_with_subscription_metadata(self, client):

assert resp.status_code == 200

def test_subscription_user_id_is_path_safe(
self, client, mock_session_service
):
"""Pub/Sub subscription-derived user_id is stored without slashes."""
message_data = base64.b64encode(b"test").decode("utf-8")
payload = {
"message": {"data": message_data},
"subscription": "projects/p/subscriptions/orders-sub",
}

resp = client.post("/apps/test_app/trigger/pubsub", json=payload)

assert resp.status_code == 200
assert (
"projects--p--subscriptions--orders-sub"
in mock_session_service.sessions["test_app"]
)

def test_unknown_app_fails_early(
self, client, mock_agent_loader, mock_session_service
):
Expand Down Expand Up @@ -513,6 +531,25 @@ def test_source_from_ce_header(self, client):
)
assert resp.status_code == 200

def test_eventarc_source_user_id_is_path_safe(
self, client, mock_session_service
):
"""Eventarc ce-source-derived user_id is stored without slashes."""
payload = {
"data": {"key": "value"},
}
resp = client.post(
"/apps/test_app/trigger/eventarc",
json=payload,
headers={"ce-source": "//pubsub.googleapis.com/projects/p/topics/t"},
)

assert resp.status_code == 200
assert (
"pubsub.googleapis.com--projects--p--topics--t"
in mock_session_service.sessions["test_app"]
)

def test_complex_event_data(self, client, monkeypatch):
"""Complex nested event data is serialized as JSON for the agent."""
captured_messages = []
Expand Down