Skip to content

Commit

Permalink
Add Cors middleware
Browse files Browse the repository at this point in the history
Close #154
Close #239
Update #238
Update #207
Update #90
  • Loading branch information
mar10 committed Dec 7, 2021
1 parent 7f94f84 commit 5ee64a9
Show file tree
Hide file tree
Showing 14 changed files with 453 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- Drop Python 2 support
- Drop support for Python syntax in config files (wsgidav.conf)
- Drop support for Microsoft Web Folders (option `dir_browser.ms_mount`).
- CORS support
- Provider root paths are relative to configuration file
- DAVCollection, DAVNonCollection, DAVProvider are now ABCs.
- API enforces some named keyword args (`..., *, ...`)
Expand Down
40 changes: 35 additions & 5 deletions sample_wsgidav.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@ host: 0.0.0.0
# Server port (default: 8080, use --port on command line)
port: 8080

#: Add custom response headers (list of header-name / header-value tuples):
# response_headers:
# - ["Access-Control-Allow-Origin", "http://example.org"]

# Transfer block size in bytes
block_size: 8192

Expand Down Expand Up @@ -94,7 +90,8 @@ hotfixes:
#: See here for an example how to add custom middlewares:
#: https://wsgidav.readthedocs.io/en/latest/user_guide_configure.html#middleware-stack
middleware_stack:
- wsgidav.mw.debug_filter.WsgiDavDebugFilter
- wsgidav.mw.cors.Cors
# - wsgidav.mw.debug_filter.WsgiDavDebugFilter
- wsgidav.error_printer.ErrorPrinter
- wsgidav.http_authenticator.HTTPAuthenticator
- wsgidav.dir_browser.WsgiDavDirBrowser
Expand Down Expand Up @@ -182,6 +179,39 @@ pam_dc:
encoding: 'utf-8'
resetcreds: true

# ----------------------------------------------------------------------------
# CORS
# (Requires `wsgidav.mw.cors.Cors`, which is enabled by default.)
cors:
#: List of allowed Origins or '*'
#: Default: false, i.e. prevent CORS
allow_origin: null
# allow_origin: '*'
# allow_origin:
# - 'https://example.com'
# - 'https://localhost:8081'

#: List or comma-separated string of allowed methods (returned as
#: response to preflight request)
allow_methods:
# allow_methods: POST,HEAD
#: List or comma-separated string of allowed header names (returned as
#: response to preflight request)
allow_headers:
# - X-PINGOTHER
#: List or comma-separated string of allowed headers that JavaScript in
#: browsers is allowed to access.
expose_headers:
#: Set to true to allow responses on requests with credentials flag set
allow_credentials: false
#: Time in seconds for how long the response to the preflight request can
#: be cached (default: 5)
max_age: 600
#: Add custom response headers (dict of header-name -> header-value items)
#: (This is not related to CORS or required to implement CORS functionality)
add_always:
# 'X-Foo-Header: 'qux'

# ----------------------------------------------------------------------------
# Property Manager
# null: (default) no support for dead properties
Expand Down
31 changes: 31 additions & 0 deletions tests/fixtures/share/cors.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
<!DOCTYPE html>
<html>

<head>
<meta charset='utf-8'>
<meta http-equiv='X-UA-Compatible' content='IE=edge'>
<title>WsgiDAV CORS Test</title>
<meta name='viewport' content='width=device-width, initial-scale=1'>
<link rel='stylesheet' type='text/css' media='screen' href='main.css'>
<script defer src='main.js'></script>
<!-- <script defer src='///localhost:5000/main.js'></script> -->
</head>

<body>
<p>
<ul>
<li>Load ./main.js</li>
<li>Load ./main.js from :5000</li>
<li>Local image: <img src="logo.png"></li>
<li>image on :5000: <img src="///localhost:5000/logo.png"></li>
<li>Edit docx locally:
<a href="Lotosblütenstengel (蓮花莖).docx">Lotosblütenstengel (蓮花莖).docx</a>
</li>
<li>Edit docx on :5000:
<a href="///localhost:5000/Lotosblütenstengel (蓮花莖).docx">Lotosblütenstengel (蓮花莖).docx :5000</a>
</li>
</ul>
</p>
</body>

