Skip to content

Commit 995cd1c

Browse files
GWealecopybara-github
authored andcommitted
fix: add protection for arbitrary module imports
Close #4947 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 888296476
1 parent 0f351bf commit 995cd1c

File tree

3 files changed

+292
-3
lines changed

3 files changed

+292
-3
lines changed

src/google/adk/cli/adk_web_server.py

Lines changed: 178 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import json
2121
import logging
2222
import os
23+
import re
2324
import sys
2425
import time
2526
import traceback
@@ -138,6 +139,158 @@ def _parse_cors_origins(
138139
return literal_origins, combined_regex
139140

140141

142+
def _is_origin_allowed(
143+
origin: str,
144+
allowed_literal_origins: list[str],
145+
allowed_origin_regex: Optional[re.Pattern[str]],
146+
) -> bool:
147+
"""Check whether the given origin matches the allowed origins."""
148+
if "*" in allowed_literal_origins:
149+
return True
150+
if origin in allowed_literal_origins:
151+
return True
152+
if allowed_origin_regex is not None:
153+
return allowed_origin_regex.fullmatch(origin) is not None
154+
return False
155+
156+
157+
def _normalize_origin_scheme(scheme: str) -> str:
158+
"""Normalize request schemes to the browser Origin scheme space."""
159+
if scheme == "ws":
160+
return "http"
161+
if scheme == "wss":
162+
return "https"
163+
return scheme
164+
165+
166+
def _strip_optional_quotes(value: str) -> str:
167+
"""Strip a single pair of wrapping quotes from a header value."""
168+
if len(value) >= 2 and value[0] == '"' and value[-1] == '"':
169+
return value[1:-1]
170+
return value
171+
172+
173+
def _get_scope_header(
174+
scope: dict[str, Any], header_name: bytes
175+
) -> Optional[str]:
176+
"""Return the first matching header value from an ASGI scope."""
177+
for candidate_name, candidate_value in scope.get("headers", []):
178+
if candidate_name == header_name:
179+
return candidate_value.decode("latin-1").split(",", 1)[0].strip()
180+
return None
181+
182+
183+
def _get_request_origin(scope: dict[str, Any]) -> Optional[str]:
184+
"""Compute the effective origin for the current HTTP/WebSocket request."""
185+
forwarded = _get_scope_header(scope, b"forwarded")
186+
if forwarded is not None:
187+
proto = None
188+
host = None
189+
for element in forwarded.split(",", 1)[0].split(";"):
190+
if "=" not in element:
191+
continue
192+
name, value = element.split("=", 1)
193+
if name.strip().lower() == "proto":
194+
proto = _strip_optional_quotes(value.strip())
195+
elif name.strip().lower() == "host":
196+
host = _strip_optional_quotes(value.strip())
197+
if proto is not None and host is not None:
198+
return f"{_normalize_origin_scheme(proto)}://{host}"
199+
200+
host = _get_scope_header(scope, b"x-forwarded-host")
201+
if host is None:
202+
host = _get_scope_header(scope, b"host")
203+
if host is None:
204+
return None
205+
206+
proto = _get_scope_header(scope, b"x-forwarded-proto")
207+
if proto is None:
208+
proto = scope.get("scheme", "http")
209+
return f"{_normalize_origin_scheme(proto)}://{host}"
210+
211+
212+
def _is_request_origin_allowed(
213+
origin: str,
214+
scope: dict[str, Any],
215+
allowed_literal_origins: list[str],
216+
allowed_origin_regex: Optional[re.Pattern[str]],
217+
has_configured_allowed_origins: bool,
218+
) -> bool:
219+
"""Validate an Origin header against explicit config or same-origin."""
220+
if has_configured_allowed_origins and _is_origin_allowed(
221+
origin, allowed_literal_origins, allowed_origin_regex
222+
):
223+
return True
224+
225+
request_origin = _get_request_origin(scope)
226+
if request_origin is None:
227+
return False
228+
return origin == request_origin
229+
230+
231+
_SAFE_HTTP_METHODS = frozenset({"GET", "HEAD", "OPTIONS"})
232+
233+
234+
class _OriginCheckMiddleware:
235+
"""ASGI middleware that blocks cross-origin state-changing requests."""
236+
237+
def __init__(
238+
self,
239+
app: Any,
240+
has_configured_allowed_origins: bool,
241+
allowed_origins: list[str],
242+
allowed_origin_regex: Optional[re.Pattern[str]],
243+
) -> None:
244+
self._app = app
245+
self._has_configured_allowed_origins = has_configured_allowed_origins
246+
self._allowed_origins = allowed_origins
247+
self._allowed_origin_regex = allowed_origin_regex
248+
249+
async def __call__(
250+
self,
251+
scope: dict[str, Any],
252+
receive: Any,
253+
send: Any,
254+
) -> None:
255+
if scope["type"] != "http":
256+
await self._app(scope, receive, send)
257+
return
258+
259+
method = scope.get("method", "GET")
260+
if method in _SAFE_HTTP_METHODS:
261+
await self._app(scope, receive, send)
262+
return
263+
264+
origin = _get_scope_header(scope, b"origin")
265+
if origin is None:
266+
await self._app(scope, receive, send)
267+
return
268+
269+
if _is_request_origin_allowed(
270+
origin,
271+
scope,
272+
self._allowed_origins,
273+
self._allowed_origin_regex,
274+
self._has_configured_allowed_origins,
275+
):
276+
await self._app(scope, receive, send)
277+
return
278+
279+
response_body = b"Forbidden: origin not allowed"
280+
await send({
281+
"type": "http.response.start",
282+
"status": 403,
283+
"headers": [
284+
(b"content-type", b"text/plain"),
285+
(b"content-length", str(len(response_body)).encode()),
286+
],
287+
})
288+
await send({
289+
"type": "http.response.body",
290+
"body": response_body,
291+
})
292+
293+
141294
class ApiServerSpanExporter(export_lib.SpanExporter):
142295

143296
def __init__(self, trace_dict):
@@ -757,8 +910,12 @@ async def internal_lifespan(app: FastAPI):
757910
# Run the FastAPI server.
758911
app = FastAPI(lifespan=internal_lifespan)
759912

913+
has_configured_allowed_origins = bool(allow_origins)
760914
if allow_origins:
761915
literal_origins, combined_regex = _parse_cors_origins(allow_origins)
916+
compiled_origin_regex = (
917+
re.compile(combined_regex) if combined_regex is not None else None
918+
)
762919
app.add_middleware(
763920
CORSMiddleware,
764921
allow_origins=literal_origins,
@@ -767,6 +924,16 @@ async def internal_lifespan(app: FastAPI):
767924
allow_methods=["*"],
768925
allow_headers=["*"],
769926
)
927+
else:
928+
literal_origins = []
929+
compiled_origin_regex = None
930+
931+
app.add_middleware(
932+
_OriginCheckMiddleware,
933+
has_configured_allowed_origins=has_configured_allowed_origins,
934+
allowed_origins=literal_origins,
935+
allowed_origin_regex=compiled_origin_regex,
936+
)
770937

771938
@app.get("/health")
772939
async def health() -> dict[str, str]:
@@ -1802,14 +1969,23 @@ async def run_agent_live(
18021969
enable_affective_dialog: bool | None = Query(default=None),
18031970
enable_session_resumption: bool | None = Query(default=None),
18041971
) -> None:
1972+
ws_origin = websocket.headers.get("origin")
1973+
if ws_origin is not None and not _is_request_origin_allowed(
1974+
ws_origin,
1975+
websocket.scope,
1976+
literal_origins,
1977+
compiled_origin_regex,
1978+
has_configured_allowed_origins,
1979+
):
1980+
await websocket.close(code=1008, reason="Origin not allowed")
1981+
return
1982+
18051983
await websocket.accept()
18061984

18071985
session = await self.session_service.get_session(
18081986
app_name=app_name, user_id=user_id, session_id=session_id
18091987
)
18101988
if not session:
1811-
# Accept first so that the client is aware of connection establishment,
1812-
# then close with a specific code.
18131989
await websocket.close(code=1002, reason="Session not found")
18141990
return
18151991

tests/unittests/cli/test_adk_web_server_run_live.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from google.adk.events.event import Event
2323
from google.adk.sessions.in_memory_session_service import InMemorySessionService
2424
import pytest
25+
from starlette.websockets import WebSocketDisconnect
2526

2627

2728
class _DummyAgent(BaseAgent):
@@ -203,3 +204,75 @@ async def _get_runner_async(_self, _app_name: str):
203204
run_config.session_resumption.transparent
204205
is expected_session_resumption_transparent
205206
)
207+
208+
209+
_WS_BASE_URL = (
210+
"/run_live"
211+
"?app_name=test_app"
212+
"&user_id=user"
213+
"&session_id=session"
214+
"&modalities=AUDIO"
215+
)
216+
217+
218+
def _build_ws_client():
219+
"""Build a TestClient wired to a capturing runner."""
220+
session_service = InMemorySessionService()
221+
asyncio.run(
222+
session_service.create_session(
223+
app_name="test_app",
224+
user_id="user",
225+
session_id="session",
226+
state={},
227+
)
228+
)
229+
230+
runner = _CapturingRunner()
231+
adk_web_server = AdkWebServer(
232+
agent_loader=_DummyAgentLoader(),
233+
session_service=session_service,
234+
memory_service=types.SimpleNamespace(),
235+
artifact_service=types.SimpleNamespace(),
236+
credential_service=types.SimpleNamespace(),
237+
eval_sets_manager=types.SimpleNamespace(),
238+
eval_set_results_manager=types.SimpleNamespace(),
239+
agents_dir=".",
240+
)
241+
242+
async def _get_runner_async(_self, _app_name: str):
243+
return runner
244+
245+
adk_web_server.get_runner_async = _get_runner_async.__get__(adk_web_server) # pytype: disable=attribute-error
246+
247+
fast_api_app = adk_web_server.get_fast_api_app(
248+
setup_observer=lambda _observer, _server: None,
249+
tear_down_observer=lambda _observer, _server: None,
250+
)
251+
return TestClient(fast_api_app)
252+
253+
254+
def test_run_live_rejects_disallowed_origin():
255+
client = _build_ws_client()
256+
with pytest.raises(WebSocketDisconnect) as exc_info:
257+
with client.websocket_connect(
258+
_WS_BASE_URL,
259+
headers={"origin": "https://evil.com"},
260+
) as ws:
261+
ws.receive_text()
262+
assert exc_info.value.code == 1008
263+
264+
265+
def test_run_live_allows_matching_origin():
266+
client = _build_ws_client()
267+
with client.websocket_connect(
268+
_WS_BASE_URL,
269+
headers={"origin": "http://testserver"},
270+
) as ws:
271+
_ = ws.receive_text()
272+
273+
274+
def test_run_live_allows_no_origin_header():
275+
"""Non-browser clients (curl, wscat, SDKs) send no Origin header."""
276+
client = _build_ws_client()
277+
with client.websocket_connect(_WS_BASE_URL) as ws:
278+
_ = ws.receive_text()

