Skip to content

Commit

Permalink
repofs: use underlying fs.download to download files (#6401)
Browse files Browse the repository at this point in the history
  • Loading branch information
efiop committed Aug 10, 2021
1 parent 1f82782 commit e99e6cd
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 19 deletions.
22 changes: 15 additions & 7 deletions dvc/fs/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ def _get_granular_hash(
return obj.hash_info
raise FileNotFoundError

def open( # type: ignore
self, path: PathInfo, mode="r", encoding=None, remote=None, **kwargs
): # pylint: disable=arguments-differ
def _get_fs_path(self, path: PathInfo, remote=None):
try:
outs = self._find_outs(path, strict=False)
except OutputNotFoundError as exc:
Expand Down Expand Up @@ -92,16 +90,20 @@ def open( # type: ignore
else:
checksum = out.hash_info.value
remote_info = remote_odb.hash_to_path_info(checksum)
return remote_odb.fs.open(
remote_info, mode=mode, encoding=encoding
)
return remote_odb.fs, remote_info

if out.is_dir_checksum:
checksum = self._get_granular_hash(path, out).value
cache_path = out.odb.hash_to_path_info(checksum).url
else:
cache_path = out.cache_path
return open(cache_path, mode=mode, encoding=encoding)
return out.odb.fs, cache_path

def open( # type: ignore
self, path: PathInfo, mode="r", encoding=None, **kwargs
): # pylint: disable=arguments-renamed
fs, fspath = self._get_fs_path(path, **kwargs)
return fs.open(fspath, mode=mode, encoding=encoding)

def exists(self, path): # pylint: disable=arguments-renamed
try:
Expand Down Expand Up @@ -253,3 +255,9 @@ def info(self, path_info):
ret[obj.hash_info.name] = obj.hash_info.value

return ret

def _download(self, from_info, to_file, **kwargs):
fs, path = self._get_fs_path(from_info)
fs._download( # pylint: disable=protected-access
path, to_file, **kwargs
)
14 changes: 14 additions & 0 deletions dvc/fs/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,17 @@ def walk_files(self, path_info, **kwargs):
for file in files:
# NOTE: os.path.join is ~5.5 times slower
yield f"{root}{os.sep}{file}"

def _download(
self, from_info, to_file, name=None, no_progress_bar=False, **kwargs
):
import shutil

from dvc.progress import Tqdm

with open(to_file, "wb+") as to_fobj:
with Tqdm.wrapattr(
to_fobj, "write", desc=name, disable=no_progress_bar
) as wrapped:
with self.open(from_info, "rb", **kwargs) as from_fobj:
shutil.copyfileobj(from_fobj, wrapped)
25 changes: 13 additions & 12 deletions dvc/fs/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,19 +453,20 @@ def walk_files(self, path_info, **kwargs):
for fname in files:
yield PathInfo(root) / fname

def _download(
self, from_info, to_file, name=None, no_progress_bar=False, **kwargs
):
import shutil

from dvc.progress import Tqdm
def _download(self, from_info, to_file, **kwargs):
fs, dvc_fs = self._get_fs_pair(from_info)
try:
fs._download( # pylint: disable=protected-access
from_info, to_file, **kwargs
)
return
except FileNotFoundError:
if not dvc_fs:
raise

with open(to_file, "wb+") as to_fobj:
with Tqdm.wrapattr(
to_fobj, "write", desc=name, disable=no_progress_bar
) as wrapped:
with self.open(from_info, "rb", **kwargs) as from_fobj:
shutil.copyfileobj(from_fobj, wrapped)
dvc_fs._download( # pylint: disable=protected-access
from_info, to_file, **kwargs
)

def metadata(self, path):
abspath = os.path.abspath(path)
Expand Down

0 comments on commit e99e6cd

Please sign in to comment.