diff --git a/.gitignore b/.gitignore index 24d900226..739b2d424 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ __pycache__ .vscode/ .pytest_cache/ venv/ +.venv/ dist/ docs/_build diff --git a/storages/backends/s3boto3.py b/storages/backends/s3boto3.py index 5122aa937..a6becc1e5 100644 --- a/storages/backends/s3boto3.py +++ b/storages/backends/s3boto3.py @@ -23,6 +23,7 @@ from storages.base import BaseStorage from storages.compress import CompressedFileMixin from storages.compress import CompressStorageMixin +from storages.utils import ReadBytesWrapper from storages.utils import check_location from storages.utils import clean_name from storages.utils import get_available_overwrite_name @@ -432,6 +433,11 @@ def _save(self, name, content): if is_seekable(content): content.seek(0, os.SEEK_SET) + + # wrap content so read() always returns bytes. This is required for passing it + # to obj.upload_fileobj() or self._compress_content() + content = ReadBytesWrapper(content) + if (self.gzip and params['ContentType'] in self.gzip_content_types and 'ContentEncoding' not in params): diff --git a/storages/utils.py b/storages/utils.py index a44c524c9..144935527 100644 --- a/storages/utils.py +++ b/storages/utils.py @@ -5,6 +5,7 @@ from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import SuspiciousFileOperation +from django.core.files.utils import FileProxyMixin from django.utils.encoding import force_bytes @@ -125,3 +126,35 @@ def get_available_overwrite_name(name, max_length): def is_seekable(file_object): return not hasattr(file_object, 'seekable') or file_object.seekable() + + +class ReadBytesWrapper(FileProxyMixin): + """ + A wrapper for a file-like object, that makes read() always returns bytes. + """ + def __init__(self, file, encoding=None): + """ + :param file: The file-like object to wrap. + :param encoding: Specify the encoding to use when file.read() returns strings. + If not provided will default to file.encoding, of if that's not available, + to utf-8. + """ + self.file = file + self._encoding = ( + encoding + or getattr(file, "encoding", None) + or "utf-8" + ) + + def read(self, *args, **kwargs): + content = self.file.read(*args, **kwargs) + + if not isinstance(content, bytes): + content = content.encode(self._encoding) + return content + + def close(self): + self.file.close() + + def readable(self): + return True diff --git a/tests/settings.py b/tests/settings.py index adeb99786..244f94de6 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -10,3 +10,6 @@ SECRET_KEY = 'hailthesunshine' USE_TZ = True + +# the following test settings are required for moto to work. +AWS_STORAGE_BUCKET_NAME = "test-bucket" diff --git a/tests/test_files/windows-1252-encoded.txt b/tests/test_files/windows-1252-encoded.txt new file mode 100644 index 000000000..19f52934b --- /dev/null +++ b/tests/test_files/windows-1252-encoded.txt @@ -0,0 +1 @@ +™€‰ \ No newline at end of file diff --git a/tests/test_s3boto3.py b/tests/test_s3boto3.py index bf3dfe1e9..f1d49fa2d 100644 --- a/tests/test_s3boto3.py +++ b/tests/test_s3boto3.py @@ -8,14 +8,17 @@ from unittest import skipIf from urllib.parse import urlparse +import boto3 import boto3.s3.transfer from botocore.exceptions import ClientError from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.core.files.base import ContentFile +from django.core.files.base import File from django.test import TestCase from django.test import override_settings from django.utils.timezone import is_aware +from moto import mock_s3 from storages.backends import s3boto3 from tests.utils import NonSeekableContentFile @@ -32,11 +35,11 @@ def setUp(self): self.storage._connections.connection = mock.MagicMock() def test_s3_session(self): - settings.AWS_S3_SESSION_PROFILE = "test_profile" - with mock.patch('boto3.Session') as mock_session: - storage = s3boto3.S3Boto3Storage() - _ = storage.connection - mock_session.assert_called_once_with(profile_name="test_profile") + with override_settings(AWS_S3_SESSION_PROFILE="test_profile"): + with mock.patch('boto3.Session') as mock_session: + storage = s3boto3.S3Boto3Storage() + _ = storage.connection + mock_session.assert_called_once_with(profile_name="test_profile") def test_pickle_with_bucket(self): """ @@ -94,7 +97,7 @@ def test_storage_save(self): obj = self.storage.bucket.Object.return_value obj.upload_fileobj.assert_called_with( - content, + mock.ANY, ExtraArgs={ 'ContentType': 'text/plain', }, @@ -112,7 +115,7 @@ def test_storage_save_non_seekable(self): obj = self.storage.bucket.Object.return_value obj.upload_fileobj.assert_called_with( - content, + mock.ANY, ExtraArgs={ 'ContentType': 'text/plain', }, @@ -131,7 +134,7 @@ def test_storage_save_with_default_acl(self): obj = self.storage.bucket.Object.return_value obj.upload_fileobj.assert_called_with( - content, + mock.ANY, ExtraArgs={ 'ContentType': 'text/plain', 'ACL': 'private', @@ -152,7 +155,7 @@ def test_storage_object_parameters_not_overwritten_by_default(self): obj = self.storage.bucket.Object.return_value obj.upload_fileobj.assert_called_with( - content, + mock.ANY, ExtraArgs={ 'ContentType': 'text/plain', 'ACL': 'private', @@ -172,7 +175,7 @@ def test_content_type(self): obj = self.storage.bucket.Object.return_value obj.upload_fileobj.assert_called_with( - content, + mock.ANY, ExtraArgs={ 'ContentType': 'image/jpeg', }, @@ -187,8 +190,8 @@ def test_storage_save_gzipped(self): content = ContentFile("I am gzip'd") self.storage.save(name, content) obj = self.storage.bucket.Object.return_value - obj.upload_fileobj.assert_called_with( - content, + obj.upload_fileobj.assert_called_once_with( + mock.ANY, ExtraArgs={ 'ContentType': 'application/octet-stream', 'ContentEncoding': 'gzip', @@ -208,7 +211,7 @@ def get_object_parameters(name): obj = self.storage.bucket.Object.return_value obj.upload_fileobj.assert_called_with( - content, + mock.ANY, ExtraArgs={ "ContentType": "application/gzip", }, @@ -223,8 +226,8 @@ def test_storage_save_gzipped_non_seekable(self): content = NonSeekableContentFile("I am gzip'd") self.storage.save(name, content) obj = self.storage.bucket.Object.return_value - obj.upload_fileobj.assert_called_with( - content, + obj.upload_fileobj.assert_called_once_with( + mock.ANY, ExtraArgs={ 'ContentType': 'application/octet-stream', 'ContentEncoding': 'gzip', @@ -287,7 +290,7 @@ def test_compress_content_len(self): Test that file returned by _compress_content() is readable. """ self.storage.gzip = True - content = ContentFile("I should be gzip'd") + content = ContentFile(b"I should be gzip'd") content = self.storage._compress_content(content) self.assertTrue(len(content.read()) > 0) @@ -569,7 +572,7 @@ def test_storage_listdir_base(self): self.storage._connections.connection.meta.client.get_paginator.return_value = paginator dirs, files = self.storage.listdir('') - paginator.paginate.assert_called_with(Bucket=None, Delimiter='/', Prefix='') + paginator.paginate.assert_called_with(Bucket=settings.AWS_STORAGE_BUCKET_NAME, Delimiter='/', Prefix='') self.assertEqual(dirs, ['some', 'other']) self.assertEqual(files, ['2.txt', '4.txt']) @@ -594,7 +597,7 @@ def test_storage_listdir_subdir(self): self.storage._connections.connection.meta.client.get_paginator.return_value = paginator dirs, files = self.storage.listdir('some/') - paginator.paginate.assert_called_with(Bucket=None, Delimiter='/', Prefix='some/') + paginator.paginate.assert_called_with(Bucket=settings.AWS_STORAGE_BUCKET_NAME, Delimiter='/', Prefix='some/') self.assertEqual(dirs, ['path']) self.assertEqual(files, ['2.txt']) @@ -615,7 +618,7 @@ def test_storage_listdir_empty(self): self.storage._connections.connection.meta.client.get_paginator.return_value = paginator dirs, files = self.storage.listdir('dir/') - paginator.paginate.assert_called_with(Bucket=None, Delimiter='/', Prefix='dir/') + paginator.paginate.assert_called_with(Bucket=settings.AWS_STORAGE_BUCKET_NAME, Delimiter='/', Prefix='dir/') self.assertEqual(dirs, []) self.assertEqual(files, []) @@ -865,3 +868,90 @@ def test_closed(self): with self.subTest("is True after close"): f.close() self.assertTrue(f.closed) + + +@mock_s3 +class S3Boto3StorageTestsWithMoto(TestCase): + """ + Using mock_s3 as a class decorator automatically decorates methods, + but NOT classmethods or staticmethods. + """ + + def setUp(cls): + super().setUp() + + cls.storage = s3boto3.S3Boto3Storage() + cls.bucket = cls.storage.connection.Bucket(settings.AWS_STORAGE_BUCKET_NAME) + cls.bucket.create() + + def test_save_bytes_file(self): + self.storage.save("bytes_file.txt", File(io.BytesIO(b"foo1"))) + + self.assertEqual( + b"foo1", + self.bucket.Object("bytes_file.txt").get()['Body'].read(), + ) + + def test_save_string_file(self): + self.storage.save("string_file.txt", File(io.StringIO("foo2"))) + + self.assertEqual( + b"foo2", + self.bucket.Object("string_file.txt").get()['Body'].read(), + ) + + def test_save_bytes_content_file(self): + self.storage.save("bytes_content.txt", ContentFile(b"foo3")) + + self.assertEqual( + b"foo3", + self.bucket.Object("bytes_content.txt").get()['Body'].read(), + ) + + def test_save_string_content_file(self): + self.storage.save("string_content.txt", ContentFile("foo4")) + + self.assertEqual( + b"foo4", + self.bucket.Object("string_content.txt").get()['Body'].read(), + ) + + def test_content_type_guess(self): + """ + Test saving a file where the ContentType is guessed from the filename. + """ + name = 'test_image.jpg' + content = ContentFile(b'data') + content.content_type = None + self.storage.save(name, content) + + s3_object_fetched = self.bucket.Object(name).get() + self.assertEqual(b"data", s3_object_fetched['Body'].read()) + self.assertEqual(s3_object_fetched["ContentType"], "image/jpeg") + + def test_content_type_attribute(self): + """ + Test saving a file with a custom content type attribute. + """ + content = ContentFile(b'data') + content.content_type = "test/foo" + self.storage.save("test_file", content) + + s3_object_fetched = self.bucket.Object("test_file").get() + self.assertEqual(b"data", s3_object_fetched['Body'].read()) + self.assertEqual(s3_object_fetched["ContentType"], "test/foo") + + def test_content_type_not_detectable(self): + """ + Test saving a file with no detectable content type. + """ + content = ContentFile(b'data') + content.content_type = None + self.storage.save("test_file", content) + + s3_object_fetched = self.bucket.Object("test_file").get() + self.assertEqual(b"data", s3_object_fetched['Body'].read()) + self.assertEqual( + s3_object_fetched["ContentType"], + s3boto3.S3Boto3Storage.default_content_type, + ) diff --git a/tests/test_utils.py b/tests/test_utils.py index b60827652..17a7c93e6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,6 @@ import datetime +import io +import os.path import pathlib from django.conf import settings @@ -136,3 +138,77 @@ def test_truncates_away_filename_raises(self): name = 'parent/child.txt' with self.assertRaises(SuspiciousFileOperation): gaon(name, len(name) - 5) + + +class TestReadBytesWrapper(TestCase): + def test_with_bytes_file(self): + file = io.BytesIO(b"abcd") + file_wrapped = utils.ReadBytesWrapper(file) + + # test read() with default args + self.assertEqual(b"abcd", file_wrapped.read()) + + # test seek() with default args + self.assertEqual(0, file_wrapped.seek(0)) + self.assertEqual(b"abcd", file_wrapped.read()) + + # test read() with custom args + file_wrapped.seek(0) + self.assertEqual(b"ab", file_wrapped.read(2)) + + # test seek() with custom args + self.assertEqual(1, file_wrapped.seek(-1, io.SEEK_CUR)) + self.assertEqual(b"bcd", file_wrapped.read()) + + def test_with_string_file(self): + file = io.StringIO("wxyz") + file_wrapped = utils.ReadBytesWrapper(file) + + # test read() with default args + self.assertEqual(b"wxyz", file_wrapped.read()) + + # test seek() with default args + self.assertEqual(0, file_wrapped.seek(0)) + self.assertEqual(b"wxyz", file_wrapped.read()) + + # test read() with custom args + file_wrapped.seek(0) + self.assertEqual(b"wx", file_wrapped.read(2)) + + # test seek() with custom args + self.assertEqual(2, file_wrapped.seek(0, io.SEEK_CUR)) + self.assertEqual(b"yz", file_wrapped.read()) + + # I chose the characters ™€‰ for the following tests because they produce different + # bytes when encoding with utf-8 vs windows-1252 vs utf-16 + + def test_with_string_file_specified_encoding(self): + content = "\u2122\u20AC\u2030" + file = io.StringIO(content) + file_wrapped = utils.ReadBytesWrapper(file, encoding="utf-16") + + # test read() returns specified encoding + self.assertEqual(file_wrapped.read(), content.encode("utf-16")) + + def test_with_string_file_detect_encoding(self): + content = "\u2122\u20AC\u2030" + with open( + file=os.path.join(os.path.dirname(__file__), "test_files", "windows-1252-encoded.txt"), + mode="r", + encoding="windows-1252", + ) as file: + self.assertEqual(file.read(), content) + file.seek(0) + + file_wrapped = utils.ReadBytesWrapper(file) + + # test read() returns encoding detected from file object. + self.assertEqual(file_wrapped.read(), content.encode("windows-1252")) + + def test_with_string_file_fallback_encoding(self): + content = "\u2122\u20AC\u2030" + file = io.StringIO(content) + file_wrapped = utils.ReadBytesWrapper(file) + + # test read() returns fallback utf-8 encoding + self.assertEqual(file_wrapped.read(), content.encode("utf-8")) diff --git a/tox.ini b/tox.ini index 3531648ef..33a55c545 100644 --- a/tox.ini +++ b/tox.ini @@ -22,6 +22,7 @@ deps = pytest pytest-cov rsa + moto extras = azure boto3