Skip to content

Commit

Permalink
objects.transfer: minor refactoring, move lazy taskset inside custom …
Browse files Browse the repository at this point in the history
…executor
  • Loading branch information
skshetry committed Sep 10, 2021
1 parent 61fe8bc commit d4207c0
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 108 deletions.
149 changes: 41 additions & 108 deletions dvc/objects/transfer.py
@@ -1,20 +1,19 @@
import errno
import itertools
import logging
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor
from functools import partial, wraps
from typing import TYPE_CHECKING, Callable, Iterable, Optional
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional

from funcy import split

from dvc.progress import Tqdm
from dvc.utils.threadpool import ThreadPoolExecutor

if TYPE_CHECKING:
from dvc.hash_info import HashInfo

from .db.base import ObjectDB
from .db.index import ObjectDBIndexBase
from .tree import Tree

logger = logging.getLogger(__name__)

Expand All @@ -38,109 +37,48 @@ def wrapper(path_info, *args, **kwargs):
return wrapper


def _transfer(
src, dest, dir_ids, file_ids, missing_ids, jobs, verify, move, **kwargs
):
func = _log_exceptions(dest.add)
total = len(dir_ids) + len(file_ids)
if total == 0:
return 0
with Tqdm(total=total, unit="file", desc="Transferring") as pbar:
func = pbar.wrap_fn(func)
with ThreadPoolExecutor(max_workers=jobs) as executor:
processor = partial(
_create_tasks,
executor,
jobs,
func,
src,
verify,
move,
)
processor.add_func = func
_do_transfer(
src,
dest,
dir_ids,
file_ids,
missing_ids,
processor,
verify=verify,
move=move,
**kwargs,
)
return total


def _create_tasks(executor, jobs, func, src, verify, move, obj_ids):
fails = 0
hash_iter = iter(obj_ids)

def submit(hash_info):
obj = src.get(hash_info)
return executor.submit(
func,
obj.path_info,
obj.fs,
obj.hash_info,
verify=verify,
move=move,
)

def create_taskset(amount):
return {
submit(hash_info)
for hash_info in itertools.islice(hash_iter, amount)
}
def find_tree_by_obj_id(
odbs: Iterable[Optional["ObjectDB"]], obj_id: "HashInfo"
) -> Optional["Tree"]:
from .errors import ObjectFormatError
from .tree import Tree

tasks = create_taskset(jobs * 5)
while tasks:
done, tasks = futures.wait(tasks, return_when=futures.FIRST_COMPLETED)
fails += sum(task.result() for task in done)
tasks.update(create_taskset(len(done)))
return fails
for odb in odbs:
if odb is not None:
try:
return Tree.load(odb, obj_id)
except (FileNotFoundError, ObjectFormatError):
pass
return None


def _do_transfer(
src: "ObjectDB",
dest: "ObjectDB",
dir_ids: Iterable["HashInfo"],
file_ids: Iterable["HashInfo"],
obj_ids: Iterable["HashInfo"],
missing_ids: Iterable["HashInfo"],
processor: Callable,
src_index: Optional["ObjectDBIndexBase"] = None,
dest_index: Optional["ObjectDBIndexBase"] = None,
cache_odb: Optional["ObjectDB"] = None,
**kwargs,
**kwargs: Any,
):
from dvc.exceptions import FileTransferError
from dvc.objects.errors import ObjectFormatError

dir_ids, file_ids = split(lambda hash_info: hash_info.isdir, obj_ids)
total_fails = 0
succeeded_dir_objs = []
all_file_ids = set(file_ids)

for dir_hash in dir_ids:
from .tree import Tree

bound_file_ids = set()
dir_obj: Optional["Tree"] = None
for odb in (cache_odb, src):
if odb is not None:
try:
dir_obj = Tree.load(odb, dir_hash)
break
except (FileNotFoundError, ObjectFormatError):
pass
dir_obj = find_tree_by_obj_id([cache_odb, src], dir_hash)
assert dir_obj
entry_ids = {entry.hash_info for _, entry in dir_obj}

for file_hash in all_file_ids.copy():
if file_hash in entry_ids:
bound_file_ids.add(file_hash)
all_file_ids.remove(file_hash)
entry_ids = {entry.hash_info for _, entry in dir_obj}
bound_file_ids = all_file_ids & entry_ids
all_file_ids -= entry_ids

dir_fails = processor(bound_file_ids)
dir_fails = sum(processor(bound_file_ids))
if dir_fails:
logger.debug(
"failed to upload full contents of '%s', "
Expand All @@ -164,19 +102,13 @@ def _do_transfer(
dir_obj.name,
)
else:
raw_obj = src.get(dir_obj.hash_info)
is_dir_failed = processor.add_func( # type: ignore[attr-defined]
raw_obj.path_info,
raw_obj.fs,
raw_obj.hash_info,
**kwargs,
)
is_dir_failed = sum(processor([dir_obj.hash_info]))
total_fails += is_dir_failed
if not is_dir_failed:
succeeded_dir_objs.append(dir_obj)

# insert the rest
total_fails += processor(all_file_ids)
total_fails += sum(processor(all_file_ids))
if total_fails:
if src_index:
src_index.clear()
Expand Down Expand Up @@ -222,18 +154,19 @@ def transfer(
if not status.new:
return 0

dir_ids, file_ids = split(lambda hash_info: hash_info.isdir, status.new)
if jobs is None:
jobs = dest.fs.jobs

return _transfer(
src,
dest,
set(dir_ids),
set(file_ids),
status.missing,
jobs,
verify,
move,
**kwargs,
)
def func(hash_info: "HashInfo") -> None:
obj = src.get(hash_info)
return dest.add(
obj.path_info, obj.fs, obj.hash_info, verify=verify, move=move
)

total = len(status.new)
jobs = jobs or dest.fs.jobs
with Tqdm(total=total, unit="file", desc="Transferring") as pbar:
with ThreadPoolExecutor(max_workers=jobs) as executor:
wrapped_func = pbar.wrap_fn(_log_exceptions(func))
processor = partial(executor.imap_unordered, wrapped_func)
_do_transfer(
src, dest, status.new, status.missing, processor, **kwargs
)
return total
34 changes: 34 additions & 0 deletions dvc/utils/threadpool.py
@@ -0,0 +1,34 @@
from concurrent import futures
from itertools import islice
from typing import Any, Callable, Iterable, Iterator, Set, TypeVar

_T = TypeVar("_T")


class ThreadPoolExecutor(futures.ThreadPoolExecutor):
_max_workers: int

@property
def max_workers(self) -> int:
return self._max_workers

def imap_unordered(
self, fn: Callable[..., _T], *iterables: Iterable[Any]
) -> Iterator[_T]:
"""Lazier version of map that does not preserve ordering of results.
It does not create all the futures at once to reduce memory usage.
"""

def create_taskset(n: int) -> Set[futures.Future]:
return {self.submit(fn, *args) for args in islice(it, n)}

it = zip(*iterables)
tasks = create_taskset(self.max_workers * 5)
while tasks:
done, tasks = futures.wait(
tasks, return_when=futures.FIRST_COMPLETED
)
for fut in done:
yield fut.result()
tasks.update(create_taskset(len(done)))

0 comments on commit d4207c0

Please sign in to comment.