Skip to content

Commit

Permalink
Add always_send, enable by default
Browse files Browse the repository at this point in the history
  • Loading branch information
corydolphin committed May 22, 2016
1 parent bddb13c commit 5582b3c
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 131 deletions.
38 changes: 29 additions & 9 deletions flask_cors/core.py
Expand Up @@ -39,7 +39,8 @@
'CORS_EXPOSE_HEADERS', 'CORS_SUPPORTS_CREDENTIALS',
'CORS_MAX_AGE', 'CORS_SEND_WILDCARD',
'CORS_AUTOMATIC_OPTIONS', 'CORS_VARY_HEADER',
'CORS_RESOURCES', 'CORS_INTERCEPT_EXCEPTIONS']
'CORS_RESOURCES', 'CORS_INTERCEPT_EXCEPTIONS',
'CORS_ALWAYS_SEND']
# Attribute added to request object by decorator to indicate that CORS
# was evaluated, in case the decorator and extension are both applied
# to a view.
Expand All @@ -58,7 +59,8 @@
automatic_options=True,
vary_header=True,
resources=r'/*',
intercept_exceptions=True)
intercept_exceptions=True,
always_send=True)


def parse_resources(resources):
Expand Down Expand Up @@ -108,7 +110,7 @@ def get_regexp_pattern(regexp):
return str(regexp)


def get_cors_origin(options, request_origin):
def get_cors_origins(options, request_origin):
origins = options.get('origins')
wildcard = r'.*' in origins

Expand All @@ -120,18 +122,32 @@ def get_cors_origin(options, request_origin):
# If the allowed origins is an asterisk or 'wildcard', always match
if wildcard and options.get('send_wildcard'):
LOG.debug("Allowed origins are set to '*'. Sending wildcard CORS header.")
return '*'
return ['*']
# If the value of the Origin header is a case-sensitive match
# for any of the values in list of origins
elif try_match_any(request_origin, origins):
LOG.debug("The request's Origin header matches. Sending CORS headers.", )
# Add a single Access-Control-Allow-Origin header, with either
# the value of the Origin header or the string "*" as value.
# -- W3Spec
return request_origin
return [request_origin]
else:
LOG.debug("The request's Origin header does not match any of allowed origins.")
return None


elif options.get('always_send'):
if wildcard:
# If wildcard is in the origins, even if 'send_wildcard' is False,
# simply send the wildcard. It is the most-likely to be correct
# thing to do (the only other option is to return nothing, which)
# pretty is probably not whawt you want if you specify origins as
# '*'
return ['*']
else:
# Return all origins that are not regexes.
return sorted([o for o in origins if not probably_regex(o)])

# Terminate these steps, return the original request untouched.
else:
LOG.debug("The request did not contain an 'Origin' header. This means the browser or client did not request CORS, ensure the Origin Header is set.")
Expand All @@ -154,13 +170,15 @@ def get_allow_headers(options, acl_request_headers):


def get_cors_headers(options, request_headers, request_method, response_headers):
origin_to_set = get_cors_origin(options, request_headers.get('Origin'))
origins_to_set = get_cors_origins(options, request_headers.get('Origin'))
headers = MultiDict()

if origin_to_set is None: # CORS is not enabled for this route
if not origins_to_set: # CORS is not enabled for this route
return headers

headers[ACL_ORIGIN] = origin_to_set
for origin in origins_to_set:
headers.add(ACL_ORIGIN, origin)

headers[ACL_EXPOSE_HEADERS] = options.get('expose_headers')

if options.get('supports_credentials'):
Expand Down Expand Up @@ -191,7 +209,9 @@ def get_cors_headers(options, request_headers, request_method, response_headers)
# origins that can be matched.
if headers[ACL_ORIGIN] == '*':
pass
elif len(options.get('origins')) > 1 or any(map(probably_regex, options.get('origins'))):
elif (len(options.get('origins')) > 1 or
len(origins_to_set) > 1 or
any(map(probably_regex, options.get('origins')))):
headers.add('Vary', 'Origin')

return MultiDict((k, v) for k, v in headers.items() if v)
Expand Down
3 changes: 0 additions & 3 deletions tests/decorator/test_credentials.py
Expand Up @@ -42,9 +42,6 @@ def test_credentials_supported(self):
resp = self.get('/test_credentials_supported', origin='www.example.com')
self.assertEquals(resp.headers.get(ACL_CREDENTIALS), 'true')

resp = self.get('/test_credentials_supported')
self.assertEquals(resp.headers.get(ACL_CREDENTIALS), None )

def test_default(self):
''' The default behavior should be to disallow credentials.
'''
Expand Down
151 changes: 32 additions & 119 deletions tests/decorator/test_origins.py
Expand Up @@ -17,7 +17,6 @@

letters = 'abcdefghijklmnopqrstuvwxyz' # string.letters is not PY3 compatible


class OriginsTestCase(FlaskCorsTestCase):
def setUp(self):
self.app = Flask(__name__)
Expand All @@ -27,9 +26,19 @@ def setUp(self):
def wildcard():
return 'Welcome!'

@self.app.route('/test_always_send')
@cross_origin(always_send=True)
def test_always_send():
return 'Welcome!'

@self.app.route('/test_always_send_no_wildcard')
@cross_origin(always_send=True, send_wildcard=False)
def test_always_send_no_wildcard():
return 'Welcome!'

@self.app.route('/test_send_wildcard_with_origin')
@cross_origin(send_wildcard=True)
def send_wildcard():
def test_send_wildcard_with_origin():
return 'Welcome!'

@self.app.route('/test_list')
Expand All @@ -49,31 +58,30 @@ def test_set():