</html>
4 changes: 4 additions & 0 deletions tests/fixtures/share/data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"value": "foo",
"values": [1, 2, 3]
}
Binary file added tests/fixtures/share/logo.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions tests/fixtures/share/main.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
body {
background-color: antiquewhite;
}
23 changes: 23 additions & 0 deletions tests/fixtures/share/main.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
(function () {
console.log("main.js loaded.");
console.log("Loading 'data.json'...", window.origin, window.location);
fetch("data.json")
.then((response) => {
return response.json();
})
.then((data) => {
console.log("Loading 'data.json': ", data);
});

// Calling a cross-origin target with a non-tandrad header should trigger a
// preflight request:
fetch("//localhost:5000/data.json", {
headers: { "X-PINGOTHER": "pingpong" },
})
.then((response) => {
return response.json();
})
.then((data) => {
console.log("Loading 'data.json' from :5000: ", data);
});
})();
22 changes: 22 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
check_tags,
checked_etag,
deep_update,
fix_path,
get_dict_value,
get_module_logger,
init_logging,
is_child_uri,
Expand All @@ -23,6 +25,7 @@
pop_path,
removeprefix,
shift_path,
update_headers_in_place,
)


Expand Down Expand Up @@ -117,6 +120,14 @@ def testBasics(self):
assert parse_if_match_header(' "abc" , def ') == ["abc", "def"]
assert parse_if_match_header(' W/"abc" , def ') == ["abc", "def"]

self.assertRaises(ValueError, fix_path, "a/b", "/root/x")
assert fix_path("a/b", "/root/x", must_exist=False) == "/root/x/a/b"
assert fix_path("/a/b", "/root/x", must_exist=False) == "/a/b"

headers = [("foo", "bar"), ("baz", "qux")]
update_headers_in_place(headers, [("Foo", "bar2"), ("New", "new_val")])
assert headers == [("Foo", "bar2"), ("baz", "qux"), ("New", "new_val")]

d_org = {"b": True, "d": {"i": 1, "t": (1, 2)}}
d_new = {}
assert deep_update(d_org.copy(), d_new) == d_org
Expand Down Expand Up @@ -147,6 +158,17 @@ def testBasics(self):
}
assert deep_update({"user_mapping": {}}, d_new) == d_new

d = {"b": True, "d": {"i": 1, "t": (1, 2)}}
assert get_dict_value(d, "b") is True
assert get_dict_value(d, "d.i") == 1
assert get_dict_value(d, "d.i", default="def") == 1
assert get_dict_value(d, "d.q", default="def") == "def"
assert get_dict_value(d, "d.q.v", default="def") == "def"
assert get_dict_value(d, "q.q.q", default="def") == "def"
assert get_dict_value(d, "d.t.[1]") == 2
self.assertRaises(IndexError, get_dict_value, d, "d.t.[2]")
self.assertRaises(KeyError, get_dict_value, d, "d.q")


class LoggerTest(unittest.TestCase):
"""Test configurable logging."""
Expand Down
2 changes: 2 additions & 0 deletions wsgidav/default_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from wsgidav.dir_browser import WsgiDavDirBrowser
from wsgidav.error_printer import ErrorPrinter
from wsgidav.http_authenticator import HTTPAuthenticator
from wsgidav.mw.cors import Cors
from wsgidav.request_resolver import RequestResolver

