-
-
Notifications
You must be signed in to change notification settings - Fork 386
/
loader.py
298 lines (240 loc) · 9.14 KB
/
loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
from typing import Any, Callable, Iterable, Union
import queue
import sys
import threading
import numpy as np
import torch
from torch.utils.data import DataLoader
class ILoaderWrapper:
def __init__(self, loader: DataLoader):
self.origin = loader
def __getattr__(self, key):
"""
Gets attribute by ``key``.
Firstly, looks at the ``origin`` for the appropriate ``key``.
If none founds - looks at the wrappers attributes.
If could not found anything - raises ``NotImplementedError``.
Args:
key: attribute key
Returns:
attribute value
Raises:
NotImplementedError: if could not find attribute in ``origin``
or ``wrapper``
"""
value = getattr(self.origin, key, None)
if value is not None:
return value
value = getattr(self, key, None)
if value is not None:
return value
raise NotImplementedError()
def __len__(self) -> int:
"""Returns length of the wrapper loader.
Returns:
int: length of the wrapper loader
"""
return len(self.origin)
class BatchLimitLoaderWrapper(ILoaderWrapper):
"""
Loader wrapper. Limits number of batches used per each iteration.
For example, if you have some loader and want to use only first 5 bathes:
.. code-block:: python
import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst.data.loader import BatchLimitLoaderWrapper
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loader = BatchLimitLoaderWrapper(loader, num_batches=5)
or if you would like to use only some portion of Dataloader
(we use 30% in the example below):
.. code-block:: python
import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst.data.loader import BatchLimitLoaderWrapper
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loader = BatchLimitLoaderWrapper(loader, num_batches=0.3)
.. note::
Generally speaking, this wrapper could be used with any iterator-like
object. No ``DataLoader``-specific code used.
"""
def __init__(self, loader: DataLoader, num_batches: Union[int, float]):
"""
Loader wrapper. Limits number of batches used per each iteration.
Args:
loader: torch dataloader.
num_batches (Union[int, float]): number of batches to use (int),
or portion of iterator (float, should be in [0;1] range)
"""
super().__init__(loader)
assert isinstance(num_batches, (int, float)), (
"Expected ``num_batches`` type is int/float"
f"but got {type(num_batches)}"
)
if isinstance(num_batches, float):
assert 0.0 <= num_batches <= 1, (
"Expected ``num_batches`` to be in range [0; 1]"
f"but got {num_batches}"
)
num_batches = int(len(loader) * num_batches)
self.iterator = iter(self.origin)
self.iteration_index = 0
self.num_batches = num_batches
def __iter__(self):
"""Iterator.
Returns:
iterator object
"""
self.iteration_index = 0
self.iterator = iter(self.origin)
return self
def __next__(self):
"""Next batch.
Returns:
next batch
"""
if self.iteration_index >= len(self.origin):
raise StopIteration()
self.iteration_index += 1
if self.iteration_index % self.num_batches == 0:
self.iterator = iter(self.origin)
batch = next(self.iterator)
return batch
def _any2cuda_non_blocking(value: Any):
# based on catalyst.utils.torch.any2device
# but with cuda non_blocking trick
if isinstance(value, dict):
return {k: _any2cuda_non_blocking(v) for k, v in value.items()}
elif isinstance(value, (tuple, list)):
return [_any2cuda_non_blocking(v) for v in value]
elif torch.is_tensor(value):
return value.cuda(non_blocking=True)
elif (
isinstance(value, (np.ndarray, np.void))
and value.dtype.fields is not None
):
return {
k: _any2cuda_non_blocking(value[k])
for k in value.dtype.fields.keys()
}
elif isinstance(value, np.ndarray):
return torch.tensor(value).cuda(non_blocking=True)
def _map_loop(
func: Callable,
iterable: Iterable,
result_queue: queue.Queue,
error_queue: queue.Queue,
done_event: threading.Event,
):
try:
for x in iterable:
result = func(x)
result_queue.put(result)
except BaseException:
error_queue.put(sys.exc_info())
finally:
done_event.set()
def _prefetch_map(
func: Callable,
iterable: Iterable,
num_prefetches: int = 1,
timeout: int = 2,
) -> Iterable:
result_queue = queue.Queue(num_prefetches)
error_queue = queue.Queue(1)
done_event = threading.Event()
map_thread = threading.Thread(
target=_map_loop,
args=(func, iterable, result_queue, error_queue, done_event),
)
map_thread.daemon = True
map_thread.start()
while not (done_event.is_set() and result_queue.empty()):
try:
result = result_queue.get(timeout=timeout)
except queue.Empty:
continue
yield result
if error_queue.full():
raise error_queue.get()[1]
def _prefetch_loader(loader: DataLoader, num_prefetches: int) -> Iterable:
if torch.cuda.is_available():
loader = _prefetch_map(
_any2cuda_non_blocking, loader, num_prefetches=num_prefetches,
)
return loader
class BatchPrefetchLoaderWrapper(ILoaderWrapper):
"""
Base usage:
.. code-block:: python
import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst.data import BatchPrefetchLoaderWrapper
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loader = BatchPrefetchLoaderWrapper(loader)
Minimal working example:
.. code-block:: python
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from catalyst import dl, metrics
from catalyst.data.cv import ToTensor
from catalyst.contrib.datasets import MNIST
from catalyst.data import BatchPrefetchLoaderWrapper
class CustomRunner(dl.Runner):
def predict_batch(self, batch):
# model inference step
return self.model(batch[0].to(self.device).view(batch[0].size(0), -1))
def _handle_batch(self, batch):
# model train/valid step
x, y = batch
y_hat = self.model(x.view(x.size(0), -1))
loss = F.cross_entropy(y_hat, y)
accuracy01, accuracy03 = metrics.accuracy(y_hat, y, topk=(1, 3))
self.batch_metrics.update(
{"loss": loss, "accuracy01": accuracy01, "accuracy03": accuracy03}
)
if self.is_train_loader:
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
model = torch.nn.Linear(28 * 28, 10)
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)
batch_size=32
loaders = {
"train": DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=batch_size),
"valid": DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=batch_size),
}
loaders = {k: BatchPrefetchLoaderWrapper(v) for k, v in loaders.items()}
runner = CustomRunner()
# model training
runner.train(
model=model,
optimizer=optimizer,
loaders=loaders,
logdir="./logs",
num_epochs=5,
verbose=True,
load_best_on_end=True,
)
# model inference
for prediction in runner.predict_loader(loader=loaders["valid"]):
assert prediction.detach().cpu().numpy().shape[-1] == 10
# model tracing
traced_model = runner.trace(loader=loaders["valid"])
"""
def __init__(self, loader: DataLoader, num_prefetches: int = None):
super().__init__(loader)
self.num_prefetches = num_prefetches or loader.batch_size
def __iter__(self):
return _prefetch_loader(self.origin, self.num_prefetches)
__all__ = ["BatchLimitLoaderWrapper", "BatchPrefetchLoaderWrapper"]