diff --git a/dvc/config.py b/dvc/config.py index ad2070939c..0fd312aa14 100644 --- a/dvc/config.py +++ b/dvc/config.py @@ -186,6 +186,7 @@ class Config(object): # pylint: disable=too-many-instance-attributes SECTION_GDRIVE_CLIENT_SECRET = "gdrive_client_secret" SECTION_GDRIVE_USER_CREDENTIALS_FILE = "gdrive_user_credentials_file" + SECTION_REMOTE_CHECKSUM_JOBS = "checksum_jobs" SECTION_REMOTE_REGEX = r'^\s*remote\s*"(?P.*)"\s*$' SECTION_REMOTE_FMT = 'remote "{}"' SECTION_REMOTE_URL = "url" @@ -214,6 +215,7 @@ class Config(object): # pylint: disable=too-many-instance-attributes SECTION_GCP_PROJECTNAME: str, SECTION_CACHE_TYPE: supported_cache_type, Optional(SECTION_CACHE_PROTECTED, default=False): Bool, + SECTION_REMOTE_CHECKSUM_JOBS: All(Coerce(int), Range(1)), SECTION_REMOTE_USER: str, SECTION_REMOTE_PORT: Coerce(int), SECTION_REMOTE_KEY_FILE: str, diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 36409645eb..9f47779c8f 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -86,11 +86,7 @@ def __init__(self, repo, config): self.repo = repo self._check_requires(config) - - core = config.get(Config.SECTION_CORE, {}) - self.checksum_jobs = core.get( - Config.SECTION_CORE_CHECKSUM_JOBS, self.CHECKSUM_JOBS - ) + self.checksum_jobs = self._get_checksum_jobs(config) self.protected = False self.no_traverse = config.get(Config.SECTION_REMOTE_NO_TRAVERSE, True) @@ -142,6 +138,19 @@ def _check_requires(self, config): ).format(url, missing, " ".join(missing), self.scheme) raise RemoteMissingDepsError(msg) + def _get_checksum_jobs(self, config): + checksum_jobs = config.get(Config.SECTION_REMOTE_CHECKSUM_JOBS) + if checksum_jobs: + return checksum_jobs + + if self.repo: + core = self.repo.config.config.get(Config.SECTION_CORE, {}) + return core.get( + Config.SECTION_CORE_CHECKSUM_JOBS, self.CHECKSUM_JOBS + ) + + return self.CHECKSUM_JOBS + def __repr__(self): return "{class_name}: '{path_info}'".format( class_name=type(self).__name__, diff --git a/tests/unit/remote/test_gdrive.py b/tests/unit/remote/test_gdrive.py index bdc8951073..09dc7e62b6 100644 --- a/tests/unit/remote/test_gdrive.py +++ b/tests/unit/remote/test_gdrive.py @@ -1,6 +1,7 @@ import pytest import os +from dvc.config import Config from dvc.remote.gdrive import ( RemoteGDrive, GDriveAccessTokenRefreshError, @@ -14,6 +15,7 @@ class Repo(object): tmp_dir = "" + config = Config() class TestRemoteGDrive(object): diff --git a/tests/unit/remote/test_remote.py b/tests/unit/remote/test_remote.py new file mode 100644 index 0000000000..b2fa73aff7 --- /dev/null +++ b/tests/unit/remote/test_remote.py @@ -0,0 +1,41 @@ +from dvc.remote import Remote + + +def set_config_opts(dvc, commands): + list(map(lambda args: dvc.config.set(*args), commands)) + + +def test_remote_with_checksum_jobs(dvc): + set_config_opts( + dvc, + [ + ('remote "with_checksum_jobs"', "url", "s3://bucket/name"), + ('remote "with_checksum_jobs"', "checksum_jobs", 100), + ("core", "checksum_jobs", 200), + ], + ) + + remote = Remote(dvc, name="with_checksum_jobs") + assert remote.checksum_jobs == 100 + + +def test_remote_without_checksum_jobs(dvc): + set_config_opts( + dvc, + [ + ('remote "without_checksum_jobs"', "url", "s3://bucket/name"), + ("core", "checksum_jobs", "200"), + ], + ) + + remote = Remote(dvc, name="without_checksum_jobs") + assert remote.checksum_jobs == 200 + + +def test_remote_without_checksum_jobs_default(dvc): + set_config_opts( + dvc, [('remote "without_checksum_jobs"', "url", "s3://bucket/name")] + ) + + remote = Remote(dvc, name="without_checksum_jobs") + assert remote.checksum_jobs == remote.CHECKSUM_JOBS