Skip to content
Merged
174 changes: 158 additions & 16 deletions pydrive2/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from googleapiclient import errors
from googleapiclient.http import MediaIoBaseUpload
from googleapiclient.http import MediaIoBaseDownload
from googleapiclient.http import DEFAULT_CHUNK_SIZE
from functools import wraps

from .apiattr import ApiAttribute
Expand Down Expand Up @@ -97,6 +98,81 @@ def _GetList(self):
return result


class IoBuffer(object):
"""Lightweight retention of one chunk."""

def __init__(self, encoding):
self.encoding = encoding
self.chunk = None

def write(self, chunk):
self.chunk = chunk

def read(self):
return (
self.chunk.decode(self.encoding)
if self.chunk and self.encoding
else self.chunk
)


class MediaIoReadable(object):
def __init__(
self,
request,
encoding=None,
pre_buffer=True,
remove_prefix=b"",
chunksize=DEFAULT_CHUNK_SIZE,
):
"""File-like wrapper around MediaIoBaseDownload.

:param pre_buffer: Whether to read one chunk into an internal buffer
immediately in order to raise any potential errors.
:param remove_prefix: Bytes prefix to remove from internal pre_buffer.
:raises: ApiRequestError
"""
self.done = False
self._fd = IoBuffer(encoding)
self.downloader = MediaIoBaseDownload(
self._fd, request, chunksize=chunksize
)
self._pre_buffer = False
if pre_buffer:
self.read()
if remove_prefix:
chunk = io.BytesIO(self._fd.chunk)
GoogleDriveFile._RemovePrefix(chunk, remove_prefix)
self._fd.chunk = chunk.getvalue()
self._pre_buffer = True

def read(self):
"""
:returns: bytes or str -- chunk (or None if done)
:raises: ApiRequestError
"""
if self._pre_buffer:
self._pre_buffer = False
return self._fd.read()
if self.done:
return None
try:
_, self.done = self.downloader.next_chunk()
except errors.HttpError as error:
raise ApiRequestError(error)
return self._fd.read()

def __iter__(self):
"""
:raises: ApiRequestError
"""
while True:
chunk = self.read()
if chunk is None:
break
yield chunk


class GoogleDriveFile(ApiAttributeMixin, ApiResource):
"""Google Drive File instance.

Expand Down Expand Up @@ -247,22 +323,18 @@ def GetContentFile(
raise FileNotUploadedError()

def download(fd, request):
# Ensures thread safety. Similar to other places where we call
# `.execute(http=self.http)` to pass a client from the thread
# local storage.
if self.http:
request.http = self.http
downloader = MediaIoBaseDownload(fd, request)
downloader = MediaIoBaseDownload(fd, self._WrapRequest(request))
done = False
while done is False:
status, done = downloader.next_chunk()
if callback:
callback(status.resumable_progress, status.total_size)

with open(filename, mode="w+b") as fd:
# Ideally would use files.export_media instead if
# metadata.get("mimeType").startswith("application/vnd.google-apps.")
# but that would first require a slow call to FetchMetadata()
# Should use files.export_media instead of files.get_media if
# metadata["mimeType"].startswith("application/vnd.google-apps.").
# But that would first require a slow call to FetchMetadata().
# We prefer to try-except for speed.
try:
download(fd, files.get_media(fileId=file_id))
except errors.HttpError as error:
Expand All @@ -284,13 +356,66 @@ def download(fd, request):

if mimetype == "text/plain" and remove_bom:
fd.seek(0)
boms = [
bom[mimetype]
for bom in MIME_TYPE_TO_BOM.values()
if mimetype in bom
]
if boms:
self._RemovePrefix(fd, boms[0])
bom = self._GetBOM(mimetype)
if bom:
self._RemovePrefix(fd, bom)

@LoadAuth
def GetContentIOBuffer(
self,
mimetype=None,
encoding=None,
remove_bom=False,
chunksize=DEFAULT_CHUNK_SIZE,
):
"""Get a file-like object which has a buffered read() method.

