2020import json
2121import logging
2222import os
23+ import re
2324import sys
2425import time
2526import traceback
@@ -138,6 +139,158 @@ def _parse_cors_origins(
138139 return literal_origins , combined_regex
139140
140141
142+ def _is_origin_allowed (
143+ origin : str ,
144+ allowed_literal_origins : list [str ],
145+ allowed_origin_regex : Optional [re .Pattern [str ]],
146+ ) -> bool :
147+ """Check whether the given origin matches the allowed origins."""
148+ if "*" in allowed_literal_origins :
149+ return True
150+ if origin in allowed_literal_origins :
151+ return True
152+ if allowed_origin_regex is not None :
153+ return allowed_origin_regex .fullmatch (origin ) is not None
154+ return False
155+
156+
157+ def _normalize_origin_scheme (scheme : str ) -> str :
158+ """Normalize request schemes to the browser Origin scheme space."""
159+ if scheme == "ws" :
160+ return "http"
161+ if scheme == "wss" :
162+ return "https"
163+ return scheme
164+
165+
166+ def _strip_optional_quotes (value : str ) -> str :
167+ """Strip a single pair of wrapping quotes from a header value."""
168+ if len (value ) >= 2 and value [0 ] == '"' and value [- 1 ] == '"' :
169+ return value [1 :- 1 ]
170+ return value
171+
172+
173+ def _get_scope_header (
174+ scope : dict [str , Any ], header_name : bytes
175+ ) -> Optional [str ]:
176+ """Return the first matching header value from an ASGI scope."""
177+ for candidate_name , candidate_value in scope .get ("headers" , []):
178+ if candidate_name == header_name :
179+ return candidate_value .decode ("latin-1" ).split ("," , 1 )[0 ].strip ()
180+ return None
181+
182+
183+ def _get_request_origin (scope : dict [str , Any ]) -> Optional [str ]:
184+ """Compute the effective origin for the current HTTP/WebSocket request."""
185+ forwarded = _get_scope_header (scope , b"forwarded" )
186+ if forwarded is not None :
187+ proto = None
188+ host = None
189+ for element in forwarded .split ("," , 1 )[0 ].split (";" ):
190+ if "=" not in element :
191+ continue
192+ name , value = element .split ("=" , 1 )
193+ if name .strip ().lower () == "proto" :
194+ proto = _strip_optional_quotes (value .strip ())
195+ elif name .strip ().lower () == "host" :
196+ host = _strip_optional_quotes (value .strip ())
197+ if proto is not None and host is not None :
198+ return f"{ _normalize_origin_scheme (proto )} ://{ host } "
199+
200+ host = _get_scope_header (scope , b"x-forwarded-host" )
201+ if host is None :
202+ host = _get_scope_header (scope , b"host" )
203+ if host is None :
204+ return None
205+
206+ proto = _get_scope_header (scope , b"x-forwarded-proto" )
207+ if proto is None :
208+ proto = scope .get ("scheme" , "http" )
209+ return f"{ _normalize_origin_scheme (proto )} ://{ host } "
210+
211+
212+ def _is_request_origin_allowed (
213+ origin : str ,
214+ scope : dict [str , Any ],
215+ allowed_literal_origins : list [str ],
216+ allowed_origin_regex : Optional [re .Pattern [str ]],
217+ has_configured_allowed_origins : bool ,
218+ ) -> bool :
219+ """Validate an Origin header against explicit config or same-origin."""
220+ if has_configured_allowed_origins and _is_origin_allowed (
221+ origin , allowed_literal_origins , allowed_origin_regex
222+ ):
223+ return True
224+
225+ request_origin = _get_request_origin (scope )
226+ if request_origin is None :
227+ return False
228+ return origin == request_origin
229+
230+
231+ _SAFE_HTTP_METHODS = frozenset ({"GET" , "HEAD" , "OPTIONS" })
232+
233+
234+ class _OriginCheckMiddleware :
235+ """ASGI middleware that blocks cross-origin state-changing requests."""
236+
237+ def __init__ (
238+ self ,
239+ app : Any ,
240+ has_configured_allowed_origins : bool ,
241+ allowed_origins : list [str ],
242+ allowed_origin_regex : Optional [re .Pattern [str ]],
243+ ) -> None :
244+ self ._app = app
245+ self ._has_configured_allowed_origins = has_configured_allowed_origins
246+ self ._allowed_origins = allowed_origins
247+ self ._allowed_origin_regex = allowed_origin_regex
248+
249+ async def __call__ (
250+ self ,
251+ scope : dict [str , Any ],
252+ receive : Any ,
253+ send : Any ,
254+ ) -> None :
255+ if scope ["type" ] != "http" :
256+ await self ._app (scope , receive , send )
257+ return
258+
259+ method = scope .get ("method" , "GET" )
260+ if method in _SAFE_HTTP_METHODS :
261+ await self ._app (scope , receive , send )
262+ return
263+
264+ origin = _get_scope_header (scope , b"origin" )
265+ if origin is None :
266+ await self ._app (scope , receive , send )
267+ return
268+
269+ if _is_request_origin_allowed (
270+ origin ,
271+ scope ,
272+ self ._allowed_origins ,
273+ self ._allowed_origin_regex ,
274+ self ._has_configured_allowed_origins ,
275+ ):
276+ await self ._app (scope , receive , send )
277+ return
278+
279+ response_body = b"Forbidden: origin not allowed"
280+ await send ({
281+ "type" : "http.response.start" ,
282+ "status" : 403 ,
283+ "headers" : [
284+ (b"content-type" , b"text/plain" ),
285+ (b"content-length" , str (len (response_body )).encode ()),
286+ ],
287+ })
288+ await send ({
289+ "type" : "http.response.body" ,
290+ "body" : response_body ,
291+ })
292+
293+
141294class ApiServerSpanExporter (export_lib .SpanExporter ):
142295
143296 def __init__ (self , trace_dict ):
@@ -757,8 +910,12 @@ async def internal_lifespan(app: FastAPI):
757910 # Run the FastAPI server.
758911 app = FastAPI (lifespan = internal_lifespan )
759912
913+ has_configured_allowed_origins = bool (allow_origins )
760914 if allow_origins :
761915 literal_origins , combined_regex = _parse_cors_origins (allow_origins )
916+ compiled_origin_regex = (
917+ re .compile (combined_regex ) if combined_regex is not None else None
918+ )
762919 app .add_middleware (
763920 CORSMiddleware ,
764921 allow_origins = literal_origins ,
@@ -767,6 +924,16 @@ async def internal_lifespan(app: FastAPI):
767924 allow_methods = ["*" ],
768925 allow_headers = ["*" ],
769926 )
927+ else :
928+ literal_origins = []
929+ compiled_origin_regex = None
930+
931+ app .add_middleware (
932+ _OriginCheckMiddleware ,
933+ has_configured_allowed_origins = has_configured_allowed_origins ,
934+ allowed_origins = literal_origins ,
935+ allowed_origin_regex = compiled_origin_regex ,
936+ )
770937
771938 @app .get ("/health" )
772939 async def health () -> dict [str , str ]:
@@ -1802,14 +1969,23 @@ async def run_agent_live(
18021969 enable_affective_dialog : bool | None = Query (default = None ),
18031970 enable_session_resumption : bool | None = Query (default = None ),
18041971 ) -> None :
1972+ ws_origin = websocket .headers .get ("origin" )
1973+ if ws_origin is not None and not _is_request_origin_allowed (
1974+ ws_origin ,
1975+ websocket .scope ,
1976+ literal_origins ,
1977+ compiled_origin_regex ,
1978+ has_configured_allowed_origins ,
1979+ ):
1980+ await websocket .close (code = 1008 , reason = "Origin not allowed" )
1981+ return
1982+
18051983 await websocket .accept ()
18061984
18071985 session = await self .session_service .get_session (
18081986 app_name = app_name , user_id = user_id , session_id = session_id
18091987 )
18101988 if not session :
1811- # Accept first so that the client is aware of connection establishment,
1812- # then close with a specific code.
18131989 await websocket .close (code = 1002 , reason = "Session not found" )
18141990 return
18151991
0 commit comments