From 3748164ddb902553a84390c0146b2e92b5a572dd Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Thu, 12 Apr 2018 12:54:37 +0300 Subject: [PATCH 1/2] cloud: aws: migrate to boto3 This allows us to abandon our own additional config parameters for aws cloud(like region) and use the ones from .aws/config by default. Plus, boto3 automatically switches to multipartupload if file is bigger than a certain limit, allowing us to remove a lot of our own code. Fixes #650 Fixes #636 Fixes #601 Signed-off-by: Ruslan Kuprieiev --- .appveyor.yml | 1 + dvc/cloud/aws.py | 227 ++++++++--------------------------- dvc/cloud/credentials_aws.py | 92 -------------- dvc/utils.py | 2 +- requirements.txt | 2 +- scripts/ci/install.sh | 15 ++- setup.py | 2 +- tests/test_data_cloud.py | 89 +++++++------- 8 files changed, 107 insertions(+), 323 deletions(-) delete mode 100644 dvc/cloud/credentials_aws.py diff --git a/.appveyor.yml b/.appveyor.yml index a1d257cdd1..1662210da2 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -57,6 +57,7 @@ build: false before_test: - aws configure set aws_access_key_id "%aws_access_key_id%" - aws configure set aws_secret_access_key "%aws_secret_access_key%" + - aws configure set region us-east-2 - openssl enc -d -aes-256-cbc -md md5 -k "%GCP_CREDS%" -in scripts\ci\gcp-creds.json.enc -out scripts\ci\gcp-creds.json & exit 0 - pip install -r test-requirements.txt diff --git a/dvc/cloud/aws.py b/dvc/cloud/aws.py index fa2e81ed71..6360056d26 100644 --- a/dvc/cloud/aws.py +++ b/dvc/cloud/aws.py @@ -1,7 +1,9 @@ import os import math +import threading -from boto.s3.connection import S3Connection +import boto3 +import botocore try: import httplib except ImportError: @@ -15,7 +17,6 @@ from dvc.config import Config from dvc.logger import Logger from dvc.progress import progress -from dvc.cloud.credentials_aws import AWSCredentials from dvc.cloud.base import DataCloudError, DataCloudBase @@ -28,92 +29,59 @@ def sizeof_fmt(num, suffix='B'): return "%.1f%s%s" % (num, 'Y', suffix) -def percent_cb(name, part_complete, part_total, offset=0, multipart_total=None): +def percent_cb(name, complete, total): """ Callback for updating target progress """ - complete = offset + part_complete - total = multipart_total if multipart_total != None else part_total - Logger.debug('{}: {} transferred out of {}'.format(name, sizeof_fmt(complete), sizeof_fmt(total))) progress.update_target(name, complete, total) -def create_cb(name, offset=0, multipart_total=None): +def create_cb(name): """ Create callback function for multipart object """ - return (lambda cur, tot: percent_cb(name, cur, tot, offset, multipart_total)) + return (lambda cur, tot: percent_cb(name, cur, tot)) + + +class Callback(object): + def __init__(self, name, total): + self.name = name + self.total = total + self.current = 0 + self.lock = threading.Lock() + + def __call__(self, byts): + with self.lock: + self.current += byts + progress.update_target(self.name, self.current, self.total) + + +class AWSKey(object): + def __init__(self, bucket, name): + self.name = name + self.bucket = bucket class DataCloudAWS(DataCloudBase): """ DataCloud class for Amazon Web Services """ REGEX = r'^s3://(?P.*)$' - def __init__(self, cloud_settings): - super(DataCloudAWS, self).__init__(cloud_settings) - self._aws_creds = AWSCredentials(cloud_settings.cloud_config) - - @property - def aws_region_host(self): - """ get the region host needed for s3 access - - See notes http://docs.aws.amazon.com/general/latest/gr/rande.html#s3_region - """ - - region = self._cloud_settings.cloud_config.get(Config.SECTION_AWS_REGION, None) - if region is None or region == '': - return 's3.amazonaws.com' - if region == 'us-east-1': - return 's3.amazonaws.com' - return 's3.%s.amazonaws.com' % region - - def credential_paths(self, default): - """ - Try obtaining path to aws credentials from config file. - """ - paths = [] - credpath = self._cloud_settings.cloud_config.get(Config.SECTION_AWS_CREDENTIALPATH, None) - if credpath is not None and len(credpath) > 0: - credpath = os.path.expanduser(credpath) - if os.path.isfile(credpath): - paths.append(credpath) - else: - Logger.warn('AWS CredentialPath "%s" not found;' - 'falling back to default "%s"' % (credpath, default)) - paths.append(default) - else: - paths.append(default) - return paths - def connect(self): - if all([self._aws_creds.access_key_id, - self._aws_creds.secret_access_key, - self.aws_region_host]): - conn = S3Connection(self._aws_creds.access_key_id, - self._aws_creds.secret_access_key, - host=self.aws_region_host) - else: - conn = S3Connection() - self.bucket = conn.lookup(self.storage_bucket) - if self.bucket is None: + self.s3 = boto3.resource('s3') + bucket = self.s3.Bucket(self.storage_bucket) + if bucket is None: raise DataCloudError('Storage path {} is not setup correctly'.format(self.storage_bucket)) - @staticmethod - def _upload_tracker(fname): - """ - File name for upload tracker. - """ - return fname + '.upload' + def create_cb_pull(self, name, key): + total = self.s3.Object(bucket_name=key.bucket, key=key.name).content_length + return Callback(name, total) - @staticmethod - def _download_tracker(fname): - """ - File name for download tracker. - """ - return fname + '.download' + def create_cb_push(self, name, fname): + total = os.path.getsize(fname) + return Callback(name, total) def _pull_key(self, key, fname, no_progress_bar=False): Logger.debug("Pulling key '{}' from bucket '{}' to file '{}'".format(key.name, - key.bucket.name, + key.bucket, fname)) self._makedirs(fname) @@ -124,19 +92,18 @@ def _pull_key(self, key, fname, no_progress_bar=False): Logger.debug('File "{}" matches with "{}".'.format(fname, key.name)) return fname - Logger.debug('Downloading cache file from S3 "{}/{}" to "{}"'.format(key.bucket.name, + Logger.debug('Downloading cache file from S3 "{}/{}" to "{}"'.format(key.bucket, key.name, fname)) if no_progress_bar: cb = None else: - cb = create_cb(name) + cb = self.create_cb_pull(name, key) + - res_h = ResumableDownloadHandler(tracker_file_name=self._download_tracker(tmp_file), - num_retries=10) try: - key.get_contents_to_filename(tmp_file, cb=cb, res_download_handler=res_h) + self.s3.Object(key.bucket, key.name).download_file(tmp_file, Callback=cb) except Exception as exc: Logger.error('Failed to download "{}": {}'.format(key.name, exc)) return None @@ -152,122 +119,26 @@ def _pull_key(self, key, fname, no_progress_bar=False): def _get_key(self, path): key_name = self.cache_file_key(path) - return self.bucket.get_key(key_name) - - def _new_key(self, path): - key_name = self.cache_file_key(path) - return self.bucket.new_key(key_name) - - def _write_upload_tracker(self, fname, mp_id): - """ - Write multipart id to upload tracker. - """ - try: - open(self._upload_tracker(fname), 'w+').write(mp_id) - except Exception as exc: - Logger.debug("Failed to write upload tracker file for {}: {}".format(fname, exc)) - - def _unlink_upload_tracker(self, fname): - """ - Remove upload tracker file. - """ try: - os.unlink(self._upload_tracker(fname)) - except Exception as exc: - Logger.debug("Failed to unlink upload tracker file for {}: {}".format(fname, exc)) - - def _resume_multipart(self, key, fname): - """ - Try resuming multipart upload. - """ - try: - mp_id = open(self._upload_tracker(fname), 'r').read() - except Exception as exc: - Logger.debug("Failed to read upload tracker file for {}: {}".format(fname, exc)) + self.s3.Object(self.storage_bucket, key_name).get() + return AWSKey(self.storage_bucket, key_name) + except botocore.errorfactory.ClientError: return None - for part in key.bucket.get_all_multipart_uploads(): - if part.id != mp_id: - continue - - Logger.debug("Found existing multipart {}".format(mp_id)) - return part - - return None - - def _create_multipart(self, key, fname): - """ - Create multipart upload and save info to tracker file. - """ - multipart = key.bucket.initiate_multipart_upload(key.name) - self._write_upload_tracker(fname, multipart.id) - return multipart - - def _get_multipart(self, key, fname): - """ - Try resuming multipart upload if supported. - """ - multipart = self._resume_multipart(key, fname) - if multipart != None: - return multipart - - return self._create_multipart(key, fname) - - @staticmethod - def _skip_part(multipart, part_num, size): - """ - Skip part of multipart upload if it has been already uploaded to the server. - """ - for part in multipart.get_all_parts(): - if part.part_number == part_num and part.size == size:# and p.etag and p.last_modified - Logger.debug("Skipping part #{}".format(str(part_num))) - return True - return False - - def _push_multipart(self, key, fname): - """ - Upload local file to cloud as a multipart upload. - """ - multipart = self._get_multipart(key, fname) - - source_size = os.stat(fname).st_size - chunk_size = 50*1024*1024 - chunk_count = int(math.ceil(source_size / float(chunk_size))) - - with open(fname, 'rb') as fobj: - for i in range(chunk_count): - offset = i * chunk_size - left = source_size - offset - size = min([chunk_size, left]) - part_num = i + 1 - - if self._skip_part(multipart, part_num, size): - continue - - fobj.seek(offset) - name = os.path.relpath(fname, self._cloud_settings.cache.cache_dir) - cb = create_cb(name, offset, source_size) - multipart.upload_part_from_file(fp=fobj, - replace=False, - size=size, - num_cb=100, - part_num=part_num, - cb=cb) - - if len(multipart.get_all_parts()) != chunk_count: - raise Exception("Couldn't upload all file parts") - - multipart.complete_upload() - self._unlink_upload_tracker(fname) + def _new_key(self, path): + key_name = self.cache_file_key(path) + return AWSKey(self.storage_bucket, key_name) def _push_key(self, key, path): """ push, aws version """ + name = os.path.relpath(path, self._cloud_settings.cache.cache_dir) + cb = self.create_cb_push(name, path) try: - self._push_multipart(key, path) + self.s3.Object(key.bucket, key.name).upload_file(path, Callback=cb) except Exception as exc: Logger.error('Failed to upload "{}": {}'.format(path, exc)) return None - progress.finish_target(os.path.relpath(path, self._cloud_settings.cache.cache_dir)) + progress.finish_target(name) return path diff --git a/dvc/cloud/credentials_aws.py b/dvc/cloud/credentials_aws.py deleted file mode 100644 index 1885a0ea11..0000000000 --- a/dvc/cloud/credentials_aws.py +++ /dev/null @@ -1,92 +0,0 @@ -import os -import configparser - -from dvc.config import Config -from dvc.utils import cached_property -from dvc.logger import Logger - - -class AWSCredentials(object): - def __init__(self, cloud_config): - self._conf_credpath = cloud_config.get(Config.SECTION_AWS_CREDENTIALPATH, None) - self._conf_credsect = cloud_config.get(Config.SECTION_AWS_PROFILE, 'default') - - @property - def access_key_id(self): - if self.creds: - return self.creds[0] - return None - - @property - def secret_access_key(self): - if self.creds: - return self.creds[1] - return None - - @cached_property - def creds(self): - return self._get_credentials() - - def _get_credentials(self): - """ gets aws credentials, looking in various places - - Params: - - Searches: - 1 any override in dvc.conf [AWS] CredentialPath; - 2 ~/.aws/credentials - - - Returns: - if successfully found, (access_key_id, secret) - None otherwise - """ - - # FIX: It won't work in Windows. - default_path = os.path.expanduser('~/.aws/credentials') - default_sect = 'default' - default_cred_location = (default_path, default_sect) - - cred_locations = self._credential_paths(default_cred_location) - for cred_location in cred_locations: - try: - path = cred_location[0] - section = cred_location[1] - - cc = configparser.SafeConfigParser() - - # use readfp(open( ... to aid mocking. - cc.readfp(open(path, 'r')) - - if section in cc.keys(): - access_key = cc[section].get('aws_access_key_id', None) - secret = cc[section].get('aws_secret_access_key', None) - - if access_key is not None and secret is not None: - return (access_key, secret) - else: - Logger.warn('Unable to find section {} in AWS credential file {}'.format(section, path)) - except Exception as e: - pass - - return None - - def _credential_paths(self, default_cred_location): - results = [] - if self._conf_credpath is not None and len(self._conf_credpath) > 0: - credpath = os.path.expanduser(self._conf_credpath) - if os.path.isfile(credpath): - results.append((credpath, self._conf_credsect)) - else: - msg = 'AWS CredentialPath {} not found; falling back to default file {} and section {}' - Logger.warn(msg.format(credpath, default_cred_location[0], default_cred_location[1])) - results.append(default_cred_location) - else: - results.append(default_cred_location) - return results - - def sanity_check(self): - creds = self._get_credentials() - if creds is None: - Logger.info("can't find aws credetials, assuming envirment variables or iam role") - # self._aws_creds = creds diff --git a/dvc/utils.py b/dvc/utils.py index 395acc52f8..b3196dac5c 100644 --- a/dvc/utils.py +++ b/dvc/utils.py @@ -84,7 +84,7 @@ def wrap(func, t): try: return func(t) except Exception as exc: - Logger.error('wrap', exc) + Logger.error('Error', exc) raise diff --git a/requirements.txt b/requirements.txt index 286627bef3..e8077d2ca2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -boto>=2.46.1 +boto3==1.7.4 google-compute-engine>=2.4.1 #required by boto configparser>=3.5.0 zc.lockfile>=1.2.1 diff --git a/scripts/ci/install.sh b/scripts/ci/install.sh index 87551efb3d..b4849ff68b 100644 --- a/scripts/ci/install.sh +++ b/scripts/ci/install.sh @@ -8,9 +8,12 @@ pip install -r requirements.txt pip install -r test-requirements.txt git config --global user.email "dvctester@example.com" git config --global user.name "DVC Tester" -mkdir ~/.aws -printf "[default]\n" > ~/.aws/credentials -printf "aws_access_key_id = $AWS_ACCESS_KEY_ID\n" >> ~/.aws/credentials -printf "aws_secret_access_key = $AWS_SECRET_ACCESS_KEY\n" >> ~/.aws/credentials -printf "[default]\n" > ~/.aws/config -openssl enc -d -aes-256-cbc -md md5 -k $GCP_CREDS -in scripts/ci/gcp-creds.json.enc -out scripts/ci/gcp-creds.json || true + +if [[ "$TRAVIS_PULL_REQUEST" == "false" && \ + "$TRAVIS_SECURE_ENV_VARS" == "true" ]]; then + aws configure set aws_access_key_id $AWS_ACCESS_KEY_ID + aws configure set aws_secret_access_key $AWS_SECRET_ACCESS_KEY + aws configure set region us-east-2 + + openssl enc -d -aes-256-cbc -md md5 -k $GCP_CREDS -in scripts/ci/gcp-creds.json.enc -out scripts/ci/gcp-creds.json +fi diff --git a/setup.py b/setup.py index 0a59e02e65..9c2e01d75e 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ install_requires = [ - "boto>=2.46.1", + "boto3==1.7.4", "google-compute-engine>=2.4.1", #required by boto "configparser>=3.5.0", "zc.lockfile>=1.2.1", diff --git a/tests/test_data_cloud.py b/tests/test_data_cloud.py index c511188062..171cd43a1d 100644 --- a/tests/test_data_cloud.py +++ b/tests/test_data_cloud.py @@ -20,8 +20,6 @@ TEST_SECTION: {Config.SECTION_REMOTE_URL: ''}} TEST_AWS_REPO_BUCKET = 'dvc-test' -TEST_AWS_REPO_REGION = 'us-east-2' - TEST_GCP_REPO_BUCKET = 'dvc-test' @@ -122,8 +120,28 @@ def test_unsupported(self): class TestDataCloudBase(TestDvc): + def _should_test(self): + return False + + def _get_url(self): + return None + + @property + def cloud_class(self): + return None + def _setup_cloud(self): - self.cloud = None + if not self._should_test(): + return + + repo = self._get_url() + + config = TEST_CONFIG + config[TEST_SECTION][Config.SECTION_REMOTE_URL] = repo + cloud_settings = CloudSettings(cache=self.dvc.cache, + state=self.dvc.state, + cloud_config=config[TEST_SECTION]) + self.cloud = self._get_cloud_class()(cloud_settings) def _test_cloud(self): self._setup_cloud() @@ -193,61 +211,46 @@ def _test_cloud(self): status_dir = self.cloud.status(cache_dir) self.assertEqual(status_dir, STATUS_OK) + def test(self): + if self._should_test(): + self._test_cloud() -class TestDataCloudAWS(TestDataCloudBase): - def _setup_cloud(self): - if not _should_test_aws(): - return - repo = get_aws_url() +class TestDataCloudAWS(TestDataCloudBase): + def _should_test(self): + return _should_test_aws() - # Setup cloud - config = TEST_CONFIG - config[TEST_SECTION][Config.SECTION_REMOTE_URL] = repo - config[TEST_SECTION][Config.SECTION_AWS_REGION] = TEST_AWS_REPO_REGION - cloud_settings = CloudSettings(cache=self.dvc.cache, - state=self.dvc.state, - cloud_config=config[TEST_SECTION]) - self.cloud = DataCloudAWS(cloud_settings) + def _get_url(self): + return get_aws_url() - def test(self): - if _should_test_aws(): - self._test_cloud() + def _get_cloud_class(self): + return DataCloudAWS class TestDataCloudGCP(TestDataCloudBase): - def _setup_cloud(self): - if not _should_test_gcp(): - return + def _should_test(self): + return _should_test_gcp() - repo = get_gcp_url() + def _get_url(self): + return get_gcp_url() - # Setup cloud - config = TEST_CONFIG - config[TEST_SECTION][Config.SECTION_REMOTE_URL] = repo - cloud_settings = CloudSettings(cache=self.dvc.cache, - state=self.dvc.state, - cloud_config=config[TEST_SECTION]) - self.cloud = DataCloudGCP(cloud_settings) - - def test(self): - if _should_test_gcp(): - self._test_cloud() + def _get_cloud_class(self): + return DataCloudGCP class TestDataCloudLOCAL(TestDataCloudBase): - def _setup_cloud(self): + def _should_test(self): + return True + + def _get_url(self): self.dname = get_local_url() + return self.dname - config = TEST_CONFIG - config[TEST_SECTION][Config.SECTION_REMOTE_URL] = self.dname - cloud_settings = CloudSettings(cache=self.dvc.cache, - state=self.dvc.state, - cloud_config=config[TEST_SECTION]) - self.cloud = DataCloudLOCAL(cloud_settings) + def _get_cloud_class(self): + return DataCloudLOCAL def test(self): - self._test_cloud() + super(TestDataCloudLOCAL, self).test() self.assertTrue(os.path.isdir(self.dname)) @@ -352,7 +355,6 @@ def _test_compat(self): storagepath = get_aws_storagepath() self.main(['config', 'core.cloud', 'aws']) self.main(['config', 'aws.storagepath', storagepath]) - self.main(['config', 'aws.region', TEST_AWS_REPO_REGION]) self._test_cloud() @@ -360,7 +362,6 @@ def _test(self): url = get_aws_url() self.main(['remote', 'add', TEST_REMOTE, url]) - self.main(['remote', 'modify', TEST_REMOTE, Config.SECTION_AWS_REGION, TEST_AWS_REPO_REGION]) self._test_cloud(TEST_REMOTE) From 48cc9af6caa9e0b7411e36eb5142259a33eb13ef Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Thu, 12 Apr 2018 19:46:31 +0300 Subject: [PATCH 2/2] config: introduce config.local Fixes #524 Signed-off-by: Ruslan Kuprieiev --- dvc/cli.py | 20 ++++++++++++++++++++ dvc/command/config.py | 6 +++++- dvc/config.py | 6 ++++++ dvc/project.py | 3 ++- 4 files changed, 33 insertions(+), 2 deletions(-) diff --git a/dvc/cli.py b/dvc/cli.py index 09f5fd5d1f..722d011196 100644 --- a/dvc/cli.py +++ b/dvc/cli.py @@ -239,6 +239,11 @@ def parse_args(argv=None): nargs='?', default=None, help='Option value') + config_parser.add_argument( + '--local', + action='store_true', + default=False, + help='Use local config') config_parser.set_defaults(func=CmdConfig) @@ -262,6 +267,11 @@ def parse_args(argv=None): remote_add_parser.add_argument( 'url', help='Url') + remote_add_parser.add_argument( + '--local', + action='store_true', + default=False, + help='Use local config') remote_add_parser.set_defaults(func=CmdRemoteAdd) @@ -272,6 +282,11 @@ def parse_args(argv=None): remote_remove_parser.add_argument( 'name', help='Name') + remote_remove_parser.add_argument( + '--local', + action='store_true', + default=False, + help='Use local config') remote_remove_parser.set_defaults(func=CmdRemoteRemove) @@ -294,6 +309,11 @@ def parse_args(argv=None): default=False, action='store_true', help='Unset option') + remote_modify_parser.add_argument( + '--local', + action='store_true', + default=False, + help='Use local config') remote_modify_parser.set_defaults(func=CmdRemoteModify) diff --git a/dvc/command/config.py b/dvc/command/config.py index 4bb216d804..dc669c6579 100644 --- a/dvc/command/config.py +++ b/dvc/command/config.py @@ -11,7 +11,11 @@ class CmdConfig(CmdBase): def __init__(self, args): self.args = args root_dir = self._find_root() - self.config_file = os.path.join(root_dir, Project.DVC_DIR, Config.CONFIG) + if args.local: + config = Config.CONFIG_LOCAL + else: + config = Config.CONFIG + self.config_file = os.path.join(root_dir, Project.DVC_DIR, config) # Using configobj because it doesn't # drop comments like configparser does. self.configobj = configobj.ConfigObj(self.config_file) diff --git a/dvc/config.py b/dvc/config.py index 156c77c302..b1832c0de7 100644 --- a/dvc/config.py +++ b/dvc/config.py @@ -21,6 +21,7 @@ def supported_url(url): class Config(object): CONFIG = 'config' + CONFIG_LOCAL = 'config.local' SECTION_CORE = 'core' SECTION_CORE_LOGLEVEL = 'loglevel' @@ -95,13 +96,18 @@ class Config(object): def __init__(self, dvc_dir): self.dvc_dir = os.path.abspath(os.path.realpath(dvc_dir)) self.config_file = os.path.join(dvc_dir, self.CONFIG) + self.config_local_file = os.path.join(dvc_dir, self.CONFIG_LOCAL) try: self._config = configobj.ConfigObj(self.config_file) + local = configobj.ConfigObj(self.config_local_file) # NOTE: schema doesn't support ConfigObj.Section validation, so we # need to convert our config to dict before passing it to schema. self._config = self._lower(self._config) + local = self._lower(local) + self._config.update(local) + self._config = schema.Schema(self.SCHEMA).validate(self._config) # NOTE: now converting back to ConfigObj diff --git a/dvc/project.py b/dvc/project.py index 1537a411ab..cd15461253 100644 --- a/dvc/project.py +++ b/dvc/project.py @@ -69,7 +69,8 @@ def init(root_dir=os.curdir): scm = SCM(root_dir) scm.ignore_list([cache.cache_dir, state.state_file, - lock.lock_file]) + lock.lock_file, + config.config_local_file]) ignore_file = os.path.join(dvc_dir, scm.ignore_file()) scm.add([config.config_file, ignore_file])