Skip to content

Commit

Permalink
Azure: improve handling of special characters in filenames (#752)
Browse files Browse the repository at this point in the history
*  remove file name special char cleaner fixes #609

* tests
  • Loading branch information
nitely authored and jschneier committed Sep 9, 2019
1 parent c175abf commit 2fbaa70
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 46 deletions.
32 changes: 10 additions & 22 deletions storages/backends/azure_storage.py
@@ -1,7 +1,6 @@
from __future__ import unicode_literals

import mimetypes
import re
from datetime import datetime, timedelta
from tempfile import SpooledTemporaryFile

Expand All @@ -13,7 +12,7 @@
from django.core.files.storage import Storage
from django.utils import timezone
from django.utils.deconstruct import deconstructible
from django.utils.encoding import force_bytes, force_text
from django.utils.encoding import filepath_to_uri, force_bytes

from storages.utils import (
clean_name, get_available_overwrite_name, safe_join, setting,
Expand Down Expand Up @@ -101,10 +100,7 @@ def _get_valid_path(s):
# * must not end with dot or slash
# * can contain any character
# * must escape URL reserved characters
# We allow a subset of this to avoid
# illegal file names. We must ensure it is idempotent.
s = force_text(s).strip().replace(' ', '_')
s = re.sub(r'(?u)[^-\w./]', '', s)
# (not needed here since the azure client will do that)
s = s.strip('./')
if len(s) > _AZURE_NAME_MAX_LEN:
raise ValueError(
Expand All @@ -120,12 +116,6 @@ def _get_valid_path(s):
return s


def _clean_name_dance(name):
# `get_valid_path` may return `foo/../bar`
name = name.replace('\\', '/')
return clean_name(_get_valid_path(clean_name(name)))


# Max len according to azure's docs
_AZURE_NAME_MAX_LEN = 1024

Expand Down Expand Up @@ -198,29 +188,27 @@ def azure_protocol(self):
else:
return 'http'

def _path(self, name):
name = _clean_name_dance(name)
def _normalize_name(self, name):
try:
return safe_join(self.location, name)
except ValueError:
raise SuspiciousOperation("Attempted access to '%s' denied." % name)

def _get_valid_path(self, name):
# Must be idempotent
return _get_valid_path(self._path(name))
return _get_valid_path(
self._normalize_name(
clean_name(name)))

def _open(self, name, mode="rb"):
return AzureStorageFile(name, mode, self)

def get_valid_name(self, name):
return _clean_name_dance(name)

def get_available_name(self, name, max_length=_AZURE_NAME_MAX_LEN):
"""
Returns a filename that's free on the target storage system, and
available for new content to be written to.
"""
name = self.get_valid_name(name)
name = clean_name(name)
if self.overwrite_files:
return get_available_overwrite_name(name, max_length)
return super(AzureStorage, self).get_available_name(name, max_length)
Expand Down Expand Up @@ -248,7 +236,7 @@ def size(self, name):
return properties.content_length

def _save(self, name, content):
name_only = self.get_valid_name(name)
cleaned_name = clean_name(name)
name = self._get_valid_path(name)
guessed_type, content_encoding = mimetypes.guess_type(name)
content_type = (
Expand All @@ -270,7 +258,7 @@ def _save(self, name, content):
content_encoding=content_encoding),
max_connections=self.upload_max_conn,
timeout=self.timeout)
return name_only
return cleaned_name

def _expire_at(self, expire):
# azure expects time in UTC
Expand All @@ -292,7 +280,7 @@ def url(self, name, expire=None):
make_blob_url_kwargs['protocol'] = self.azure_protocol
return self.custom_service.make_blob_url(
container_name=self.azure_container,
blob_name=name,
blob_name=filepath_to_uri(name),
**make_blob_url_kwargs)

def get_modified_time(self, name):
Expand Down
57 changes: 50 additions & 7 deletions tests/integration/test_azure.py
Expand Up @@ -29,30 +29,30 @@ def setUp(self, *args):
self.storage.azure_container, public_access=False, fail_on_exist=False)

def test_save(self):
expected_name = "some_blob_Ϊ.txt"
expected_name = "some blob Ϊ.txt"
self.assertFalse(self.storage.exists(expected_name))
stream = io.BytesIO(b'Im a stream')
name = self.storage.save('some blob Ϊ.txt', stream)
name = self.storage.save(expected_name, stream)
self.assertEqual(name, expected_name)
self.assertTrue(self.storage.exists(expected_name))

def test_delete(self):
self.storage.location = 'path'
expected_name = "some_blob_Ϊ.txt"
expected_name = "some blob Ϊ.txt"
self.assertFalse(self.storage.exists(expected_name))
stream = io.BytesIO(b'Im a stream')
name = self.storage.save('some blob Ϊ.txt', stream)
name = self.storage.save(expected_name, stream)
self.assertEqual(name, expected_name)
self.assertTrue(self.storage.exists(expected_name))
self.storage.delete(expected_name)
self.assertFalse(self.storage.exists(expected_name))

def test_size(self):
self.storage.location = 'path'
expected_name = "some_path/some_blob_Ϊ.txt"
expected_name = "some path/some blob Ϊ.txt"
self.assertFalse(self.storage.exists(expected_name))
stream = io.BytesIO(b'Im a stream')
name = self.storage.save('some path/some blob Ϊ.txt', stream)
name = self.storage.save(expected_name, stream)
self.assertEqual(name, expected_name)
self.assertTrue(self.storage.exists(expected_name))
self.assertEqual(self.storage.size(expected_name), len(b'Im a stream'))
Expand All @@ -64,6 +64,15 @@ def test_url(self):
# has some query-string
self.assertTrue("/test/my_file.txt?" in self.storage.url("my_file.txt"))

def test_url_unsafe_chars(self):
name = "my?file <foo>.txt"
expected = "/test/my%3Ffile%20%3Cfoo%3E.txt"
self.assertTrue(
self.storage.url(name).endswith(expected))
# has some query-string
self.storage.expiration_secs = 360
self.assertTrue("{}?".format(expected) in self.storage.url(name))

def test_url_custom_endpoint(self):
storage = azure_storage.AzureStorage()
storage.is_emulated = True
Expand Down Expand Up @@ -107,7 +116,7 @@ def test_open_read(self):
stream = io.BytesIO()
self.storage.service.get_blob_to_stream(
container_name=self.storage.azure_container,
blob_name='root/path/some_file.txt',
blob_name='root/path/some file.txt',
stream=stream,
max_connections=1,
timeout=10)
Expand Down Expand Up @@ -184,6 +193,11 @@ class AzureStorageExpiry(azure_storage.AzureStorage):
expiration_secs = 360


class AzureStorageSpecialChars(azure_storage.AzureStorage):
def get_valid_name(self, name):
return name


class FooFileForm(forms.Form):

foo_file = forms.FileField()
Expand Down Expand Up @@ -265,3 +279,32 @@ def test_model_form(self):
self.assertEqual(fh.read(), b'foo content')
finally:
fh.close()

def test_name_clean_issue_609(self):
"""
Should strip special characters when using the default storage
"""
simple_file = SimpleFileModel()
simple_file.foo_file = SimpleUploadedFile(
name='foo%?:;~bar.txt',
content=b'foo content')
simple_file.save()
self.assertEqual(simple_file.foo_file.name, 'foo_uploads/foobar.txt')
self.assertTrue('foobar.txt' in simple_file.foo_file.url)

@override_settings(
DEFAULT_FILE_STORAGE='tests.integration.test_azure.AzureStorageSpecialChars')
def test_name_clean_issue_609_with_special_chars(self):
"""
Should not strip special chars
"""
name = 'foo%?:;~bar.txt'
simple_file = SimpleFileModel()
simple_file.foo_file = SimpleUploadedFile(
name=name,
content=b'foo content')
simple_file.save()
self.assertEqual(
simple_file.foo_file.name, 'foo_uploads/{}'.format(name))
self.assertTrue(
'foo_uploads/foo%25%3F%3A%3B~bar.txt' in simple_file.foo_file.url)
44 changes: 27 additions & 17 deletions tests/test_azure.py
Expand Up @@ -45,13 +45,13 @@ def test_get_valid_path(self):
self.storage._get_valid_path("path\\to\\somewhere"),
"path/to/somewhere")
self.assertEqual(
self.storage._get_valid_path("some/$/path"), "some/path")
self.storage._get_valid_path("some/$/path"), "some/$/path")
self.assertEqual(
self.storage._get_valid_path("/$/path"), "path")
self.storage._get_valid_path("/$/path"), "$/path")
self.assertEqual(
self.storage._get_valid_path("path/$/"), "path")
self.storage._get_valid_path("path/$/"), "path/$")
self.assertEqual(
self.storage._get_valid_path("path/$/$/$/path"), "path/path")
self.storage._get_valid_path("path/$/$/$/path"), "path/$/$/$/path")
self.assertEqual(
self.storage._get_valid_path("some///path"), "some/path")
self.assertEqual(
Expand All @@ -67,24 +67,23 @@ def test_get_valid_path(self):
self.assertRaises(ValueError, self.storage._get_valid_path, "/../")
self.assertRaises(ValueError, self.storage._get_valid_path, "..")
self.assertRaises(ValueError, self.storage._get_valid_path, "///")
self.assertRaises(ValueError, self.storage._get_valid_path, "!!!")
self.assertRaises(ValueError, self.storage._get_valid_path, "a" * 1025)
self.assertRaises(ValueError, self.storage._get_valid_path, "a/a" * 257)

def test_get_valid_path_idempotency(self):
self.assertEqual(
self.storage._get_valid_path("//$//a//$//"), "a")
self.storage._get_valid_path("//$//a//$//"), "$/a/$")
self.assertEqual(
self.storage._get_valid_path(
self.storage._get_valid_path("//$//a//$//")),
self.storage._get_valid_path("//$//a//$//"))
some_path = "some path/some long name & then some.txt"
self.assertEqual(
self.storage._get_valid_path("some path/some long name & then some.txt"),
"some_path/some_long_name__then_some.txt")
self.storage._get_valid_path(some_path), some_path)
self.assertEqual(
self.storage._get_valid_path(
self.storage._get_valid_path("some path/some long name & then some.txt")),
self.storage._get_valid_path("some path/some long name & then some.txt"))
self.storage._get_valid_path(some_path)),
self.storage._get_valid_path(some_path))

