From b94903202be0784a8c125a3f9ec7772b63de7675 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Mon, 15 Jun 2020 13:44:00 +0900 Subject: [PATCH 01/16] remote: move constants into trees --- dvc/remote/__init__.py | 71 ++++------ dvc/remote/azure.py | 24 ++-- dvc/remote/base.py | 272 +++++++++++++++++++------------------ dvc/remote/gdrive.py | 27 ++-- dvc/remote/gs.py | 20 +-- dvc/remote/hdfs.py | 25 ++-- dvc/remote/http.py | 22 +-- dvc/remote/https.py | 4 +- dvc/remote/local.py | 159 ++++++++++------------ dvc/remote/oss.py | 52 ++++--- dvc/remote/s3.py | 16 +-- dvc/remote/ssh/__init__.py | 33 ++--- 12 files changed, 326 insertions(+), 399 deletions(-) diff --git a/dvc/remote/__init__.py b/dvc/remote/__init__.py index 4ea41c4ee9..90c49c6f65 100644 --- a/dvc/remote/__init__.py +++ b/dvc/remote/__init__.py @@ -1,45 +1,36 @@ import posixpath from urllib.parse import urlparse -from dvc.remote.azure import AzureCache, AzureRemote -from dvc.remote.gdrive import GDriveRemote -from dvc.remote.gs import GSCache, GSRemote -from dvc.remote.hdfs import HDFSCache, HDFSRemote -from dvc.remote.http import HTTPRemote -from dvc.remote.https import HTTPSRemote -from dvc.remote.local import LocalCache, LocalRemote -from dvc.remote.oss import OSSRemote -from dvc.remote.s3 import S3Cache, S3Remote -from dvc.remote.ssh import SSHCache, SSHRemote +from dvc.remote.azure import AzureRemoteTree +from dvc.remote.gdrive import GDriveRemoteTree +from dvc.remote.gs import GSRemoteTree +from dvc.remote.hdfs import HDFSRemoteTree +from dvc.remote.http import HTTPRemoteTree +from dvc.remote.https import HTTPSRemoteTree +from dvc.remote.local import LocalRemoteTree +from dvc.remote.oss import OSSRemoteTree +from dvc.remote.s3 import S3RemoteTree +from dvc.remote.ssh import SSHRemoteTree -CACHES = [ - AzureCache, - GSCache, - HDFSCache, - S3Cache, - SSHCache, - # LocalCache is the default -] - -REMOTES = [ - AzureRemote, - GDriveRemote, - GSRemote, - HDFSRemote, - HTTPRemote, - HTTPSRemote, - S3Remote, - SSHRemote, - OSSRemote, - # NOTE: LocalRemote is the default +TREES = [ + AzureRemoteTree, + GDriveRemoteTree, + GSRemoteTree, + HDFSRemoteTree, + HTTPRemoteTree, + HTTPSRemoteTree, + S3RemoteTree, + SSHRemoteTree, + OSSRemoteTree, + # NOTE: LocalRemoteTree is the default ] -def _get(remote_conf, remotes, default): - for remote in remotes: - if remote.supported(remote_conf): - return remote - return default +def get_cloud_tree(remote_conf, remotes): + for tree_cls in TREES: + if tree_cls.supported(remote_conf): + return tree_cls + return LocalRemoteTree def _get_conf(repo, **kwargs): @@ -51,16 +42,6 @@ def _get_conf(repo, **kwargs): return _resolve_remote_refs(repo.config, remote_conf) -def Remote(repo, **kwargs): - remote_conf = _get_conf(repo, **kwargs) - return _get(remote_conf, REMOTES, LocalRemote)(repo, remote_conf) - - -def Cache(repo, **kwargs): - remote_conf = _get_conf(repo, **kwargs) - return _get(remote_conf, CACHES, LocalCache)(repo, remote_conf) - - def _resolve_remote_refs(config, remote_conf): # Support for cross referenced remotes. # This will merge the settings, shadowing base ref with remote_conf. diff --git a/dvc/remote/azure.py b/dvc/remote/azure.py index 46606d11a9..c162072a88 100644 --- a/dvc/remote/azure.py +++ b/dvc/remote/azure.py @@ -7,17 +7,22 @@ from dvc.path_info import CloudURLInfo from dvc.progress import Tqdm -from dvc.remote.base import BaseRemote, BaseRemoteTree, CacheMixin +from dvc.remote.base import BaseRemoteTree from dvc.scheme import Schemes logger = logging.getLogger(__name__) class AzureRemoteTree(BaseRemoteTree): + scheme = Schemes.AZURE PATH_CLS = CloudURLInfo + REQUIRES = {"azure-storage-blob": "azure.storage.blob"} + PARAM_CHECKSUM = "etag" + COPY_POLL_SECONDS = 5 + LIST_OBJECT_PAGE_SIZE = 5000 - def __init__(self, remote, config): - super().__init__(remote, config) + def __init__(self, repo, config): + super().__init__(repo, config) url = config.get("url", "azure://") self.path_info = self.PATH_CLS(url) @@ -132,16 +137,3 @@ def _download( to_file, progress_callback=pbar.update_to, ) - - -class AzureRemote(BaseRemote): - scheme = Schemes.AZURE - REQUIRES = {"azure-storage-blob": "azure.storage.blob"} - TREE_CLS = AzureRemoteTree - PARAM_CHECKSUM = "etag" - COPY_POLL_SECONDS = 5 - LIST_OBJECT_PAGE_SIZE = 5000 - - -class AzureCache(AzureRemote, CacheMixin): - pass diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 625f0fcd18..721df03fbc 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -82,15 +82,90 @@ def wrapper(remote_obj, *args, **kwargs): class BaseRemoteTree: - SHARED_MODE_MAP = {None: (None, None), "group": (None, None)} + scheme = "base" + REQUIRES = {} PATH_CLS = URLInfo + JOBS = 4 * cpu_count() + + PARAM_RELPATH = "relpath" CHECKSUM_DIR_SUFFIX = ".dir" + CHECKSUM_JOBS = max(1, min(4, cpu_count() // 2)) + DEFAULT_VERIFY = False + LIST_OBJECT_PAGE_SIZE = 1000 + TRAVERSE_WEIGHT_MULTIPLIER = 5 + TRAVERSE_PREFIX_LEN = 3 + TRAVERSE_THRESHOLD_SIZE = 500000 + CAN_TRAVERSE = True + + SHARED_MODE_MAP = {None: (None, None), "group": (None, None)} + CHECKSUM_DIR_SUFFIX = ".dir" + + state = StateNoop() + + def __init__(self, repo, config): + self.repo = repo + self._check_requires(config) - def __init__(self, remote, config): - self.remote = remote shared = config.get("shared") self._file_mode, self._dir_mode = self.SHARED_MODE_MAP[shared] + self.checksum_jobs = ( + config.get("checksum_jobs") + or (self.repo and self.repo.config["core"].get("checksum_jobs")) + or self.CHECKSUM_JOBS + ) + self.verify = config.get("verify", self.DEFAULT_VERIFY) + + @classmethod + def get_missing_deps(cls): + import importlib + + missing = [] + for package, module in cls.REQUIRES.items(): + try: + importlib.import_module(module) + except ImportError: + missing.append(package) + + return missing + + def _check_requires(self, config): + missing = self.get_missing_deps() + if not missing: + return + + url = config.get("url", f"{self.scheme}://") + msg = ( + "URL '{}' is supported but requires these missing " + "dependencies: {}. If you have installed dvc using pip, " + "choose one of these options to proceed: \n" + "\n" + " 1) Install specific missing dependencies:\n" + " pip install {}\n" + " 2) Install dvc package that includes those missing " + "dependencies: \n" + " pip install 'dvc[{}]'\n" + " 3) Install dvc package with all possible " + "dependencies included: \n" + " pip install 'dvc[all]'\n" + "\n" + "If you have installed dvc from a binary package and you " + "are still seeing this message, please report it to us " + "using https://github.com/iterative/dvc/issues. Thank you!" + ).format(url, missing, " ".join(missing), self.scheme) + raise RemoteMissingDepsError(msg) + + @classmethod + def supported(cls, config): + if isinstance(config, (str, bytes)): + url = config + else: + url = config["url"] + + # NOTE: silently skipping remote, calling code should handle that + parsed = urlparse(url) + return parsed.scheme == cls.scheme + @property def file_mode(self): return self._file_mode @@ -99,17 +174,9 @@ def file_mode(self): def dir_mode(self): return self._dir_mode - @property - def scheme(self): - return self.remote.scheme - - @property - def state(self): - return self.remote.state - @property def cache(self): - return self.remote.cache + return getattr(self.repo.cache, self.scheme) def open(self, path_info, mode="r", encoding=None): if hasattr(self, "_generate_download_url"): @@ -172,6 +239,17 @@ def hardlink(self, from_info, to_info): def reflink(self, from_info, to_info): raise RemoteActionNotImplemented("reflink", self.scheme) + @staticmethod + def protect(path_info): + pass + + def is_protected(self, path_info): + return False + + @staticmethod + def unprotect(path_info): + pass + @classmethod def is_dir_checksum(cls, checksum): if not checksum: @@ -234,7 +312,7 @@ def _calculate_checksums(self, file_infos, tree): ) as pbar: worker = pbar.wrap_fn(tree.get_file_checksum) with ThreadPoolExecutor( - max_workers=self.remote.checksum_jobs + max_workers=self.checksum_jobs ) as executor: tasks = executor.map(worker, file_infos) checksums = dict(zip(file_infos, tasks)) @@ -259,7 +337,7 @@ def _collect_dir(self, path_info, tree, **kwargs): result = [ { - self.remote.PARAM_CHECKSUM: checksums[fi], + self.PARAM_CHECKSUM: checksums[fi], # NOTE: this is lossy transformation: # "hey\there" -> "hey/there" # "hey/there" -> "hey/there" @@ -268,24 +346,20 @@ def _collect_dir(self, path_info, tree, **kwargs): # # Yes, this is a BUG, as long as we permit "/" in # filenames on Windows and "\" on Unix - self.remote.PARAM_RELPATH: fi.relative_to( - path_info - ).as_posix(), + self.PARAM_RELPATH: fi.relative_to(path_info).as_posix(), } for fi in file_infos ] # Sorting the list by path to ensure reproducibility - return sorted(result, key=itemgetter(self.remote.PARAM_RELPATH)) + return sorted(result, key=itemgetter(self.PARAM_RELPATH)) def _save_dir_info(self, dir_info, path_info): checksum, tmp_info = self._get_dir_info_checksum(dir_info) new_info = self.cache.checksum_to_path_info(checksum) if self.cache.changed_cache_file(checksum): self.cache.tree.makedirs(new_info.parent) - self.cache.tree.move( - tmp_info, new_info, mode=self.remote.CACHE_MODE - ) + self.cache.tree.move(tmp_info, new_info, mode=self.CACHE_MODE) if self.exists(path_info): self.state.save(path_info, checksum) @@ -418,124 +492,44 @@ def _download_file( move(tmp_file, to_info, mode=file_mode) -class BaseRemote: - """Base cloud remote class.""" +class Remote: + """Cloud remote class.""" - scheme = "base" - REQUIRES = {} - JOBS = 4 * cpu_count() INDEX_CLS = RemoteIndex - TREE_CLS = BaseRemoteTree - PARAM_RELPATH = "relpath" - CHECKSUM_DIR_SUFFIX = ".dir" - CHECKSUM_JOBS = max(1, min(4, cpu_count() // 2)) - DEFAULT_CACHE_TYPES = ["copy"] - DEFAULT_VERIFY = False - LIST_OBJECT_PAGE_SIZE = 1000 - TRAVERSE_WEIGHT_MULTIPLIER = 5 - TRAVERSE_PREFIX_LEN = 3 - TRAVERSE_THRESHOLD_SIZE = 500000 - CAN_TRAVERSE = True - - CACHE_MODE = None - - state = StateNoop() - - def __init__(self, repo, config): + def __init__(self, repo, config, tree): self.repo = repo - - self._check_requires(config) - - self.checksum_jobs = ( - config.get("checksum_jobs") - or (self.repo and self.repo.config["core"].get("checksum_jobs")) - or self.CHECKSUM_JOBS - ) - self.verify = config.get("verify", self.DEFAULT_VERIFY) - self._dir_info = {} - - self.cache_types = config.get("type") or copy(self.DEFAULT_CACHE_TYPES) - self.cache_type_confirmed = False + self.tree = tree url = config.get("url") - if url: + if self.scheme != "local" and url: index_name = hashlib.sha256(url.encode("utf-8")).hexdigest() - self.index = self.INDEX_CLS( + self.index = self.RemoteIndex( self.repo, index_name, dir_suffix=self.CHECKSUM_DIR_SUFFIX ) else: self.index = RemoteIndexNoop() - self.tree = self.TREE_CLS(self, config) - @property def path_info(self): return self.tree.path_info - @classmethod - def get_missing_deps(cls): - import importlib - - missing = [] - for package, module in cls.REQUIRES.items(): - try: - importlib.import_module(module) - except ImportError: - missing.append(package) - - return missing - - def _check_requires(self, config): - missing = self.get_missing_deps() - if not missing: - return - - url = config.get("url", f"{self.scheme}://") - msg = ( - "URL '{}' is supported but requires these missing " - "dependencies: {}. If you have installed dvc using pip, " - "choose one of these options to proceed: \n" - "\n" - " 1) Install specific missing dependencies:\n" - " pip install {}\n" - " 2) Install dvc package that includes those missing " - "dependencies: \n" - " pip install 'dvc[{}]'\n" - " 3) Install dvc package with all possible " - "dependencies included: \n" - " pip install 'dvc[all]'\n" - "\n" - "If you have installed dvc from a binary package and you " - "are still seeing this message, please report it to us " - "using https://github.com/iterative/dvc/issues. Thank you!" - ).format(url, missing, " ".join(missing), self.scheme) - raise RemoteMissingDepsError(msg) - def __repr__(self): return "{class_name}: '{path_info}'".format( class_name=type(self).__name__, path_info=self.path_info or "No path", ) - @classmethod - def supported(cls, config): - if isinstance(config, (str, bytes)): - url = config - else: - url = config["url"] - - # NOTE: silently skipping remote, calling code should handle that - parsed = urlparse(url) - return parsed.scheme == cls.scheme - @property def cache(self): return getattr(self.repo.cache, self.scheme) - @classmethod - def is_dir_checksum(cls, checksum): - return cls.TREE_CLS.is_dir_checksum(checksum) + @property + def scheme(self): + return self.tree.scheme + + def is_dir_checksum(self, checksum): + return self.tree.is_dir_checksum(checksum) def get_checksum(self, path_info, **kwargs): return self.tree.get_checksum(path_info, **kwargs) @@ -561,17 +555,6 @@ def save_info(self, path_info, tree=None, **kwargs): def open(self, *args, **kwargs): return self.tree.open(*args, **kwargs) - @staticmethod - def protect(path_info): - pass - - def is_protected(self, path_info): - return False - - @staticmethod - def unprotect(path_info): - pass - def list_paths(self, prefix=None, progress_callback=None): if prefix: if len(prefix) > 2: @@ -613,7 +596,7 @@ def all(self, jobs=None, name=None): ) ) - if not self.CAN_TRAVERSE: + if not self.tree.CAN_TRAVERSE: return self.list_checksums() remote_size, remote_checksums = self._estimate_remote_size(name=name) @@ -667,7 +650,7 @@ def checksums_exist(self, checksums, jobs=None, name=None): if not checksums: return indexed_checksums - if len(checksums) == 1 or not self.CAN_TRAVERSE: + if len(checksums) == 1 or not self.tree.CAN_TRAVERSE: remote_checksums = self._list_checksums_exists( checksums, jobs, name ) @@ -879,8 +862,35 @@ def _remove_unpacked_dir(self, checksum): pass -class CacheMixin: - """BaseRemote extensions for cache link/checkout operations.""" +class CloudCache: + """Cloud cache class.""" + + DEFAULT_CACHE_TYPES = ["copy"] + CACHE_MODE = None + + def __init__(self, repo, config, tree): + self.repo = repo + self.tree = tree + + self.cache_types = tree.config.get("type") or copy( + self.DEFAULT_CACHE_TYPES + ) + self.cache_type_confirmed = False + self._dir_info = {} + + @property + def cache(self): + return getattr(self.repo.cache, self.scheme) + + @property + def scheme(self): + return self.tree.scheme + + def is_dir_checksum(self, checksum): + return self.tree.is_dir_checksum(checksum) + + def get_checksum(self, path_info, **kwargs): + return self.tree.get_checksum(path_info, **kwargs) # Override to return path as a string instead of PathInfo for clouds # which support string paths (see local) @@ -1037,7 +1047,7 @@ def _save_file(self, path_info, tree, checksum, save_link=True, **kwargs): path_info ): # Default relink procedure involves unneeded copy - self.unprotect(path_info) + self.tree.unprotect(path_info) else: self.tree.remove(path_info) self.link(cache_info, path_info) @@ -1141,7 +1151,7 @@ def changed_cache_file(self, checksum): """ # Prefer string path over PathInfo when possible due to performance cache_info = self.checksum_to_path(checksum) - if self.is_protected(cache_info): + if self.tree.is_protected(cache_info): logger.debug( "Assuming '%s' is unchanged since it is read-only", cache_info ) @@ -1162,7 +1172,7 @@ def changed_cache_file(self, checksum): if actual.split(".")[0] == checksum.split(".")[0]: # making cache file read-only so we don't need to check it # next time - self.protect(cache_info) + self.tree.protect(cache_info) return False if self.tree.exists(cache_info): diff --git a/dvc/remote/gdrive.py b/dvc/remote/gdrive.py index 31f63cedc3..fa7194d4e6 100644 --- a/dvc/remote/gdrive.py +++ b/dvc/remote/gdrive.py @@ -14,7 +14,7 @@ from dvc.exceptions import DvcException, FileMissingError from dvc.path_info import CloudURLInfo from dvc.progress import Tqdm -from dvc.remote.base import BaseRemote, BaseRemoteTree +from dvc.remote.base import BaseRemoteTree from dvc.scheme import Schemes from dvc.utils import format_link, tmp_fname from dvc.utils.stream import IterStream @@ -85,7 +85,13 @@ def __init__(self, url): class GDriveRemoteTree(BaseRemoteTree): + scheme = Schemes.GDRIVE PATH_CLS = GDriveURLInfo + REQUIRES = {"pydrive2": "pydrive2"} + DEFAULT_VERIFY = True + # Always prefer traverse for GDrive since API usage quotas are a concern. + TRAVERSE_WEIGHT_MULTIPLIER = 1 + TRAVERSE_PREFIX_LEN = 2 GDRIVE_CREDENTIALS_DATA = "GDRIVE_CREDENTIALS_DATA" DEFAULT_USER_CREDENTIALS_FILE = "gdrive-user-credentials.json" @@ -93,8 +99,8 @@ class GDriveRemoteTree(BaseRemoteTree): DEFAULT_GDRIVE_CLIENT_ID = "710796635688-iivsgbgsb6uv1fap6635dhvuei09o66c.apps.googleusercontent.com" # noqa: E501 DEFAULT_GDRIVE_CLIENT_SECRET = "a1Fz59uTpVNeG_VGuSKDLJXv" - def __init__(self, remote, config): - super().__init__(remote, config) + def __init__(self, repo, config): + super().__init__(repo, config) self.path_info = self.PATH_CLS(config["url"]) @@ -122,13 +128,12 @@ def __init__(self, remote, config): self._client_secret = config.get("gdrive_client_secret") self._validate_config() self._gdrive_user_credentials_path = ( - tmp_fname(os.path.join(self.remote.repo.tmp_dir, "")) + tmp_fname(os.path.join(self.repo.tmp_dir, "")) if os.getenv(GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA) else config.get( "gdrive_user_credentials_file", os.path.join( - self.remote.repo.tmp_dir, - self.DEFAULT_USER_CREDENTIALS_FILE, + self.repo.tmp_dir, self.DEFAULT_USER_CREDENTIALS_FILE, ), ) ) @@ -574,13 +579,3 @@ def _upload(self, from_file, to_info, name=None, no_progress_bar=False): def _download(self, from_info, to_file, name=None, no_progress_bar=False): item_id = self._get_item_id(from_info) self._gdrive_download_file(item_id, to_file, name, no_progress_bar) - - -class GDriveRemote(BaseRemote): - scheme = Schemes.GDRIVE - REQUIRES = {"pydrive2": "pydrive2"} - TREE_CLS = GDriveRemoteTree - DEFAULT_VERIFY = True - # Always prefer traverse for GDrive since API usage quotas are a concern. - TRAVERSE_WEIGHT_MULTIPLIER = 1 - TRAVERSE_PREFIX_LEN = 2 diff --git a/dvc/remote/gs.py b/dvc/remote/gs.py index 341589739d..5796403102 100644 --- a/dvc/remote/gs.py +++ b/dvc/remote/gs.py @@ -9,7 +9,7 @@ from dvc.exceptions import DvcException from dvc.path_info import CloudURLInfo from dvc.progress import Tqdm -from dvc.remote.base import BaseRemote, BaseRemoteTree, CacheMixin +from dvc.remote.base import BaseRemoteTree from dvc.scheme import Schemes logger = logging.getLogger(__name__) @@ -65,10 +65,13 @@ def _upload_to_bucket( class GSRemoteTree(BaseRemoteTree): + scheme = Schemes.GS PATH_CLS = CloudURLInfo + REQUIRES = {"google-cloud-storage": "google.cloud.storage"} + PARAM_CHECKSUM = "md5" - def __init__(self, remote, config): - super().__init__(remote, config) + def __init__(self, repo, config): + super().__init__(repo, config) url = config.get("url", "gs:///") self.path_info = self.PATH_CLS(url) @@ -193,14 +196,3 @@ def _download(self, from_info, to_file, name=None, no_progress_bar=False): disable=no_progress_bar, ) as wrapped: blob.download_to_file(wrapped) - - -class GSRemote(BaseRemote): - scheme = Schemes.GS - REQUIRES = {"google-cloud-storage": "google.cloud.storage"} - TREE_CLS = GSRemoteTree - PARAM_CHECKSUM = "md5" - - -class GSCache(GSRemote, CacheMixin): - pass diff --git a/dvc/remote/hdfs.py b/dvc/remote/hdfs.py index 629b5767bd..fa3f62e460 100644 --- a/dvc/remote/hdfs.py +++ b/dvc/remote/hdfs.py @@ -11,15 +11,21 @@ from dvc.scheme import Schemes from dvc.utils import fix_env, tmp_fname -from .base import BaseRemote, BaseRemoteTree, CacheMixin, RemoteCmdError +from .base import BaseRemoteTree, RemoteCmdError from .pool import get_connection logger = logging.getLogger(__name__) class HDFSRemoteTree(BaseRemoteTree): - def __init__(self, remote, config): - super().__init__(remote, config) + scheme = Schemes.HDFS + REQUIRES = {"pyarrow": "pyarrow"} + REGEX = r"^hdfs://((?P.*)@)?.*$" + PARAM_CHECKSUM = "checksum" + TRAVERSE_PREFIX_LEN = 2 + + def __init__(self, repo, config): + super().__init__(repo, config) self.path_info = None url = config.get("url") @@ -173,16 +179,3 @@ def _download(self, from_info, to_file, **_kwargs): with self.hdfs(from_info) as hdfs: with open(to_file, "wb+") as fobj: hdfs.download(from_info.path, fobj) - - -class HDFSRemote(BaseRemote): - scheme = Schemes.HDFS - REGEX = r"^hdfs://((?P.*)@)?.*$" - PARAM_CHECKSUM = "checksum" - REQUIRES = {"pyarrow": "pyarrow"} - TREE_CLS = HDFSRemoteTree - TRAVERSE_PREFIX_LEN = 2 - - -class HDFSCache(HDFSRemote, CacheMixin): - pass diff --git a/dvc/remote/http.py b/dvc/remote/http.py index 941b6de25c..b28e7670df 100644 --- a/dvc/remote/http.py +++ b/dvc/remote/http.py @@ -8,7 +8,7 @@ from dvc.exceptions import DvcException, HTTPError from dvc.path_info import HTTPURLInfo from dvc.progress import Tqdm -from dvc.remote.base import BaseRemote, BaseRemoteTree +from dvc.remote.base import BaseRemoteTree from dvc.scheme import Schemes logger = logging.getLogger(__name__) @@ -24,15 +24,18 @@ def ask_password(host, user): class HTTPRemoteTree(BaseRemoteTree): + scheme = Schemes.HTTP PATH_CLS = HTTPURLInfo + PARAM_CHECKSUM = "etag" + CAN_TRAVERSE = False SESSION_RETRIES = 5 SESSION_BACKOFF_FACTOR = 0.1 REQUEST_TIMEOUT = 10 CHUNK_SIZE = 2 ** 16 - def __init__(self, remote, config): - super().__init__(remote, config) + def __init__(self, repo, config): + super().__init__(repo, config) url = config.get("url") if url: @@ -185,16 +188,3 @@ def chunks(): def _content_length(response): res = response.headers.get("Content-Length") return int(res) if res else None - - -class HTTPRemote(BaseRemote): - scheme = Schemes.HTTP - PARAM_CHECKSUM = "etag" - CAN_TRAVERSE = False - TREE_CLS = HTTPRemoteTree - - def list_paths(self, prefix=None, progress_callback=None): - raise NotImplementedError - - def gc(self): - raise NotImplementedError diff --git a/dvc/remote/https.py b/dvc/remote/https.py index 370648d495..1cfad246fa 100644 --- a/dvc/remote/https.py +++ b/dvc/remote/https.py @@ -1,7 +1,7 @@ from dvc.scheme import Schemes -from .http import HTTPRemote +from .http import HTTPRemoteTree -class HTTPSRemote(HTTPRemote): +class HTTPSRemoteTree(HTTPRemoteTree): scheme = Schemes.HTTPS diff --git a/dvc/remote/local.py b/dvc/remote/local.py index 9e6b7b6a2e..bb3a156d60 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -16,12 +16,11 @@ STATUS_MAP, STATUS_MISSING, STATUS_NEW, - BaseRemote, BaseRemoteTree, - CacheMixin, + CloudCache, + Remote, index_locked, ) -from dvc.remote.index import RemoteIndexNoop from dvc.scheme import Schemes from dvc.scm.tree import WorkingTree, is_working_tree from dvc.system import System @@ -39,17 +38,25 @@ class LocalRemoteTree(BaseRemoteTree): - SHARED_MODE_MAP = {None: (0o644, 0o755), "group": (0o664, 0o775)} + scheme = Schemes.LOCAL PATH_CLS = PathInfo + PARAM_CHECKSUM = "md5" + PARAM_PATH = "path" + DEFAULT_CACHE_TYPES = ["reflink", "copy"] + TRAVERSE_PREFIX_LEN = 2 + UNPACKED_DIR_SUFFIX = ".unpacked" + + CACHE_MODE = 0o444 + SHARED_MODE_MAP = {None: (0o644, 0o755), "group": (0o664, 0o775)} - def __init__(self, remote, config): - super().__init__(remote, config) + def __init__(self, repo, config): + super().__init__(repo, config) url = config.get("url") self.path_info = self.PATH_CLS(url) if url else None @property - def repo(self): - return self.remote.repo + def state(self): + return self.repo.state @cached_property def work_tree(self): @@ -195,78 +202,6 @@ def reflink(self, from_info, to_info): os.chmod(tmp_info, self.file_mode) os.rename(tmp_info, to_info) - def get_file_checksum(self, path_info): - return file_md5(path_info)[0] - - @staticmethod - def getsize(path_info): - return os.path.getsize(path_info) - - def _upload( - self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs - ): - makedirs(to_info.parent, exist_ok=True) - - tmp_file = tmp_fname(to_info) - copyfile( - from_file, tmp_file, name=name, no_progress_bar=no_progress_bar - ) - - self.remote.protect(tmp_file) - os.rename(tmp_file, to_info) - - @staticmethod - def _download( - from_info, to_file, name=None, no_progress_bar=False, **_kwargs - ): - copyfile( - from_info, to_file, no_progress_bar=no_progress_bar, name=name - ) - - -def _log_exceptions(func, operation): - @wraps(func) - def wrapper(from_info, to_info, *args, **kwargs): - try: - func(from_info, to_info, *args, **kwargs) - return 0 - except Exception as exc: - # NOTE: this means we ran out of file descriptors and there is no - # reason to try to proceed, as we will hit this error anyways. - if isinstance(exc, OSError) and exc.errno == errno.EMFILE: - raise - - logger.exception( - "failed to %s '%s' to '%s'", operation, from_info, to_info - ) - return 1 - - return wrapper - - -class LocalRemote(BaseRemote): - scheme = Schemes.LOCAL - INDEX_CLS = RemoteIndexNoop - TREE_CLS = LocalRemoteTree - - PARAM_CHECKSUM = "md5" - PARAM_PATH = "path" - DEFAULT_CACHE_TYPES = ["reflink", "copy"] - TRAVERSE_PREFIX_LEN = 2 - UNPACKED_DIR_SUFFIX = ".unpacked" - - CACHE_MODE = 0o444 - - @property - def state(self): - return self.repo.state - - def get(self, md5): - if not md5: - return None - - return self.checksum_to_path_info(md5).url - def _unprotect_file(self, path): if System.is_symlink(path) or System.is_hardlink(path): logger.debug(f"Unprotecting '{path}'") @@ -333,6 +268,62 @@ def is_protected(self, path_info): return stat.S_IMODE(mode) == self.CACHE_MODE + def get_file_checksum(self, path_info): + return file_md5(path_info)[0] + + @staticmethod + def getsize(path_info): + return os.path.getsize(path_info) + + def _upload( + self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs + ): + makedirs(to_info.parent, exist_ok=True) + + tmp_file = tmp_fname(to_info) + copyfile( + from_file, tmp_file, name=name, no_progress_bar=no_progress_bar + ) + + self.protect(tmp_file) + os.rename(tmp_file, to_info) + + @staticmethod + def _download( + from_info, to_file, name=None, no_progress_bar=False, **_kwargs + ): + copyfile( + from_info, to_file, no_progress_bar=no_progress_bar, name=name + ) + + +def _log_exceptions(func, operation): + @wraps(func) + def wrapper(from_info, to_info, *args, **kwargs): + try: + func(from_info, to_info, *args, **kwargs) + return 0 + except Exception as exc: + # NOTE: this means we ran out of file descriptors and there is no + # reason to try to proceed, as we will hit this error anyways. + if isinstance(exc, OSError) and exc.errno == errno.EMFILE: + raise + + logger.exception( + "failed to %s '%s' to '%s'", operation, from_info, to_info + ) + return 1 + + return wrapper + + +class LocalRemote(Remote): + def get(self, md5): + if not md5: + return None + + return self.checksum_to_path_info(md5).url + def list_paths(self, prefix=None, progress_callback=None): assert self.path_info is not None if prefix: @@ -354,10 +345,10 @@ def _remove_unpacked_dir(self, checksum): self.tree.remove(path_info) -class LocalCache(LocalRemote, CacheMixin): - def __init__(self, repo, config): - super().__init__(repo, config) - self.cache_dir = config.get("url") +class LocalCache(CloudCache): + def __init__(self, tree): + super().__init__(tree) + self.cache_dir = tree.config.get("url") @property def cache_dir(self): diff --git a/dvc/remote/oss.py b/dvc/remote/oss.py index 8dcee9d584..11b6f10002 100644 --- a/dvc/remote/oss.py +++ b/dvc/remote/oss.py @@ -6,14 +6,37 @@ from dvc.path_info import CloudURLInfo from dvc.progress import Tqdm -from dvc.remote.base import BaseRemote, BaseRemoteTree +from dvc.remote.base import BaseRemoteTree from dvc.scheme import Schemes logger = logging.getLogger(__name__) class OSSRemoteTree(BaseRemoteTree): + """ + oss2 document: + https://www.alibabacloud.com/help/doc-detail/32026.htm + + + Examples + ---------- + $ dvc remote add myremote oss://my-bucket/path + Set key id, key secret and endpoint using modify command + $ dvc remote modify myremote oss_key_id my-key-id + $ dvc remote modify myremote oss_key_secret my-key-secret + $ dvc remote modify myremote oss_endpoint endpoint + or environment variables + $ export OSS_ACCESS_KEY_ID="my-key-id" + $ export OSS_ACCESS_KEY_SECRET="my-key-secret" + $ export OSS_ENDPOINT="endpoint" + """ + + scheme = Schemes.OSS PATH_CLS = CloudURLInfo + REQUIRES = {"oss2": "oss2"} + PARAM_CHECKSUM = "etag" + COPY_POLL_SECONDS = 5 + LIST_OBJECT_PAGE_SIZE = 100 def __init__(self, repo, config): super().__init__(repo, config) @@ -105,30 +128,3 @@ def _download( self.oss_service.get_object_to_file( from_info.path, to_file, progress_callback=pbar.update_to ) - - -class OSSRemote(BaseRemote): - """ - oss2 document: - https://www.alibabacloud.com/help/doc-detail/32026.htm - - - Examples - ---------- - $ dvc remote add myremote oss://my-bucket/path - Set key id, key secret and endpoint using modify command - $ dvc remote modify myremote oss_key_id my-key-id - $ dvc remote modify myremote oss_key_secret my-key-secret - $ dvc remote modify myremote oss_endpoint endpoint - or environment variables - $ export OSS_ACCESS_KEY_ID="my-key-id" - $ export OSS_ACCESS_KEY_SECRET="my-key-secret" - $ export OSS_ENDPOINT="endpoint" - """ - - scheme = Schemes.OSS - REQUIRES = {"oss2": "oss2"} - PARAM_CHECKSUM = "etag" - COPY_POLL_SECONDS = 5 - LIST_OBJECT_PAGE_SIZE = 100 - TREE_CLS = OSSRemoteTree diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index 5ed0d67091..19670a9787 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -8,14 +8,17 @@ from dvc.exceptions import DvcException, ETagMismatchError from dvc.path_info import CloudURLInfo from dvc.progress import Tqdm -from dvc.remote.base import BaseRemote, BaseRemoteTree, CacheMixin +from dvc.remote.base import BaseRemoteTree from dvc.scheme import Schemes logger = logging.getLogger(__name__) class S3RemoteTree(BaseRemoteTree): + scheme = Schemes.S3 PATH_CLS = CloudURLInfo + REQUIRES = {"boto3": "boto3"} + PARAM_CHECKSUM = "etag" def __init__(self, repo, config): super().__init__(repo, config) @@ -334,14 +337,3 @@ def _download(self, from_info, to_file, name=None, no_progress_bar=False): self.s3.download_file( from_info.bucket, from_info.path, to_file, Callback=pbar.update ) - - -class S3Remote(BaseRemote): - scheme = Schemes.S3 - REQUIRES = {"boto3": "boto3"} - PARAM_CHECKSUM = "etag" - TREE_CLS = S3RemoteTree - - -class S3Cache(S3Remote, CacheMixin): - pass diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index f5784f0ab3..9ec5cdcbbf 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -14,7 +14,7 @@ import dvc.prompt as prompt from dvc.progress import Tqdm -from dvc.remote.base import BaseRemote, BaseRemoteTree, CacheMixin +from dvc.remote.base import BaseRemoteTree, Remote from dvc.remote.pool import get_connection from dvc.scheme import Schemes from dvc.utils import to_chunks @@ -34,6 +34,18 @@ def ask_password(host, user, port): class SSHRemoteTree(BaseRemoteTree): + scheme = Schemes.SSH + REQUIRES = {"paramiko": "paramiko"} + JOBS = 4 + + PARAM_CHECKSUM = "md5" + # At any given time some of the connections will go over network and + # paramiko stuff, so we would ideally have it double of server processors. + # We use conservative setting of 4 instead to not exhaust max sessions. + CHECKSUM_JOBS = 4 + DEFAULT_CACHE_TYPES = ["copy"] + TRAVERSE_PREFIX_LEN = 2 + DEFAULT_PORT = 22 TIMEOUT = 1800 @@ -257,20 +269,7 @@ def _upload(self, from_file, to_info, name=None, no_progress_bar=False): ) -class SSHRemote(BaseRemote): - scheme = Schemes.SSH - REQUIRES = {"paramiko": "paramiko"} - JOBS = 4 - TREE_CLS = SSHRemoteTree - - PARAM_CHECKSUM = "md5" - # At any given time some of the connections will go over network and - # paramiko stuff, so we would ideally have it double of server processors. - # We use conservative setting of 4 instead to not exhaust max sessions. - CHECKSUM_JOBS = 4 - DEFAULT_CACHE_TYPES = ["copy"] - TRAVERSE_PREFIX_LEN = 2 - +class SSHRemote(Remote): def list_paths(self, prefix=None, progress_callback=None): if prefix: root = posixpath.join(self.path_info.path, prefix[:2]) @@ -344,7 +343,3 @@ def exists_with_progress(chunks): in_remote = itertools.chain.from_iterable(results) ret = list(itertools.compress(checksums, in_remote)) return ret - - -class SSHCache(SSHRemote, CacheMixin): - pass From 319b90c8cf6c1c675396a3cc943215c9cb937b2b Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Mon, 15 Jun 2020 14:53:39 +0900 Subject: [PATCH 02/16] remote: replace Remote() Cache() helper functions --- dvc/remote/__init__.py | 21 ++++++++++++++++++--- dvc/remote/base.py | 34 +++++++++++++++++++++------------- dvc/remote/local.py | 6 ------ 3 files changed, 39 insertions(+), 22 deletions(-) diff --git a/dvc/remote/__init__.py b/dvc/remote/__init__.py index 90c49c6f65..d375ac202f 100644 --- a/dvc/remote/__init__.py +++ b/dvc/remote/__init__.py @@ -2,15 +2,16 @@ from urllib.parse import urlparse from dvc.remote.azure import AzureRemoteTree +from dvc.remote.base import Remote from dvc.remote.gdrive import GDriveRemoteTree from dvc.remote.gs import GSRemoteTree from dvc.remote.hdfs import HDFSRemoteTree from dvc.remote.http import HTTPRemoteTree from dvc.remote.https import HTTPSRemoteTree -from dvc.remote.local import LocalRemoteTree +from dvc.remote.local import LocalRemote, LocalRemoteTree from dvc.remote.oss import OSSRemoteTree from dvc.remote.s3 import S3RemoteTree -from dvc.remote.ssh import SSHRemoteTree +from dvc.remote.ssh import SSHRemote, SSHRemoteTree TREES = [ AzureRemoteTree, @@ -26,7 +27,7 @@ ] -def get_cloud_tree(remote_conf, remotes): +def _get_tree(remote_conf): for tree_cls in TREES: if tree_cls.supported(remote_conf): return tree_cls @@ -71,3 +72,17 @@ def _resolve_remote_refs(config, remote_conf): base = config["remote"][parsed.netloc] url = posixpath.join(base["url"], parsed.path.lstrip("/")) return {**base, **remote_conf, "url": url} + + +def get_cloud_tree(repo, **kwargs): + remote_conf = _get_conf(repo, **kwargs) + return _get_tree(remote_conf)(repo, remote_conf) + + +def get_remote(repo, **kwargs): + tree = get_cloud_tree(repo, **kwargs) + if tree.scheme == "local": + return LocalRemote(tree) + if tree.scheme == "ssh": + return SSHRemote(tree) + return Remote(tree) diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 721df03fbc..8ba4447bca 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -104,6 +104,8 @@ class BaseRemoteTree: def __init__(self, repo, config): self.repo = repo + self.config = config + self._check_requires(config) shared = config.get("shared") @@ -303,6 +305,13 @@ def get_dir_checksum(self, path_info, tree, **kwargs): dir_info = self._collect_dir(path_info, tree, **kwargs) return self._save_dir_info(dir_info, path_info) + def save_info(self, path_info, tree=None, **kwargs): + return { + self.PARAM_CHECKSUM: self.get_checksum( + path_info, tree=tree, **kwargs + ) + } + def _calculate_checksums(self, file_infos, tree): file_infos = list(file_infos) with Tqdm( @@ -493,18 +502,21 @@ def _download_file( class Remote: - """Cloud remote class.""" + """Cloud remote class. - INDEX_CLS = RemoteIndex + Provides methods for indexing and garbage collecting trees which contain + DVC remotes. + """ - def __init__(self, repo, config, tree): - self.repo = repo + def __init__(self, tree): self.tree = tree + self.repo = tree.repo + config = tree.config url = config.get("url") if self.scheme != "local" and url: index_name = hashlib.sha256(url.encode("utf-8")).hexdigest() - self.index = self.RemoteIndex( + self.index = RemoteIndex( self.repo, index_name, dir_suffix=self.CHECKSUM_DIR_SUFFIX ) else: @@ -545,12 +557,8 @@ def path_to_checksum(self, path): return "".join(parts) - def save_info(self, path_info, tree=None, **kwargs): - return { - self.PARAM_CHECKSUM: self.tree.get_checksum( - path_info, tree=tree, **kwargs - ) - } + def save_info(self, path_info, **kwargs): + return self.tree.save_info(path_info, **kwargs) def open(self, *args, **kwargs): return self.tree.open(*args, **kwargs) @@ -868,9 +876,9 @@ class CloudCache: DEFAULT_CACHE_TYPES = ["copy"] CACHE_MODE = None - def __init__(self, repo, config, tree): - self.repo = repo + def __init__(self, tree): self.tree = tree + self.repo = tree.repo self.cache_types = tree.config.get("type") or copy( self.DEFAULT_CACHE_TYPES diff --git a/dvc/remote/local.py b/dvc/remote/local.py index bb3a156d60..8256389bc2 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -318,12 +318,6 @@ def wrapper(from_info, to_info, *args, **kwargs): class LocalRemote(Remote): - def get(self, md5): - if not md5: - return None - - return self.checksum_to_path_info(md5).url - def list_paths(self, prefix=None, progress_callback=None): assert self.path_info is not None if prefix: From e808fbf0c84c51fc925e74763d2f1665f23f4b56 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Mon, 15 Jun 2020 15:01:13 +0900 Subject: [PATCH 03/16] output: update for remote changes --- dvc/output/__init__.py | 16 ++++++++-------- dvc/output/base.py | 29 +++++++++++++++++------------ dvc/output/gs.py | 4 ++-- dvc/output/hdfs.py | 4 ++-- dvc/output/local.py | 9 +++++---- dvc/output/s3.py | 4 ++-- dvc/output/ssh.py | 5 +++-- 7 files changed, 39 insertions(+), 32 deletions(-) diff --git a/dvc/output/__init__.py b/dvc/output/__init__.py index 74a550a8f5..57e6eff82d 100644 --- a/dvc/output/__init__.py +++ b/dvc/output/__init__.py @@ -10,10 +10,10 @@ from dvc.output.local import LocalOutput from dvc.output.s3 import S3Output from dvc.output.ssh import SSHOutput -from dvc.remote import Remote -from dvc.remote.hdfs import HDFSRemote -from dvc.remote.local import LocalRemote -from dvc.remote.s3 import S3Remote +from dvc.remote import get_remote +from dvc.remote.hdfs import HDFSRemoteTree +from dvc.remote.local import LocalRemoteTree +from dvc.remote.s3 import S3RemoteTree from dvc.scheme import Schemes OUTS = [ @@ -47,9 +47,9 @@ # so when a few types of outputs share the same name, we only need # specify it once. CHECKSUMS_SCHEMA = { - LocalRemote.PARAM_CHECKSUM: CHECKSUM_SCHEMA, - S3Remote.PARAM_CHECKSUM: CHECKSUM_SCHEMA, - HDFSRemote.PARAM_CHECKSUM: CHECKSUM_SCHEMA, + LocalRemoteTree.PARAM_CHECKSUM: CHECKSUM_SCHEMA, + S3RemoteTree.PARAM_CHECKSUM: CHECKSUM_SCHEMA, + HDFSRemoteTree.PARAM_CHECKSUM: CHECKSUM_SCHEMA, } SCHEMA = CHECKSUMS_SCHEMA.copy() @@ -66,7 +66,7 @@ def _get( parsed = urlparse(p) if parsed.scheme == "remote": - remote = Remote(stage.repo, name=parsed.netloc) + remote = get_remote(stage.repo, name=parsed.netloc) return OUTS_MAP[remote.scheme]( stage, p, diff --git a/dvc/output/base.py b/dvc/output/base.py index e57ae6e8fa..cc95fa655f 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -12,7 +12,7 @@ DvcException, RemoteCacheRequiredError, ) -from dvc.remote.base import BaseRemote +from dvc.remote.base import BaseRemoteTree, Remote logger = logging.getLogger(__name__) @@ -47,7 +47,8 @@ def __init__(self, path): class BaseOutput: IS_DEPENDENCY = False - REMOTE = BaseRemote + REMOTE_CLS = Remote + TREE_CLS = BaseRemoteTree PARAM_PATH = "path" PARAM_CACHE = "cache" @@ -105,7 +106,11 @@ def __init__( self.repo = stage.repo if stage else None self.def_path = path self.info = info - self.remote = remote or self.REMOTE(self.repo, {}) + if remote: + self.remote = remote + else: + tree = self.TREE_CLS(self.repo, {}) + self.remote = self.REMOTE_CLS(tree) self.use_cache = False if self.IS_DEPENDENCY else cache self.metric = False if self.IS_DEPENDENCY else metric self.plot = False if self.IS_DEPENDENCY else plot @@ -119,7 +124,7 @@ def _parse_path(self, remote, path): if remote: parsed = urlparse(path) return remote.path_info / parsed.path.lstrip("/") - return self.REMOTE.TREE_CLS.PATH_CLS(path) + return self.TREE_CLS.PATH_CLS(path) def __repr__(self): return "{class_name}: '{def_path}'".format( @@ -131,7 +136,7 @@ def __str__(self): @property def scheme(self): - return self.REMOTE.scheme + return self.TREE_CLS.scheme @property def is_in_repo(self): @@ -154,7 +159,7 @@ def dir_cache(self): @classmethod def supported(cls, url): - return cls.REMOTE.supported(url) + return cls.TREE_CLS.supported(url) @property def cache_path(self): @@ -162,15 +167,15 @@ def cache_path(self): @property def checksum_type(self): - return self.remote.PARAM_CHECKSUM + return self.remote.tree.PARAM_CHECKSUM @property def checksum(self): - return self.info.get(self.remote.PARAM_CHECKSUM) + return self.info.get(self.remote.tree.PARAM_CHECKSUM) @checksum.setter def checksum(self, checksum): - self.info[self.remote.PARAM_CHECKSUM] = checksum + self.info[self.remote.tree.PARAM_CHECKSUM] = checksum def get_checksum(self): return self.remote.get_checksum(self.path_info) @@ -357,7 +362,7 @@ def get_files_number(self, filter_info=None): def unprotect(self): if self.exists: - self.remote.unprotect(self.path_info) + self.remote.tree.unprotect(self.path_info) def get_dir_cache(self, **kwargs): if not self.is_dir_checksum: @@ -417,8 +422,8 @@ def collect_used_dir_cache( filter_path = str(filter_info) if filter_info else None is_win = os.name == "nt" for entry in self.dir_cache: - checksum = entry[self.remote.PARAM_CHECKSUM] - entry_relpath = entry[self.remote.PARAM_RELPATH] + checksum = entry[self.remote.tree.PARAM_CHECKSUM] + entry_relpath = entry[self.remote.tree.PARAM_RELPATH] if is_win: entry_relpath = entry_relpath.replace("/", os.sep) entry_path = os.path.join(path, entry_relpath) diff --git a/dvc/output/gs.py b/dvc/output/gs.py index 00b3a14f9b..ccab57a4f7 100644 --- a/dvc/output/gs.py +++ b/dvc/output/gs.py @@ -1,6 +1,6 @@ from dvc.output.s3 import S3Output -from dvc.remote.gs import GSRemote +from dvc.remote.gs import GSRemoteTree class GSOutput(S3Output): - REMOTE = GSRemote + TREE_CLS = GSRemoteTree diff --git a/dvc/output/hdfs.py b/dvc/output/hdfs.py index 6c32787393..fd44193a30 100644 --- a/dvc/output/hdfs.py +++ b/dvc/output/hdfs.py @@ -1,6 +1,6 @@ from dvc.output.base import BaseOutput -from dvc.remote.hdfs import HDFSRemote +from dvc.remote.hdfs import HDFSRemoteTree class HDFSOutput(BaseOutput): - REMOTE = HDFSRemote + REMOTE = HDFSRemoteTree diff --git a/dvc/output/local.py b/dvc/output/local.py index 2c097ebb55..259253bdb9 100644 --- a/dvc/output/local.py +++ b/dvc/output/local.py @@ -5,7 +5,7 @@ from dvc.exceptions import DvcException from dvc.istextfile import istextfile from dvc.output.base import BaseOutput -from dvc.remote.local import LocalRemote +from dvc.remote.local import LocalRemote, LocalRemoteTree from dvc.utils import relpath from dvc.utils.fs import path_isin @@ -13,7 +13,8 @@ class LocalOutput(BaseOutput): - REMOTE = LocalRemote + REMOTE_CLS = LocalRemote + TREE_CLS = LocalRemoteTree sep = os.sep def __init__(self, stage, path, *args, **kwargs): @@ -33,12 +34,12 @@ def _parse_path(self, remote, path): # # FIXME: if we have Windows path containing / or posix one with \ # then we have #2059 bug and can't really handle that. - p = self.REMOTE.TREE_CLS.PATH_CLS(path) + p = self.TREE_CLS.PATH_CLS(path) if not p.is_absolute(): p = self.stage.wdir / p abs_p = os.path.abspath(os.path.normpath(p)) - return self.REMOTE.TREE_CLS.PATH_CLS(abs_p) + return self.TREE_CLS.PATH_CLS(abs_p) def __str__(self): if not self.is_in_repo: diff --git a/dvc/output/s3.py b/dvc/output/s3.py index dbd6ee8995..92be85e510 100644 --- a/dvc/output/s3.py +++ b/dvc/output/s3.py @@ -1,6 +1,6 @@ from dvc.output.base import BaseOutput -from dvc.remote.s3 import S3Remote +from dvc.remote.s3 import S3RemoteTree class S3Output(BaseOutput): - REMOTE = S3Remote + TREE_CLS = S3RemoteTree diff --git a/dvc/output/ssh.py b/dvc/output/ssh.py index 29d0593f9c..27d3e6d436 100644 --- a/dvc/output/ssh.py +++ b/dvc/output/ssh.py @@ -1,6 +1,7 @@ from dvc.output.base import BaseOutput -from dvc.remote.ssh import SSHRemote +from dvc.remote.ssh import SSHRemote, SSHRemoteTree class SSHOutput(BaseOutput): - REMOTE = SSHRemote + REMOTE_CLS = SSHRemote + TREE_CLS = SSHRemoteTree From 05b13915f63f26dd6483060453df55ff79dd55d3 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Mon, 15 Jun 2020 15:01:39 +0900 Subject: [PATCH 04/16] cache/data_cloud: update for remote changes --- dvc/cache.py | 12 ++++++++---- dvc/data_cloud.py | 4 ++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/dvc/cache.py b/dvc/cache.py index a03912c470..acc2719042 100644 --- a/dvc/cache.py +++ b/dvc/cache.py @@ -28,13 +28,15 @@ def _make_remote_property(name): """ def getter(self): - from dvc.remote import Cache as CloudCache + from dvc.remote import get_cloud_tree + from dvc.remote.base import CloudCache remote = self.config.get(name) if not remote: return None - return CloudCache(self.repo, name=remote) + tree = get_cloud_tree(self.repo, name=remote) + return CloudCache(tree) getter.__name__ = name return cached_property(getter) @@ -50,7 +52,8 @@ class Cache: CACHE_DIR = "cache" def __init__(self, repo): - from dvc.remote import Cache as CloudCache + from dvc.remote import get_cloud_tree + from dvc.remote.local import LocalCache self.repo = repo self.config = config = repo.config["cache"] @@ -62,7 +65,8 @@ def __init__(self, repo): else: settings = {**config, "url": config["dir"]} - self.local = CloudCache(repo, **settings) + tree = get_cloud_tree(repo, **settings) + self.local = LocalCache(tree) s3 = _make_remote_property("s3") gs = _make_remote_property("gs") diff --git a/dvc/data_cloud.py b/dvc/data_cloud.py index 5e1757bb13..6e6503276e 100644 --- a/dvc/data_cloud.py +++ b/dvc/data_cloud.py @@ -3,7 +3,7 @@ import logging from dvc.config import NoRemoteError -from dvc.remote import Remote +from dvc.remote import get_remote logger = logging.getLogger(__name__) @@ -45,7 +45,7 @@ def get_remote(self, name=None, command=""): raise NoRemoteError(error_msg) def _init_remote(self, name): - return Remote(self.repo, name=name) + return get_remote(self.repo, name=name) def push( self, cache, jobs=None, remote=None, show_checksums=False, From 53651c657368bed661e5dc24b15de0c6214c736c Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Mon, 15 Jun 2020 15:03:35 +0900 Subject: [PATCH 05/16] dependency: update for remote changes --- dvc/dependency/__init__.py | 4 ++-- dvc/dependency/azure.py | 4 ++-- dvc/dependency/http.py | 4 ++-- dvc/dependency/https.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/dvc/dependency/__init__.py b/dvc/dependency/__init__.py index 1a778dbfc0..1d1730c061 100644 --- a/dvc/dependency/__init__.py +++ b/dvc/dependency/__init__.py @@ -12,7 +12,7 @@ from dvc.dependency.s3 import S3Dependency from dvc.dependency.ssh import SSHDependency from dvc.output.base import BaseOutput -from dvc.remote import Remote +from dvc.remote import get_remote from dvc.scheme import Schemes from .repo import RepoDependency @@ -54,7 +54,7 @@ def _get(stage, p, info): parsed = urlparse(p) if p else None if parsed and parsed.scheme == "remote": - remote = Remote(stage.repo, name=parsed.netloc) + remote = get_remote(stage.repo, name=parsed.netloc) return DEP_MAP[remote.scheme](stage, p, info, remote=remote) if info and info.get(RepoDependency.PARAM_REPO): diff --git a/dvc/dependency/azure.py b/dvc/dependency/azure.py index c119d15efb..da7b0727c3 100644 --- a/dvc/dependency/azure.py +++ b/dvc/dependency/azure.py @@ -1,7 +1,7 @@ from dvc.dependency.base import BaseDependency from dvc.output.base import BaseOutput -from dvc.remote.azure import AzureRemote +from dvc.remote.azure import AzureRemoteTree class AzureDependency(BaseDependency, BaseOutput): - REMOTE = AzureRemote + TREE_CLS = AzureRemoteTree diff --git a/dvc/dependency/http.py b/dvc/dependency/http.py index 653a2f0138..e8f0bf1e3d 100644 --- a/dvc/dependency/http.py +++ b/dvc/dependency/http.py @@ -1,7 +1,7 @@ from dvc.dependency.base import BaseDependency from dvc.output.base import BaseOutput -from dvc.remote.http import HTTPRemote +from dvc.remote.http import HTTPRemoteTree class HTTPDependency(BaseDependency, BaseOutput): - REMOTE = HTTPRemote + TREE_CLS = HTTPRemoteTree diff --git a/dvc/dependency/https.py b/dvc/dependency/https.py index b8ab2922f7..e95ac83f67 100644 --- a/dvc/dependency/https.py +++ b/dvc/dependency/https.py @@ -1,7 +1,7 @@ -from dvc.remote.https import HTTPSRemote +from dvc.remote.https import HTTPSRemoteTree from .http import HTTPDependency class HTTPSDependency(HTTPDependency): - REMOTE = HTTPSRemote + TREE_CLS = HTTPSRemoteTree From 4e1b5fa84da591c887e9355d305caf4e0e3331f2 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Mon, 15 Jun 2020 15:06:24 +0900 Subject: [PATCH 06/16] stage: update for remote changes --- dvc/stage/utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dvc/stage/utils.py b/dvc/stage/utils.py index 939803e8bf..ab7f61f7f2 100644 --- a/dvc/stage/utils.py +++ b/dvc/stage/utils.py @@ -8,7 +8,8 @@ from dvc.utils.fs import path_isin from ..dependency import ParamsDependency -from ..remote import LocalRemote, S3Remote +from ..remote.local import LocalRemoteTree +from ..remote.s3 import S3RemoteTree from ..utils import dict_md5, format_link, relpath from .exceptions import ( MissingDataSource, @@ -132,8 +133,8 @@ def stage_dump_eq(stage_cls, old_d, new_d): new_d.pop(stage_cls.PARAM_MD5, None) outs = old_d.get(stage_cls.PARAM_OUTS, []) for out in outs: - out.pop(LocalRemote.PARAM_CHECKSUM, None) - out.pop(S3Remote.PARAM_CHECKSUM, None) + out.pop(LocalRemoteTree.PARAM_CHECKSUM, None) + out.pop(S3RemoteTree.PARAM_CHECKSUM, None) # outs and deps are lists of dicts. To check equality, we need to make # them independent of the order, so, we convert them to dicts. From fbc2838876695a945edace4d3e9917804dd196e3 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Mon, 15 Jun 2020 16:11:24 +0900 Subject: [PATCH 07/16] remote: move remaining non-index methods into tree --- dvc/remote/base.py | 351 +++++++++++++++++++------------------ dvc/remote/local.py | 32 ++-- dvc/remote/ssh/__init__.py | 4 +- 3 files changed, 200 insertions(+), 187 deletions(-) diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 8ba4447bca..f919381793 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -305,6 +305,17 @@ def get_dir_checksum(self, path_info, tree, **kwargs): dir_info = self._collect_dir(path_info, tree, **kwargs) return self._save_dir_info(dir_info, path_info) + def checksum_to_path_info(self, checksum): + return self.path_info / checksum[0:2] / checksum[2:] + + def path_to_checksum(self, path): + parts = self.PATH_CLS(path).parts[-2:] + + if not (len(parts) == 2 and parts[0] and len(parts[0]) == 2): + raise ValueError(f"Bad cache file path '{path}'") + + return "".join(parts) + def save_info(self, path_info, tree=None, **kwargs): return { self.PARAM_CHECKSUM: self.get_checksum( @@ -463,7 +474,7 @@ def _download_dir( dir_mode=dir_mode, ) ) - with ThreadPoolExecutor(max_workers=self.remote.JOBS) as executor: + with ThreadPoolExecutor(max_workers=self.JOBS) as executor: futures = [ executor.submit(download_files, from_info, to_info) for from_info, to_info in zip(from_infos, to_infos) @@ -500,69 +511,6 @@ def _download_file( move(tmp_file, to_info, mode=file_mode) - -class Remote: - """Cloud remote class. - - Provides methods for indexing and garbage collecting trees which contain - DVC remotes. - """ - - def __init__(self, tree): - self.tree = tree - self.repo = tree.repo - - config = tree.config - url = config.get("url") - if self.scheme != "local" and url: - index_name = hashlib.sha256(url.encode("utf-8")).hexdigest() - self.index = RemoteIndex( - self.repo, index_name, dir_suffix=self.CHECKSUM_DIR_SUFFIX - ) - else: - self.index = RemoteIndexNoop() - - @property - def path_info(self): - return self.tree.path_info - - def __repr__(self): - return "{class_name}: '{path_info}'".format( - class_name=type(self).__name__, - path_info=self.path_info or "No path", - ) - - @property - def cache(self): - return getattr(self.repo.cache, self.scheme) - - @property - def scheme(self): - return self.tree.scheme - - def is_dir_checksum(self, checksum): - return self.tree.is_dir_checksum(checksum) - - def get_checksum(self, path_info, **kwargs): - return self.tree.get_checksum(path_info, **kwargs) - - def checksum_to_path_info(self, checksum): - return self.path_info / checksum[0:2] / checksum[2:] - - def path_to_checksum(self, path): - parts = self.tree.PATH_CLS(path).parts[-2:] - - if not (len(parts) == 2 and parts[0] and len(parts[0]) == 2): - raise ValueError(f"Bad cache file path '{path}'") - - return "".join(parts) - - def save_info(self, path_info, **kwargs): - return self.tree.save_info(path_info, **kwargs) - - def open(self, *args, **kwargs): - return self.tree.open(*args, **kwargs) - def list_paths(self, prefix=None, progress_callback=None): if prefix: if len(prefix) > 2: @@ -572,14 +520,14 @@ def list_paths(self, prefix=None, progress_callback=None): else: path_info = self.path_info if progress_callback: - for file_info in self.tree.walk_files(path_info): + for file_info in self.walk_files(path_info): progress_callback() yield file_info.path else: - yield from self.tree.walk_files(path_info) + yield from self.walk_files(path_info) def list_checksums(self, prefix=None, progress_callback=None): - """Iterate over remote checksums. + """Iterate over checksums in this tree. If `prefix` is specified, only checksums which begin with `prefix` will be returned. @@ -593,7 +541,7 @@ def list_checksums(self, prefix=None, progress_callback=None): ) def all(self, jobs=None, name=None): - """Iterate over all checksums in the remote. + """Iterate over all checksums in this tree. Checksums will be fetched in parallel threads according to prefix (except for small remotes) and a progress bar will be displayed. @@ -604,107 +552,14 @@ def all(self, jobs=None, name=None): ) ) - if not self.tree.CAN_TRAVERSE: + if not self.CAN_TRAVERSE: return self.list_checksums() - remote_size, remote_checksums = self._estimate_remote_size(name=name) - return self._list_checksums_traverse( + remote_size, remote_checksums = self.estimate_remote_size(name=name) + return self.list_checksums_traverse( remote_size, remote_checksums, jobs, name ) - def checksums_exist(self, checksums, jobs=None, name=None): - """Check if the given checksums are stored in the remote. - - There are two ways of performing this check: - - - Traverse method: Get a list of all the files in the remote - (traversing the cache directory) and compare it with - the given checksums. Cache entries will be retrieved in parallel - threads according to prefix (i.e. entries starting with, "00...", - "01...", and so on) and a progress bar will be displayed. - - - Exists method: For each given checksum, run the `exists` - method and filter the checksums that aren't on the remote. - This is done in parallel threads. - It also shows a progress bar when performing the check. - - The reason for such an odd logic is that most of the remotes - take much shorter time to just retrieve everything they have under - a certain prefix (e.g. s3, gs, ssh, hdfs). Other remotes that can - check if particular file exists much quicker, use their own - implementation of checksums_exist (see ssh, local). - - Which method to use will be automatically determined after estimating - the size of the remote cache, and comparing the estimated size with - len(checksums). To estimate the size of the remote cache, we fetch - a small subset of cache entries (i.e. entries starting with "00..."). - Based on the number of entries in that subset, the size of the full - cache can be estimated, since the cache is evenly distributed according - to checksum. - - Returns: - A list with checksums that were found in the remote - """ - # Remotes which do not use traverse prefix should override - # checksums_exist() (see ssh, local) - assert self.TRAVERSE_PREFIX_LEN >= 2 - - checksums = set(checksums) - indexed_checksums = set(self.index.intersection(checksums)) - checksums -= indexed_checksums - logger.debug( - "Matched '{}' indexed checksums".format(len(indexed_checksums)) - ) - if not checksums: - return indexed_checksums - - if len(checksums) == 1 or not self.tree.CAN_TRAVERSE: - remote_checksums = self._list_checksums_exists( - checksums, jobs, name - ) - return list(indexed_checksums) + remote_checksums - - # Max remote size allowed for us to use traverse method - remote_size, remote_checksums = self._estimate_remote_size( - checksums, name - ) - - traverse_pages = remote_size / self.LIST_OBJECT_PAGE_SIZE - # For sufficiently large remotes, traverse must be weighted to account - # for performance overhead from large lists/sets. - # From testing with S3, for remotes with 1M+ files, object_exists is - # faster until len(checksums) is at least 10k~100k - if remote_size > self.TRAVERSE_THRESHOLD_SIZE: - traverse_weight = traverse_pages * self.TRAVERSE_WEIGHT_MULTIPLIER - else: - traverse_weight = traverse_pages - if len(checksums) < traverse_weight: - logger.debug( - "Large remote ('{}' checksums < '{}' traverse weight), " - "using object_exists for remaining checksums".format( - len(checksums), traverse_weight - ) - ) - return ( - list(indexed_checksums) - + list(checksums & remote_checksums) - + self._list_checksums_exists( - checksums - remote_checksums, jobs, name - ) - ) - - logger.debug( - "Querying '{}' checksums via traverse".format(len(checksums)) - ) - remote_checksums = set( - self._list_checksums_traverse( - remote_size, remote_checksums, jobs, name - ) - ) - return list(indexed_checksums) + list( - checksums & set(remote_checksums) - ) - def _checksums_with_limit( self, limit, prefix=None, progress_callback=None ): @@ -728,8 +583,8 @@ def _max_estimation_size(self, checksums): * self.LIST_OBJECT_PAGE_SIZE, ) - def _estimate_remote_size(self, checksums=None, name=None): - """Estimate remote cache size based on number of entries beginning with + def estimate_remote_size(self, checksums=None, name=None): + """Estimate tree size based on number of entries beginning with "00..." prefix. """ prefix = "0" * self.TRAVERSE_PREFIX_LEN @@ -763,10 +618,10 @@ def update(n=1): logger.debug(f"Estimated remote size: {remote_size} files") return remote_size, remote_checksums - def _list_checksums_traverse( + def list_checksums_traverse( self, remote_size, remote_checksums, jobs=None, name=None ): - """Iterate over all checksums in the remote cache. + """Iterate over all checksums found in this tree. Checksums are fetched in parallel according to prefix, except in cases where the remote size is very small. @@ -818,7 +673,10 @@ def list_with_update(prefix): in_remote = executor.map(list_with_update, traverse_prefixes,) yield from itertools.chain.from_iterable(in_remote) - def _list_checksums_exists(self, checksums, jobs=None, name=None): + def list_checksums_exists(self, checksums, jobs=None, name=None): + """Return list of the specified checksums which exist in this tree. + Checksums will be queried individually. + """ logger.debug( "Querying {} checksums via object_exists".format(len(checksums)) ) @@ -830,7 +688,7 @@ def _list_checksums_exists(self, checksums, jobs=None, name=None): ) as pbar: def exists_with_progress(path_info): - ret = self.tree.exists(path_info) + ret = self.exists(path_info) pbar.update_msg(str(path_info)) return ret @@ -840,6 +698,159 @@ def exists_with_progress(path_info): ret = list(itertools.compress(checksums, in_remote)) return ret + +class Remote: + """Cloud remote class. + + Provides methods for indexing and garbage collecting trees which contain + DVC remotes. + """ + + def __init__(self, tree): + self.tree = tree + self.repo = tree.repo + + config = tree.config + url = config.get("url") + if self.scheme != "local" and url: + index_name = hashlib.sha256(url.encode("utf-8")).hexdigest() + self.index = RemoteIndex( + self.repo, index_name, dir_suffix=self.tree.CHECKSUM_DIR_SUFFIX + ) + else: + self.index = RemoteIndexNoop() + + @property + def path_info(self): + return self.tree.path_info + + def __repr__(self): + return "{class_name}: '{path_info}'".format( + class_name=type(self).__name__, + path_info=self.path_info or "No path", + ) + + @property + def cache(self): + return getattr(self.repo.cache, self.scheme) + + @property + def scheme(self): + return self.tree.scheme + + def is_dir_checksum(self, checksum): + return self.tree.is_dir_checksum(checksum) + + def get_checksum(self, path_info, **kwargs): + return self.tree.get_checksum(path_info, **kwargs) + + def checksum_to_path_info(self, checksum): + return self.tree.checksum_to_path_info(checksum) + + def path_to_checksum(self, path): + return self.tree.path_to_checksum(path) + + def save_info(self, path_info, **kwargs): + return self.tree.save_info(path_info, **kwargs) + + def open(self, *args, **kwargs): + return self.tree.open(*args, **kwargs) + + def checksums_exist(self, checksums, jobs=None, name=None): + """Check if the given checksums are stored in the remote. + + There are two ways of performing this check: + + - Traverse method: Get a list of all the files in the remote + (traversing the cache directory) and compare it with + the given checksums. Cache entries will be retrieved in parallel + threads according to prefix (i.e. entries starting with, "00...", + "01...", and so on) and a progress bar will be displayed. + + - Exists method: For each given checksum, run the `exists` + method and filter the checksums that aren't on the remote. + This is done in parallel threads. + It also shows a progress bar when performing the check. + + The reason for such an odd logic is that most of the remotes + take much shorter time to just retrieve everything they have under + a certain prefix (e.g. s3, gs, ssh, hdfs). Other remotes that can + check if particular file exists much quicker, use their own + implementation of checksums_exist (see ssh, local). + + Which method to use will be automatically determined after estimating + the size of the remote cache, and comparing the estimated size with + len(checksums). To estimate the size of the remote cache, we fetch + a small subset of cache entries (i.e. entries starting with "00..."). + Based on the number of entries in that subset, the size of the full + cache can be estimated, since the cache is evenly distributed according + to checksum. + + Returns: + A list with checksums that were found in the remote + """ + # Remotes which do not use traverse prefix should override + # checksums_exist() (see ssh, local) + assert self.tree.TRAVERSE_PREFIX_LEN >= 2 + + checksums = set(checksums) + indexed_checksums = set(self.index.intersection(checksums)) + checksums -= indexed_checksums + logger.debug( + "Matched '{}' indexed checksums".format(len(indexed_checksums)) + ) + if not checksums: + return indexed_checksums + + if len(checksums) == 1 or not self.tree.CAN_TRAVERSE: + remote_checksums = self.tree.list_checksums_exists( + checksums, jobs, name + ) + return list(indexed_checksums) + remote_checksums + + # Max remote size allowed for us to use traverse method + remote_size, remote_checksums = self.tree.estimate_remote_size( + checksums, name + ) + + traverse_pages = remote_size / self.tree.LIST_OBJECT_PAGE_SIZE + # For sufficiently large remotes, traverse must be weighted to account + # for performance overhead from large lists/sets. + # From testing with S3, for remotes with 1M+ files, object_exists is + # faster until len(checksums) is at least 10k~100k + if remote_size > self.tree.TRAVERSE_THRESHOLD_SIZE: + traverse_weight = ( + traverse_pages * self.tree.TRAVERSE_WEIGHT_MULTIPLIER + ) + else: + traverse_weight = traverse_pages + if len(checksums) < traverse_weight: + logger.debug( + "Large remote ('{}' checksums < '{}' traverse weight), " + "using object_exists for remaining checksums".format( + len(checksums), traverse_weight + ) + ) + return ( + list(indexed_checksums) + + list(checksums & remote_checksums) + + self.tree.list_checksums_exists( + checksums - remote_checksums, jobs, name + ) + ) + + logger.debug( + "Querying '{}' checksums via traverse".format(len(checksums)) + ) + remote_checksums = set( + self.tree.list_checksums_traverse( + remote_size, remote_checksums, jobs, name + ) + ) + return list(indexed_checksums) + list( + checksums & set(remote_checksums) + ) + @index_locked def gc(self, named_cache, jobs=None): used = set(named_cache.scheme_keys("local")) diff --git a/dvc/remote/local.py b/dvc/remote/local.py index 8256389bc2..58dda14cfb 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -296,6 +296,23 @@ def _download( from_info, to_file, no_progress_bar=no_progress_bar, name=name ) + def list_paths(self, prefix=None, progress_callback=None): + assert self.path_info is not None + if prefix: + path_info = self.path_info / prefix[:2] + if not self.tree.exists(path_info): + return + else: + path_info = self.path_info + # NOTE: use utils.fs walk_files since tree.walk_files will not follow + # symlinks + if progress_callback: + for path in walk_files(path_info): + progress_callback() + yield path + else: + yield from walk_files(path_info) + def _log_exceptions(func, operation): @wraps(func) @@ -318,21 +335,6 @@ def wrapper(from_info, to_info, *args, **kwargs): class LocalRemote(Remote): - def list_paths(self, prefix=None, progress_callback=None): - assert self.path_info is not None - if prefix: - path_info = self.path_info / prefix[:2] - if not self.tree.exists(path_info): - return - else: - path_info = self.path_info - if progress_callback: - for path in walk_files(path_info): - progress_callback() - yield path - else: - yield from walk_files(path_info) - def _remove_unpacked_dir(self, checksum): info = self.checksum_to_path_info(checksum) path_info = info.with_name(info.name + self.UNPACKED_DIR_SUFFIX) diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index 9ec5cdcbbf..45d871ccf3 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -268,8 +268,6 @@ def _upload(self, from_file, to_info, name=None, no_progress_bar=False): no_progress_bar=no_progress_bar, ) - -class SSHRemote(Remote): def list_paths(self, prefix=None, progress_callback=None): if prefix: root = posixpath.join(self.path_info.path, prefix[:2]) @@ -286,6 +284,8 @@ def list_paths(self, prefix=None, progress_callback=None): else: yield from ssh.walk_files(root) + +class SSHRemote(Remote): def batch_exists(self, path_infos, callback): def _exists(chunk_and_channel): chunk, channel = chunk_and_channel From 3a737d0ed22ed487fe7c9c347753f08d190192ea Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Mon, 15 Jun 2020 16:50:42 +0900 Subject: [PATCH 08/16] tests: update remote unit tests --- tests/remotes.py | 15 ++++--- tests/unit/remote/ssh/test_ssh.py | 44 +++++++++--------- tests/unit/remote/test_azure.py | 24 +++++----- tests/unit/remote/test_base.py | 64 ++++++++++++++------------- tests/unit/remote/test_gdrive.py | 14 +++--- tests/unit/remote/test_gs.py | 20 ++++----- tests/unit/remote/test_http.py | 30 ++++++------- tests/unit/remote/test_local.py | 38 ++++++++-------- tests/unit/remote/test_oss.py | 12 ++--- tests/unit/remote/test_remote.py | 28 ++++++------ tests/unit/remote/test_remote_tree.py | 4 +- tests/unit/remote/test_s3.py | 15 +++---- 12 files changed, 158 insertions(+), 150 deletions(-) diff --git a/tests/remotes.py b/tests/remotes.py index cff71a173f..c1000f5fac 100644 --- a/tests/remotes.py +++ b/tests/remotes.py @@ -7,9 +7,10 @@ from moto.s3 import mock_s3 -from dvc.remote.gdrive import GDriveRemote, GDriveRemoteTree -from dvc.remote.gs import GSRemote -from dvc.remote.s3 import S3Remote +from dvc.remote.base import Remote +from dvc.remote.gdrive import GDriveRemoteTree +from dvc.remote.gs import GSRemoteTree +from dvc.remote.s3 import S3RemoteTree from dvc.utils import env2bool from tests.basic_env import TestDvc @@ -79,7 +80,7 @@ class S3Mocked(S3): @contextmanager def remote(cls, repo): with mock_s3(): - yield S3Remote(repo, {"url": cls.get_url()}) + yield Remote(S3RemoteTree(repo, {"url": cls.get_url()})) @staticmethod def put_objects(remote, objects): @@ -127,7 +128,7 @@ def get_url(): @classmethod @contextmanager def remote(cls, repo): - yield GSRemote(repo, {"url": cls.get_url()}) + yield Remote(GSRemoteTree(repo, {"url": cls.get_url()})) @staticmethod def put_objects(remote, objects): @@ -150,8 +151,8 @@ def create_dir(dvc, url): "gdrive_service_account_p12_file_path": "test.p12", "gdrive_use_service_account": True, } - remote = GDriveRemote(dvc, config) - remote.tree._gdrive_create_dir("root", remote.path_info.path) + tree = GDriveRemoteTree(dvc, config) + tree._gdrive_create_dir("root", tree.path_info.path) @staticmethod def get_storagepath(): diff --git a/tests/unit/remote/ssh/test_ssh.py b/tests/unit/remote/ssh/test_ssh.py index 288a1d7b13..27a3fd92bc 100644 --- a/tests/unit/remote/ssh/test_ssh.py +++ b/tests/unit/remote/ssh/test_ssh.py @@ -5,7 +5,7 @@ import pytest from mock import mock_open, patch -from dvc.remote.ssh import SSHRemote, SSHRemoteTree +from dvc.remote.ssh import SSHRemoteTree from dvc.system import System from tests.remotes import SSHMocked @@ -20,21 +20,21 @@ def test_url(dvc): url = f"ssh://{user}@{host}:{port}{path}" config = {"url": url} - remote = SSHRemote(dvc, config) - assert remote.path_info == url + tree = SSHRemoteTree(dvc, config) + assert tree.path_info == url # SCP-like URL ssh://[user@]host.xz:/absolute/path url = f"ssh://{user}@{host}:{path}" config = {"url": url} - remote = SSHRemote(dvc, config) - assert remote.tree.path_info == url + tree = SSHRemoteTree(dvc, config) + assert tree.path_info == url def test_no_path(dvc): config = {"url": "ssh://127.0.0.1"} - remote = SSHRemote(dvc, config) - assert remote.tree.path_info.path == "" + tree = SSHRemoteTree(dvc, config) + assert tree.path_info.path == "" mock_ssh_config = """ @@ -67,11 +67,11 @@ def test_no_path(dvc): def test_ssh_host_override_from_config( mock_file, mock_exists, dvc, config, expected_host ): - remote = SSHRemote(dvc, config) + tree = SSHRemoteTree(dvc, config) mock_exists.assert_called_with(SSHRemoteTree.ssh_config_filename()) mock_file.assert_called_with(SSHRemoteTree.ssh_config_filename()) - assert remote.tree.path_info.host == expected_host + assert tree.path_info.host == expected_host @pytest.mark.parametrize( @@ -95,11 +95,11 @@ def test_ssh_host_override_from_config( read_data=mock_ssh_config, ) def test_ssh_user(mock_file, mock_exists, dvc, config, expected_user): - remote = SSHRemote(dvc, config) + tree = SSHRemoteTree(dvc, config) mock_exists.assert_called_with(SSHRemoteTree.ssh_config_filename()) mock_file.assert_called_with(SSHRemoteTree.ssh_config_filename()) - assert remote.tree.path_info.user == expected_user + assert tree.path_info.user == expected_user @pytest.mark.parametrize( @@ -120,11 +120,11 @@ def test_ssh_user(mock_file, mock_exists, dvc, config, expected_user): read_data=mock_ssh_config, ) def test_ssh_port(mock_file, mock_exists, dvc, config, expected_port): - remote = SSHRemote(dvc, config) + tree = SSHRemoteTree(dvc, config) mock_exists.assert_called_with(SSHRemoteTree.ssh_config_filename()) mock_file.assert_called_with(SSHRemoteTree.ssh_config_filename()) - assert remote.path_info.port == expected_port + assert tree.path_info.port == expected_port @pytest.mark.parametrize( @@ -155,11 +155,11 @@ def test_ssh_port(mock_file, mock_exists, dvc, config, expected_port): read_data=mock_ssh_config, ) def test_ssh_keyfile(mock_file, mock_exists, dvc, config, expected_keyfile): - remote = SSHRemote(dvc, config) + tree = SSHRemoteTree(dvc, config) mock_exists.assert_called_with(SSHRemoteTree.ssh_config_filename()) mock_file.assert_called_with(SSHRemoteTree.ssh_config_filename()) - assert remote.tree.keyfile == expected_keyfile + assert tree.keyfile == expected_keyfile @pytest.mark.parametrize( @@ -177,11 +177,11 @@ def test_ssh_keyfile(mock_file, mock_exists, dvc, config, expected_keyfile): read_data=mock_ssh_config, ) def test_ssh_gss_auth(mock_file, mock_exists, dvc, config, expected_gss_auth): - remote = SSHRemote(dvc, config) + tree = SSHRemoteTree(dvc, config) mock_exists.assert_called_with(SSHRemoteTree.ssh_config_filename()) mock_file.assert_called_with(SSHRemoteTree.ssh_config_filename()) - assert remote.tree.gss_auth == expected_gss_auth + assert tree.gss_auth == expected_gss_auth def test_hardlink_optimization(dvc, tmp_dir, ssh_server): @@ -194,12 +194,12 @@ def test_hardlink_optimization(dvc, tmp_dir, ssh_server): "user": user, "keyfile": ssh_server.test_creds["key_filename"], } - remote = SSHRemote(dvc, config) + tree = SSHRemoteTree(dvc, config) - from_info = remote.path_info / "empty" - to_info = remote.path_info / "link" + from_info = tree.path_info / "empty" + to_info = tree.path_info / "link" - with remote.open(from_info, "wb"): + with tree.open(from_info, "wb"): pass if os.name == "nt": @@ -207,5 +207,5 @@ def test_hardlink_optimization(dvc, tmp_dir, ssh_server): else: link_path = to_info.path - remote.tree.hardlink(from_info, to_info) + tree.hardlink(from_info, to_info) assert not System.is_hardlink(link_path) diff --git a/tests/unit/remote/test_azure.py b/tests/unit/remote/test_azure.py index c9009e9f71..0c6edaac06 100644 --- a/tests/unit/remote/test_azure.py +++ b/tests/unit/remote/test_azure.py @@ -1,7 +1,7 @@ import pytest from dvc.path_info import PathInfo -from dvc.remote.azure import AzureRemote +from dvc.remote.azure import AzureRemoteTree from tests.remotes import Azure container_name = "container-name" @@ -18,18 +18,18 @@ def test_init_env_var(monkeypatch, dvc): monkeypatch.setenv("AZURE_STORAGE_CONNECTION_STRING", connection_string) config = {"url": "azure://"} - remote = AzureRemote(dvc, config) - assert remote.tree.path_info == "azure://" + container_name - assert remote.tree.connection_string == connection_string + tree = AzureRemoteTree(dvc, config) + assert tree.path_info == "azure://" + container_name + assert tree.connection_string == connection_string def test_init(dvc): prefix = "some/prefix" url = f"azure://{container_name}/{prefix}" config = {"url": url, "connection_string": connection_string} - remote = AzureRemote(dvc, config) - assert remote.tree.path_info == url - assert remote.tree.connection_string == connection_string + tree = AzureRemoteTree(dvc, config) + assert tree.path_info == url + assert tree.connection_string == connection_string def test_get_file_checksum(tmp_dir): @@ -38,11 +38,11 @@ def test_get_file_checksum(tmp_dir): tmp_dir.gen("foo", "foo") - remote = AzureRemote(None, {}) - to_info = remote.tree.PATH_CLS(Azure.get_url()) - remote.tree.upload(PathInfo("foo"), to_info) - assert remote.tree.exists(to_info) - checksum = remote.tree.get_file_checksum(to_info) + tree = AzureRemoteTree(None, {}) + to_info = tree.PATH_CLS(Azure.get_url()) + tree.upload(PathInfo("foo"), to_info) + assert tree.exists(to_info) + checksum = tree.get_file_checksum(to_info) assert checksum assert isinstance(checksum, str) assert checksum.strip("'").strip('"') == checksum diff --git a/tests/unit/remote/test_base.py b/tests/unit/remote/test_base.py index 9bbabb901c..e560452a62 100644 --- a/tests/unit/remote/test_base.py +++ b/tests/unit/remote/test_base.py @@ -4,7 +4,12 @@ import pytest from dvc.path_info import PathInfo -from dvc.remote.base import BaseRemote, RemoteCmdError, RemoteMissingDepsError +from dvc.remote.base import ( + BaseRemoteTree, + Remote, + RemoteCmdError, + RemoteMissingDepsError, +) class _CallableOrNone: @@ -15,14 +20,13 @@ def __eq__(self, other): CallableOrNone = _CallableOrNone() -REMOTE_CLS = BaseRemote def test_missing_deps(dvc): requires = {"missing": "missing"} - with mock.patch.object(REMOTE_CLS, "REQUIRES", requires): + with mock.patch.object(BaseRemoteTree, "REQUIRES", requires): with pytest.raises(RemoteMissingDepsError): - REMOTE_CLS(dvc, {}) + BaseRemoteTree(dvc, {}) def test_cmd_error(dvc): @@ -33,44 +37,44 @@ def test_cmd_error(dvc): err = "sed: expression #1, char 2: extra characters after command" with mock.patch.object( - REMOTE_CLS.TREE_CLS, + BaseRemoteTree, "remove", side_effect=RemoteCmdError("base", cmd, ret, err), ): with pytest.raises(RemoteCmdError): - REMOTE_CLS(dvc, config).tree.remove("file") + BaseRemoteTree(dvc, config).remove("file") -@mock.patch.object(BaseRemote, "_list_checksums_traverse") -@mock.patch.object(BaseRemote, "_list_checksums_exists") +@mock.patch.object(BaseRemoteTree, "list_checksums_traverse") +@mock.patch.object(BaseRemoteTree, "list_checksums_exists") def test_checksums_exist(object_exists, traverse, dvc): - remote = BaseRemote(dvc, {}) + remote = Remote(BaseRemoteTree(dvc, {})) # remote does not support traverse - remote.CAN_TRAVERSE = False + remote.tree.CAN_TRAVERSE = False with mock.patch.object( - remote, "list_checksums", return_value=list(range(256)) + remote.tree, "list_checksums", return_value=list(range(256)) ): checksums = set(range(1000)) remote.checksums_exist(checksums) object_exists.assert_called_with(checksums, None, None) traverse.assert_not_called() - remote.CAN_TRAVERSE = True + remote.tree.CAN_TRAVERSE = True # large remote, small local object_exists.reset_mock() traverse.reset_mock() with mock.patch.object( - remote, "list_checksums", return_value=list(range(256)) + remote.tree, "list_checksums", return_value=list(range(256)) ): checksums = list(range(1000)) remote.checksums_exist(checksums) # verify that _cache_paths_with_max() short circuits # before returning all 256 remote checksums max_checksums = math.ceil( - remote._max_estimation_size(checksums) - / pow(16, remote.TRAVERSE_PREFIX_LEN) + remote.tree._max_estimation_size(checksums) + / pow(16, remote.tree.TRAVERSE_PREFIX_LEN) ) assert max_checksums < 256 object_exists.assert_called_with( @@ -81,15 +85,15 @@ def test_checksums_exist(object_exists, traverse, dvc): # large remote, large local object_exists.reset_mock() traverse.reset_mock() - remote.JOBS = 16 + remote.tree.JOBS = 16 with mock.patch.object( - remote, "list_checksums", return_value=list(range(256)) + remote.tree, "list_checksums", return_value=list(range(256)) ): checksums = list(range(1000000)) remote.checksums_exist(checksums) object_exists.assert_not_called() traverse.assert_called_with( - 256 * pow(16, remote.TRAVERSE_PREFIX_LEN), + 256 * pow(16, remote.tree.TRAVERSE_PREFIX_LEN), set(range(256)), None, None, @@ -97,18 +101,18 @@ def test_checksums_exist(object_exists, traverse, dvc): @mock.patch.object( - BaseRemote, "list_checksums", return_value=[], + BaseRemoteTree, "list_checksums", return_value=[], ) @mock.patch.object( - BaseRemote, "path_to_checksum", side_effect=lambda x: x, + BaseRemoteTree, "path_to_checksum", side_effect=lambda x: x, ) def test_list_checksums_traverse(path_to_checksum, list_checksums, dvc): - remote = BaseRemote(dvc, {}) - remote.tree.path_info = PathInfo("foo") + tree = BaseRemoteTree(dvc, {}) + tree.path_info = PathInfo("foo") # parallel traverse - size = 256 / remote.JOBS * remote.LIST_OBJECT_PAGE_SIZE - list(remote._list_checksums_traverse(size, {0})) + size = 256 / tree.JOBS * tree.LIST_OBJECT_PAGE_SIZE + list(tree.list_checksums_traverse(size, {0})) for i in range(1, 16): list_checksums.assert_any_call( prefix=f"{i:03x}", progress_callback=CallableOrNone @@ -121,20 +125,20 @@ def test_list_checksums_traverse(path_to_checksum, list_checksums, dvc): # default traverse (small remote) size -= 1 list_checksums.reset_mock() - list(remote._list_checksums_traverse(size - 1, {0})) + list(tree.list_checksums_traverse(size - 1, {0})) list_checksums.assert_called_with( prefix=None, progress_callback=CallableOrNone ) def test_list_checksums(dvc): - remote = BaseRemote(dvc, {}) - remote.tree.path_info = PathInfo("foo") + tree = BaseRemoteTree(dvc, {}) + tree.path_info = PathInfo("foo") with mock.patch.object( - remote, "list_paths", return_value=["12/3456", "bar"] + tree, "list_paths", return_value=["12/3456", "bar"] ): - checksums = list(remote.list_checksums()) + checksums = list(tree.list_checksums()) assert checksums == ["123456"] @@ -143,4 +147,4 @@ def test_list_checksums(dvc): [(None, False), ("", False), ("3456.dir", True), ("3456", False)], ) def test_is_dir_checksum(checksum, result): - assert BaseRemote.is_dir_checksum(checksum) == result + assert BaseRemoteTree.is_dir_checksum(checksum) == result diff --git a/tests/unit/remote/test_gdrive.py b/tests/unit/remote/test_gdrive.py index b8b5ed83e8..5574c72908 100644 --- a/tests/unit/remote/test_gdrive.py +++ b/tests/unit/remote/test_gdrive.py @@ -2,7 +2,7 @@ import pytest -from dvc.remote.gdrive import GDriveAuthError, GDriveRemote, GDriveRemoteTree +from dvc.remote.gdrive import GDriveAuthError, GDriveRemoteTree USER_CREDS_TOKEN_REFRESH_ERROR = '{"access_token": "", "client_id": "", "client_secret": "", "refresh_token": "", "token_expiry": "", "token_uri": "https://oauth2.googleapis.com/token", "user_agent": null, "revoke_uri": "https://oauth2.googleapis.com/revoke", "id_token": null, "id_token_jwt": null, "token_response": {"access_token": "", "expires_in": 3600, "scope": "https://www.googleapis.com/auth/drive.appdata https://www.googleapis.com/auth/drive", "token_type": "Bearer"}, "scopes": ["https://www.googleapis.com/auth/drive", "https://www.googleapis.com/auth/drive.appdata"], "token_info_uri": "https://oauth2.googleapis.com/tokeninfo", "invalid": true, "_class": "OAuth2Credentials", "_module": "oauth2client.client"}' # noqa: E501 @@ -17,21 +17,21 @@ class TestRemoteGDrive: } def test_init(self, dvc): - remote = GDriveRemote(dvc, self.CONFIG) - assert str(remote.tree.path_info) == self.CONFIG["url"] + tree = GDriveRemoteTree(dvc, self.CONFIG) + assert str(tree.path_info) == self.CONFIG["url"] def test_drive(self, dvc): - remote = GDriveRemote(dvc, self.CONFIG) + tree = GDriveRemoteTree(dvc, self.CONFIG) os.environ[ GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA ] = USER_CREDS_TOKEN_REFRESH_ERROR with pytest.raises(GDriveAuthError): - remote.tree._drive + tree._drive os.environ[GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA] = "" - remote = GDriveRemote(dvc, self.CONFIG) + tree = GDriveRemoteTree(dvc, self.CONFIG) os.environ[ GDriveRemoteTree.GDRIVE_CREDENTIALS_DATA ] = USER_CREDS_MISSED_KEY_ERROR with pytest.raises(GDriveAuthError): - remote.tree._drive + tree._drive diff --git a/tests/unit/remote/test_gs.py b/tests/unit/remote/test_gs.py index cee6a004b3..1953410caa 100644 --- a/tests/unit/remote/test_gs.py +++ b/tests/unit/remote/test_gs.py @@ -2,7 +2,7 @@ import pytest import requests -from dvc.remote.gs import GSRemote, dynamic_chunk_size +from dvc.remote.gs import GSRemoteTree, dynamic_chunk_size BUCKET = "bucket" PREFIX = "prefix" @@ -17,17 +17,17 @@ def test_init(dvc): - remote = GSRemote(dvc, CONFIG) - assert remote.tree.path_info == URL - assert remote.tree.projectname == PROJECT - assert remote.tree.credentialpath == CREDENTIALPATH + tree = GSRemoteTree(dvc, CONFIG) + assert tree.path_info == URL + assert tree.projectname == PROJECT + assert tree.credentialpath == CREDENTIALPATH @mock.patch("google.cloud.storage.Client.from_service_account_json") def test_gs(mock_client, dvc): - remote = GSRemote(dvc, CONFIG) - assert remote.tree.credentialpath - remote.tree.gs() + tree = GSRemoteTree(dvc, CONFIG) + assert tree.credentialpath + tree.gs() mock_client.assert_called_once_with(CREDENTIALPATH) @@ -35,8 +35,8 @@ def test_gs(mock_client, dvc): def test_gs_no_credspath(mock_client, dvc): config = CONFIG.copy() del config["credentialpath"] - remote = GSRemote(dvc, config) - remote.tree.gs() + tree = GSRemoteTree(dvc, config) + tree.gs() mock_client.assert_called_with(PROJECT) diff --git a/tests/unit/remote/test_http.py b/tests/unit/remote/test_http.py index b8f0d15d6c..16949c4849 100644 --- a/tests/unit/remote/test_http.py +++ b/tests/unit/remote/test_http.py @@ -2,7 +2,7 @@ from dvc.exceptions import HTTPError from dvc.path_info import URLInfo -from dvc.remote.http import HTTPRemote +from dvc.remote.http import HTTPRemoteTree from tests.utils.httpd import StaticFileServer @@ -11,10 +11,10 @@ def test_download_fails_on_error_code(dvc): url = f"http://localhost:{httpd.server_port}/" config = {"url": url} - remote = HTTPRemote(dvc, config) + tree = HTTPRemoteTree(dvc, config) with pytest.raises(HTTPError): - remote.tree._download(URLInfo(url) / "missing.txt", "missing.txt") + tree._download(URLInfo(url) / "missing.txt", "missing.txt") def test_public_auth_method(dvc): @@ -25,9 +25,9 @@ def test_public_auth_method(dvc): "password": "", } - remote = HTTPRemote(dvc, config) + tree = HTTPRemoteTree(dvc, config) - assert remote.tree._auth_method() is None + assert tree._auth_method() is None def test_basic_auth_method(dvc): @@ -44,10 +44,10 @@ def test_basic_auth_method(dvc): "password": password, } - remote = HTTPRemote(dvc, config) + tree = HTTPRemoteTree(dvc, config) - assert remote.tree._auth_method() == auth - assert isinstance(remote.tree._auth_method(), HTTPBasicAuth) + assert tree._auth_method() == auth + assert isinstance(tree._auth_method(), HTTPBasicAuth) def test_digest_auth_method(dvc): @@ -64,10 +64,10 @@ def test_digest_auth_method(dvc): "password": password, } - remote = HTTPRemote(dvc, config) + tree = HTTPRemoteTree(dvc, config) - assert remote.tree._auth_method() == auth - assert isinstance(remote.tree._auth_method(), HTTPDigestAuth) + assert tree._auth_method() == auth + assert isinstance(tree._auth_method(), HTTPDigestAuth) def test_custom_auth_method(dvc): @@ -81,8 +81,8 @@ def test_custom_auth_method(dvc): "password": password, } - remote = HTTPRemote(dvc, config) + tree = HTTPRemoteTree(dvc, config) - assert remote.tree._auth_method() is None - assert header in remote.tree.headers - assert remote.tree.headers[header] == password + assert tree._auth_method() is None + assert header in tree.headers + assert tree.headers[header] == password diff --git a/tests/unit/remote/test_local.py b/tests/unit/remote/test_local.py index fe2cb016d3..28d3716189 100644 --- a/tests/unit/remote/test_local.py +++ b/tests/unit/remote/test_local.py @@ -5,7 +5,8 @@ from dvc.cache import NamedCache from dvc.path_info import PathInfo -from dvc.remote.local import LocalCache +from dvc.remote.index import RemoteIndexNoop +from dvc.remote.local import LocalCache, LocalRemoteTree def test_status_download_optimization(mocker, dvc): @@ -13,7 +14,7 @@ def test_status_download_optimization(mocker, dvc): And the desired files to fetch are already on the local cache, Don't check the existence of the desired files on the remote cache """ - cache = LocalCache(dvc, {}) + cache = LocalCache(LocalRemoteTree(dvc, {})) infos = NamedCache() infos.add("local", "acbd18db4cc2f85cedef654fccc4a4d8", "foo") @@ -25,6 +26,7 @@ def test_status_download_optimization(mocker, dvc): other_remote = mocker.Mock() other_remote.url = "other_remote" other_remote.checksums_exist.return_value = [] + other_remote.index = RemoteIndexNoop() cache.status(infos, other_remote, download=True) @@ -33,8 +35,8 @@ def test_status_download_optimization(mocker, dvc): @pytest.mark.parametrize("link_name", ["hardlink", "symlink"]) def test_is_protected(tmp_dir, dvc, link_name): - cache = LocalCache(dvc, {}) - link_method = getattr(cache.tree, link_name) + tree = LocalRemoteTree(dvc, {}) + link_method = getattr(tree, link_name) (tmp_dir / "foo").write_text("foo") @@ -43,47 +45,47 @@ def test_is_protected(tmp_dir, dvc, link_name): link_method(foo, link) - assert not cache.is_protected(foo) - assert not cache.is_protected(link) + assert not tree.is_protected(foo) + assert not tree.is_protected(link) - cache.protect(foo) + tree.protect(foo) - assert cache.is_protected(foo) - assert cache.is_protected(link) + assert tree.is_protected(foo) + assert tree.is_protected(link) - cache.unprotect(link) + tree.unprotect(link) - assert not cache.is_protected(link) + assert not tree.is_protected(link) if os.name == "nt" and link_name == "hardlink": # NOTE: NTFS doesn't allow deleting read-only files, which forces us to # set write perms on the link, which propagates to the source. - assert not cache.is_protected(foo) + assert not tree.is_protected(foo) else: - assert cache.is_protected(foo) + assert tree.is_protected(foo) @pytest.mark.parametrize("err", [errno.EPERM, errno.EACCES]) def test_protect_ignore_errors(tmp_dir, mocker, err): tmp_dir.gen("foo", "foo") foo = PathInfo("foo") - cache = LocalCache(None, {}) + tree = LocalRemoteTree(None, {}) - cache.protect(foo) + tree.protect(foo) mock_chmod = mocker.patch( "os.chmod", side_effect=OSError(err, "something") ) - cache.protect(foo) + tree.protect(foo) assert mock_chmod.called def test_protect_ignore_erofs(tmp_dir, mocker): tmp_dir.gen("foo", "foo") foo = PathInfo("foo") - cache = LocalCache(None, {}) + tree = LocalRemoteTree(None, {}) mock_chmod = mocker.patch( "os.chmod", side_effect=OSError(errno.EROFS, "read-only fs") ) - cache.protect(foo) + tree.protect(foo) assert mock_chmod.called diff --git a/tests/unit/remote/test_oss.py b/tests/unit/remote/test_oss.py index 3bffe14a43..c88804876d 100644 --- a/tests/unit/remote/test_oss.py +++ b/tests/unit/remote/test_oss.py @@ -1,4 +1,4 @@ -from dvc.remote.oss import OSSRemote +from dvc.remote.oss import OSSRemoteTree bucket_name = "bucket-name" endpoint = "endpoint" @@ -15,8 +15,8 @@ def test_init(dvc): "oss_key_secret": key_secret, "oss_endpoint": endpoint, } - remote = OSSRemote(dvc, config) - assert remote.tree.path_info == url - assert remote.tree.endpoint == endpoint - assert remote.tree.key_id == key_id - assert remote.tree.key_secret == key_secret + tree = OSSRemoteTree(dvc, config) + assert tree.path_info == url + assert tree.endpoint == endpoint + assert tree.key_id == key_id + assert tree.key_secret == key_secret diff --git a/tests/unit/remote/test_remote.py b/tests/unit/remote/test_remote.py index 34dc6948c6..a748bde825 100644 --- a/tests/unit/remote/test_remote.py +++ b/tests/unit/remote/test_remote.py @@ -1,6 +1,8 @@ import pytest -from dvc.remote import GSRemote, Remote, S3Remote +from dvc.remote import get_cloud_tree +from dvc.remote.gs import GSRemoteTree +from dvc.remote.s3 import S3RemoteTree def test_remote_with_checksum_jobs(dvc): @@ -10,32 +12,32 @@ def test_remote_with_checksum_jobs(dvc): } dvc.config["core"]["checksum_jobs"] = 200 - remote = Remote(dvc, name="with_checksum_jobs") - assert remote.checksum_jobs == 100 + tree = get_cloud_tree(dvc, name="with_checksum_jobs") + assert tree.checksum_jobs == 100 def test_remote_without_checksum_jobs(dvc): dvc.config["remote"]["without_checksum_jobs"] = {"url": "s3://bucket/name"} dvc.config["core"]["checksum_jobs"] = 200 - remote = Remote(dvc, name="without_checksum_jobs") - assert remote.checksum_jobs == 200 + tree = get_cloud_tree(dvc, name="without_checksum_jobs") + assert tree.checksum_jobs == 200 def test_remote_without_checksum_jobs_default(dvc): dvc.config["remote"]["without_checksum_jobs"] = {"url": "s3://bucket/name"} - remote = Remote(dvc, name="without_checksum_jobs") - assert remote.checksum_jobs == remote.CHECKSUM_JOBS + tree = get_cloud_tree(dvc, name="without_checksum_jobs") + assert tree.checksum_jobs == tree.CHECKSUM_JOBS -@pytest.mark.parametrize("remote_cls", [GSRemote, S3Remote]) -def test_makedirs_not_create_for_top_level_path(remote_cls, dvc, mocker): - url = f"{remote_cls.scheme}://bucket/" - remote = remote_cls(dvc, {"url": url}) +@pytest.mark.parametrize("tree_cls", [GSRemoteTree, S3RemoteTree]) +def test_makedirs_not_create_for_top_level_path(tree_cls, dvc, mocker): + url = f"{tree_cls.scheme}://bucket/" + tree = tree_cls(dvc, {"url": url}) mocked_client = mocker.PropertyMock() # we use remote clients with same name as scheme to interact with remote - mocker.patch.object(remote_cls.TREE_CLS, remote.scheme, mocked_client) + mocker.patch.object(tree_cls, tree.scheme, mocked_client) - remote.tree.makedirs(remote.path_info) + tree.makedirs(tree.path_info) assert not mocked_client.called diff --git a/tests/unit/remote/test_remote_tree.py b/tests/unit/remote/test_remote_tree.py index e9b8dd213e..a52416ad13 100644 --- a/tests/unit/remote/test_remote_tree.py +++ b/tests/unit/remote/test_remote_tree.py @@ -3,7 +3,7 @@ import pytest from dvc.path_info import PathInfo -from dvc.remote.s3 import S3Remote, S3RemoteTree +from dvc.remote.s3 import S3RemoteTree from dvc.utils.fs import walk_files from tests.remotes import GCP, S3Mocked @@ -91,7 +91,7 @@ def test_copy_preserve_etag_across_buckets(remote, dvc): s3 = remote.tree.s3 s3.create_bucket(Bucket="another") - another = S3Remote(dvc, {"url": "s3://another", "region": "us-east-1"}) + another = S3RemoteTree(dvc, {"url": "s3://another", "region": "us-east-1"}) from_info = remote.path_info / "foo" to_info = another.path_info / "foo" diff --git a/tests/unit/remote/test_s3.py b/tests/unit/remote/test_s3.py index d61c0d735d..4623bd4144 100644 --- a/tests/unit/remote/test_s3.py +++ b/tests/unit/remote/test_s3.py @@ -1,7 +1,7 @@ import pytest from dvc.config import ConfigError -from dvc.remote.s3 import S3Remote +from dvc.remote.s3 import S3RemoteTree bucket_name = "bucket-name" prefix = "some/prefix" @@ -20,9 +20,9 @@ def grants(): def test_init(dvc): config = {"url": url} - remote = S3Remote(dvc, config) + tree = S3RemoteTree(dvc, config) - assert remote.tree.path_info == url + assert tree.path_info == url def test_grants(dvc): @@ -33,8 +33,7 @@ def test_grants(dvc): "grant_write_acp": "id=write-acp-permission-id", "grant_full_control": "id=full-control-permission-id", } - remote = S3Remote(dvc, config) - tree = remote.tree + tree = S3RemoteTree(dvc, config) assert ( tree.extra_args["GrantRead"] @@ -52,9 +51,9 @@ def test_grants_mutually_exclusive_acl_error(dvc, grants): config = {"url": url, "acl": "public-read", grant_option: grant_value} with pytest.raises(ConfigError): - S3Remote(dvc, config) + S3RemoteTree(dvc, config) def test_sse_kms_key_id(dvc): - remote = S3Remote(dvc, {"url": url, "sse_kms_key_id": "key"}) - assert remote.tree.extra_args["SSEKMSKeyId"] == "key" + tree = S3RemoteTree(dvc, {"url": url, "sse_kms_key_id": "key"}) + assert tree.extra_args["SSEKMSKeyId"] == "key" From f5d7f0a9011443fefccbd8e34cb4e1bee1d5b9b8 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Mon, 15 Jun 2020 16:51:29 +0900 Subject: [PATCH 09/16] remote: LocalCache needs to lock remote index, not cache index --- dvc/remote/local.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/dvc/remote/local.py b/dvc/remote/local.py index 58dda14cfb..0eb52fd0f8 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -19,7 +19,6 @@ BaseRemoteTree, CloudCache, Remote, - index_locked, ) from dvc.scheme import Schemes from dvc.scm.tree import WorkingTree, is_working_tree @@ -222,7 +221,7 @@ def _unprotect_file(self, path): "a symlink or a hardlink.".format(path) ) - os.chmod(path, self.tree.file_mode) + os.chmod(path, self.file_mode) def _unprotect_dir(self, path): assert is_working_tree(self.repo.tree) @@ -341,6 +340,15 @@ def _remove_unpacked_dir(self, checksum): self.tree.remove(path_info) +def sync_index_locked(f): + @wraps(f) + def wrapper(cache_obj, named_cache, remote, *args, **kwargs): + with remote.index: + return f(cache_obj, named_cache, remote, *args, **kwargs) + + return wrapper + + class LocalCache(CloudCache): def __init__(self, tree): super().__init__(tree) @@ -398,7 +406,7 @@ def _verify_link(self, path_info, link_type): super()._verify_link(path_info, link_type) - @index_locked + @sync_index_locked def status( self, named_cache, @@ -686,7 +694,7 @@ def _dir_upload(func, futures, from_info, to_info, name): return 1 return func(from_info, to_info, name) - @index_locked + @sync_index_locked def push(self, named_cache, remote, jobs=None, show_checksums=False): return self._process( named_cache, @@ -696,7 +704,7 @@ def push(self, named_cache, remote, jobs=None, show_checksums=False): download=False, ) - @index_locked + @sync_index_locked def pull(self, named_cache, remote, jobs=None, show_checksums=False): return self._process( named_cache, From 13a22b0332c9db931c60ba78f4572828159fa0d7 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Mon, 15 Jun 2020 17:17:10 +0900 Subject: [PATCH 10/16] fix bugs from moved PARAM_CHECKSUM/PARAM_RELPATH --- dvc/remote/base.py | 39 +++++++++++++++++++++++++-------------- dvc/repo/tree.py | 2 +- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/dvc/remote/base.py b/dvc/remote/base.py index f919381793..851961d9c5 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -905,6 +905,13 @@ def cache(self): def scheme(self): return self.tree.scheme + @property + def state(self): + return self.tree.state + + def open(self, *args, **kwargs): + return self.tree.open(*args, **kwargs) + def is_dir_checksum(self, checksum): return self.tree.is_dir_checksum(checksum) @@ -916,6 +923,9 @@ def get_checksum(self, path_info, **kwargs): def checksum_to_path(self, checksum): return self.checksum_to_path_info(checksum) + def checksum_to_path_info(self, checksum): + return self.tree.checksum_to_path_info(checksum) + def get_dir_cache(self, checksum): assert checksum @@ -951,7 +961,8 @@ def load_dir_cache(self, checksum): # only need to convert it for Windows for info in d: # NOTE: here is a BUG, see comment to .as_posix() below - info[self.PARAM_RELPATH] = info[self.PARAM_RELPATH].replace( + relpath = info[self.tree.PARAM_RELPATH] + info[self.tree.PARAM_RELPATH] = relpath.replace( "/", self.tree.PATH_CLS.sep ) @@ -1091,7 +1102,7 @@ def _save_file(self, path_info, tree, checksum, save_link=True, **kwargs): callback(1) self.state.save(cache_info, checksum) - return {self.PARAM_CHECKSUM: checksum} + return {self.tree.PARAM_CHECKSUM: checksum} def _cache_is_copy(self, path_info): """Checks whether cache uses copies.""" @@ -1120,8 +1131,8 @@ def _save_dir(self, path_info, tree, checksum, save_link=True, **kwargs): for entry in Tqdm( dir_info, desc="Saving " + path_info.name, unit="file" ): - entry_info = path_info / entry[self.PARAM_RELPATH] - entry_checksum = entry[self.PARAM_CHECKSUM] + entry_info = path_info / entry[self.tree.PARAM_RELPATH] + entry_checksum = entry[self.tree.PARAM_CHECKSUM] self._save_file( entry_info, tree, entry_checksum, save_link=False, **kwargs ) @@ -1133,7 +1144,7 @@ def _save_dir(self, path_info, tree, checksum, save_link=True, **kwargs): cache_info = self.checksum_to_path_info(checksum) self.state.save(cache_info, checksum) - return {self.PARAM_CHECKSUM: checksum} + return {self.tree.PARAM_CHECKSUM: checksum} def save(self, path_info, tree, checksum_info, save_link=True, **kwargs): if path_info.scheme != self.scheme: @@ -1143,7 +1154,7 @@ def save(self, path_info, tree, checksum_info, save_link=True, **kwargs): if not checksum_info: checksum_info = self.save_info(path_info, tree=tree, **kwargs) - checksum = checksum_info[self.PARAM_CHECKSUM] + checksum = checksum_info[self.tree.PARAM_CHECKSUM] return self._save(path_info, tree, checksum, save_link, **kwargs) def _save(self, path_info, tree, checksum, save_link=True, **kwargs): @@ -1205,10 +1216,10 @@ def _changed_dir_cache(self, checksum, path_info=None, filter_info=None): return True for entry in self.get_dir_cache(checksum): - entry_checksum = entry[self.PARAM_CHECKSUM] + entry_checksum = entry[self.tree.PARAM_CHECKSUM] if path_info and filter_info: - entry_info = path_info / entry[self.PARAM_RELPATH] + entry_info = path_info / entry[self.tree.PARAM_RELPATH] if not entry_info.isin_or_eq(filter_info): continue @@ -1287,15 +1298,15 @@ def _checkout_dir( logger.debug("Linking directory '%s'.", path_info) for entry in dir_info: - relative_path = entry[self.PARAM_RELPATH] - entry_checksum = entry[self.PARAM_CHECKSUM] + relative_path = entry[self.tree.PARAM_RELPATH] + entry_checksum = entry[self.tree.PARAM_CHECKSUM] entry_cache_info = self.checksum_to_path_info(entry_checksum) entry_info = path_info / relative_path if filter_info and not entry_info.isin_or_eq(filter_info): continue - entry_checksum_info = {self.PARAM_CHECKSUM: entry_checksum} + entry_checksum_info = {self.tree.PARAM_CHECKSUM: entry_checksum} if relink or self.changed(entry_info, entry_checksum_info): modified = True self.safe_remove(entry_info, force=force) @@ -1319,7 +1330,7 @@ def _remove_redundant_files(self, path_info, dir_info, force): existing_files = set(self.tree.walk_files(path_info)) needed_files = { - path_info / entry[self.PARAM_RELPATH] for entry in dir_info + path_info / entry[self.tree.PARAM_RELPATH] for entry in dir_info } redundant_files = existing_files - needed_files for path in redundant_files: @@ -1339,7 +1350,7 @@ def checkout( if path_info.scheme not in ["local", self.scheme]: raise NotImplementedError - checksum = checksum_info.get(self.PARAM_CHECKSUM) + checksum = checksum_info.get(self.tree.PARAM_CHECKSUM) failed = None skip = False if not checksum: @@ -1414,6 +1425,6 @@ def get_files_number(self, path_info, checksum, filter_info): return len(self.get_dir_cache(checksum)) return ilen( - filter_info.isin_or_eq(path_info / entry[self.PARAM_CHECKSUM]) + filter_info.isin_or_eq(path_info / entry[self.tree.PARAM_CHECKSUM]) for entry in self.get_dir_cache(checksum) ) diff --git a/dvc/repo/tree.py b/dvc/repo/tree.py index 7b6dd29a3f..77867808c8 100644 --- a/dvc/repo/tree.py +++ b/dvc/repo/tree.py @@ -194,7 +194,7 @@ def walk(self, top, topdown=True, download_callback=None, **kwargs): download_callback(downloaded) for entry in dir_cache: - entry_relpath = entry[out.remote.PARAM_RELPATH] + entry_relpath = entry[out.remote.tree.PARAM_RELPATH] if os.name == "nt": entry_relpath = entry_relpath.replace("/", os.sep) path_info = out.path_info / entry_relpath From fe1a6ba777d3c7d61135315819079d62e4042509 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Mon, 15 Jun 2020 17:17:43 +0900 Subject: [PATCH 11/16] tests: repo tree open needs state context --- tests/unit/repo/test_repo_tree.py | 5 +++-- tests/unit/repo/test_tree.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/unit/repo/test_repo_tree.py b/tests/unit/repo/test_repo_tree.py index 9b00d3641c..cef3628265 100644 --- a/tests/unit/repo/test_repo_tree.py +++ b/tests/unit/repo/test_repo_tree.py @@ -22,8 +22,9 @@ def test_open(tmp_dir, dvc): (tmp_dir / "foo").unlink() tree = RepoTree(dvc) - with tree.open("foo", "r") as fobj: - assert fobj.read() == "foo" + with dvc.state: + with tree.open("foo", "r") as fobj: + assert fobj.read() == "foo" def test_open_dirty_hash(tmp_dir, dvc): diff --git a/tests/unit/repo/test_tree.py b/tests/unit/repo/test_tree.py index 50191224f0..8af75ebf67 100644 --- a/tests/unit/repo/test_tree.py +++ b/tests/unit/repo/test_tree.py @@ -22,8 +22,9 @@ def test_open(tmp_dir, dvc): (tmp_dir / "foo").unlink() tree = DvcTree(dvc) - with tree.open("foo", "r") as fobj: - assert fobj.read() == "foo" + with dvc.state: + with tree.open("foo", "r") as fobj: + assert fobj.read() == "foo" def test_open_dirty_hash(tmp_dir, dvc): From 143628d6e537c322103edd1530184ee97105e52c Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Mon, 15 Jun 2020 17:52:18 +0900 Subject: [PATCH 12/16] remote: make Remote.gc a class method * operates on obj.tree, so that GC can be done with both remote and cache instances --- dvc/remote/base.py | 46 +++++++++++++++++++-------------- dvc/remote/local.py | 37 ++++++++++++-------------- dvc/repo/gc.py | 21 ++++++++------- tests/func/remote/test_index.py | 6 ++--- 4 files changed, 58 insertions(+), 52 deletions(-) diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 851961d9c5..6eb5e29684 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -73,10 +73,11 @@ def __init__(self, checksum): def index_locked(f): @wraps(f) - def wrapper(remote_obj, *args, **kwargs): - remote = kwargs.get("remote", remote_obj) - with remote.index: - return f(remote_obj, *args, **kwargs) + def wrapper(obj, named_cache, remote, *args, **kwargs): + if hasattr(remote, "index"): + with remote.index: + return f(obj, named_cache, remote, *args, **kwargs) + return f(obj, named_cache, remote, *args, **kwargs) return wrapper @@ -698,6 +699,9 @@ def exists_with_progress(path_info): ret = list(itertools.compress(checksums, in_remote)) return ret + def _remove_unpacked_dir(self, checksum): + pass + class Remote: """Cloud remote class. @@ -706,15 +710,17 @@ class Remote: DVC remotes. """ + INDEX_CLS = RemoteIndex + def __init__(self, tree): self.tree = tree self.repo = tree.repo config = tree.config url = config.get("url") - if self.scheme != "local" and url: + if url: index_name = hashlib.sha256(url.encode("utf-8")).hexdigest() - self.index = RemoteIndex( + self.index = self.INDEX_CLS( self.repo, index_name, dir_suffix=self.tree.CHECKSUM_DIR_SUFFIX ) else: @@ -851,34 +857,34 @@ def checksums_exist(self, checksums, jobs=None, name=None): checksums & set(remote_checksums) ) + @classmethod @index_locked - def gc(self, named_cache, jobs=None): + def gc(cls, named_cache, remote, jobs=None): + tree = remote.tree used = set(named_cache.scheme_keys("local")) - if self.scheme != "": - used.update(named_cache.scheme_keys(self.scheme)) + if tree.scheme != "": + used.update(named_cache.scheme_keys(tree.scheme)) removed = False # checksums must be sorted to ensure we always remove .dir files first for checksum in sorted( - self.all(jobs, str(self.path_info)), - key=self.is_dir_checksum, + tree.all(jobs, str(tree.path_info)), + key=tree.is_dir_checksum, reverse=True, ): if checksum in used: continue - path_info = self.checksum_to_path_info(checksum) - if self.is_dir_checksum(checksum): + path_info = tree.checksum_to_path_info(checksum) + if tree.is_dir_checksum(checksum): # backward compatibility - self._remove_unpacked_dir(checksum) - self.tree.remove(path_info) + tree._remove_unpacked_dir(checksum) + tree.remove(path_info) removed = True - if removed: - self.index.clear() - return removed - def _remove_unpacked_dir(self, checksum): - pass + if removed and hasattr(remote, "index"): + remote.index.clear() + return removed class CloudCache: diff --git a/dvc/remote/local.py b/dvc/remote/local.py index 0eb52fd0f8..e2f0cc3a94 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -19,7 +19,9 @@ BaseRemoteTree, CloudCache, Remote, + index_locked, ) +from dvc.remote.index import RemoteIndexNoop from dvc.scheme import Schemes from dvc.scm.tree import WorkingTree, is_working_tree from dvc.system import System @@ -299,7 +301,7 @@ def list_paths(self, prefix=None, progress_callback=None): assert self.path_info is not None if prefix: path_info = self.path_info / prefix[:2] - if not self.tree.exists(path_info): + if not self.exists(path_info): return else: path_info = self.path_info @@ -312,6 +314,11 @@ def list_paths(self, prefix=None, progress_callback=None): else: yield from walk_files(path_info) + def _remove_unpacked_dir(self, checksum): + info = self.checksum_to_path_info(checksum) + path_info = info.with_name(info.name + self.UNPACKED_DIR_SUFFIX) + self.remove(path_info) + def _log_exceptions(func, operation): @wraps(func) @@ -334,19 +341,7 @@ def wrapper(from_info, to_info, *args, **kwargs): class LocalRemote(Remote): - def _remove_unpacked_dir(self, checksum): - info = self.checksum_to_path_info(checksum) - path_info = info.with_name(info.name + self.UNPACKED_DIR_SUFFIX) - self.tree.remove(path_info) - - -def sync_index_locked(f): - @wraps(f) - def wrapper(cache_obj, named_cache, remote, *args, **kwargs): - with remote.index: - return f(cache_obj, named_cache, remote, *args, **kwargs) - - return wrapper + INDEX_CLS = RemoteIndexNoop class LocalCache(CloudCache): @@ -406,7 +401,7 @@ def _verify_link(self, path_info, link_type): super()._verify_link(path_info, link_type) - @sync_index_locked + @index_locked def status( self, named_cache, @@ -508,7 +503,7 @@ def _indexed_dir_checksums(self, named_cache, remote, dir_md5s): indexed_dir_exists = set() if indexed_dirs: indexed_dir_exists.update( - remote._list_checksums_exists(indexed_dirs) + remote.tree.list_checksums_exists(indexed_dirs) ) missing_dirs = indexed_dirs.difference(indexed_dir_exists) if missing_dirs: @@ -520,7 +515,9 @@ def _indexed_dir_checksums(self, named_cache, remote, dir_md5s): # Check if non-indexed (new) dir checksums exist on remote dir_exists = dir_md5s.intersection(indexed_dir_exists) - dir_exists.update(remote._list_checksums_exists(dir_md5s - dir_exists)) + dir_exists.update( + remote.tree.list_checksums_exists(dir_md5s - dir_exists) + ) # If .dir checksum exists on the remote, assume directory contents # still exists on the remote @@ -600,7 +597,7 @@ def _process( desc = "Uploading" if jobs is None: - jobs = remote.JOBS + jobs = remote.tree.JOBS dir_status, file_status, dir_contents = self._status( named_cache, @@ -694,7 +691,7 @@ def _dir_upload(func, futures, from_info, to_info, name): return 1 return func(from_info, to_info, name) - @sync_index_locked + @index_locked def push(self, named_cache, remote, jobs=None, show_checksums=False): return self._process( named_cache, @@ -704,7 +701,7 @@ def push(self, named_cache, remote, jobs=None, show_checksums=False): download=False, ) - @sync_index_locked + @index_locked def pull(self, named_cache, remote, jobs=None, show_checksums=False): return self._process( named_cache, diff --git a/dvc/repo/gc.py b/dvc/repo/gc.py index 086d248ab0..96ca40aece 100644 --- a/dvc/repo/gc.py +++ b/dvc/repo/gc.py @@ -8,8 +8,10 @@ logger = logging.getLogger(__name__) -def _do_gc(typ, func, clist, jobs=None): - removed = func(clist, jobs=jobs) +def _do_gc(typ, remote, clist, jobs=None): + from dvc.remote.base import Remote + + removed = Remote.gc(clist, remote, jobs=jobs) if not removed: logger.info(f"No unused '{typ}' cache to remove.") @@ -74,22 +76,23 @@ def gc( ) ) - _do_gc("local", self.cache.local.gc, used, jobs) + # treat caches as remotes for garbage collection + _do_gc("local", self.cache.local, used, jobs) if self.cache.s3: - _do_gc("s3", self.cache.s3.gc, used, jobs) + _do_gc("s3", self.cache.s3, used, jobs) if self.cache.gs: - _do_gc("gs", self.cache.gs.gc, used, jobs) + _do_gc("gs", self.cache.gs, used, jobs) if self.cache.ssh: - _do_gc("ssh", self.cache.ssh.gc, used, jobs) + _do_gc("ssh", self.cache.ssh, used, jobs) if self.cache.hdfs: - _do_gc("hdfs", self.cache.hdfs.gc, used, jobs) + _do_gc("hdfs", self.cache.hdfs, used, jobs) if self.cache.azure: - _do_gc("azure", self.cache.azure.gc, used, jobs) + _do_gc("azure", self.cache.azure, used, jobs) if cloud: - _do_gc("remote", self.cloud.get_remote(remote, "gc -c").gc, used, jobs) + _do_gc("remote", self.cloud.get_remote(remote, "gc -c"), used, jobs) diff --git a/tests/func/remote/test_index.py b/tests/func/remote/test_index.py index 517d1ec9ed..9df9cf7e9c 100644 --- a/tests/func/remote/test_index.py +++ b/tests/func/remote/test_index.py @@ -3,7 +3,7 @@ import pytest from dvc.exceptions import DownloadError, UploadError -from dvc.remote.base import BaseRemote +from dvc.remote.base import Remote from dvc.remote.index import RemoteIndex from dvc.remote.local import LocalRemote, LocalRemoteTree from dvc.utils.fs import remove @@ -16,9 +16,9 @@ def remote(tmp_dir, dvc, tmp_path_factory, mocker): dvc.config["core"]["remote"] = "upstream" # patch checksums_exist since the LocalRemote normally overrides - # BaseRemote.checksums_exist. + # BaseRemoteTree.checksums_exist. def checksums_exist(self, *args, **kwargs): - return BaseRemote.checksums_exist(self, *args, **kwargs) + return Remote.checksums_exist(self, *args, **kwargs) mocker.patch.object(LocalRemote, "checksums_exist", checksums_exist) From a7109e356a66fecefc40c1480bfbb281659977fc Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Mon, 15 Jun 2020 19:07:33 +0900 Subject: [PATCH 13/16] bug fixes --- dvc/data_cloud.py | 2 +- dvc/remote/base.py | 8 ++++++-- dvc/remote/ssh/__init__.py | 12 +++++++----- dvc/repo/tree.py | 4 ++-- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/dvc/data_cloud.py b/dvc/data_cloud.py index 6e6503276e..f6afd28b51 100644 --- a/dvc/data_cloud.py +++ b/dvc/data_cloud.py @@ -85,7 +85,7 @@ def pull( cache, jobs=jobs, remote=remote, show_checksums=show_checksums ) - if not remote.verify: + if not remote.tree.verify: self._save_pulled_checksums(cache) return downloaded_items_num diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 6eb5e29684..448ba36a26 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -903,6 +903,10 @@ def __init__(self, tree): self.cache_type_confirmed = False self._dir_info = {} + @property + def path_info(self): + return self.tree.path_info + @property def cache(self): return getattr(self.repo.cache, self.scheme) @@ -999,7 +1003,7 @@ def changed(self, path_info, checksum_info): logger.debug("'%s' doesn't exist.", path_info) return True - checksum = checksum_info.get(self.PARAM_CHECKSUM) + checksum = checksum_info.get(self.tree.PARAM_CHECKSUM) if checksum is None: logger.debug("hash value for '%s' is missing.", path_info) return True @@ -1159,7 +1163,7 @@ def save(self, path_info, tree, checksum_info, save_link=True, **kwargs): ) if not checksum_info: - checksum_info = self.save_info(path_info, tree=tree, **kwargs) + checksum_info = self.tree.save_info(path_info, tree=tree, **kwargs) checksum = checksum_info[self.tree.PARAM_CHECKSUM] return self._save(path_info, tree, checksum, save_link, **kwargs) diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index 45d871ccf3..db355c6fd1 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -273,7 +273,7 @@ def list_paths(self, prefix=None, progress_callback=None): root = posixpath.join(self.path_info.path, prefix[:2]) else: root = self.path_info.path - with self.tree.ssh(self.path_info) as ssh: + with self.ssh(self.path_info) as ssh: if prefix and not ssh.exists(root): return # If we simply return an iterator then with above closes instantly @@ -320,8 +320,8 @@ def checksums_exist(self, checksums, jobs=None, name=None): faster than current approach (relying on exists(path_info)) applied in remote/base. """ - if not self.CAN_TRAVERSE: - return list(set(checksums) & set(self.all())) + if not self.tree.CAN_TRAVERSE: + return list(set(checksums) & set(self.tree.all())) # possibly prompt for credentials before "Querying" progress output self.tree.ensure_credentials() @@ -336,9 +336,11 @@ def checksums_exist(self, checksums, jobs=None, name=None): def exists_with_progress(chunks): return self.batch_exists(chunks, callback=pbar.update_msg) - with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor: + with ThreadPoolExecutor( + max_workers=jobs or self.tree.JOBS + ) as executor: path_infos = [self.checksum_to_path_info(x) for x in checksums] - chunks = to_chunks(path_infos, num_chunks=self.JOBS) + chunks = to_chunks(path_infos, num_chunks=self.tree.JOBS) results = executor.map(exists_with_progress, chunks) in_remote = itertools.chain.from_iterable(results) ret = list(itertools.compress(checksums, in_remote)) diff --git a/dvc/repo/tree.py b/dvc/repo/tree.py index 77867808c8..1c83155b1a 100644 --- a/dvc/repo/tree.py +++ b/dvc/repo/tree.py @@ -48,11 +48,11 @@ def _get_granular_checksum(self, path, out, remote=None): raise FileNotFoundError dir_cache = out.get_dir_cache(remote=remote) for entry in dir_cache: - entry_relpath = entry[out.remote.PARAM_RELPATH] + entry_relpath = entry[out.remote.tree.PARAM_RELPATH] if os.name == "nt": entry_relpath = entry_relpath.replace("/", os.sep) if path == out.path_info / entry_relpath: - return entry[out.remote.PARAM_CHECKSUM] + return entry[out.remote.tree.PARAM_CHECKSUM] raise FileNotFoundError def open(self, path, mode="r", encoding="utf-8", remote=None): From a5ff610fb4c321f44b4b1433804f04e67ca39208 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Mon, 15 Jun 2020 19:08:48 +0900 Subject: [PATCH 14/16] tests: update test_data_cloud --- tests/func/test_data_cloud.py | 80 +++++++++++++++++++---------------- 1 file changed, 43 insertions(+), 37 deletions(-) diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index 02611f4e15..eda784b6c2 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -12,19 +12,16 @@ from dvc.data_cloud import DataCloud from dvc.external_repo import clean_repos from dvc.main import main -from dvc.remote import ( - AzureRemote, - GDriveRemote, - GSRemote, - HDFSRemote, - HTTPRemote, - LocalRemote, - OSSRemote, - S3Remote, - SSHRemote, -) +from dvc.remote.azure import AzureRemoteTree from dvc.remote.base import STATUS_DELETED, STATUS_NEW, STATUS_OK +from dvc.remote.gdrive import GDriveRemoteTree +from dvc.remote.gs import GSRemoteTree +from dvc.remote.hdfs import HDFSRemoteTree +from dvc.remote.http import HTTPRemoteTree from dvc.remote.local import LocalRemoteTree +from dvc.remote.oss import OSSRemoteTree +from dvc.remote.s3 import S3RemoteTree +from dvc.remote.ssh import SSHRemoteTree from dvc.stage.exceptions import StageNotFound from dvc.utils import file_md5 from dvc.utils.fs import remove @@ -50,19 +47,19 @@ class TestDataCloud(TestDvc): def _test_cloud(self, config, cl): self.dvc.config = config cloud = DataCloud(self.dvc) - self.assertIsInstance(cloud.get_remote(), cl) + self.assertIsInstance(cloud.get_remote().tree, cl) def test(self): config = copy.deepcopy(TEST_CONFIG) clist = [ - ("s3://mybucket/", S3Remote), - ("gs://mybucket/", GSRemote), - ("ssh://user@localhost:/", SSHRemote), - ("http://localhost:8000/", HTTPRemote), - ("azure://ContainerName=mybucket;conn_string;", AzureRemote), - ("oss://mybucket/", OSSRemote), - (TestDvc.mkdtemp(), LocalRemote), + ("s3://mybucket/", S3RemoteTree), + ("gs://mybucket/", GSRemoteTree), + ("ssh://user@localhost:/", SSHRemoteTree), + ("http://localhost:8000/", HTTPRemoteTree), + ("azure://ContainerName=mybucket;conn_string;", AzureRemoteTree), + ("oss://mybucket/", OSSRemoteTree), + (TestDvc.mkdtemp(), LocalRemoteTree), ] for scheme, cl in clist: @@ -101,7 +98,9 @@ def _setup_cloud(self): self.dvc.config = config self.cloud = DataCloud(self.dvc) - self.assertIsInstance(self.cloud.get_remote(), self._get_cloud_class()) + self.assertIsInstance( + self.cloud.get_remote().tree, self._get_cloud_class() + ) def _test_cloud(self): self._setup_cloud() @@ -187,7 +186,7 @@ def test(self): class TestS3Remote(S3, TestDataCloudBase): def _get_cloud_class(self): - return S3Remote + return S3RemoteTree class TestGDriveRemote(GDrive, TestDataCloudBase): @@ -208,10 +207,10 @@ def _setup_cloud(self): self.dvc.config = config self.cloud = DataCloud(self.dvc) remote = self.cloud.get_remote() - self.assertIsInstance(remote, self._get_cloud_class()) + self.assertIsInstance(remote.tree, self._get_cloud_class()) def _get_cloud_class(self): - return GDriveRemote + return GDriveRemoteTree class TestGSRemote(GCP, TestDataCloudBase): @@ -228,25 +227,27 @@ def _setup_cloud(self): self.dvc.config = config self.cloud = DataCloud(self.dvc) - self.assertIsInstance(self.cloud.get_remote(), self._get_cloud_class()) + self.assertIsInstance( + self.cloud.get_remote().tree, self._get_cloud_class() + ) def _get_cloud_class(self): - return GSRemote + return GSRemoteTree class TestAzureRemote(Azure, TestDataCloudBase): def _get_cloud_class(self): - return AzureRemote + return AzureRemoteTree class TestOSSRemote(OSS, TestDataCloudBase): def _get_cloud_class(self): - return OSSRemote + return OSSRemoteTree class TestLocalRemote(Local, TestDataCloudBase): def _get_cloud_class(self): - return LocalRemote + return LocalRemoteTree @pytest.mark.usefixtures("ssh_server") @@ -271,7 +272,9 @@ def _setup_cloud(self): self.dvc.config = config self.cloud = DataCloud(self.dvc) - self.assertIsInstance(self.cloud.get_remote(), self._get_cloud_class()) + self.assertIsInstance( + self.cloud.get_remote().tree, self._get_cloud_class() + ) def get_url(self): user = self.ssh_server.test_creds["username"] @@ -281,12 +284,12 @@ def _get_keyfile(self): return self.ssh_server.test_creds["key_filename"] def _get_cloud_class(self): - return SSHRemote + return SSHRemoteTree class TestHDFSRemote(HDFS, TestDataCloudBase): def _get_cloud_class(self): - return HDFSRemote + return HDFSRemoteTree @pytest.mark.usefixtures("http_server") @@ -300,7 +303,7 @@ def get_url(self): return super().get_url(self.http_server.server_port) def _get_cloud_class(self): - return HTTPRemote + return HTTPRemoteTree class TestDataCloudCLIBase(TestDvc): @@ -544,7 +547,7 @@ def main(self, args): self.assertEqual(ret, 0) def _get_cloud_class(self): - return LocalRemote + return LocalRemoteTree def _prepare_repo(self): remote = self.cloud.get_remote() @@ -577,8 +580,11 @@ def _clear_local_cache(self): def _test_recursive_fetch(self, data_md5, data_sub_md5): self._clear_local_cache() - local_cache_data_path = self.dvc.cache.local.get(data_md5) - local_cache_data_sub_path = self.dvc.cache.local.get(data_sub_md5) + local_cache = self.dvc.cache.local + local_cache_data_path = local_cache.checksum_to_path_info(data_md5) + local_cache_data_sub_path = local_cache.checksum_to_path_info( + data_sub_md5 + ) self.assertFalse(os.path.exists(local_cache_data_path)) self.assertFalse(os.path.exists(local_cache_data_sub_path)) @@ -590,8 +596,8 @@ def _test_recursive_fetch(self, data_md5, data_sub_md5): def _test_recursive_push(self, data_md5, data_sub_md5): remote = self.cloud.get_remote() - cloud_data_path = remote.get(data_md5) - cloud_data_sub_path = remote.get(data_sub_md5) + cloud_data_path = remote.checksum_to_path_info(data_md5) + cloud_data_sub_path = remote.checksum_to_path_info(data_sub_md5) self.assertFalse(os.path.exists(cloud_data_path)) self.assertFalse(os.path.exists(cloud_data_sub_path)) From 12488f84c199b5af2ce080f4142125d6af91d946 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Mon, 15 Jun 2020 19:45:56 +0900 Subject: [PATCH 15/16] bug fixes --- dvc/command/version.py | 8 ++++---- dvc/output/hdfs.py | 2 +- dvc/remote/base.py | 7 +++++-- dvc/remote/local.py | 4 +++- dvc/repo/__init__.py | 2 +- 5 files changed, 14 insertions(+), 9 deletions(-) diff --git a/dvc/command/version.py b/dvc/command/version.py index e53e36d571..02d38e81fb 100644 --- a/dvc/command/version.py +++ b/dvc/command/version.py @@ -122,12 +122,12 @@ def get_linktype_support_info(repo): @staticmethod def get_supported_remotes(): - from dvc.remote import REMOTES + from dvc.remote import TREES supported_remotes = [] - for remote in REMOTES: - if not remote.get_missing_deps(): - supported_remotes.append(remote.scheme) + for tree_cls in TREES: + if not tree_cls.get_missing_deps(): + supported_remotes.append(tree_cls.scheme) return ", ".join(supported_remotes) diff --git a/dvc/output/hdfs.py b/dvc/output/hdfs.py index fd44193a30..a6632db308 100644 --- a/dvc/output/hdfs.py +++ b/dvc/output/hdfs.py @@ -3,4 +3,4 @@ class HDFSOutput(BaseOutput): - REMOTE = HDFSRemoteTree + TREE_CLS = HDFSRemoteTree diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 448ba36a26..71b8100f7a 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -98,6 +98,7 @@ class BaseRemoteTree: TRAVERSE_THRESHOLD_SIZE = 500000 CAN_TRAVERSE = True + CACHE_MODE = None SHARED_MODE_MAP = {None: (None, None), "group": (None, None)} CHECKSUM_DIR_SUFFIX = ".dir" @@ -380,7 +381,9 @@ def _save_dir_info(self, dir_info, path_info): new_info = self.cache.checksum_to_path_info(checksum) if self.cache.changed_cache_file(checksum): self.cache.tree.makedirs(new_info.parent) - self.cache.tree.move(tmp_info, new_info, mode=self.CACHE_MODE) + self.cache.tree.move( + tmp_info, new_info, mode=self.cache.CACHE_MODE + ) if self.exists(path_info): self.state.save(path_info, checksum) @@ -891,7 +894,7 @@ class CloudCache: """Cloud cache class.""" DEFAULT_CACHE_TYPES = ["copy"] - CACHE_MODE = None + CACHE_MODE = BaseRemoteTree.CACHE_MODE def __init__(self, tree): self.tree = tree diff --git a/dvc/remote/local.py b/dvc/remote/local.py index e2f0cc3a94..38889d35fe 100644 --- a/dvc/remote/local.py +++ b/dvc/remote/local.py @@ -43,7 +43,6 @@ class LocalRemoteTree(BaseRemoteTree): PATH_CLS = PathInfo PARAM_CHECKSUM = "md5" PARAM_PATH = "path" - DEFAULT_CACHE_TYPES = ["reflink", "copy"] TRAVERSE_PREFIX_LEN = 2 UNPACKED_DIR_SUFFIX = ".unpacked" @@ -345,6 +344,9 @@ class LocalRemote(Remote): class LocalCache(CloudCache): + DEFAULT_CACHE_TYPES = ["reflink", "copy"] + CACHE_MODE = LocalRemoteTree.CACHE_MODE + def __init__(self, tree): super().__init__(tree) self.cache_dir = tree.config.get("url") diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index dd0b1528ac..acf2ebc9f8 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -187,7 +187,7 @@ def init(root_dir=os.curdir, no_scm=False, force=False, subdir=False): return Repo(root_dir) def unprotect(self, target): - return self.cache.local.unprotect(PathInfo(target)) + return self.cache.local.tree.unprotect(PathInfo(target)) def _ignore(self): flist = [self.config.files["local"], self.tmp_dir] From aef4701ee47a4c53e7015572f327b1398a5c8235 Mon Sep 17 00:00:00 2001 From: Peter Rowlands Date: Mon, 15 Jun 2020 19:46:23 +0900 Subject: [PATCH 16/16] tests: update func tests --- tests/func/test_add.py | 4 ++-- tests/func/test_cache.py | 10 +++++----- tests/func/test_checkout.py | 17 +++++++++-------- tests/func/test_gc.py | 7 ++++--- tests/func/test_repro.py | 6 +++--- tests/func/test_s3.py | 10 ++++++---- tests/func/test_stage.py | 6 +++--- tests/func/test_tree.py | 2 +- 8 files changed, 33 insertions(+), 29 deletions(-) diff --git a/tests/func/test_add.py b/tests/func/test_add.py index 1669974d34..5a9965d9dd 100644 --- a/tests/func/test_add.py +++ b/tests/func/test_add.py @@ -21,7 +21,7 @@ ) from dvc.main import main from dvc.output.base import OutputAlreadyTrackedError, OutputIsStageFileError -from dvc.remote.local import LocalRemote, LocalRemoteTree +from dvc.remote.local import LocalRemoteTree from dvc.repo import Repo as DvcRepo from dvc.stage import Stage from dvc.system import System @@ -583,7 +583,7 @@ def test_readding_dir_should_not_unprotect_all(tmp_dir, dvc, mocker): dvc.add("dir") tmp_dir.gen("dir/new_file", "new_file_content") - unprotect_spy = mocker.spy(LocalRemote, "unprotect") + unprotect_spy = mocker.spy(LocalRemoteTree, "unprotect") dvc.add("dir") assert not unprotect_spy.mock.called diff --git a/tests/func/test_cache.py b/tests/func/test_cache.py index 841e06f1cd..7f7972dcf7 100644 --- a/tests/func/test_cache.py +++ b/tests/func/test_cache.py @@ -30,14 +30,14 @@ def setUp(self): self.create(self.cache2, "2") def test_all(self): - md5_list = list(Cache(self.dvc).local.all()) + md5_list = list(Cache(self.dvc).local.tree.all()) self.assertEqual(len(md5_list), 2) self.assertIn(self.cache1_md5, md5_list) self.assertIn(self.cache2_md5, md5_list) def test_get(self): - cache = Cache(self.dvc).local.get(self.cache1_md5) - self.assertEqual(cache, self.cache1) + cache = Cache(self.dvc).local.checksum_to_path_info(self.cache1_md5) + self.assertEqual(os.fspath(cache), self.cache1) class TestCacheLoadBadDirCache(TestDvc): @@ -47,13 +47,13 @@ def _do_test(self, ret): def test(self): checksum = "123.dir" - fname = self.dvc.cache.local.get(checksum) + fname = os.fspath(self.dvc.cache.local.checksum_to_path_info(checksum)) self.create(fname, "not,json") with pytest.raises(DirCacheError): self.dvc.cache.local.load_dir_cache(checksum) checksum = "234.dir" - fname = self.dvc.cache.local.get(checksum) + fname = os.fspath(self.dvc.cache.local.checksum_to_path_info(checksum)) self.create(fname, '{"a": "b"}') self._do_test(self.dvc.cache.local.load_dir_cache(checksum)) diff --git a/tests/func/test_checkout.py b/tests/func/test_checkout.py index b417cd8e0d..e83de5c804 100644 --- a/tests/func/test_checkout.py +++ b/tests/func/test_checkout.py @@ -16,8 +16,9 @@ NoOutputOrStageError, ) from dvc.main import main -from dvc.remote import S3Cache, S3Remote -from dvc.remote.local import LocalRemote +from dvc.remote.base import CloudCache, Remote +from dvc.remote.local import LocalRemoteTree +from dvc.remote.s3 import S3RemoteTree from dvc.repo import Repo as DvcRepo from dvc.stage import Stage from dvc.stage.exceptions import StageFileDoesNotExistError @@ -99,8 +100,8 @@ def test(self): # NOTE: modifying cache file for one of the files inside the directory # to check if dvc will detect that the cache is corrupted. entry = self.dvc.cache.local.load_dir_cache(out.checksum)[0] - checksum = entry[self.dvc.cache.local.PARAM_CHECKSUM] - cache = self.dvc.cache.local.get(checksum) + checksum = entry[self.dvc.cache.local.tree.PARAM_CHECKSUM] + cache = os.fspath(self.dvc.cache.local.checksum_to_path_info(checksum)) os.chmod(cache, 0o644) with open(cache, "w+") as fobj: @@ -305,8 +306,8 @@ def test(self): class TestCheckoutMissingMd5InStageFile(TestRepro): def test(self): d = load_yaml(self.file1_stage) - del d[Stage.PARAM_OUTS][0][LocalRemote.PARAM_CHECKSUM] - del d[Stage.PARAM_DEPS][0][LocalRemote.PARAM_CHECKSUM] + del d[Stage.PARAM_OUTS][0][LocalRemoteTree.PARAM_CHECKSUM] + del d[Stage.PARAM_DEPS][0][LocalRemoteTree.PARAM_CHECKSUM] dump_yaml(self.file1_stage, d) with pytest.raises(CheckoutError): @@ -755,9 +756,9 @@ def test_checkout_recursive(tmp_dir, dvc): not S3.should_test(), reason="Only run with S3 credentials" ) def test_checkout_for_external_outputs(tmp_dir, dvc): - dvc.cache.s3 = S3Cache(dvc, {"url": S3.get_url()}) + dvc.cache.s3 = CloudCache(S3RemoteTree(dvc, {"url": S3.get_url()})) - remote = S3Remote(dvc, {"url": S3.get_url()}) + remote = Remote(S3RemoteTree(dvc, {"url": S3.get_url()})) file_path = remote.path_info / "foo" remote.tree.s3.put_object( Bucket=remote.path_info.bucket, Key=file_path.path, Body="foo" diff --git a/tests/func/test_gc.py b/tests/func/test_gc.py index 0685ed972d..52c16adf9b 100644 --- a/tests/func/test_gc.py +++ b/tests/func/test_gc.py @@ -8,7 +8,7 @@ from dvc.exceptions import CollectCacheError from dvc.main import main -from dvc.remote.local import LocalRemote, LocalRemoteTree +from dvc.remote.local import LocalRemoteTree from dvc.repo import Repo as DvcRepo from dvc.utils.fs import remove from tests.basic_env import TestDir, TestDvcGit @@ -21,7 +21,8 @@ def setUp(self): self.dvc.add(self.FOO) self.dvc.add(self.DATA_DIR) self.good_cache = [ - self.dvc.cache.local.get(md5) for md5 in self.dvc.cache.local.all() + self.dvc.cache.local.checksum_to_path_info(md5) + for md5 in self.dvc.cache.local.tree.all() ] self.bad_cache = [] @@ -216,7 +217,7 @@ def test_gc_no_unpacked_dir(tmp_dir, dvc): os.remove("dir.dvc") unpackeddir = ( - dir_stages[0].outs[0].cache_path + LocalRemote.UNPACKED_DIR_SUFFIX + dir_stages[0].outs[0].cache_path + LocalRemoteTree.UNPACKED_DIR_SUFFIX ) # older (pre 1.0) versions of dvc used to generate this dir diff --git a/tests/func/test_repro.py b/tests/func/test_repro.py index 473bd46db1..75d62ece71 100644 --- a/tests/func/test_repro.py +++ b/tests/func/test_repro.py @@ -26,7 +26,7 @@ from dvc.main import main from dvc.output.base import BaseOutput from dvc.path_info import URLInfo -from dvc.remote.local import LocalRemote, LocalRemoteTree +from dvc.remote.local import LocalRemoteTree from dvc.repo import Repo as DvcRepo from dvc.stage import Stage from dvc.stage.exceptions import StageFileDoesNotExistError @@ -782,8 +782,8 @@ def test(self): class TestReproMissingMd5InStageFile(TestRepro): def test(self): d = load_yaml(self.file1_stage) - del d[Stage.PARAM_OUTS][0][LocalRemote.PARAM_CHECKSUM] - del d[Stage.PARAM_DEPS][0][LocalRemote.PARAM_CHECKSUM] + del d[Stage.PARAM_OUTS][0][LocalRemoteTree.PARAM_CHECKSUM] + del d[Stage.PARAM_DEPS][0][LocalRemoteTree.PARAM_CHECKSUM] dump_yaml(self.file1_stage, d) stages = self.dvc.reproduce(self.file1_stage) diff --git a/tests/func/test_s3.py b/tests/func/test_s3.py index 58f93a5b0c..679ccaed9a 100644 --- a/tests/func/test_s3.py +++ b/tests/func/test_s3.py @@ -5,7 +5,8 @@ import pytest from moto import mock_s3 -from dvc.remote.s3 import S3Cache, S3Remote, S3RemoteTree +from dvc.remote.base import CloudCache +from dvc.remote.s3 import S3RemoteTree from tests.remotes import S3 # from https://github.com/spulec/moto/blob/v1.3.5/tests/test_s3/test_s3.py#L40 @@ -54,7 +55,8 @@ def test_copy_singlepart_preserve_etag(): ], ) def test_link_created_on_non_nested_path(base_info, tmp_dir, dvc, scm): - cache = S3Cache(dvc, {"url": str(base_info.parent)}) + tree = S3RemoteTree(dvc, {"url": str(base_info.parent)}) + cache = CloudCache(tree) s3 = cache.tree.s3 s3.create_bucket(Bucket=base_info.bucket) s3.put_object( @@ -69,8 +71,8 @@ def test_link_created_on_non_nested_path(base_info, tmp_dir, dvc, scm): @mock_s3 def test_makedirs_doesnot_try_on_top_level_paths(tmp_dir, dvc, scm): base_info = S3RemoteTree.PATH_CLS("s3://bucket/") - remote = S3Remote(dvc, {"url": str(base_info)}) - remote.tree.makedirs(base_info) + tree = S3RemoteTree(dvc, {"url": str(base_info)}) + tree.makedirs(base_info) def _upload_multipart(s3, Bucket, Key): diff --git a/tests/func/test_stage.py b/tests/func/test_stage.py index ec202dc60e..f2c00113b9 100644 --- a/tests/func/test_stage.py +++ b/tests/func/test_stage.py @@ -6,7 +6,7 @@ from dvc.dvcfile import SingleStageFile from dvc.main import main from dvc.output.local import LocalOutput -from dvc.remote.local import LocalRemote +from dvc.remote.local import LocalRemoteTree from dvc.repo import Repo from dvc.stage import PipelineStage, Stage from dvc.stage.exceptions import StageFileFormatError @@ -54,8 +54,8 @@ def test_empty_list(): def test_list(): lst = [ - {LocalOutput.PARAM_PATH: "foo", LocalRemote.PARAM_CHECKSUM: "123"}, - {LocalOutput.PARAM_PATH: "bar", LocalRemote.PARAM_CHECKSUM: None}, + {LocalOutput.PARAM_PATH: "foo", LocalRemoteTree.PARAM_CHECKSUM: "123"}, + {LocalOutput.PARAM_PATH: "bar", LocalRemoteTree.PARAM_CHECKSUM: None}, {LocalOutput.PARAM_PATH: "baz"}, ] d = {Stage.PARAM_DEPS: lst} diff --git a/tests/func/test_tree.py b/tests/func/test_tree.py index daffd364e1..7c7dd0ed59 100644 --- a/tests/func/test_tree.py +++ b/tests/func/test_tree.py @@ -192,7 +192,7 @@ def test_repotree_walk_fetch(tmp_dir, dvc, scm, setup_remote): assert os.path.exists(out.cache_path) for entry in out.dir_cache: - checksum = entry[out.remote.PARAM_CHECKSUM] + checksum = entry[out.remote.tree.PARAM_CHECKSUM] assert os.path.exists(dvc.cache.local.checksum_to_path_info(checksum))