Skip to content

Commit

Permalink
Refactor CORS Middleware to be much faster (#7801)
Browse files Browse the repository at this point in the history
* changes

* add changeset

* log

* add changeset

* changes

* changes

* lint

* middlware

* lint

* lint

* s
implified docstring

* fix

* revert test change

* docstring

* remove print

* update

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
abidlabs and gradio-pr-bot committed Mar 26, 2024
1 parent 8b099a0 commit 05db0c4
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 31 deletions.
5 changes: 5 additions & 0 deletions .changeset/stale-grapes-roll.md
@@ -0,0 +1,5 @@
---
"gradio": patch
---

feat:Refactor CORS Middleware to be much faster
109 changes: 78 additions & 31 deletions gradio/route_utils.py
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import functools
import hashlib
import hmac
import json
Expand Down Expand Up @@ -33,9 +34,10 @@
import multipart
from gradio_client.documentation import document
from multipart.multipart import parse_options_header
from starlette.datastructures import FormData, Headers, UploadFile
from starlette.datastructures import FormData, Headers, MutableHeaders, UploadFile
from starlette.formparsers import MultiPartException, MultipartPart
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send

from gradio import processing_utils, utils
from gradio.data_classes import PredictBody
Expand Down Expand Up @@ -648,41 +650,86 @@ def get_hostname(url: str) -> str:
return ""


class CustomCORSMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: fastapi.Request, call_next):
host: str = request.headers.get("host", "")
origin: str = request.headers.get("origin", "")
host_name = get_hostname(host)
origin_name = get_hostname(origin)
class CustomCORSMiddleware:
# This is a modified version of the Starlette CORSMiddleware that restricts the allowed origins when the host is localhost.
# Adapted from: https://github.com/encode/starlette/blob/89fae174a1ea10f59ae248fe030d9b7e83d0b8a0/starlette/middleware/cors.py

def __init__(
self,
app: ASGIApp,
) -> None:
self.app = app
self.all_methods = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT")
self.preflight_headers = {
"Access-Control-Allow-Methods": ", ".join(self.all_methods),
"Access-Control-Max-Age": str(600),
}
self.simple_headers = {"Access-Control-Allow-Credentials": "true"}
# Any of these hosts suggests that the Gradio app is running locally.
# Note: "null" is a special case that happens if a Gradio app is running
# as an embedded web component in a local static webpage.
localhost_aliases = ["localhost", "127.0.0.1", "0.0.0.0", "null"]
is_preflight = (
request.method == "OPTIONS"
and "access-control-request-method" in request.headers
)
self.localhost_aliases = ["localhost", "127.0.0.1", "0.0.0.0", "null"]

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
headers = Headers(scope=scope)
origin = headers.get("origin")
if origin is None:
await self.app(scope, receive, send)
return
if scope["method"] == "OPTIONS" and "access-control-request-method" in headers:
response = self.preflight_response(request_headers=headers)
await response(scope, receive, send)
return
await self.simple_response(scope, receive, send, request_headers=headers)

def preflight_response(self, request_headers: Headers) -> Response:
headers = dict(self.preflight_headers)
origin = request_headers["Origin"]
if self.is_valid_origin(request_headers):
headers["Access-Control-Allow-Origin"] = origin
requested_headers = request_headers.get("access-control-request-headers")
if requested_headers is not None:
headers["Access-Control-Allow-Headers"] = requested_headers
return PlainTextResponse("OK", status_code=200, headers=headers)

async def simple_response(
self, scope: Scope, receive: Receive, send: Send, request_headers: Headers
) -> None:
send = functools.partial(self._send, send=send, request_headers=request_headers)
await self.app(scope, receive, send)

if host_name in localhost_aliases and origin_name not in localhost_aliases:
allow_origin_header = None
else:
allow_origin_header = origin
async def _send(
self, message: Message, send: Send, request_headers: Headers
) -> None:
if message["type"] != "http.response.start":
await send(message)
return
message.setdefault("headers", [])
headers = MutableHeaders(scope=message)
headers.update(self.simple_headers)
has_cookie = "cookie" in request_headers
origin = request_headers["Origin"]
if has_cookie or self.is_valid_origin(request_headers):
self.allow_explicit_origin(headers, origin)
await send(message)

def is_valid_origin(self, request_headers: Headers) -> bool:
origin = request_headers["Origin"]
host = request_headers["Host"]
host_name = get_hostname(host)
origin_name = get_hostname(origin)
return (
host_name not in self.localhost_aliases
or origin_name in self.localhost_aliases
)

if is_preflight:
response = fastapi.Response()
else:
response = await call_next(request)

if allow_origin_header:
response.headers["Access-Control-Allow-Origin"] = allow_origin_header
response.headers[
"Access-Control-Allow-Methods"
] = "GET, POST, PUT, DELETE, OPTIONS"
response.headers[
"Access-Control-Allow-Headers"
] = "Origin, Content-Type, Accept"
return response
@staticmethod
def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None:
headers["Access-Control-Allow-Origin"] = origin
headers.add_vary_header("Origin")


def delete_files_created_by_app(blocks: Blocks, age: int | None) -> None:
Expand Down

0 comments on commit 05db0c4

Please sign in to comment.