Skip to content

Commit

Permalink
- Added static_cors argument to :func:serve(), allowing CORS to…
Browse files Browse the repository at this point in the history
… be configured for static files.

- Added ``--static-cors`` argument to :doc:`aiohttp-wsgi-serve <main>` command line interface.
  • Loading branch information
etianen committed Mar 29, 2018
1 parent 3f1e2e9 commit 7ac9a2d
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 2 deletions.
4 changes: 4 additions & 0 deletions aiohttp_wsgi/__main__.py
Expand Up @@ -107,6 +107,10 @@ def add_argument(name, *aliases, **kwargs):
"`path` must start with a slash, but not end with a slash."
),
)
add_argument(
"--static-cors",
type=str,
)
add_argument(
"--script-name",
)
Expand Down
30 changes: 28 additions & 2 deletions aiohttp_wsgi/wsgi.py
Expand Up @@ -97,7 +97,7 @@
from contextlib import contextmanager
from tempfile import SpooledTemporaryFile
from wsgiref.util import is_hop_by_hop
from aiohttp.web import Application, AppRunner, TCPSite, UnixSite, Response, HTTPRequestEntityTooLarge
from aiohttp.web import Application, AppRunner, TCPSite, UnixSite, Response, HTTPRequestEntityTooLarge, middleware
from aiohttp_wsgi.utils import parse_sockname


Expand Down Expand Up @@ -279,6 +279,18 @@ def format_path(path):
return path


def static_cors_middleware(*, static, static_cors):
@middleware
async def do_static_cors_middleware(request, handler):
response = await handler(request)
for path, _ in static:
if request.path.startswith(path):
response.headers["Access-Control-Allow-Origin"] = static_cors
break
return response
return do_static_cors_middleware


@contextmanager
def run_server(
application,
Expand All @@ -295,6 +307,7 @@ def run_server(
backlog=1024,
# aiohttp config.
static=(),
static_cors=None,
script_name="",
shutdown_timeout=60.0,
**kwargs
Expand All @@ -307,8 +320,9 @@ def run_server(
# Create aiohttp app.
app = Application()
# Add static routes.
static = [(format_path(path), dirname) for path, dirname in static]
for path, dirname in static:
app.router.add_static(format_path(path), dirname)
app.router.add_static(path, dirname)
# Add the wsgi application. This has to be last.
app.router.add_route(
"*",
Expand All @@ -320,6 +334,13 @@ def run_server(
**kwargs
).handle_request,
)
# Configure middleware.
if static_cors:
app.middlewares.append(static_cors_middleware(
static=static,
static_cors=static_cors,
))
# Start the app runner.
runner = AppRunner(app)
loop.run_until_complete(runner.setup())
# Set up the server.
Expand Down Expand Up @@ -379,6 +400,7 @@ def serve(application, **kwargs): # pragma: no cover
:param int unix_socket_perms: {unix_socket_perms}
:param int backlog: {backlog}
:param list static: {static}
:param list static_cors: {static_cors}
:param str script_name: {script_name}
:param int shutdown_timeout: {shutdown_timeout}
"""
Expand Down Expand Up @@ -421,6 +443,10 @@ def serve(application, **kwargs): # pragma: no cover
).format_map(DEFAULTS),
"backlog": "Socket connection backlog. Defaults to {backlog!r}.".format_map(DEFAULTS),
"static": "Static root mappings in the form (path, directory). Defaults to {static!r}".format_map(DEFAULTS),
"static_cors": (
"Set to '*' to enable CORS on static files for all origins, or a string to enable CORS for a specific origin. "
"Defaults to {static_cors!r}"
).format_map(DEFAULTS),
"script_name": (
"URL prefix for the WSGI application, should start with a slash, but not end with a slash. "
"Defaults to ``{script_name!r}``."
Expand Down
8 changes: 8 additions & 0 deletions docs/changelog.rst
Expand Up @@ -3,6 +3,14 @@ aiohttp-wsgi changelog

.. currentmodule:: aiohttp_wsgi


0.8.1
-----

- Added ``static_cors`` argument to :func:`serve()`, allowing CORS to be configured for static files.
- Added ``--static-cors`` argument to :doc:`aiohttp-wsgi-serve <main>` command line interface.


0.8.0
-----

Expand Down
7 changes: 7 additions & 0 deletions tests/test_static.py
Expand Up @@ -23,3 +23,10 @@ def testStaticHitMissing(self):
with self.run_server(noop_application, static=STATIC) as client:
response = client.request(path="/static/missing.txt")
self.assertEqual(response.status, 404)

def testStaticHitCors(self):
with self.run_server(noop_application, static=STATIC, static_cors="*") as client:
response = client.request(path="/static/text.txt")
self.assertEqual(response.status, 200)
self.assertEqual(response.content, b"Test file")
self.assertEqual(response.headers["Access-Control-Allow-Origin"], "*")

0 comments on commit 7ac9a2d

Please sign in to comment.