Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def init_pyaudio_playback():
try:
pya_interface_instance.terminate()
except:
# TODO: be more specific about exception type
pass
pya_interface_instance = None
pya_output_stream_instance = None
Expand Down
32 changes: 30 additions & 2 deletions src/google/adk/cli/adk_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import asyncio
import re
from contextlib import asynccontextmanager
import importlib
import json
Expand Down Expand Up @@ -757,16 +758,25 @@ async def internal_lifespan(app: FastAPI):
# Run the FastAPI server.
app = FastAPI(lifespan=internal_lifespan)

# Store parsed allow_origins for WebSocket Origin validation
# (CORS middleware doesn't apply to WebSocket upgrades)
_ws_allowed_origins: tuple[list[str], Optional[re.Pattern[str]], bool] = (
[],
None,
False,
)
if allow_origins:
literal_origins, combined_regex = _parse_cors_origins(allow_origins)
literal_origins, combined_regex_str = _parse_cors_origins(allow_origins)
compiled_regex = re.compile(combined_regex_str) if combined_regex_str else None
app.add_middleware(
CORSMiddleware,
allow_origins=literal_origins,
allow_origin_regex=combined_regex,
allow_origin_regex=combined_regex_str,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
_ws_allowed_origins = (literal_origins, compiled_regex, True)

@app.get("/health")
async def health() -> dict[str, str]:
Expand Down Expand Up @@ -1802,6 +1812,24 @@ async def run_agent_live(
enable_affective_dialog: bool | None = Query(default=None),
enable_session_resumption: bool | None = Query(default=None),
) -> None:
# Validate Origin header to prevent cross-origin WebSocket hijacking.
# WebSocket connections are not protected by CORS, so we must validate
# the Origin ourselves. See https://github.com/google/adk-python/issues/4947
origin = websocket.headers.get("origin")
literal_origins, compiled_regex, origins_configured = _ws_allowed_origins
if origins_configured:
# CORS origins were configured: allow only listed origins
allowed = origin in literal_origins or (
compiled_regex and origin and compiled_regex.match(origin)
)
elif origin:
# No CORS config: only allow same-origin requests
allowed = False
else:
allowed = True # No Origin header (non-browser client)
if not allowed:
await websocket.close(code=1008, reason="Origin not allowed")
return
await websocket.accept()

session = await self.session_service.get_session(
Expand Down
87 changes: 87 additions & 0 deletions tests/unittests/cli/test_adk_web_server_run_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,90 @@ async def _get_runner_async(_self, _app_name: str):
run_config.session_resumption.transparent
is expected_session_resumption_transparent
)


def _make_app(allow_origins=None):
"""Helper to create a test FastAPI app with optional allow_origins."""
session_service = InMemorySessionService()
asyncio.run(
session_service.create_session(
app_name="test_app",
user_id="user",
session_id="session",
state={},
)
)

runner = _CapturingRunner()
adk_web_server = AdkWebServer(
agent_loader=_DummyAgentLoader(),
session_service=session_service,
memory_service=types.SimpleNamespace(),
artifact_service=types.SimpleNamespace(),
credential_service=types.SimpleNamespace(),
eval_sets_manager=types.SimpleNamespace(),
eval_set_results_manager=types.SimpleNamespace(),
agents_dir=".",
)

async def _get_runner_async(_self, _app_name: str):
return runner

adk_web_server.get_runner_async = _get_runner_async.__get__(adk_web_server) # pytype: disable=attribute-error

fast_api_app = adk_web_server.get_fast_api_app(
setup_observer=lambda _observer, _server: None,
tear_down_observer=lambda _observer, _server: None,
allow_origins=allow_origins,
)
return TestClient(fast_api_app)


def test_websocket_rejects_cross_origin_without_config():
"""WebSocket without allow_origins config rejects cross-origin requests.

Regression test for https://github.com/google/adk-python/issues/4947
"""
client = _make_app(allow_origins=None)
url = "/run_live?app_name=test_app&user_id=user&session_id=session&modalities=TEXT"

# Simulate a cross-origin request by manually providing an Origin header
with pytest.raises(Exception) as exc_info:
client.websocket_connect(url, headers={"origin": "http://evil.com"})

# Connection should be rejected (1008 or connection error)
assert "1008" in str(exc_info.value) or "WebSocket" in str(type(exc_info.value).__name__)


def test_websocket_accepts_same_origin_without_config():
"""WebSocket without allow_origins accepts requests without Origin header (non-browser clients)."""
client = _make_app(allow_origins=None)
url = "/run_live?app_name=test_app&user_id=user&session_id=session&modalities=TEXT"

# No Origin header = non-browser client = allowed
with client.websocket_connect(url) as ws:
_ = ws.receive_text()


def test_websocket_accepts_configured_origin():
"""WebSocket accepts when origin matches the configured allow_origins list."""
client = _make_app(allow_origins=["http://localhost:8000"])
url = "/run_live?app_name=test_app&user_id=user&session_id=session&modalities=TEXT"

with client.websocket_connect(
url, headers={"origin": "http://localhost:8000"}
) as ws:
_ = ws.receive_text()


def test_websocket_rejects_unlisted_origin():
"""WebSocket rejects when origin is not in the configured allow_origins list."""
client = _make_app(allow_origins=["http://localhost:8000"])
url = "/run_live?app_name=test_app&user_id=user&session_id=session&modalities=TEXT"

with pytest.raises(Exception) as exc_info:
client.websocket_connect(
url, headers={"origin": "http://evil.com"}
)

assert "1008" in str(exc_info.value) or "WebSocket" in str(type(exc_info.value).__name__)