def test_get_available_name(self):
self.storage.overwrite_files = False
Expand All @@ -100,7 +99,7 @@ def test_get_available_name_first(self):
self.storage._service.exists.return_value = False
self.assertEqual(
self.storage.get_available_name('foo bar baz.txt'),
'foo_bar_baz.txt')
'foo bar baz.txt')
self.assertEqual(self.storage._service.exists.call_count, 1)

def test_get_available_name_max_len(self):
Expand All @@ -119,14 +118,25 @@ def test_get_available_invalid(self):
self.storage.overwrite_files = False
self.storage._service.exists.return_value = False
self.assertRaises(ValueError, self.storage.get_available_name, "")
self.assertRaises(ValueError, self.storage.get_available_name, "$$")
self.assertRaises(ValueError, self.storage.get_available_name, "/")
self.assertRaises(ValueError, self.storage.get_available_name, ".")
self.assertRaises(ValueError, self.storage.get_available_name, "///")
self.assertRaises(ValueError, self.storage.get_available_name, "...")

def test_url(self):
self.storage._custom_service.make_blob_url.return_value = 'ret_foo'
self.assertEqual(self.storage.url('some blob'), 'ret_foo')
self.storage._custom_service.make_blob_url.assert_called_once_with(
container_name=self.container_name,
blob_name='some_blob',
blob_name='some%20blob',
protocol='https')