:param mimetype: mimeType of the file.
:type mimetype: str
:param encoding: The encoding to use when decoding the byte string.
:type encoding: str
:param remove_bom: Whether to remove the byte order marking.
:type remove_bom: bool
:param chunksize: default read()/iter() chunksize.
:type chunksize: int
:returns: MediaIoReadable -- file-like object.
:raises: ApiRequestError, FileNotUploadedError
"""
files = self.auth.service.files()
file_id = self.metadata.get("id") or self.get("id")
if not file_id:
raise FileNotUploadedError()

# Should use files.export_media instead of files.get_media if
# metadata["mimeType"].startswith("application/vnd.google-apps.").
# But that would first require a slow call to FetchMetadata().
# We prefer to try-except for speed.
try:
request = self._WrapRequest(files.get_media(fileId=file_id))
return MediaIoReadable(
request, encoding=encoding, chunksize=chunksize
)
except ApiRequestError as exc:
if (
exc.error["code"] != 403
or exc.GetField("reason") != "fileNotDownloadable"
):
raise exc
mimetype = mimetype or "text/plain"
request = self._WrapRequest(
files.export_media(fileId=file_id, mimeType=mimetype)
)
remove_prefix = (
self._GetBOM(mimetype)
if mimetype == "text/plain" and remove_bom
else b""
)
return MediaIoReadable(
request,
encoding=encoding,
remove_prefix=remove_prefix,
chunksize=chunksize,
)

@LoadAuth
def FetchMetadata(self, fields=None, fetch_all=False):
Expand Down Expand Up @@ -446,6 +571,16 @@ def DeletePermission(self, permission_id):
"""
return self._DeletePermission(permission_id)

def _WrapRequest(self, request):
"""Replaces request.http with self.http.

Ensures thread safety. Similar to other places where we call
`.execute(http=self.http)` to pass a client from the thread local storage.
"""
if self.http:
request.http = self.http
return request

@LoadAuth
def _FilesInsert(self, param=None):
"""Upload a new file using Files.insert().
Expand Down Expand Up @@ -663,6 +798,13 @@ def _DeletePermission(self, permission_id):
self.metadata["permissions"] = permissions
return True

@staticmethod
def _GetBOM(mimetype):
"""Based on download mime type (ignores Google Drive mime type)"""
for bom in MIME_TYPE_TO_BOM.values():
if mimetype in bom:
return bom[mimetype]

@staticmethod
def _RemovePrefix(file_object, prefix, block_size=BLOCK_SIZE):
"""Deletes passed prefix by shifting content of passed file object by to
Expand Down
35 changes: 35 additions & 0 deletions pydrive2/test/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,24 @@ def test_10_Files_Download_Service(self):

self.DeleteUploadedFiles(drive, [file1["id"]])

def test_11_Files_Get_Content_Buffer(self):
drive = GoogleDrive(self.ga)
file1 = drive.CreateFile()
filename = self.getTempFile()
content = "hello world!\ngoodbye, cruel world!"
file1["title"] = filename
file1.SetContentString(content)
pydrive_retry(file1.Upload) # Files.insert

buffer1 = pydrive_retry(file1.GetContentIOBuffer)
self.assertEqual(file1.metadata["title"], filename)
self.assertEqual(b"".join(iter(buffer1)).decode("ascii"), content)

buffer2 = pydrive_retry(file1.GetContentIOBuffer, encoding="ascii")
self.assertEqual("".join(iter(buffer2)), content)

self.DeleteUploadedFiles(drive, [file1["id"]])

# Tests for Trash/UnTrash/Delete.
# ===============================

Expand Down Expand Up @@ -610,6 +628,23 @@ def test_Gfile_Conversion_Add_Remove_BOM(self):
self.assertNotEqual(content_bom, content_no_bom)
self.assertTrue(len(content_bom) > len(content_no_bom))

buffer_bom = pydrive_retry(
file1.GetContentIOBuffer,
mimetype="text/plain",
encoding="utf-8",
)
buffer_bom = u"".join(iter(buffer_bom))
self.assertEqual(content_bom, buffer_bom)

buffer_no_bom = pydrive_retry(
file1.GetContentIOBuffer,
mimetype="text/plain",
remove_bom=True,
encoding="utf-8",
)
buffer_no_bom = u"".join(iter(buffer_no_bom))
self.assertEqual(content_no_bom, buffer_no_bom)

finally:
self.cleanup_gfile_conversion_test(
file1, file_name, downloaded_file_name
Expand Down