-
Notifications
You must be signed in to change notification settings - Fork 400
/
libcloud_object_store.py
185 lines (157 loc) · 7.58 KB
/
libcloud_object_store.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Utility for uploading to and downloading from cloud object stores."""
import io
import os
import pathlib
import uuid
from typing import Any, Callable, Dict, Optional, Union
from requests.exceptions import ConnectionError
from urllib3.exceptions import ProtocolError
from composer.utils.import_helpers import MissingConditionalImportError
from composer.utils.iter_helpers import iterate_with_callback
from composer.utils.object_store.object_store import ObjectStore, ObjectStoreTransientError
__all__ = ['LibcloudObjectStore']
class LibcloudObjectStore(ObjectStore):
"""Utility for uploading to and downloading from object (blob) stores, such as Amazon S3.
.. rubric:: Example
Here's an example for an Amazon S3 bucket named ``MY_CONTAINER``:
>>> from composer.utils import LibcloudObjectStore
>>> object_store = LibcloudObjectStore(
... provider="s3",
... container="MY_CONTAINER",
... provider_kwargs={
... "key": "AKIA...",
... "secret": "*********",
... }
... )
>>> object_store
<composer.utils.object_store.libcloud_object_store.LibcloudObjectStore object at ...>
Args:
provider (str): Cloud provider to use. Valid options are:
* :mod:`~libcloud.storage.drivers.atmos`
* :mod:`~libcloud.storage.drivers.auroraobjects`
* :mod:`~libcloud.storage.drivers.azure_blobs`
* :mod:`~libcloud.storage.drivers.backblaze_b2`
* :mod:`~libcloud.storage.drivers.cloudfiles`
* :mod:`~libcloud.storage.drivers.digitalocean_spaces`
* :mod:`~libcloud.storage.drivers.google_storage`
* :mod:`~libcloud.storage.drivers.ktucloud`
* :mod:`~libcloud.storage.drivers.local`
* :mod:`~libcloud.storage.drivers.minio`
* :mod:`~libcloud.storage.drivers.nimbus`
* :mod:`~libcloud.storage.drivers.ninefold`
* :mod:`~libcloud.storage.drivers.oss`
* :mod:`~libcloud.storage.drivers.rgw`
* :mod:`~libcloud.storage.drivers.s3`
.. seealso:: :doc:`Full list of libcloud providers <libcloud:storage/supported_providers>`
container (str): The name of the container (i.e. bucket) to use.
provider_kwargs (Dict[str, Any], optional): Keyword arguments to pass into the constructor
for the specified provider. These arguments would usually include the cloud region
and credentials.
Common keys are:
* ``key`` (str): API key or username to be used (required).
* ``secret`` (str): Secret password to be used (required).
* ``secure`` (bool): Whether to use HTTPS or HTTP. Note: Some providers only support HTTPS, and it is on by default.
* ``host`` (str): Override hostname used for connections.
* ``port`` (int): Override port used for connections.
* ``api_version`` (str): Optional API version. Only used by drivers which support multiple API versions.
* ``region`` (str): Optional driver region. Only used by drivers which support multiple regions.
.. seealso:: :class:`libcloud.storage.base.StorageDriver`
"""
def __init__(self,
provider: str,
container: str,
chunk_size: int = 1_024 * 1_024,
provider_kwargs: Optional[Dict[str, Any]] = None) -> None:
try:
from libcloud.storage.providers import get_driver
except ImportError as e:
raise MissingConditionalImportError('libcloud', 'apache-libcloud') from e
provider_cls = get_driver(provider)
if provider_kwargs is None:
provider_kwargs = {}
self.chunk_size = chunk_size
self._provider_name = provider
self._provider = provider_cls(**provider_kwargs)
self._container = self._provider.get_container(container)
def get_uri(self, object_name: str):
return f'{self._provider_name}://{self._container.name}/{object_name}'
def upload_object(
self,
object_name: str,
filename: Union[str, pathlib.Path],
callback: Optional[Callable[[int, int], None]] = None,
):
with open(filename, 'rb') as f:
stream = iterate_with_callback(_file_to_iterator(f, self.chunk_size),
os.fstat(f.fileno()).st_size, callback)
try:
self._provider.upload_object_via_stream(
stream,
container=self._container,
object_name=object_name,
)
except Exception as e:
self._ensure_transient_errors_are_wrapped(e)
def _ensure_transient_errors_are_wrapped(self, exc: Exception):
from libcloud.common.types import LibcloudError
if isinstance(exc, (LibcloudError, ProtocolError, TimeoutError, ConnectionError)):
if isinstance(exc, LibcloudError):
# The S3 driver does not encode the error code in an easy-to-parse manner
# So first checking if the error code is non-transient
is_transient_error = any(x in str(exc) for x in ('408', '409', '425', '429', '500', '503', '504'))
if not is_transient_error:
raise exc
raise ObjectStoreTransientError() from exc
raise exc
def _get_object(self, object_name: str):
"""Get object from object store.
Args:
object_name (str): The name of the object.
"""
from libcloud.storage.types import ObjectDoesNotExistError
try:
return self._provider.get_object(self._container.name, object_name)
except ObjectDoesNotExistError as e:
raise FileNotFoundError(f'Object not found: {self.get_uri(object_name)}') from e
except Exception as e:
self._ensure_transient_errors_are_wrapped(e)
def get_object_size(self, object_name: str) -> int:
return self._get_object(object_name).size
def download_object(
self,
object_name: str,
filename: Union[str, pathlib.Path],
overwrite: bool = False,
callback: Optional[Callable[[int, int], None]] = None,
):
if os.path.exists(filename) and not overwrite:
# If the file already exits, short-circuit and skip the download
raise FileExistsError(f'filename {filename} exists and overwrite was set to False.')
obj = self._get_object(object_name)
# Download first to a tempfile, and then rename, in case if the file gets corrupted in transit
tmp_filepath = str(filename) + f'.{uuid.uuid4()}.tmp'
try:
with open(tmp_filepath, 'wb+') as f:
stream = self._provider.download_object_as_stream(obj, chunk_size=self.chunk_size)
for chunk in iterate_with_callback(stream, obj.size, callback):
f.write(chunk)
except Exception as e:
# The download failed for some reason. Make a best-effort attempt to remove the temporary file.
try:
os.remove(tmp_filepath)
except OSError:
pass
self._ensure_transient_errors_are_wrapped(e)
# The download was successful.
if overwrite:
os.replace(tmp_filepath, filename)
else:
os.rename(tmp_filepath, filename)
def _file_to_iterator(f: io.IOBase, chunk_size: int):
while True:
byte = f.read(chunk_size)
if byte == b'':
break
yield byte