New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
prefetch mvp #986
prefetch mvp #986
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,53 @@ | ||
from typing import Union | ||
from typing import Any, Callable, Iterable, Union | ||
import queue | ||
import sys | ||
import threading | ||
|
||
import numpy as np | ||
|
||
import torch | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
class BatchLimitLoaderWrapper: | ||
class ILoaderWrapper: | ||
def __init__(self, loader: DataLoader): | ||
self.origin = loader | ||
|
||
def __getattr__(self, key): | ||
""" | ||
Gets attribute by ``key``. | ||
Firstly, looks at the ``origin`` for the appropriate ``key``. | ||
If none founds - looks at the wrappers attributes. | ||
If could not found anything - raises ``NotImplementedError``. | ||
|
||
Args: | ||
key: attribute key | ||
|
||
Returns: | ||
attribute value | ||
|
||
Raises: | ||
NotImplementedError: if could not find attribute in ``origin`` | ||
or ``wrapper`` | ||
""" | ||
value = getattr(self.origin, key, None) | ||
if value is not None: | ||
return value | ||
value = getattr(self, key, None) | ||
if value is not None: | ||
return value | ||
raise NotImplementedError() | ||
|
||
def __len__(self) -> int: | ||
"""Returns length of the wrapper loader. | ||
|
||
Returns: | ||
int: length of the wrapper loader | ||
""" | ||
return len(self.origin) | ||
|
||
|
||
class BatchLimitLoaderWrapper(ILoaderWrapper): | ||
""" | ||
Loader wrapper. Limits number of batches used per each iteration. | ||
|
||
|
@@ -50,6 +94,7 @@ def __init__(self, loader: DataLoader, num_batches: Union[int, float]): | |
num_batches (Union[int, float]): number of batches to use (int), | ||
or portion of iterator (float, should be in [0;1] range) | ||
""" | ||
super().__init__(loader) | ||
assert isinstance(num_batches, (int, float)), ( | ||
"Expected ``num_batches`` type is int/float" | ||
f"but got {type(num_batches)}" | ||
|
@@ -61,36 +106,10 @@ def __init__(self, loader: DataLoader, num_batches: Union[int, float]): | |
) | ||
num_batches = int(len(loader) * num_batches) | ||
|
||
self.origin = loader | ||
self.iterator = iter(self.origin) | ||
self.iteration_index = 0 | ||
self.num_batches = num_batches | ||
|
||
def __getattr__(self, key): | ||
""" | ||
Gets attribute by ``key``. | ||
Firstly, looks at the ``origin`` for the appropriate ``key``. | ||
If none founds - looks at the wrappers attributes. | ||
If could not found anything - raises ``NotImplementedError``. | ||
|
||
Args: | ||
key: attribute key | ||
|
||
Returns: | ||
attribute value | ||
|
||
Raises: | ||
NotImplementedError: if could not find attribute in ``origin`` | ||
or ``wrapper`` | ||
""" | ||
value = getattr(self.origin, key, None) | ||
if value is not None: | ||
return value | ||
value = getattr(self, key, None) | ||
if value is not None: | ||
return value | ||
raise NotImplementedError() | ||
|
||
def __iter__(self): | ||
"""Iterator. | ||
|
||
|
@@ -115,13 +134,165 @@ def __next__(self): | |
batch = next(self.iterator) | ||
return batch | ||
|
||
def __len__(self) -> int: | ||
"""Returns length of the wrapper loader. | ||
|
||
Returns: | ||
int: length of the wrapper loader | ||
""" | ||
return len(self.origin) | ||
def _any2cuda_non_blocking(value: Any): | ||
# based on catalyst.utils.torch.any2device | ||
# but with cuda non_blocking trick | ||
if isinstance(value, dict): | ||
return {k: _any2cuda_non_blocking(v) for k, v in value.items()} | ||
elif isinstance(value, (tuple, list)): | ||
return [_any2cuda_non_blocking(v) for v in value] | ||
elif torch.is_tensor(value): | ||
return value.cuda(non_blocking=True) | ||
elif ( | ||
isinstance(value, (np.ndarray, np.void)) | ||
and value.dtype.fields is not None | ||
): | ||
return { | ||
k: _any2cuda_non_blocking(value[k]) | ||
for k in value.dtype.fields.keys() | ||
} | ||
elif isinstance(value, np.ndarray): | ||
return torch.tensor(value).cuda(non_blocking=True) | ||
|
||
|
||
def _map_loop( | ||
func: Callable, | ||
iterable: Iterable, | ||
result_queue: queue.Queue, | ||
error_queue: queue.Queue, | ||
done_event: threading.Event, | ||
): | ||
try: | ||
for x in iterable: | ||
result = func(x) | ||
result_queue.put(result) | ||
except BaseException: | ||
error_queue.put(sys.exc_info()) | ||
finally: | ||
done_event.set() | ||
|
||
|
||
def _prefetch_map( | ||
func: Callable, | ||
iterable: Iterable, | ||
num_prefetches: int = 1, | ||
timeout: int = 2, | ||
) -> Iterable: | ||
result_queue = queue.Queue(num_prefetches) | ||
error_queue = queue.Queue(1) | ||
done_event = threading.Event() | ||
map_thread = threading.Thread( | ||
target=_map_loop, | ||
args=(func, iterable, result_queue, error_queue, done_event), | ||
) | ||
map_thread.daemon = True | ||
map_thread.start() | ||
while not (done_event.is_set() and result_queue.empty()): | ||
try: | ||
result = result_queue.get(timeout=timeout) | ||
except queue.Empty: | ||
continue | ||
yield result | ||
if error_queue.full(): | ||
raise error_queue.get()[1] | ||
|
||
|
||
def _prefetch_loader(loader: DataLoader, num_prefetches: int) -> Iterable: | ||
if torch.cuda.is_available(): | ||
loader = _prefetch_map( | ||
_any2cuda_non_blocking, loader, num_prefetches=num_prefetches, | ||
) | ||
return loader | ||
|
||
|
||
class BatchPrefetchLoaderWrapper(ILoaderWrapper): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [pep8] reported by reviewdog 🐶 |
||
""" | ||
|
||
Base usage: | ||
|
||
.. code-block:: python | ||
|
||
import torch | ||
from torch.utils.data import DataLoader, TensorDataset | ||
from catalyst.data import BatchPrefetchLoaderWrapper | ||
|
||
num_samples, num_features = int(1e4), int(1e1) | ||
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples) | ||
dataset = TensorDataset(X, y) | ||
loader = DataLoader(dataset, batch_size=32, num_workers=1) | ||
loader = BatchPrefetchLoaderWrapper(loader) | ||
|
||
Minimal working example: | ||
|
||
.. code-block:: python | ||
|
||
import os | ||
import torch | ||
from torch.nn import functional as F | ||
from torch.utils.data import DataLoader | ||
from catalyst import dl, metrics | ||
from catalyst.data.cv import ToTensor | ||
from catalyst.contrib.datasets import MNIST | ||
from catalyst.data import BatchPrefetchLoaderWrapper | ||
|
||
class CustomRunner(dl.Runner): | ||
|
||
def predict_batch(self, batch): | ||
# model inference step | ||
return self.model(batch[0].to(self.device).view(batch[0].size(0), -1)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [pep8] reported by reviewdog 🐶 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [pep8] reported by reviewdog 🐶 |
||
|
||
def _handle_batch(self, batch): | ||
# model train/valid step | ||
x, y = batch | ||
y_hat = self.model(x.view(x.size(0), -1)) | ||
|
||
loss = F.cross_entropy(y_hat, y) | ||
accuracy01, accuracy03 = metrics.accuracy(y_hat, y, topk=(1, 3)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [pep8] reported by reviewdog 🐶 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [pep8] reported by reviewdog 🐶 |
||
self.batch_metrics.update( | ||
{"loss": loss, "accuracy01": accuracy01, "accuracy03": accuracy03} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [pep8] reported by reviewdog 🐶 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [pep8] reported by reviewdog 🐶 |
||
) | ||
|
||
if self.is_train_loader: | ||
loss.backward() | ||
self.optimizer.step() | ||
self.optimizer.zero_grad() | ||
|
||
model = torch.nn.Linear(28 * 28, 10) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.02) | ||
|
||
batch_size=32 | ||
loaders = { | ||
"train": DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=batch_size), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [pep8] reported by reviewdog 🐶 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [pep8] reported by reviewdog 🐶 |
||
"valid": DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=batch_size), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [pep8] reported by reviewdog 🐶 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [pep8] reported by reviewdog 🐶 |
||
} | ||
loaders = {k: BatchPrefetchLoaderWrapper(v) for k, v in loaders.items()} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [pep8] reported by reviewdog 🐶 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [pep8] reported by reviewdog 🐶 |
||
|
||
runner = CustomRunner() | ||
# model training | ||
runner.train( | ||
model=model, | ||
optimizer=optimizer, | ||
loaders=loaders, | ||
logdir="./logs", | ||
num_epochs=5, | ||
verbose=True, | ||
load_best_on_end=True, | ||
) | ||
# model inference | ||
for prediction in runner.predict_loader(loader=loaders["valid"]): | ||
assert prediction.detach().cpu().numpy().shape[-1] == 10 | ||
# model tracing | ||
traced_model = runner.trace(loader=loaders["valid"]) | ||
|
||
""" | ||
|
||
def __init__(self, loader: DataLoader, num_prefetches: int = None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [pep8] reported by reviewdog 🐶 |
||
super().__init__(loader) | ||
self.num_prefetches = num_prefetches or loader.batch_size | ||
|
||
def __iter__(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [pep8] reported by reviewdog 🐶 |
||
return _prefetch_loader(self.origin, self.num_prefetches) | ||
|
||
|
||
__all__ = ["BatchLimitLoaderWrapper"] | ||
__all__ = ["BatchLimitLoaderWrapper", "BatchPrefetchLoaderWrapper"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[pep8] reported by reviewdog 🐶
WPS424 Found except
BaseException