diff --git a/databricks_cli/dbfs/api.py b/databricks_cli/dbfs/api.py index eb363a4a..0be47f57 100644 --- a/databricks_cli/dbfs/api.py +++ b/databricks_cli/dbfs/api.py @@ -86,6 +86,8 @@ class DbfsErrorCodes(object): class DbfsApi(object): + MULTIPART_UPLOAD_LIMIT = 2147483648 + def __init__(self, api_client): self.client = DbfsService(api_client) @@ -113,16 +115,24 @@ def get_status(self, dbfs_path, headers=None): json = self.client.get_status(dbfs_path.absolute_path, headers=headers) return FileInfo.from_json(json) + # Method makes multipart/form-data file upload for files <2GB. + # Otherwise uses create, add-block, close methods for streaming upload. def put_file(self, src_path, dbfs_path, overwrite, headers=None): - handle = self.client.create(dbfs_path.absolute_path, overwrite, headers=headers)['handle'] - with open(src_path, 'rb') as local_file: - while True: - contents = local_file.read(BUFFER_SIZE_BYTES) - if len(contents) == 0: - break - # add_block should not take a bytes object. - self.client.add_block(handle, b64encode(contents).decode(), headers=headers) - self.client.close(handle, headers=headers) + # If file size is >2Gb use streaming upload. + if os.path.getsize(src_path) < self.MULTIPART_UPLOAD_LIMIT: + self.client.put(dbfs_path.absolute_path, src_path=src_path, + overwrite=overwrite, headers=headers) + else: + handle = self.client.create(dbfs_path.absolute_path, overwrite, + headers=headers)['handle'] + with open(src_path, 'rb') as local_file: + while True: + contents = local_file.read(BUFFER_SIZE_BYTES) + if len(contents) == 0: + break + # add_block should not take a bytes object. + self.client.add_block(handle, b64encode(contents).decode(), headers=headers) + self.client.close(handle, headers=headers) def get_file(self, dbfs_path, dst_path, overwrite, headers=None): if os.path.exists(dst_path) and not overwrite: diff --git a/databricks_cli/sdk/api_client.py b/databricks_cli/sdk/api_client.py index 70501e43..f1af7fec 100644 --- a/databricks_cli/sdk/api_client.py +++ b/databricks_cli/sdk/api_client.py @@ -109,7 +109,7 @@ def close(self): # helper functions starting here - def perform_query(self, method, path, data = {}, headers = None): + def perform_query(self, method, path, data = {}, headers = None, files=None): """set up connection and perform query""" if headers is None: headers = self.default_headers @@ -125,8 +125,13 @@ def perform_query(self, method, path, data = {}, headers = None): resp = self.session.request(method, self.url + path, params = translated_data, verify = self.verify, headers = headers) else: - resp = self.session.request(method, self.url + path, data = json.dumps(data), - verify = self.verify, headers = headers) + if files is None: + resp = self.session.request(method, self.url + path, data = json.dumps(data), + verify = self.verify, headers = headers) + else: + # Multipart file upload + resp = self.session.request(method, self.url + path, files = files, data = data, + verify = self.verify, headers = headers) try: resp.raise_for_status() except requests.exceptions.HTTPError as e: diff --git a/databricks_cli/sdk/service.py b/databricks_cli/sdk/service.py index c3cc1a7b..9489a760 100755 --- a/databricks_cli/sdk/service.py +++ b/databricks_cli/sdk/service.py @@ -23,6 +23,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os + + class JobsService(object): def __init__(self, client): self.client = client @@ -519,25 +522,35 @@ def list_test(self, path, headers=None): _data['path'] = path return self.client.perform_query('GET', '/dbfs-testing/list', data=_data, headers=headers) - def put(self, path, contents=None, overwrite=None, headers=None): + def put(self, path, contents=None, overwrite=None, headers=None, src_path=None): _data = {} + _files = None if path is not None: _data['path'] = path if contents is not None: _data['contents'] = contents if overwrite is not None: _data['overwrite'] = overwrite - return self.client.perform_query('POST', '/dbfs/put', data=_data, headers=headers) + if src_path is not None: + headers = {'Content-Type': None} + filename = os.path.basename(src_path) + _files = {'file': (filename, open(src_path, 'rb'), 'multipart/form-data')} + return self.client.perform_query('POST', '/dbfs/put', data=_data, headers=headers, files=_files) - def put_test(self, path, contents=None, overwrite=None, headers=None): + def put_test(self, path, contents=None, overwrite=None, headers=None, src_path=None): _data = {} + _files = None if path is not None: _data['path'] = path if contents is not None: _data['contents'] = contents if overwrite is not None: _data['overwrite'] = overwrite - return self.client.perform_query('POST', '/dbfs-testing/put', data=_data, headers=headers) + if src_path is not None: + headers = {'Content-Type': None} + filename = os.path.basename(src_path) + _files = {'file': (filename, open(src_path, 'rb'), 'multipart/form-data')} + return self.client.perform_query('POST', '/dbfs/put', data=_data, headers=headers, files=_files) def mkdirs(self, path, headers=None): _data = {} diff --git a/tests/dbfs/test_api.py b/tests/dbfs/test_api.py index c8de6899..ded0c976 100644 --- a/tests/dbfs/test_api.py +++ b/tests/dbfs/test_api.py @@ -135,6 +135,20 @@ def test_put_file(self, dbfs_api, tmpdir): api_mock.create.return_value = {'handle': test_handle} dbfs_api.put_file(test_file_path, TEST_DBFS_PATH, True) + # Should not call add-block since file is < 2GB + assert api_mock.add_block.call_count == 0 + + # Files >= 2GB should use create, add_block, close stream upload. + def test_put_large_file(self, dbfs_api, tmpdir): + test_file_path = os.path.join(tmpdir.strpath, 'test') + with open(test_file_path, 'wt') as f: + f.write('test') + api_mock = dbfs_api.client + # Make streaming upload threshold 2 bytes for testing. + dbfs_api.MULTIPART_UPLOAD_LIMIT = 2 + test_handle = 0 + api_mock.create.return_value = {'handle': test_handle} + dbfs_api.put_file(test_file_path, TEST_DBFS_PATH, True) assert api_mock.add_block.call_count == 1 assert test_handle == api_mock.add_block.call_args[0][0] assert b64encode(b'test').decode() == api_mock.add_block.call_args[0][1]