-
Notifications
You must be signed in to change notification settings - Fork 3
Add logging for FS instantiation, better README and Allow-Private-Network #5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| sources: | ||
| - name: inmemory | ||
| path: memory://mytests | ||
| - name: local | ||
| path: file:///Users | ||
| readonly: true | ||
| - name: "Conda Stats" | ||
| path: "s3://anaconda-package-data/conda/hourly/" | ||
| kwargs: | ||
| anon: True | ||
| - name: "MyAnaconda" | ||
| path: "anaconda://my/" | ||
| allow_reload: true | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,178 @@ | ||
| """Copy of fastapi.middleware.cors, with """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import functools | ||
| import os | ||
| import re | ||
| import typing | ||
|
|
||
| from starlette.datastructures import Headers, MutableHeaders | ||
| from starlette.responses import PlainTextResponse, Response | ||
| from starlette.types import ASGIApp, Message, Receive, Scope, Send | ||
|
|
||
| PRIVATE = os.getenv("FS_PROXY_PRIVATE", False) == "True" | ||
| ALL_METHODS = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT") | ||
| SAFELISTED_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"} | ||
|
|
||
|
|
||
| class CORSMiddleware: | ||
| def __init__( | ||
| self, | ||
| app: ASGIApp, | ||
| allow_origins: typing.Sequence[str] = (), | ||
| allow_methods: typing.Sequence[str] = ("GET",), | ||
| allow_headers: typing.Sequence[str] = (), | ||
| allow_credentials: bool = False, | ||
| allow_origin_regex: str | None = None, | ||
| expose_headers: typing.Sequence[str] = (), | ||
| max_age: int = 600, | ||
| ) -> None: | ||
| if "*" in allow_methods: | ||
| allow_methods = ALL_METHODS | ||
|
|
||
| compiled_allow_origin_regex = None | ||
| if allow_origin_regex is not None: | ||
| compiled_allow_origin_regex = re.compile(allow_origin_regex) | ||
|
|
||
| allow_all_origins = "*" in allow_origins | ||
| allow_all_headers = "*" in allow_headers | ||
| preflight_explicit_allow_origin = not allow_all_origins or allow_credentials | ||
|
|
||
| simple_headers = {} | ||
| if allow_all_origins: | ||
| simple_headers["Access-Control-Allow-Origin"] = "*" | ||
| if allow_credentials: | ||
| simple_headers["Access-Control-Allow-Credentials"] = "true" | ||
| if expose_headers: | ||
| simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers) | ||
|
|
||
| preflight_headers = {} | ||
| if preflight_explicit_allow_origin: | ||
| # The origin value will be set in preflight_response() if it is allowed. | ||
| preflight_headers["Vary"] = "Origin" | ||
| else: | ||
| preflight_headers["Access-Control-Allow-Origin"] = "*" | ||
| preflight_headers.update( | ||
| { | ||
| "Access-Control-Allow-Methods": ", ".join(allow_methods), | ||
| "Access-Control-Max-Age": str(max_age), | ||
| } | ||
| ) | ||
| allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers)) | ||
| if allow_headers and not allow_all_headers: | ||
| preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers) | ||
| if allow_credentials: | ||
| preflight_headers["Access-Control-Allow-Credentials"] = "true" | ||
|
|
||
| self.app = app | ||
| self.allow_origins = allow_origins | ||
| self.allow_methods = allow_methods | ||
| self.allow_headers = [h.lower() for h in allow_headers] | ||
| self.allow_all_origins = allow_all_origins | ||
| self.allow_all_headers = allow_all_headers | ||
| self.preflight_explicit_allow_origin = preflight_explicit_allow_origin | ||
| self.allow_origin_regex = compiled_allow_origin_regex | ||
| self.simple_headers = simple_headers | ||
| self.preflight_headers = preflight_headers | ||
|
|
||
| async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | ||
| if scope["type"] != "http": # pragma: no cover | ||
| await self.app(scope, receive, send) | ||
| return | ||
|
|
||
| method = scope["method"] | ||
| headers = Headers(scope=scope) | ||
| origin = headers.get("origin") | ||
|
|
||
| if origin is None: | ||
| await self.app(scope, receive, send) | ||
| return | ||
|
|
||
| if method == "OPTIONS" and "access-control-request-method" in headers: | ||
| response = self.preflight_response(request_headers=headers) | ||
| await response(scope, receive, send) | ||
| return | ||
|
|
||
| await self.simple_response(scope, receive, send, request_headers=headers) | ||
|
|
||
| def is_allowed_origin(self, origin: str) -> bool: | ||
| if self.allow_all_origins: | ||
| return True | ||
|
|
||
| if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin): | ||
| return True | ||
|
|
||
| return origin in self.allow_origins | ||
|
|
||
| def preflight_response(self, request_headers: Headers) -> Response: | ||
| requested_origin = request_headers["origin"] | ||
| requested_method = request_headers["access-control-request-method"] | ||
| requested_headers = request_headers.get("access-control-request-headers") | ||
|
|
||
| headers = dict(self.preflight_headers) | ||
| failures = [] | ||
|
|
||
| if self.is_allowed_origin(origin=requested_origin): | ||
| if self.preflight_explicit_allow_origin: | ||
| # The "else" case is already accounted for in self.preflight_headers | ||
| # and the value would be "*". | ||
| headers["Access-Control-Allow-Origin"] = requested_origin | ||
| else: | ||
| failures.append("origin") | ||
|
|
||
| if requested_method not in self.allow_methods: | ||
| failures.append("method") | ||
|
|
||
| # If we allow all headers, then we have to mirror back any requested | ||
| # headers in the response. | ||
| if self.allow_all_headers and requested_headers is not None: | ||
| headers["Access-Control-Allow-Headers"] = requested_headers | ||
| elif requested_headers is not None: | ||
| for header in [h.lower() for h in requested_headers.split(",")]: | ||
| if header.strip() not in self.allow_headers: | ||
| failures.append("headers") | ||
| break | ||
|
|
||
| # We don't strictly need to use 400 responses here, since its up to | ||
| # the browser to enforce the CORS policy, but its more informative | ||
| # if we do. | ||
| if failures: | ||
| failure_text = "Disallowed CORS " + ", ".join(failures) | ||
| return PlainTextResponse(failure_text, status_code=400, headers=headers) | ||
|
|
||
| if PRIVATE: | ||
| headers["Access-Control-Allow-Private-Network"] = "true" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how could we tell upstream this should be somehow handled by this file?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be a good idea to leave an issue or PR there, yes. For those that have a little time. |
||
| return PlainTextResponse("OK", status_code=200, headers=headers) | ||
|
|
||
| async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None: | ||
| send = functools.partial(self.send, send=send, request_headers=request_headers) | ||
| await self.app(scope, receive, send) | ||
|
|
||
| async def send(self, message: Message, send: Send, request_headers: Headers) -> None: | ||
| if message["type"] != "http.response.start": | ||
| await send(message) | ||
| return | ||
|
|
||
| message.setdefault("headers", []) | ||
| headers = MutableHeaders(scope=message) | ||
| headers.update(self.simple_headers) | ||
| origin = request_headers["Origin"] | ||
| has_cookie = "cookie" in request_headers | ||
|
|
||
| # If request includes any cookie headers, then we must respond | ||
| # with the specific origin instead of '*'. | ||
| if self.allow_all_origins and has_cookie: | ||
| self.allow_explicit_origin(headers, origin) | ||
|
|
||
| # If we only allow specific origins, then we have to mirror back | ||
| # the Origin header in the response. | ||
| elif not self.allow_all_origins and self.is_allowed_origin(origin=origin): | ||
| self.allow_explicit_origin(headers, origin) | ||
|
|
||
| await send(message) | ||
|
|
||
| @staticmethod | ||
| def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None: | ||
| headers["Access-Control-Allow-Origin"] = origin | ||
| headers.add_vary_header("Origin") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... with ??? 🤔