tests/unittests/cli/test_fast_api.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ def builder_test_client(
593593
session_service_uri="",
594594
artifact_service_uri="",
595595
memory_service_uri="",
596-
allow_origins=["*"],
596+
allow_origins=None,
597597
a2a=False,
598598
host="127.0.0.1",
599599
port=8000,
@@ -1595,6 +1595,46 @@ def test_builder_final_save_preserves_tools_and_cleans_tmp(
15951595
assert not tmp_dir.exists() or not any(tmp_dir.iterdir())
15961596

15971597

1598+
def test_builder_save_rejects_cross_origin_post(builder_test_client, tmp_path):
1599+
response = builder_test_client.post(
1600+
"/builder/save?tmp=true",
1601+
headers={"origin": "https://evil.com"},
1602+
files=[(
1603+
"files",
1604+
("app/root_agent.yaml", b"name: app\n", "application/x-yaml"),
1605+
)],
1606+
)
1607+
1608+
assert response.status_code == 403
1609+
assert response.text == "Forbidden: origin not allowed"
1610+
assert not (tmp_path / "app" / "tmp" / "app").exists()
1611+
1612+
1613+
def test_builder_save_allows_same_origin_post(builder_test_client, tmp_path):
1614+
response = builder_test_client.post(
1615+
"/builder/save?tmp=true",
1616+
headers={"origin": "http://testserver"},
1617+
files=[(
1618+
"files",
1619+
("app/root_agent.yaml", b"name: app\n", "application/x-yaml"),
1620+
)],
1621+
)
1622+
1623+
assert response.status_code == 200
1624+
assert response.json() is True
1625+
assert (tmp_path / "app" / "tmp" / "app" / "root_agent.yaml").is_file()
1626+
1627+
1628+
def test_builder_get_allows_cross_origin_get(builder_test_client):
1629+
response = builder_test_client.get(
1630+
"/builder/app/missing?tmp=true",
1631+
headers={"origin": "https://evil.com"},
1632+
)
1633+
1634+
assert response.status_code == 200
1635+
assert response.text == ""
1636+
1637+
15981638
def test_builder_cancel_deletes_tmp_idempotent(builder_test_client, tmp_path):
15991639
tmp_agent_root = tmp_path / "app" / "tmp" / "app"
16001640
tmp_agent_root.mkdir(parents=True, exist_ok=True)

0 commit comments

Comments
 (0)