diff --git a/corehq/blobs/s3db.py b/corehq/blobs/s3db.py index b57a6f96b4c6..3d8abd867026 100644 --- a/corehq/blobs/s3db.py +++ b/corehq/blobs/s3db.py @@ -1,5 +1,6 @@ from __future__ import absolute_import import os +import weakref from contextlib import contextmanager from threading import Lock @@ -8,7 +9,6 @@ from corehq.blobs.util import ClosingContextProxy import boto3 -from boto3.s3.transfer import S3Transfer, ReadFileChunk from botocore.client import Config from botocore.handlers import calculate_md5 from botocore.exceptions import ClientError @@ -40,20 +40,24 @@ def __init__(self, config): def put(self, content, basename="", bucket=DEFAULT_BUCKET): identifier = self.get_identifier(basename) path = self.get_path(identifier, bucket) - self._s3_bucket(create=True) - osutil = OpenFileOSUtils() - transfer = S3Transfer(self.db.meta.client, osutil=osutil) - transfer.upload_file(content, self.s3_bucket_name, path) + s3_bucket = self._s3_bucket(create=True) + if isinstance(content, BlobStream) and content.blob_db is self: + source = {"Bucket": self.s3_bucket_name, "Key": content.blob_path} + s3_bucket.copy(source, path) + obj = s3_bucket.Object(path) + # unfortunately cannot get content-md5 here + return BlobInfo(identifier, obj.content_length, None) content.seek(0) content_md5 = get_content_md5(content) - content_length = osutil.get_file_size(content) + content_length = get_file_size(content) + s3_bucket.upload_fileobj(content, path) return BlobInfo(identifier, content_length, "md5-" + content_md5) def get(self, identifier, bucket=DEFAULT_BUCKET): path = self.get_path(identifier, bucket) with maybe_not_found(throw=NotFound(identifier, bucket)): resp = self._s3_bucket().Object(path).get() - return ClosingContextProxy(resp["Body"]) # body stream + return BlobStream(resp["Body"], self, path) def delete(self, *args, **kw): identifier, bucket = self.get_args_for_delete(*args, **kw) @@ -78,9 +82,8 @@ def delete(self, *args, **kw): def copy_blob(self, content, info, bucket): self._s3_bucket(create=True) path = self.get_path(info.identifier, bucket) - osutil = OpenFileOSUtils() - transfer = S3Transfer(self.db.meta.client, osutil=osutil) - transfer.upload_file(content, self.s3_bucket_name, path) + self._s3_bucket().upload_fileobj(content, path) + def _s3_bucket(self, create=False): if create and not self._s3_bucket_exists: @@ -99,6 +102,18 @@ def get_path(self, identifier=None, bucket=DEFAULT_BUCKET): return safejoin(bucket, identifier) +class BlobStream(ClosingContextProxy): + + def __init__(self, stream, blob_db, blob_path): + super(BlobStream, self).__init__(stream) + self._blob_db = weakref.ref(blob_db) + self.blob_path = blob_path + + @property + def blob_db(self): + return self._blob_db() + + def safepath(path): if (path.startswith(("/", ".")) or "/../" in path or @@ -125,6 +140,17 @@ def get_content_md5(content): return params["headers"]["Content-MD5"] +def get_file_size(fileobj): + if not hasattr(fileobj, 'fileno'): + pos = fileobj.tell() + try: + fileobj.seek(0, os.SEEK_END) + return fileobj.tell() + finally: + fileobj.seek(pos) + return os.fstat(fileobj.fileno()).st_size + + @contextmanager def maybe_not_found(throw=None): try: @@ -134,149 +160,3 @@ def maybe_not_found(throw=None): raise if throw is not None: raise throw - - -class OpenFileOSUtils(object): - - def get_file_size(self, fileobj): - if not hasattr(fileobj, 'fileno'): - pos = fileobj.tell() - try: - fileobj.seek(0, os.SEEK_END) - return fileobj.tell() - finally: - fileobj.seek(pos) - return os.fstat(fileobj.fileno()).st_size - - def open_file_chunk_reader(self, fileobj, start_byte, size, callback): - full_size = self.get_file_size(fileobj) - return ReadOpenFileChunk(fileobj, start_byte, size, full_size, - callback, enable_callback=False) - - def open(self, filename, mode): - raise NotImplementedError - - def remove_file(self, filename): - raise NotImplementedError - - def rename_file(self, current_filename, new_filename): - raise NotImplementedError - - -class ReadOpenFileChunk(ReadFileChunk): - """Wrapper for OpenFileChunk that implements ReadFileChunk interface - """ - - def __init__(self, fileobj, start_byte, chunk_size, full_file_size, *args, **kw): - - class FakeFile: - - def seek(self, pos): - pass - - length = min(chunk_size, full_file_size - start_byte) - self._chunk = OpenFileChunk(fileobj, start_byte, length) - super(ReadOpenFileChunk, self).__init__( - FakeFile(), start_byte, chunk_size, full_file_size, *args, **kw) - assert self._size == length, (self._size, length) - - def __repr__(self): - return ("".format( - self._chunk.file, - self._start_byte, - self._size, - )) - - def read(self, amount=None): - data = self._chunk.read(amount) - if self._callback is not None and self._callback_enabled: - self._callback(len(data)) - return data - - def seek(self, where): - old_pos = self._chunk.tell() - self._chunk.seek(where) - if self._callback is not None and self._callback_enabled: - # To also rewind the callback() for an accurate progress report - self._callback(where - old_pos) - - def tell(self): - return self._chunk.tell() - - def close(self): - self._chunk.close() - self._chunk = None - - def __enter__(self): - return self - - def __exit__(self, *args, **kwargs): - self.close() - - -class OpenFileChunk(object): - """A wrapper for reading from a file-like object from multiple threads - - Each thread reading from the file-like object should have its own - private instance of this class. - """ - - init_lock = Lock() - file_locks = {} - - def __init__(self, fileobj, start_byte, length): - with self.init_lock: - try: - lock, refs = self.file_locks[fileobj] - except KeyError: - lock, refs = self.file_locks[fileobj] = (Lock(), set()) - refs.add(self) - self.lock = lock - self.file = fileobj - self.start = self.offset = start_byte - self.length = length - - def read(self, amount=None): - if self.offset >= self.start + self.length: - return b"" - with self.lock: - pos = self.file.tell() - self.file.seek(self.offset) - - if amount is None: - amount = self.length - amount = min(self.length - self.tell(), amount) - read = self.file.read(amount) - - self.offset = self.file.tell() - self.file.seek(pos) - assert self.offset - self.start >= 0, (self.start, self.offset) - assert self.offset <= self.start + self.length, \ - (self.start, self.length, self.offset) - return read - - def seek(self, pos): - assert pos >= 0, pos - self.offset = self.start + pos - - def tell(self): - return self.offset - self.start - - def close(self): - if self.file is None: - return - try: - with self.init_lock: - lock, refs = self.file_locks[self.file] - refs.remove(self) - if not refs: - self.file_locks.pop(self.file) - finally: - self.file = None - self.lock = None - - def __enter__(self): - return self - - def __exit__(self, *args, **kwargs): - self.close() diff --git a/corehq/blobs/tests/test_s3db.py b/corehq/blobs/tests/test_s3db.py index 0f52181ec383..63b64c008fb0 100644 --- a/corehq/blobs/tests/test_s3db.py +++ b/corehq/blobs/tests/test_s3db.py @@ -128,6 +128,14 @@ def test_put_with_double_dotted_name(self): with self.db.get(info.identifier) as fh: self.assertEqual(fh.read(), b"content") + def test_put_from_get_stream(self): + name = "form.xml" + old = self.db.put(StringIO(b"content"), name, "old_bucket") + with self.db.get(old.identifier, "old_bucket") as fh: + new = self.db.put(fh, name, "new_bucket") + with self.db.get(new.identifier, "new_bucket") as fh: + self.assertEqual(fh.read(), b"content") + def test_delete(self): name = "test.4" bucket = "doc.4" @@ -217,71 +225,3 @@ def test_empty_attachment_name(self): def test_bad_name(self, name, bucket=mod.DEFAULT_BUCKET): with self.assertRaises(mod.BadName): self.db.get(name, bucket) - - -class TestOpenFileChunk(TestCase): - - @classmethod - def setUpClass(cls): - cls.tmp_context = tempdir() - tmp = cls.tmp_context.__enter__() - cls.filepath = join(tmp, "file.txt") - with open(cls.filepath, "wb") as fh: - fh.write(b"data") - - @classmethod - def tearDownClass(cls): - cls.tmp_context.__exit__(None, None, None) - - def get_chunk(self, normal_file, start, length): - return mod.OpenFileChunk(normal_file, start, length) - - -@generate_cases([ - (0, 0, b"data"), - (1, 1, b"ata",), - (2, 2, b"ta",), - (3, 3, b"a"), - (4, 4, b""), - (5, 5, b""), -], TestOpenFileChunk) -def test_seek_tell_read(self, to, expect_tell, expect_read): - with open(self.filepath, "rb") as normal_file: - normal_file.seek(to) - - with self.get_chunk(normal_file, 0, 4) as chunk: - self.assertIn(normal_file, mod.OpenFileChunk.file_locks) - # chunk seek/tell/read should not affect normal_file - chunk.seek(to) - self.assertEqual(chunk.tell(), expect_tell) - self.assertEqual(chunk.read(), expect_read) - - self.assertEqual(normal_file.tell(), expect_tell) - self.assertEqual(normal_file.read(), expect_read) - self.assertNotIn(normal_file, mod.OpenFileChunk.file_locks) - - -@generate_cases([ - (0, (0, b"data"), (0, b"at")), - (1, (1, b"ata"), (1, b"t")), - (2, (2, b"ta",), (2, b"")), - (3, (3, b"a"), (3, b"")), - (4, (4, b""), (4, b"")), - (5, (5, b""), (5, b"")), -], TestOpenFileChunk) -def test_seek_tell_read_in_sub_chunk(self, to, expect_norm, expect_chunk): - with open(self.filepath, "rb") as normal_file: - normal_file.seek(to) - - with self.get_chunk(normal_file, 1, 2) as chunk: - # chunk seek/tell/read should not affect normal_file - chunk.seek(to) - self.assertEqual((chunk.tell(), chunk.read()), expect_chunk) - - self.assertEqual((normal_file.tell(), normal_file.read()), expect_norm) - - -class TestReadOpenFileChunk(TestOpenFileChunk): - - def get_chunk(self, normal_file, start, length): - return mod.ReadOpenFileChunk(normal_file, start, length, 4) diff --git a/corehq/blobs/util.py b/corehq/blobs/util.py index 26dba879d903..50763f959995 100644 --- a/corehq/blobs/util.py +++ b/corehq/blobs/util.py @@ -15,7 +15,7 @@ def __iter__(self): return iter(self._obj) def __enter__(self): - return self._obj + return self def __exit__(self, *args): self._obj.close() diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 61fb8bdd91fc..b87f744151b5 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -84,7 +84,7 @@ git+git://github.com/dimagi/pyzxcvbn.git#egg=pyzxcvbn django-statici18n==1.1.5 django-simple-captcha==0.5.1 httpagentparser==1.7.8 -boto3==1.2.3 +boto3==1.4.0 simpleeval==0.8.7 laboratory==0.2.0 ConcurrentLogHandler==0.9.1