@self.app.route('/test_subdomain_regex')
@cross_origin(origins=r"http?://\w*\.?example\.com:?\d*/?.*")
def _test_subdomain_regex():
def test_subdomain_regex():
return ''

@self.app.route('/test_compiled_subdomain_regex')
@cross_origin(origins=re.compile(r"http?://\w*\.?example\.com:?\d*/?.*"))
def _test_compiled_subdomain_regex():
def test_compiled_subdomain_regex():
return ''

@self.app.route('/test_regex_list')
@cross_origin(origins=[r".*.example.com", r".*.otherexample.com"])
def _test_regex_list():
def test_regex_list():
return ''

@self.app.route('/test_regex_mixed_list')
@cross_origin(origins=["http://example.com", r".*.otherexample.com"])
def _test_regex_mixed_list():
def test_regex_mixed_list():
return ''

def test_defaults_no_origin(self):
''' If there is no Origin header in the request, the
Access-Control-Allow-Origin header should not be included,
according to the w3 spec.
Access-Control-Allow-Origin header should be '*' by default.
'''
for resp in self.iter_responses('/'):
self.assertEqual(resp.headers.get(ACL_ORIGIN), None)
self.assertEqual(resp.headers.get(ACL_ORIGIN), '*')

def test_defaults_with_origin(self):
''' If there is an Origin header in the request, the
Expand All @@ -83,6 +91,21 @@ def test_defaults_with_origin(self):
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.headers.get(ACL_ORIGIN), 'http://example.com')

def test_always_send_no_wildcard(self):
'''
If send_wildcard=False, but the there is '*' in the
allowed origins, we should send it anyways.
'''
for resp in self.iter_responses('/'):
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.headers.get(ACL_ORIGIN), '*')

def test_always_send_no_wildcard_origins(self):
for resp in self.iter_responses('/'):
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.headers.get(ACL_ORIGIN), '*')


def test_send_wildcard_with_origin(self):
''' If there is an Origin header in the request, the
Access-Control-Allow-Origin header should be included.
Expand Down Expand Up @@ -166,115 +189,5 @@ def test_regex_mixed_list(self):
self.get('/test_regex_mixed_list', origin='http://example.com').headers.get(ACL_ORIGIN))


class AppConfigOriginsTestCase(AppConfigTest, OriginsTestCase):
def __init__(self, *args, **kwargs):
super(AppConfigOriginsTestCase, self).__init__(*args, **kwargs)

def test_defaults_no_origin(self):
@self.app.route('/')
@cross_origin()
def wildcard():
return 'Welcome!'

super(AppConfigOriginsTestCase, self).test_defaults_no_origin()

def test_defaults_with_origin(self):
@self.app.route('/')
@cross_origin()
def wildcard():
return 'Welcome!'
super(AppConfigOriginsTestCase, self).test_defaults_with_origin()

def test_send_wildcard_with_origin(self):
@self.app.route('/test_send_wildcard_with_origin')
@cross_origin(send_wildcard=True)
def send_wildcard():
return 'Welcome!'
super(AppConfigOriginsTestCase, self).test_send_wildcard_with_origin()

def test_list_serialized(self):
self.app.config['CORS_ORIGINS'] = ["http://foo.com", "http://bar.com"]

@self.app.route('/test_list')
@cross_origin()
def test_list():
return 'Welcome!'

super(AppConfigOriginsTestCase, self).test_list_serialized()

def test_string_serialized(self):
self.app.config['CORS_ORIGINS'] = "http://foo.com"

@self.app.route('/test_string')
@cross_origin()
def test_string():
return 'Welcome!'

super(AppConfigOriginsTestCase, self).test_string_serialized()

def test_set_serialized(self):
self.app.config['CORS_ORIGINS'] = set(["http://foo.com",
"http://bar.com"])

@self.app.route('/test_set')
@cross_origin()
def test_set():
return 'Welcome!'

super(AppConfigOriginsTestCase, self).test_set_serialized()

def test_not_matching_origins(self):
self.app.config['CORS_ORIGINS'] = ["http://foo.com", "http://bar.com"]

@self.app.route('/test_list')
@cross_origin()
def test_list():
return 'Welcome!'

super(AppConfigOriginsTestCase, self).test_not_matching_origins()

def test_regex_list(self):
@self.app.route('/test_regex_list')
@cross_origin()
def _test_regex_list():
return 'Welcome!'

self.app.config['CORS_ORIGINS'] = [r".*.example.com",
r".*.otherexample.com"]
super(AppConfigOriginsTestCase, self).test_regex_list()

def test_subdomain_regex(self):
self.app.config['CORS_ORIGINS'] = r"http?://\w*\.?example\.com:?\d*/?.*"

@self.app.route('/test_subdomain_regex')
@cross_origin()
def _test_subdomain_regex():
return ''

super(AppConfigOriginsTestCase, self).test_subdomain_regex()

def test_compiled_subdomain_regex(self):
self.app.config['CORS_ORIGINS'] = r"http?://\w*\.?example\.com:?\d*/?.*"

@self.app.route('/test_compiled_subdomain_regex')
@cross_origin()
def _test_compiled_subdomain_regex():
return ''

super(AppConfigOriginsTestCase, self).test_compiled_subdomain_regex()

def test_regex_mixed_list(self):
self.app.config['CORS_ORIGINS'] = ["http://example.com",
r".*.otherexample.com"]

@self.app.route('/test_regex_mixed_list')
@cross_origin()
def _test_regex_mixed_list():
return ''

super(AppConfigOriginsTestCase, self).test_regex_mixed_list()



if __name__ == "__main__":
unittest.main()

0 comments on commit 5582b3c

Please sign in to comment.