From 2c16b98d456cc1fa3da81478bc00bbf261e53091 Mon Sep 17 00:00:00 2001 From: Josh Schneier Date: Fri, 10 May 2024 19:43:27 -0400 Subject: [PATCH] [s3] Skip generating signed URLs if querystring_auth=False --- CHANGELOG.rst | 2 ++ storages/backends/s3.py | 65 +++++++++++++++------------------- tests/test_s3.py | 77 ++++++++++++++++++++--------------------- 3 files changed, 68 insertions(+), 76 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0242e847..3dff2947 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -9,9 +9,11 @@ S3 - Pull ``AWS_SESSION_TOKEN`` from the environment (`#1399`_) - Fix newline handling for text mode files (`#1381`_) +- Do not sign URLs when ``querystring_auth=False`` e.g public buckets or static files (`#1402`_) .. _#1399: https://github.com/jschneier/django-storages/pull/1399 .. _#1381: https://github.com/jschneier/django-storages/pull/1381 +.. _#1402: https://github.com/jschneier/django-storages/pull/1381 1.14.3 (2024-05-04) diff --git a/storages/backends/s3.py b/storages/backends/s3.py index 9822f2ef..3230f305 100644 --- a/storages/backends/s3.py +++ b/storages/backends/s3.py @@ -7,9 +7,7 @@ import warnings from datetime import datetime from datetime import timedelta -from urllib.parse import parse_qsl from urllib.parse import urlencode -from urllib.parse import urlsplit from django.contrib.staticfiles.storage import ManifestFilesMixin from django.core.exceptions import ImproperlyConfigured @@ -34,6 +32,7 @@ try: import boto3.session + import botocore import s3transfer.constants from boto3.s3.transfer import TransferConfig from botocore.config import Config @@ -330,6 +329,7 @@ def __init__(self, **settings): self._bucket = None self._connections = threading.local() + self._unsigned_connections = threading.local() if self.config is not None: warnings.warn( @@ -439,11 +439,13 @@ def get_default_settings(self): def __getstate__(self): state = self.__dict__.copy() state.pop("_connections", None) + state.pop("_unsigned_connections", None) state.pop("_bucket", None) return state def __setstate__(self, state): state["_connections"] = threading.local() + state["_unsigned_connections"] = threading.local() state["_bucket"] = None self.__dict__ = state @@ -462,6 +464,24 @@ def connection(self): ) return self._connections.connection + @property + def unsigned_connection(self): + unsigned_connection = getattr(self._unsigned_connections, "connection", None) + if unsigned_connection is None: + session = self._create_session() + config = self.client_config.merge( + Config(signature_version=botocore.UNSIGNED) + ) + self._unsigned_connections.connection = session.resource( + "s3", + region_name=self.region_name, + use_ssl=self.use_ssl, + endpoint_url=self.endpoint_url, + config=config, + verify=self.verify, + ) + return self._unsigned_connections.connection + def _create_session(self): """ If a user specifies a profile name and this class obtains access keys @@ -635,37 +655,6 @@ def get_modified_time(self, name): else: return make_naive(entry.last_modified) - def _strip_signing_parameters(self, url): - # Boto3 does not currently support generating URLs that are unsigned. Instead - # we take the signed URLs and strip any querystring params related to signing - # and expiration. - # Note that this may end up with URLs that are still invalid, especially if - # params are passed in that only work with signed URLs, e.g. response header - # params. - # The code attempts to strip all query parameters that match names of known - # parameters from v2 and v4 signatures, regardless of the actual signature - # version used. - split_url = urlsplit(url) - qs = parse_qsl(split_url.query, keep_blank_values=True) - blacklist = { - "x-amz-algorithm", - "x-amz-credential", - "x-amz-date", - "x-amz-expires", - "x-amz-signedheaders", - "x-amz-signature", - "x-amz-security-token", - "awsaccesskeyid", - "expires", - "signature", - } - filtered_qs = ((key, val) for key, val in qs if key.lower() not in blacklist) - # Note: Parameters that did not have a value in the original query string will - # have an '=' sign appended to it, e.g ?foo&bar becomes ?foo=&bar= - joined_qs = ("=".join(keyval) for keyval in filtered_qs) - split_url = split_url._replace(query="&".join(joined_qs)) - return split_url.geturl() - def url(self, name, parameters=None, expire=None, http_method=None): # Preserve the trailing slash after normalizing the path. name = self._normalize_name(clean_name(name)) @@ -691,12 +680,14 @@ def url(self, name, parameters=None, expire=None, http_method=None): params["Bucket"] = self.bucket.name params["Key"] = name - url = self.bucket.meta.client.generate_presigned_url( + + connection = ( + self.connection if self.querystring_auth else self.unsigned_connection + ) + url = connection.meta.client.generate_presigned_url( "get_object", Params=params, ExpiresIn=expire, HttpMethod=http_method ) - if self.querystring_auth: - return url - return self._strip_signing_parameters(url) + return url def get_available_name(self, name, max_length=None): """Overwrite existing file with the same name.""" diff --git a/tests/test_s3.py b/tests/test_s3.py index 12c2d009..6038bbe5 100644 --- a/tests/test_s3.py +++ b/tests/test_s3.py @@ -11,6 +11,7 @@ import boto3 import boto3.s3.transfer +import botocore from botocore.config import Config as ClientConfig from botocore.exceptions import ClientError from django.conf import settings @@ -35,25 +36,37 @@ class S3StorageTests(TestCase): def setUp(self): self.storage = s3.S3Storage() self.storage._connections.connection = mock.MagicMock() + self.storage._unsigned_connections.connection = mock.MagicMock() - def test_s3_session(self): + @mock.patch("boto3.Session") + def test_s3_session(self, session): with override_settings(AWS_S3_SESSION_PROFILE="test_profile"): - with mock.patch("boto3.Session") as mock_session: - storage = s3.S3Storage() - _ = storage.connection - mock_session.assert_called_once_with(profile_name="test_profile") + storage = s3.S3Storage() + _ = storage.connection + session.assert_called_once_with(profile_name="test_profile") - def test_client_config(self): + @mock.patch("boto3.Session.resource") + def test_client_config(self, resource): with override_settings( AWS_S3_CLIENT_CONFIG=ClientConfig(max_pool_connections=30) ): storage = s3.S3Storage() - with mock.patch("boto3.Session.resource") as mock_resource: - _ = storage.connection - mock_resource.assert_called_once() - self.assertEqual( - 30, mock_resource.call_args[1]["config"].max_pool_connections - ) + _ = storage.connection + resource.assert_called_once() + self.assertEqual(30, resource.call_args[1]["config"].max_pool_connections) + + @mock.patch("boto3.Session.resource") + def test_connection_unsiged(self, resource): + with override_settings(AWS_S3_ADDRESSING_STYLE="virtual"): + storage = s3.S3Storage() + _ = storage.unsigned_connection + resource.assert_called_once() + self.assertEqual( + botocore.UNSIGNED, resource.call_args[1]["config"].signature_version + ) + self.assertEqual( + "virtual", resource.call_args[1]["config"].s3["addressing_style"] + ) def test_pickle_with_bucket(self): """ @@ -664,10 +677,10 @@ def _test_storage_mtime(self, use_tz): def test_storage_url(self): name = "test_storage_size.txt" url = "http://aws.amazon.com/%s" % name - self.storage.bucket.meta.client.generate_presigned_url.return_value = url + self.storage.connection.meta.client.generate_presigned_url.return_value = url self.storage.bucket.name = "bucket" self.assertEqual(self.storage.url(name), url) - self.storage.bucket.meta.client.generate_presigned_url.assert_called_with( + self.storage.connection.meta.client.generate_presigned_url.assert_called_with( "get_object", Params={"Bucket": self.storage.bucket.name, "Key": name}, ExpiresIn=self.storage.querystring_expire, @@ -675,9 +688,8 @@ def test_storage_url(self): ) custom_expire = 123 - self.assertEqual(self.storage.url(name, expire=custom_expire), url) - self.storage.bucket.meta.client.generate_presigned_url.assert_called_with( + self.storage.connection.meta.client.generate_presigned_url.assert_called_with( "get_object", Params={"Bucket": self.storage.bucket.name, "Key": name}, ExpiresIn=custom_expire, @@ -685,16 +697,21 @@ def test_storage_url(self): ) custom_method = "HEAD" - self.assertEqual(self.storage.url(name, http_method=custom_method), url) - self.storage.bucket.meta.client.generate_presigned_url.assert_called_with( + self.storage.connection.meta.client.generate_presigned_url.assert_called_with( "get_object", Params={"Bucket": self.storage.bucket.name, "Key": name}, ExpiresIn=self.storage.querystring_expire, HttpMethod=custom_method, ) - def test_storage_url_custom_domain_signed_urls(self): + def test_url_unsigned(self): + self.storage.querystring_auth = False + self.storage.url("test_name") + self.storage.unsigned_connection.meta.client.generate_presigned_url.assert_called_once() + + @mock.patch("storages.backends.s3.datetime") + def test_storage_url_custom_domain_signed_urls(self, dt): key_id = "test-key" filename = "file.txt" pem = dedent( @@ -732,11 +749,8 @@ def test_storage_url_custom_domain_signed_urls(self): self.assertEqual(self.storage.url(filename), url) self.storage.querystring_auth = True - with mock.patch("storages.backends.s3.datetime") as mock_datetime: - mock_datetime.utcnow.return_value = datetime.datetime.utcfromtimestamp( - 0 - ) - self.assertEqual(self.storage.url(filename), signed_url) + dt.utcnow.return_value = datetime.datetime.utcfromtimestamp(0) + self.assertEqual(self.storage.url(filename), signed_url) def test_generated_url_is_encoded(self): self.storage.custom_domain = "mock.cloudfront.net" @@ -766,21 +780,6 @@ def test_custom_domain_parameters(self): self.assertEqual(parsed_url.path, "/filename.mp4") self.assertEqual(parsed_url.query, "version=10") - def test_strip_signing_parameters(self): - expected = "http://bucket.s3-aws-region.amazonaws.com/foo/bar" - self.assertEqual( - self.storage._strip_signing_parameters( - "%s?X-Amz-Date=12345678&X-Amz-Signature=Signature" % expected - ), - expected, - ) - self.assertEqual( - self.storage._strip_signing_parameters( - "%s?expires=12345678&signature=Signature" % expected - ), - expected, - ) - @skipIf(threading is None, "Test requires threading") def test_connection_threading(self): connections = []