diff --git a/pydrive2/files.py b/pydrive2/files.py index 49ceb454..c9e8a0b7 100644 --- a/pydrive2/files.py +++ b/pydrive2/files.py @@ -5,6 +5,7 @@ from googleapiclient import errors from googleapiclient.http import MediaIoBaseUpload from googleapiclient.http import MediaIoBaseDownload +from googleapiclient.http import DEFAULT_CHUNK_SIZE from functools import wraps from .apiattr import ApiAttribute @@ -97,6 +98,81 @@ def _GetList(self): return result +class IoBuffer(object): + """Lightweight retention of one chunk.""" + + def __init__(self, encoding): + self.encoding = encoding + self.chunk = None + + def write(self, chunk): + self.chunk = chunk + + def read(self): + return ( + self.chunk.decode(self.encoding) + if self.chunk and self.encoding + else self.chunk + ) + + +class MediaIoReadable(object): + def __init__( + self, + request, + encoding=None, + pre_buffer=True, + remove_prefix=b"", + chunksize=DEFAULT_CHUNK_SIZE, + ): + """File-like wrapper around MediaIoBaseDownload. + + :param pre_buffer: Whether to read one chunk into an internal buffer + immediately in order to raise any potential errors. + :param remove_prefix: Bytes prefix to remove from internal pre_buffer. + :raises: ApiRequestError + """ + self.done = False + self._fd = IoBuffer(encoding) + self.downloader = MediaIoBaseDownload( + self._fd, request, chunksize=chunksize + ) + self._pre_buffer = False + if pre_buffer: + self.read() + if remove_prefix: + chunk = io.BytesIO(self._fd.chunk) + GoogleDriveFile._RemovePrefix(chunk, remove_prefix) + self._fd.chunk = chunk.getvalue() + self._pre_buffer = True + + def read(self): + """ + :returns: bytes or str -- chunk (or None if done) + :raises: ApiRequestError + """ + if self._pre_buffer: + self._pre_buffer = False + return self._fd.read() + if self.done: + return None + try: + _, self.done = self.downloader.next_chunk() + except errors.HttpError as error: + raise ApiRequestError(error) + return self._fd.read() + + def __iter__(self): + """ + :raises: ApiRequestError + """ + while True: + chunk = self.read() + if chunk is None: + break + yield chunk + + class GoogleDriveFile(ApiAttributeMixin, ApiResource): """Google Drive File instance. @@ -247,12 +323,7 @@ def GetContentFile( raise FileNotUploadedError() def download(fd, request): - # Ensures thread safety. Similar to other places where we call - # `.execute(http=self.http)` to pass a client from the thread - # local storage. - if self.http: - request.http = self.http - downloader = MediaIoBaseDownload(fd, request) + downloader = MediaIoBaseDownload(fd, self._WrapRequest(request)) done = False while done is False: status, done = downloader.next_chunk() @@ -260,9 +331,10 @@ def download(fd, request): callback(status.resumable_progress, status.total_size) with open(filename, mode="w+b") as fd: - # Ideally would use files.export_media instead if - # metadata.get("mimeType").startswith("application/vnd.google-apps.") - # but that would first require a slow call to FetchMetadata() + # Should use files.export_media instead of files.get_media if + # metadata["mimeType"].startswith("application/vnd.google-apps."). + # But that would first require a slow call to FetchMetadata(). + # We prefer to try-except for speed. try: download(fd, files.get_media(fileId=file_id)) except errors.HttpError as error: @@ -284,13 +356,66 @@ def download(fd, request): if mimetype == "text/plain" and remove_bom: fd.seek(0) - boms = [ - bom[mimetype] - for bom in MIME_TYPE_TO_BOM.values() - if mimetype in bom - ] - if boms: - self._RemovePrefix(fd, boms[0]) + bom = self._GetBOM(mimetype) + if bom: + self._RemovePrefix(fd, bom) + + @LoadAuth + def GetContentIOBuffer( + self, + mimetype=None, + encoding=None, + remove_bom=False, + chunksize=DEFAULT_CHUNK_SIZE, + ): + """Get a file-like object which has a buffered read() method. + + :param mimetype: mimeType of the file. + :type mimetype: str + :param encoding: The encoding to use when decoding the byte string. + :type encoding: str + :param remove_bom: Whether to remove the byte order marking. + :type remove_bom: bool + :param chunksize: default read()/iter() chunksize. + :type chunksize: int + :returns: MediaIoReadable -- file-like object. + :raises: ApiRequestError, FileNotUploadedError + """ + files = self.auth.service.files() + file_id = self.metadata.get("id") or self.get("id") + if not file_id: + raise FileNotUploadedError() + + # Should use files.export_media instead of files.get_media if + # metadata["mimeType"].startswith("application/vnd.google-apps."). + # But that would first require a slow call to FetchMetadata(). + # We prefer to try-except for speed. + try: + request = self._WrapRequest(files.get_media(fileId=file_id)) + return MediaIoReadable( + request, encoding=encoding, chunksize=chunksize + ) + except ApiRequestError as exc: + if ( + exc.error["code"] != 403 + or exc.GetField("reason") != "fileNotDownloadable" + ): + raise exc + mimetype = mimetype or "text/plain" + request = self._WrapRequest( + files.export_media(fileId=file_id, mimeType=mimetype) + ) + remove_prefix = ( + self._GetBOM(mimetype) + if mimetype == "text/plain" and remove_bom + else b"" + ) + return MediaIoReadable( + request, + encoding=encoding, + remove_prefix=remove_prefix, + chunksize=chunksize, + ) @LoadAuth def FetchMetadata(self, fields=None, fetch_all=False): @@ -446,6 +571,16 @@ def DeletePermission(self, permission_id): """ return self._DeletePermission(permission_id) + def _WrapRequest(self, request): + """Replaces request.http with self.http. + + Ensures thread safety. Similar to other places where we call + `.execute(http=self.http)` to pass a client from the thread local storage. + """ + if self.http: + request.http = self.http + return request + @LoadAuth def _FilesInsert(self, param=None): """Upload a new file using Files.insert(). @@ -663,6 +798,13 @@ def _DeletePermission(self, permission_id): self.metadata["permissions"] = permissions return True + @staticmethod + def _GetBOM(mimetype): + """Based on download mime type (ignores Google Drive mime type)""" + for bom in MIME_TYPE_TO_BOM.values(): + if mimetype in bom: + return bom[mimetype] + @staticmethod def _RemovePrefix(file_object, prefix, block_size=BLOCK_SIZE): """Deletes passed prefix by shifting content of passed file object by to diff --git a/pydrive2/test/test_file.py b/pydrive2/test/test_file.py index f246080a..fe6aab27 100644 --- a/pydrive2/test/test_file.py +++ b/pydrive2/test/test_file.py @@ -280,6 +280,24 @@ def test_10_Files_Download_Service(self): self.DeleteUploadedFiles(drive, [file1["id"]]) + def test_11_Files_Get_Content_Buffer(self): + drive = GoogleDrive(self.ga) + file1 = drive.CreateFile() + filename = self.getTempFile() + content = "hello world!\ngoodbye, cruel world!" + file1["title"] = filename + file1.SetContentString(content) + pydrive_retry(file1.Upload) # Files.insert + + buffer1 = pydrive_retry(file1.GetContentIOBuffer) + self.assertEqual(file1.metadata["title"], filename) + self.assertEqual(b"".join(iter(buffer1)).decode("ascii"), content) + + buffer2 = pydrive_retry(file1.GetContentIOBuffer, encoding="ascii") + self.assertEqual("".join(iter(buffer2)), content) + + self.DeleteUploadedFiles(drive, [file1["id"]]) + # Tests for Trash/UnTrash/Delete. # =============================== @@ -610,6 +628,23 @@ def test_Gfile_Conversion_Add_Remove_BOM(self): self.assertNotEqual(content_bom, content_no_bom) self.assertTrue(len(content_bom) > len(content_no_bom)) + buffer_bom = pydrive_retry( + file1.GetContentIOBuffer, + mimetype="text/plain", + encoding="utf-8", + ) + buffer_bom = u"".join(iter(buffer_bom)) + self.assertEqual(content_bom, buffer_bom) + + buffer_no_bom = pydrive_retry( + file1.GetContentIOBuffer, + mimetype="text/plain", + remove_bom=True, + encoding="utf-8", + ) + buffer_no_bom = u"".join(iter(buffer_no_bom)) + self.assertEqual(content_no_bom, buffer_no_bom) + finally: self.cleanup_gfile_conversion_test( file1, file_name, downloaded_file_name