Skip to content

Commit

Permalink
S3: Improve Cors AllowedOrigin behaviour (#6007)
Browse files Browse the repository at this point in the history
  • Loading branch information
bblommers committed Mar 3, 2023
1 parent 96b8e12 commit 8b058d9
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 10 deletions.
7 changes: 7 additions & 0 deletions moto/s3/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def __init__(self, message, name, value, *args, **kwargs):
super().__init__("InvalidArgument", message, *args, **kwargs)


class AccessForbidden(S3ClientError):
code = 403

def __init__(self, msg):
super().__init__("AccessForbidden", msg)


class BucketError(S3ClientError):
def __init__(self, *args, **kwargs):
kwargs.setdefault("template", "bucket_error")
Expand Down
22 changes: 14 additions & 8 deletions moto/s3/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
PreconditionFailed,
InvalidRange,
LockNotEnabled,
AccessForbidden,
)
from .models import s3_backends
from .models import get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey
Expand All @@ -59,6 +60,7 @@
parse_region_from_url,
compute_checksum,
ARCHIVE_STORAGE_CLASSES,
cors_matches_origin,
)
from xml.dom import minidom

Expand Down Expand Up @@ -298,7 +300,7 @@ def _bucket_response(self, request, full_url):
elif method == "POST":
return self._bucket_response_post(request, bucket_name)
elif method == "OPTIONS":
return self._response_options(bucket_name)
return self._response_options(request.headers, bucket_name)
else:
raise NotImplementedError(
f"Method {method} has not been implemented in the S3 backend yet"
Expand Down Expand Up @@ -343,7 +345,7 @@ def _bucket_response_head(self, bucket_name, querystring):
return 404, {}, ""
return 200, {}, ""

def _set_cors_headers(self, bucket):
def _set_cors_headers(self, headers, bucket):
"""
TODO: smarter way of matching the right CORS rule:
See https://docs.aws.amazon.com/AmazonS3/latest/userguide/cors.html
Expand All @@ -367,9 +369,13 @@ def _to_string(header: Union[List[str], str]) -> str:
cors_rule.allowed_methods
)
if cors_rule.allowed_origins is not None:
self.response_headers["Access-Control-Allow-Origin"] = _to_string(
cors_rule.allowed_origins
)
origin = headers.get("Origin")
if cors_matches_origin(origin, cors_rule.allowed_origins):
self.response_headers["Access-Control-Allow-Origin"] = origin
else:
raise AccessForbidden(
"CORSResponse: This CORS request is not allowed. This is usually because the evalution of Origin, request method / Access-Control-Request-Method or Access-Control-Request-Headers are not whitelisted by the resource's CORS spec."
)
if cors_rule.allowed_headers is not None:
self.response_headers["Access-Control-Allow-Headers"] = _to_string(
cors_rule.allowed_headers
Expand All @@ -383,7 +389,7 @@ def _to_string(header: Union[List[str], str]) -> str:
cors_rule.max_age_seconds
)

def _response_options(self, bucket_name):
def _response_options(self, headers, bucket_name):
# Return 200 with the headers from the bucket CORS configuration
self._authenticate_and_authorize_s3_action()
try:
Expand All @@ -395,7 +401,7 @@ def _response_options(self, bucket_name):
"",
) # AWS S3 seems to return 403 on OPTIONS and 404 on GET/HEAD

self._set_cors_headers(bucket)
self._set_cors_headers(headers, bucket)

return 200, self.response_headers, ""

Expand Down Expand Up @@ -1241,7 +1247,7 @@ def _key_response(self, request, full_url, headers):
return self._key_response_post(request, body, bucket_name, query, key_name)
elif method == "OPTIONS":
# OPTIONS response doesn't depend on the key_name: always return 200 with CORS headers
return self._response_options(bucket_name)
return self._response_options(request.headers, bucket_name)
else:
raise NotImplementedError(
f"Method {method} has not been implemented in the S3 backend yet"
Expand Down
13 changes: 12 additions & 1 deletion moto/s3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import hashlib
from urllib.parse import urlparse, unquote, quote
from requests.structures import CaseInsensitiveDict
from typing import Union, Tuple
from typing import List, Union, Tuple
import sys
from moto.settings import S3_IGNORE_SUBDOMAIN_BUCKETNAME

Expand Down Expand Up @@ -212,3 +212,14 @@ def _hash(fn, args) -> bytes:
except TypeError:
# The usedforsecurity-parameter is only available as of Python 3.9
return fn(*args).hexdigest().encode("utf-8")


def cors_matches_origin(origin_header: str, allowed_origins: List[str]) -> bool:
if "*" in allowed_origins:
return True
if origin_header in allowed_origins:
return True
for allowed in allowed_origins:
if re.match(allowed.replace(".", "\\.").replace("*", ".*"), origin_header):
return True
return False
17 changes: 17 additions & 0 deletions tests/test_s3/test_s3_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
clean_key_name,
undo_clean_key_name,
compute_checksum,
cors_matches_origin,
)
from unittest.mock import patch

Expand Down Expand Up @@ -141,3 +142,19 @@ def test_checksum_crc32():

def test_checksum_crc32c():
compute_checksum(b"somedata", "CRC32C").should.equal(b"MTM5MzM0Mzk1Mg==")


def test_cors_utils():
"Fancy string matching"
assert cors_matches_origin("a", ["a"])
assert cors_matches_origin("b", ["a", "b"])
assert not cors_matches_origin("c", [])
assert not cors_matches_origin("c", ["a", "b"])

assert cors_matches_origin("http://www.google.com", ["http://*.google.com"])
assert cors_matches_origin("http://www.google.com", ["http://www.*.com"])
assert cors_matches_origin("http://www.google.com", ["http://*"])
assert cors_matches_origin("http://www.google.com", ["*"])

assert not cors_matches_origin("http://www.google.com", ["http://www.*.org"])
assert not cors_matches_origin("http://www.google.com", ["https://*"])
92 changes: 91 additions & 1 deletion tests/test_s3/test_server.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import io
from urllib.parse import urlparse, parse_qs
import sure # noqa # pylint: disable=unused-import
import requests
import pytest
import xmltodict

from flask.testing import FlaskClient
import moto.server as server
from moto.moto_server.threaded_moto_server import ThreadedMotoServer
from unittest.mock import patch

"""
Expand Down Expand Up @@ -223,7 +225,7 @@ def test_s3_server_post_cors_exposed_header():
preflight_headers = {
"Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": "origin, x-requested-with",
"Origin": "https://localhost:9000",
"Origin": "https://example.org",
}
# Returns 403 on non existing bucket
preflight_response = test_client.options(
Expand Down Expand Up @@ -257,3 +259,91 @@ def test_s3_server_post_cors_exposed_header():
for header_name, header_value in expected_cors_headers.items():
assert header_name in preflight_response.headers
assert preflight_response.headers[header_name] == header_value


def test_s3_server_post_cors_multiple_origins():
"""Test that Moto only responds with the Origin that we that hosts the server"""
# github.com/getmoto/moto/issues/6003

cors_config_payload = """<CORSConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
<CORSRule>
<AllowedOrigin>https://example.org</AllowedOrigin>
<AllowedOrigin>https://localhost:6789</AllowedOrigin>
<AllowedMethod>POST</AllowedMethod>
</CORSRule>
</CORSConfiguration>
"""

thread = ThreadedMotoServer(port="6789", verbose=False)
thread.start()

# Create the bucket
requests.put("http://testcors.localhost:6789/")
requests.put("http://testcors.localhost:6789/?cors", data=cors_config_payload)

# Test only our requested origin is returned
preflight_response = requests.options(
"http://testcors.localhost:6789/test",
headers={
"Access-Control-Request-Method": "POST",
"Origin": "https://localhost:6789",
},
)
assert preflight_response.status_code == 200
assert (
preflight_response.headers["Access-Control-Allow-Origin"]
== "https://localhost:6789"
)
assert preflight_response.content == b""

# Verify a request with unknown origin fails
preflight_response = requests.options(
"http://testcors.localhost:6789/test",
headers={
"Access-Control-Request-Method": "POST",
"Origin": "https://unknown.host",
},
)
assert preflight_response.status_code == 403
assert b"<Code>AccessForbidden</Code>" in preflight_response.content

# Verify we can use a wildcard anywhere in the origin
cors_config_payload = """<CORSConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><CORSRule>
<AllowedOrigin>https://*.google.com</AllowedOrigin>
<AllowedMethod>POST</AllowedMethod>
</CORSRule></CORSConfiguration>"""
requests.put("http://testcors.localhost:6789/?cors", data=cors_config_payload)
for origin in ["https://sth.google.com", "https://a.google.com"]:
preflight_response = requests.options(
"http://testcors.localhost:6789/test",
headers={"Access-Control-Request-Method": "POST", "Origin": origin},
)
assert preflight_response.status_code == 200
assert preflight_response.headers["Access-Control-Allow-Origin"] == origin

# Non-matching requests throw an error though - it does not act as a full wildcard
preflight_response = requests.options(
"http://testcors.localhost:6789/test",
headers={
"Access-Control-Request-Method": "POST",
"Origin": "sth.microsoft.com",
},
)
assert preflight_response.status_code == 403
assert b"<Code>AccessForbidden</Code>" in preflight_response.content

# Verify we can use a wildcard as the origin
cors_config_payload = """<CORSConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><CORSRule>
<AllowedOrigin>*</AllowedOrigin>
<AllowedMethod>POST</AllowedMethod>
</CORSRule></CORSConfiguration>"""
requests.put("http://testcors.localhost:6789/?cors", data=cors_config_payload)
for origin in ["https://a.google.com", "http://b.microsoft.com", "any"]:
preflight_response = requests.options(
"http://testcors.localhost:6789/test",
headers={"Access-Control-Request-Method": "POST", "Origin": origin},
)
assert preflight_response.status_code == 200
assert preflight_response.headers["Access-Control-Allow-Origin"] == origin

thread.stop()

0 comments on commit 8b058d9

Please sign in to comment.