def test_url_unsafe_chars(self):
self.storage.custom_service.make_blob_url.return_value = 'ret_foo'
self.assertEqual(self.storage.url('foo;?:@=&"<>#%{}|^~[]`bar/~!*()\''), 'ret_foo')
self.storage.custom_service.make_blob_url.assert_called_once_with(
container_name=self.container_name,
blob_name='foo%3B%3F%3A%40%3D%26%22%3C%3E%23%25%7B%7D%7C%5E~%5B%5D%60bar/~!*()\'',
protocol='https')

def test_url_expire(self):
Expand All @@ -139,12 +149,12 @@ def test_url_expire(self):
self.assertEqual(self.storage.url('some blob', 100), 'ret_foo')
self.storage._custom_service.generate_blob_shared_access_signature.assert_called_once_with(
self.container_name,
'some_blob',
'some blob',
permission=BlobPermissions.READ,
expiry=fixed_time + timedelta(seconds=100))
self.storage._custom_service.make_blob_url.assert_called_once_with(
container_name=self.container_name,
blob_name='some_blob',
blob_name='some%20blob',
sas_token='foo_token',
protocol='https')

Expand Down Expand Up @@ -284,10 +294,10 @@ def test_storage_save(self):
content = ContentFile('new content')
with mock.patch('storages.backends.azure_storage.ContentSettings') as c_mocked:
c_mocked.return_value = 'content_settings_foo'
self.assertEqual(self.storage.save(name, content), 'test_storage_save.txt')
self.assertEqual(self.storage.save(name, content), name)
self.storage._service.create_blob_from_stream.assert_called_once_with(
container_name=self.container_name,
blob_name='test_storage_save.txt',
blob_name=name,
stream=content.file,
content_settings='content_settings_foo',
max_connections=2,
Expand Down

0 comments on commit 2fbaa70

Please sign in to comment.