-
Notifications
You must be signed in to change notification settings - Fork 400
/
sftp_object_store.py
242 lines (213 loc) · 10.4 KB
/
sftp_object_store.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Utility for uploading to and downloading from cloud object stores."""
from __future__ import annotations
import contextlib
import os
import pathlib
import urllib.parse
import uuid
from typing import Any, Callable, Dict, Optional, Union
from composer.utils.import_helpers import MissingConditionalImportError
from composer.utils.object_store.object_store import ObjectStore, ObjectStoreTransientError
__all__ = ['SFTPObjectStore']
try:
import paramiko.client
from paramiko import SSHClient
_PARAMIKO_AVAILABLE = True
except ImportError:
_PARAMIKO_AVAILABLE = False
def _set_kwarg(value: Any, kwargs: Dict[str, Any], arg_name: str, kwarg_name: str):
if kwarg_name in kwargs:
raise ValueError(f'The `{arg_name}` should be not be specified directly if also included via `connect_kwargs`')
kwargs[kwarg_name] = value
class SFTPObjectStore(ObjectStore):
"""Utility for uploading to and downloading to a server via SFTP.
Args:
host (str): The server to connect to.
Also accepts a URI string in the form ``'sftp://username@host:port/./relative/path'``.
For an absolute path, use a double `//` -- e.g. ``'sftp://username@host:port//absolute/path'``.
port (int, optional): The server port to connect to.
username (str, optional): The username (if not specified in the SSH config) needed to authenticate.
Defaults to None.
password (str, optional): The password (if required) needed to authenticate. Defaults to None.
key_filename (pathlib.Path | str, optional): The filepath to the a private key (if required) needed to
authenticate. Defaults to None. Any keys specified here will be tried *in addition* to any keys
specified in ``~/.ssh/`` or via a SSH agent.
known_hosts_filename (pathlib.Path | str, optional): The filename of the known hosts file. If not specified,
the default SSH known hosts will be used.
missing_host_key_policy (str | paramiko.client.MissingHostKeyPolicy): The class name or instance of
:class:`paramiko.client.MissingHostKeyPolicy` to use for a missing host key. Defaults to ``'RejectPolicy'``.
Built-in options:
* ``'RejectPolicy'`` (the default), which will reject any host key not authorized in the ``known_hosts_filename``.
* ``'AutoAddPolicy'``, which will add any unknown host key.
* ``'WarningPolicy'``, which will warn on an unknown host key.
For custom logic, subclass :class:`paramiko.client.MissingHostKeyPolicy`, and provide an instance of this class.
cwd (str, optional): The directory to navigate to upon creating the SSH connection. If not present
it will be created.
connect_kwargs (Dict[str, Any], optional): Any additional kwargs to pass through to :meth:`.SSHClient.connect`.
"""
def __init__(
self,
host: str,
port: int = 22,
username: Optional[str] = None,
password: Optional[str] = None,
known_hosts_filename: Optional[Union[pathlib.Path, str]] = None,
key_filename: Optional[Union[pathlib.Path, str]] = None,
missing_host_key_policy: Union[str, paramiko.client.MissingHostKeyPolicy] = 'RejectPolicy',
cwd: str = '',
connect_kwargs: Optional[Dict[str, Any]] = None,
):
if not _PARAMIKO_AVAILABLE:
raise MissingConditionalImportError(extra_deps_group='streaming', conda_package='paramiko')
url = urllib.parse.urlsplit(host)
if url.scheme != '':
if url.scheme.lower() != 'sftp':
raise ValueError('If specifying a URI, only the sftp scheme is supported.')
if not url.hostname:
raise ValueError('If specifying a URI, the URI must include the hostname.')
host = url.hostname
if url.username:
if username is not None:
raise ValueError(
'If specifying the username in the `host`, then the `username` argument must be blank.')
username = url.username
if url.password:
if password is not None:
raise ValueError(
'If specifying the password in the `host`, then the `password` argument must be blank.')
password = url.password
if url.port:
if port != 22:
raise ValueError('If specifying the port in the `host`, then the `port` argument must be blank.')
port = url.port
if url.path:
# strip the first left slash. Two slashes for absolute; 1 for relative
assert url.path.startswith('/'), 'The path should always start with a `/`'
cwd = url.path[1:]
if url.query or url.fragment:
raise ValueError('Query and fragment parameters are not supported as part of a URI.')
if connect_kwargs is None:
connect_kwargs = {}
if host:
_set_kwarg(host, connect_kwargs, arg_name='host', kwarg_name='hostname')
if port:
_set_kwarg(port, connect_kwargs, arg_name='port', kwarg_name='port')
if username:
_set_kwarg(username, connect_kwargs, arg_name='username', kwarg_name='username')
if password:
_set_kwarg(password, connect_kwargs, arg_name='password', kwarg_name='password')
if key_filename:
_set_kwarg(key_filename, connect_kwargs, arg_name='key_filename', kwarg_name='key_filename')
if cwd and not cwd.endswith('/'):
cwd += '/'
self.cwd = cwd
netloc = ''
if username:
netloc += f'{username}@'
if host:
netloc += host
if port:
netloc += f':{port}'
self._base_uri = urllib.parse.urlunsplit((
'sftp', # scheme
netloc, # netloc
'/' + cwd, # path
None, # query
None, # fragment
))
self.ssh_client = SSHClient()
if known_hosts_filename is not None:
known_hosts_filename = str(known_hosts_filename)
if isinstance(missing_host_key_policy, str):
try:
missing_host_key_policy = getattr(paramiko.client, missing_host_key_policy)()
assert isinstance(missing_host_key_policy, paramiko.client.MissingHostKeyPolicy)
except AttributeError:
raise ValueError(
"Invalid `missing_host_key_policy`. Must be 'AutoAddPolicy', 'RejectPolicy', or 'WarningPolicy'.")
self.ssh_client.set_missing_host_key_policy(missing_host_key_policy)
self.ssh_client.load_system_host_keys(known_hosts_filename)
self._connect_kwargs = connect_kwargs
self.ssh_client.connect(**connect_kwargs)
self.sftp_client = self.ssh_client.open_sftp()
def close(self):
self.sftp_client.close()
self.ssh_client.close()
def get_uri(self, object_name: str) -> str:
return self._base_uri + object_name
def get_object_size(self, object_name: str) -> int:
object_name = os.path.join(self.cwd, object_name)
with self._handle_transient_errors():
st_size = self.sftp_client.stat(object_name).st_size
if st_size is None:
raise RuntimeError('Cannot determine object size: stat(object_name).st_size is None')
return st_size
@contextlib.contextmanager
def _handle_transient_errors(self):
from paramiko import ChannelException, SSHException
try:
yield
except Exception as e:
if not self._is_cnx_alive():
# If the connection dropped, then it's a transient error. Create a new one, and raise the exception to try again.
self.close()
self.ssh_client.connect(**self._connect_kwargs)
self.sftp_client = self.ssh_client.open_sftp()
raise ObjectStoreTransientError from e
if isinstance(e, SSHException):
if 'Server connection dropped:' in str(e):
raise ObjectStoreTransientError from e
if isinstance(e, (TimeoutError, ConnectionError, EOFError, ChannelException)):
raise ObjectStoreTransientError from e
raise e
def _is_cnx_alive(self):
transport = self.ssh_client.get_transport()
assert transport is not None, 'transport should not be None'
if not transport.is_active() or not transport.is_alive():
return False
channel = self.sftp_client.get_channel()
assert channel is not None, 'channels not be None if the transport is alive'
return channel.active and not channel.closed
def upload_object(
self,
object_name: str,
filename: Union[str, pathlib.Path],
callback: Optional[Callable[[int, int], None]] = None,
) -> None:
object_name = os.path.join(self.cwd, object_name)
dirname = os.path.dirname(object_name)
with self._handle_transient_errors():
if dirname:
self.ssh_client.exec_command(f'mkdir -p {dirname}')
self.sftp_client.put(str(filename), object_name, callback=callback, confirm=True)
def download_object(
self,
object_name: str,
filename: Union[str, pathlib.Path],
overwrite: bool = False,
callback: Optional[Callable[[int, int], None]] = None,
) -> None:
object_name = os.path.join(self.cwd, object_name)
dirname = os.path.dirname(filename)
if dirname:
os.makedirs(dirname, exist_ok=True)
if os.path.exists(filename) and not overwrite:
raise FileExistsError(f'The file at {filename} already exists')
tmp_path = str(filename) + f'.{uuid.uuid4()}.tmp'
try:
with self._handle_transient_errors():
self.sftp_client.get(remotepath=object_name, localpath=tmp_path, callback=callback)
except Exception:
# Make a best effort attempt to clean up the temporary file
try:
os.remove(tmp_path)
except OSError:
pass
raise
else:
if overwrite:
os.replace(tmp_path, filename)
else:
os.rename(tmp_path, filename)