__docformat__ = "reStructuredText"
Expand Down Expand Up @@ -48,6 +49,7 @@
"lock_storage": True, # True: use LockManager(lock_storage.LockStorageDict)
"middleware_stack": [
# WsgiDavDebugFilter,
Cors,
ErrorPrinter,
HTTPAuthenticator,
WsgiDavDirBrowser, # configured under dir_browser option (see below)
Expand Down
9 changes: 8 additions & 1 deletion wsgidav/mw/base_mw.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
"""
from abc import ABC, abstractmethod

from wsgidav.util import NO_DEFAULT, get_dict_value

__docformat__ = "reStructuredText"


Expand Down Expand Up @@ -37,5 +39,10 @@ def __repr__(self):
return "{}.{}".format(self.__module__, self.__class__.__name__)

def is_disabled(self):
"""Optionally return False to skip this module on startup."""
"""Optionally return True to skip this module on startup."""
return False

def get_config(self, key_path: str, default=NO_DEFAULT):
"""Optionally return True to skip this module on startup."""
res = get_dict_value(self.config, key_path, default)
return res
128 changes: 128 additions & 0 deletions wsgidav/mw/cors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# -*- coding: utf-8 -*-
# (c) 2009-2021 Martin Wendt and contributors; see WsgiDAV https://github.com/mar10/wsgidav
# Licensed under the MIT license:
# http://www.opensource.org/licenses/mit-license.php
"""
WSGI middleware used for CORS support (optional).
Respond to CORS preflight OPTIONS request and inject CORS headers.
"""
from wsgidav import util
from wsgidav.mw.base_mw import BaseMiddleware

__docformat__ = "reStructuredText"

_logger = util.get_module_logger(__name__)


class Cors(BaseMiddleware):
def __init__(self, wsgidav_app, next_app, config):
super().__init__(wsgidav_app, next_app, config)
opts = config.get("cors", None)
if opts is None:
opts = {}

allow_origins = opts.get("allow_origin")
if type(allow_origins) is str:
allow_origins = allow_origins.strip()
if allow_origins != "*":
allow_origins = [allow_origins]
elif allow_origins:
allow_origins = [ao.strip() for ao in allow_origins]

allow_headers = ",".join(util.to_set(opts.get("allow_headers")))
allow_methods = ",".join(util.to_set(opts.get("allow_methods")))
expose_headers = ",".join(util.to_set(opts.get("expose_headers")))
allow_credentials = opts.get("allow_credentials", False)
max_age = opts.get("max_age")
always_headers = opts.get("add_always")

add_always = []
if allow_credentials:
add_always.append(("Access-Control-Allow-Credentials", "true"))
if always_headers:
if type(always_headers) is not dict:
raise ValueError(
f"cors.add_always must be a list a dict: {always_headers}"
)
for n, v in always_headers.items():
add_always.append((n, v))

add_non_preflight = add_always[:]
if expose_headers:
add_always.append(("Access-Control-Expose-Headers", expose_headers))

add_preflight = add_always[:]
if allow_headers:
add_preflight.append(("Access-Control-Allow-Headers", allow_headers))
if allow_methods:
add_preflight.append(("Access-Control-Allow-Methods", allow_methods))
if max_age:
add_preflight.append(("Access-Control-Max-Age", str(int(max_age))))

self.non_preflight_headers = add_non_preflight
self.preflight_headers = add_preflight
#: Either '*' or al list of origins
self.allow_origins = allow_origins

def __repr__(self):
allow_origin = self.get_config("cors.allow_origin", None)
return f"{self.__module__}.{self.__class__.__name__}({allow_origin})"

def is_disabled(self):
"""Optionally return True to skip this module on startup."""
return not self.get_config("cors.allow_origin", False)

def __call__(self, environ, start_response):
method = environ["REQUEST_METHOD"].upper()
origin = environ.get("HTTP_ORIGIN")
ac_req_meth = environ.get("HTTP_ACCESS_CONTROL_REQUEST_METHOD")
ac_req_headers = environ.get("HTTP_ACCESS_CONTROL_REQUEST_HEADERS")

acao_headers = None
if self.allow_origins == "*":
acao_headers = [("Access-Control-Allow-Origin", "*")]
elif origin in self.allow_origins:
acao_headers = [
("Access-Control-Allow-Origin", origin),
("Vary", "Origin"),
]

if acao_headers:
_logger.debug(
f"Granted CORS {method} {environ['PATH_INFO']!r} "
f"{ac_req_meth!r}, headers: {ac_req_headers}, origin: {origin!r}"
)
else:
# Deny (still return 200 on preflight)
_logger.warning(
f"Denied CORS {method} {environ['PATH_INFO']!r} "
f"{ac_req_meth!r}, headers: {ac_req_headers}, origin: {origin!r}"
)

is_preflight = method == "OPTIONS" and ac_req_meth is not None

# Handle preflight request
if is_preflight:
# Always return 2xx, but only add Access-Control-Allow-Origin etc.
# if Origin is allowed
resp_headers = [
("Content-Length", "0"),
("Date", util.get_rfc1123_time()),
]
if acao_headers:
resp_headers += acao_headers + self.preflight_headers

start_response("204 No Content", resp_headers)
return [b""]

# non_preflight CORS request
def wrapped_start_response(status, headers, exc_info=None):
if acao_headers:
util.update_headers_in_place(
headers,
acao_headers + self.non_preflight_headers,
)
start_response(status, headers, exc_info)

return self.next_app(environ, wrapped_start_response)

0 comments on commit 5ee64a9

Please sign in to comment.