Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prefetch mvp #986

Merged
merged 6 commits into from Nov 10, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion catalyst/data/__init__.py
Expand Up @@ -10,7 +10,10 @@
MetricLearningTrainDataset,
QueryGalleryDataset,
)
from catalyst.data.loader import BatchLimitLoaderWrapper
from catalyst.data.loader import (
BatchLimitLoaderWrapper,
BatchPrefetchLoaderWrapper,
)
from catalyst.data.reader import (
ReaderSpec,
ScalarReader,
Expand Down
161 changes: 126 additions & 35 deletions catalyst/data/loader.py
@@ -1,9 +1,53 @@
from typing import Union
from typing import Any, Callable, Iterable, Union
import queue
import sys
import threading

import numpy as np

import torch
from torch.utils.data import DataLoader


class BatchLimitLoaderWrapper:
class ILoaderWrapper:
def __init__(self, loader: DataLoader):
self.origin = loader

def __getattr__(self, key):
"""
Gets attribute by ``key``.
Firstly, looks at the ``origin`` for the appropriate ``key``.
If none founds - looks at the wrappers attributes.
If could not found anything - raises ``NotImplementedError``.

Args:
key: attribute key

Returns:
attribute value

Raises:
NotImplementedError: if could not find attribute in ``origin``
or ``wrapper``
"""
value = getattr(self.origin, key, None)
if value is not None:
return value
value = getattr(self, key, None)
if value is not None:
return value
raise NotImplementedError()

def __len__(self) -> int:
"""Returns length of the wrapper loader.

Returns:
int: length of the wrapper loader
"""
return len(self.origin)


class BatchLimitLoaderWrapper(ILoaderWrapper):
"""
Loader wrapper. Limits number of batches used per each iteration.

Expand Down Expand Up @@ -50,6 +94,7 @@ def __init__(self, loader: DataLoader, num_batches: Union[int, float]):
num_batches (Union[int, float]): number of batches to use (int),
or portion of iterator (float, should be in [0;1] range)
"""
super().__init__(loader)
assert isinstance(num_batches, (int, float)), (
"Expected ``num_batches`` type is int/float"
f"but got {type(num_batches)}"
Expand All @@ -61,36 +106,10 @@ def __init__(self, loader: DataLoader, num_batches: Union[int, float]):
)
num_batches = int(len(loader) * num_batches)

self.origin = loader
self.iterator = iter(self.origin)
self.iteration_index = 0
self.num_batches = num_batches

def __getattr__(self, key):
"""
Gets attribute by ``key``.
Firstly, looks at the ``origin`` for the appropriate ``key``.
If none founds - looks at the wrappers attributes.
If could not found anything - raises ``NotImplementedError``.

Args:
key: attribute key

Returns:
attribute value

Raises:
NotImplementedError: if could not find attribute in ``origin``
or ``wrapper``
"""
value = getattr(self.origin, key, None)
if value is not None:
return value
value = getattr(self, key, None)
if value is not None:
return value
raise NotImplementedError()

def __iter__(self):
"""Iterator.

Expand All @@ -115,13 +134,85 @@ def __next__(self):
batch = next(self.iterator)
return batch

def __len__(self) -> int:
"""Returns length of the wrapper loader.

Returns:
int: length of the wrapper loader
"""
return len(self.origin)
def _any2cuda_non_blocking(value: Any):
# based on catalyst.utils.torch.any2device
# but with cuda non_blocking trick
if isinstance(value, dict):
return {k: _any2cuda_non_blocking(v) for k, v in value.items()}
elif isinstance(value, (tuple, list)):
return [_any2cuda_non_blocking(v) for v in value]
elif torch.is_tensor(value):
return value.cuda(non_blocking=True)
elif (
isinstance(value, (np.ndarray, np.void))
and value.dtype.fields is not None
):
return {
k: _any2cuda_non_blocking(value[k])
for k in value.dtype.fields.keys()
}
elif isinstance(value, np.ndarray):
return torch.tensor(value).cuda(non_blocking=True)


def _map_loop(
func: Callable,
iterable: Iterable,
result_queue: queue.Queue,
error_queue: queue.Queue,
done_event: threading.Event,
):
try:
for x in iterable:
result = func(x)
result_queue.put(result)
except BaseException:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
WPS424 Found except BaseException

error_queue.put(sys.exc_info())
finally:
done_event.set()


def _prefetch_map(
func: Callable,
iterable: Iterable,
num_prefetches: int = 1,
timeout: int = 2,
) -> Iterable:
result_queue = queue.Queue(num_prefetches)
error_queue = queue.Queue(1)
done_event = threading.Event()
map_thread = threading.Thread(
target=_map_loop,
args=(func, iterable, result_queue, error_queue, done_event),
)
map_thread.daemon = True
map_thread.start()
while not (done_event.is_set() and result_queue.empty()):
try:
result = result_queue.get(timeout=timeout)
except queue.Empty:
continue
yield result
if error_queue.full():
raise error_queue.get()[1]


def _prefetch_loader(loader: DataLoader, num_prefetches: int) -> Iterable:
if torch.cuda.is_available():
loader = _prefetch_map(
_any2cuda_non_blocking, loader, num_prefetches=num_prefetches,
)
return loader


class BatchPrefetchLoaderWrapper(ILoaderWrapper):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
D101 Missing docstring in public class

def __init__(self, loader: DataLoader, num_prefetches: int = None):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
D107 Missing docstring in init

super().__init__(loader)
self.num_prefetches = num_prefetches or loader.batch_size

def __iter__(self):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
D105 Missing docstring in magic method

return _prefetch_loader(self.origin, self.num_prefetches)


__all__ = ["BatchLimitLoaderWrapper"]
__all__ = ["BatchLimitLoaderWrapper", "BatchPrefetchLoaderWrapper"]
2 changes: 1 addition & 1 deletion catalyst/utils/torch.py
Expand Up @@ -143,7 +143,7 @@ def any2device(value, device: Device):
k: any2device(value[k], device) for k in value.dtype.fields.keys()
}
elif isinstance(value, np.ndarray):
return torch.Tensor(value).to(device)
return torch.tensor(value, device=device)
return value


Expand Down