diff --git a/.changeset/stale-grapes-roll.md b/.changeset/stale-grapes-roll.md new file mode 100644 index 000000000000..2d465678c924 --- /dev/null +++ b/.changeset/stale-grapes-roll.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +feat:Refactor CORS Middleware to be much faster diff --git a/gradio/route_utils.py b/gradio/route_utils.py index b72bde427ee3..11add279f25b 100644 --- a/gradio/route_utils.py +++ b/gradio/route_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import functools import hashlib import hmac import json @@ -33,9 +34,10 @@ import multipart from gradio_client.documentation import document from multipart.multipart import parse_options_header -from starlette.datastructures import FormData, Headers, UploadFile +from starlette.datastructures import FormData, Headers, MutableHeaders, UploadFile from starlette.formparsers import MultiPartException, MultipartPart -from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import PlainTextResponse, Response +from starlette.types import ASGIApp, Message, Receive, Scope, Send from gradio import processing_utils, utils from gradio.data_classes import PredictBody @@ -648,41 +650,86 @@ def get_hostname(url: str) -> str: return "" -class CustomCORSMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: fastapi.Request, call_next): - host: str = request.headers.get("host", "") - origin: str = request.headers.get("origin", "") - host_name = get_hostname(host) - origin_name = get_hostname(origin) +class CustomCORSMiddleware: + # This is a modified version of the Starlette CORSMiddleware that restricts the allowed origins when the host is localhost. + # Adapted from: https://github.com/encode/starlette/blob/89fae174a1ea10f59ae248fe030d9b7e83d0b8a0/starlette/middleware/cors.py + def __init__( + self, + app: ASGIApp, + ) -> None: + self.app = app + self.all_methods = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT") + self.preflight_headers = { + "Access-Control-Allow-Methods": ", ".join(self.all_methods), + "Access-Control-Max-Age": str(600), + } + self.simple_headers = {"Access-Control-Allow-Credentials": "true"} # Any of these hosts suggests that the Gradio app is running locally. # Note: "null" is a special case that happens if a Gradio app is running # as an embedded web component in a local static webpage. - localhost_aliases = ["localhost", "127.0.0.1", "0.0.0.0", "null"] - is_preflight = ( - request.method == "OPTIONS" - and "access-control-request-method" in request.headers - ) + self.localhost_aliases = ["localhost", "127.0.0.1", "0.0.0.0", "null"] + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + headers = Headers(scope=scope) + origin = headers.get("origin") + if origin is None: + await self.app(scope, receive, send) + return + if scope["method"] == "OPTIONS" and "access-control-request-method" in headers: + response = self.preflight_response(request_headers=headers) + await response(scope, receive, send) + return + await self.simple_response(scope, receive, send, request_headers=headers) + + def preflight_response(self, request_headers: Headers) -> Response: + headers = dict(self.preflight_headers) + origin = request_headers["Origin"] + if self.is_valid_origin(request_headers): + headers["Access-Control-Allow-Origin"] = origin + requested_headers = request_headers.get("access-control-request-headers") + if requested_headers is not None: + headers["Access-Control-Allow-Headers"] = requested_headers + return PlainTextResponse("OK", status_code=200, headers=headers) + + async def simple_response( + self, scope: Scope, receive: Receive, send: Send, request_headers: Headers + ) -> None: + send = functools.partial(self._send, send=send, request_headers=request_headers) + await self.app(scope, receive, send) - if host_name in localhost_aliases and origin_name not in localhost_aliases: - allow_origin_header = None - else: - allow_origin_header = origin + async def _send( + self, message: Message, send: Send, request_headers: Headers + ) -> None: + if message["type"] != "http.response.start": + await send(message) + return + message.setdefault("headers", []) + headers = MutableHeaders(scope=message) + headers.update(self.simple_headers) + has_cookie = "cookie" in request_headers + origin = request_headers["Origin"] + if has_cookie or self.is_valid_origin(request_headers): + self.allow_explicit_origin(headers, origin) + await send(message) + + def is_valid_origin(self, request_headers: Headers) -> bool: + origin = request_headers["Origin"] + host = request_headers["Host"] + host_name = get_hostname(host) + origin_name = get_hostname(origin) + return ( + host_name not in self.localhost_aliases + or origin_name in self.localhost_aliases + ) - if is_preflight: - response = fastapi.Response() - else: - response = await call_next(request) - - if allow_origin_header: - response.headers["Access-Control-Allow-Origin"] = allow_origin_header - response.headers[ - "Access-Control-Allow-Methods" - ] = "GET, POST, PUT, DELETE, OPTIONS" - response.headers[ - "Access-Control-Allow-Headers" - ] = "Origin, Content-Type, Accept" - return response + @staticmethod + def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None: + headers["Access-Control-Allow-Origin"] = origin + headers.add_vary_header("Origin") def delete_files_created_by_app(blocks: Blocks, age: int | None) -> None: