/
serial_iterator.py
132 lines (105 loc) · 4.34 KB
/
serial_iterator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from __future__ import division
import numpy
from chainer.dataset import iterator
class SerialIterator(iterator.Iterator):
"""Dataset iterator that serially reads the examples.
This is a simple implementation of :class:`~chainer.dataset.Iterator`
that just visits each example in either the order of indexes or a shuffled
order.
To avoid unintentional performance degradation, the ``shuffle`` option is
set to ``True`` by default. For validation, it is better to set it to
``False`` when the underlying dataset supports fast slicing. If the
order of examples has an important meaning and the updater depends on the
original order, this option should be set to ``False``.
This iterator saves ``-1`` instead of ``None`` in snapshots since some
serializers do not support ``None``.
Args:
dataset: Dataset to iterate.
batch_size (int): Number of examples within each batch.
repeat (bool): If ``True``, it infinitely loops over the dataset.
Otherwise, it stops iteration at the end of the first epoch.
shuffle (bool): If ``True``, the order of examples is shuffled at the
beginning of each epoch. Otherwise, examples are extracted in the
order of indexes.
"""
def __init__(self, dataset, batch_size, repeat=True, shuffle=True):
self.dataset = dataset
self.batch_size = batch_size
self._repeat = repeat
self._shuffle = shuffle
self.reset()
def __next__(self):
if not self._repeat and self.epoch > 0:
raise StopIteration
self._previous_epoch_detail = self.epoch_detail
i = self.current_position
i_end = i + self.batch_size
N = len(self.dataset)
if self._order is None:
batch = self.dataset[i:i_end]
else:
batch = [self.dataset[index] for index in self._order[i:i_end]]
if i_end >= N:
if self._repeat:
rest = i_end - N
if self._order is not None:
numpy.random.shuffle(self._order)
if rest > 0:
if self._order is None:
batch.extend(self.dataset[:rest])
else:
batch.extend([self.dataset[index]
for index in self._order[:rest]])
self.current_position = rest
else:
self.current_position = 0
self.epoch += 1
self.is_new_epoch = True
else:
self.is_new_epoch = False
self.current_position = i_end
return batch
next = __next__
@property
def epoch_detail(self):
return self.epoch + self.current_position / len(self.dataset)
@property
def previous_epoch_detail(self):
if self._previous_epoch_detail < 0:
return None
return self._previous_epoch_detail
def serialize(self, serializer):
self.current_position = serializer('current_position',
self.current_position)
self.epoch = serializer('epoch', self.epoch)
self.is_new_epoch = serializer('is_new_epoch', self.is_new_epoch)
if self._order is not None:
try:
serializer('order', self._order)
except KeyError:
serializer('_order', self._order)
try:
self._previous_epoch_detail = serializer(
'previous_epoch_detail', self._previous_epoch_detail)
except KeyError:
# guess previous_epoch_detail for older version
self._previous_epoch_detail = self.epoch + \
(self.current_position - self.batch_size) / len(self.dataset)
if self.epoch_detail > 0:
self._previous_epoch_detail = max(
self._previous_epoch_detail, 0.)
else:
self._previous_epoch_detail = -1.
def reset(self):
if self._shuffle:
self._order = numpy.random.permutation(len(self.dataset))
else:
self._order = None
self.current_position = 0
self.epoch = 0
self.is_new_epoch = False
# use -1 instead of None internally.
self._previous_epoch_detail = -1.
@property
def repeat(self):
return self._repeat