Skip to content

Commit

Permalink
Merge pull request #3429 from yuyu2172/iterator-order-sampler
Browse files Browse the repository at this point in the history
Add order_sampler option to Iterators
  • Loading branch information
hvy committed May 17, 2018
2 parents 3538f63 + 6cabc9d commit 9577a88
Show file tree
Hide file tree
Showing 9 changed files with 560 additions and 58 deletions.
3 changes: 3 additions & 0 deletions chainer/iterators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
from chainer.iterators.multiprocess_iterator import MultiprocessIterator # NOQA
from chainer.iterators.multithread_iterator import MultithreadIterator # NOQA
from chainer.iterators.serial_iterator import SerialIterator # NOQA

from chainer.iterators.order_samplers import OrderSampler # NOQA
from chainer.iterators.order_samplers import ShuffleOrderSampler # NOQA
77 changes: 54 additions & 23 deletions chainer/iterators/multiprocess_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import six

from chainer.dataset import iterator
from chainer.iterators.order_samplers import ShuffleOrderSampler


_response_time = 1.
Expand Down Expand Up @@ -42,12 +43,20 @@ class MultiprocessIterator(iterator.Iterator):
Otherwise, it stops iteration at the end of the first epoch.
shuffle (bool): If ``True``, the order of examples is shuffled at the
beginning of each epoch. Otherwise, examples are extracted in the
order of indexes.
order of indexes. If ``None`` and no ``order_sampler`` is given,
the behavior is the same as the case with ``shuffle=True``.
n_processes (int): Number of worker processes. The number of CPUs is
used by default.
n_prefetch (int): Number of prefetch batches.
shared_mem (int): The size of using shared memory per data.
If ``None``, size is adjusted automatically.
order_sampler (callable): A callable that generates the order
of the indices to sample in the next epoch when a epoch finishes.
This function should take two arguements: the current order
and the current position of the iterator.
This should return the next order. The size of the order
should remain constant.
This option cannot be used when ``shuffle`` is not ``None``.
"""

Expand All @@ -56,8 +65,9 @@ class MultiprocessIterator(iterator.Iterator):
_comm = None
_thread = None

def __init__(self, dataset, batch_size, repeat=True, shuffle=True,
n_processes=None, n_prefetch=1, shared_mem=None):
def __init__(self, dataset, batch_size, repeat=True, shuffle=None,
n_processes=None, n_prefetch=1, shared_mem=None,
order_sampler=None):
self.dataset = dataset
self.batch_size = batch_size
self.repeat = repeat
Expand All @@ -67,12 +77,27 @@ def __init__(self, dataset, batch_size, repeat=True, shuffle=True,
self.n_prefetch = max(n_prefetch, 1)
self.shared_mem = shared_mem

if self.shuffle is not None:
if order_sampler is not None:
raise ValueError('`shuffle` is not `None` and a custom '
'`order_sampler` is set. Please set '
'`shuffle` to `None` to use the custom '
'order sampler.')
else:
if self.shuffle:
order_sampler = ShuffleOrderSampler()
else:
if order_sampler is None:
order_sampler = ShuffleOrderSampler()
self.order_sampler = order_sampler

self._comm = _Communicator(self.n_prefetch)
self.reset()

self._prefetch_loop = _PrefetchLoop(
self.dataset, self.batch_size, self.repeat, self.shuffle,
self.n_processes, self.n_prefetch, self.shared_mem, self._comm,
self.dataset, self.batch_size, self.repeat,
self.n_processes, self.n_prefetch, self.shared_mem,
self._comm, self.order_sampler,
self._interruption_testing)
# defer launching prefetch thread until creating the worker pool,
# not to leave a background thread in forked processes.
Expand Down Expand Up @@ -124,8 +149,9 @@ def __exit__(self, exc_type, exc_value, traceback):

def __copy__(self):
other = MultiprocessIterator(
self.dataset, self.batch_size, self.repeat, self.shuffle,
self.n_processes, self.n_prefetch, self.shared_mem)
self.dataset, self.batch_size, self.repeat, shuffle=None,
n_processes=self.n_processes, n_prefetch=self.n_prefetch,
shared_mem=self.shared_mem, order_sampler=self.order_sampler)

other.current_position = self.current_position
other.epoch = self.epoch
Expand All @@ -138,7 +164,7 @@ def __copy__(self):

@property
def epoch_detail(self):
return self.epoch + self.current_position / len(self.dataset)
return self.epoch + self.current_position / self._epoch_size

@property
def previous_epoch_detail(self):
Expand All @@ -161,7 +187,7 @@ def serialize(self, serializer):
except KeyError:
# guess previous_epoch_detail for older version
self._previous_epoch_detail = self.epoch + \
(self.current_position - self.batch_size) / len(self.dataset)
(self.current_position - self.batch_size) / self._epoch_size
if self.epoch_detail > 0:
self._previous_epoch_detail = max(
self._previous_epoch_detail, 0.)
Expand All @@ -179,13 +205,21 @@ def reset(self):
self.is_new_epoch = False
# use -1 instead of None internally.
self._previous_epoch_detail = -1.
if self.shuffle:
self._order = numpy.random.permutation(len(self.dataset))
if self.order_sampler:
self._order = self.order_sampler(
numpy.arange(len(self.dataset)), 0)
else:
self._order = None

self._set_prefetch_state()

@property
def _epoch_size(self):
if self._order is None:
return len(self.dataset)
else:
return len(self._order)

def _set_prefetch_state(self):
prefetch_state = _PrefetchState(
current_position=self.current_position,
Expand Down Expand Up @@ -265,28 +299,21 @@ def put(self, batch, prefetch_state, reset_count):

class _PrefetchLoop(object):

def __init__(self, dataset, batch_size, repeat, shuffle,
def __init__(self, dataset, batch_size, repeat,
n_processes, n_prefetch, mem_size, comm,
order_sampler,
_interruption_testing):
self.dataset = dataset
self.batch_size = batch_size
self.repeat = repeat
self.shuffle = shuffle
self.n_processes = n_processes
self.mem_size = mem_size
self.comm = comm
self.order_sampler = order_sampler

self._allocate_shared_memory()
self._pool = None

# Use a distinct RandomState in the thread
# for deterministic random number generation.
# To support 32-bit platform and numpy < 1.11,
# the seed is taken in a verbose manner.
seed = numpy.asscalar(
numpy.random.randint(-(1 << 31), 1 << 31, 1).astype('uint32'))
self._random = numpy.random.RandomState(seed)

self._interruption_testing = _interruption_testing

def measure_required(self):
Expand Down Expand Up @@ -365,9 +392,9 @@ def _task(self):
return True

def _proceed(self):
n = len(self.dataset)
(pos, epoch, is_new_epoch,
previous_epoch_detail, order) = self.prefetch_state
n = len(order) if order is not None else len(self.dataset)

if pos < self.batch_size and epoch > 0 and not self.repeat:
return None # stop iteration
Expand All @@ -392,7 +419,11 @@ def _proceed(self):
else:
indices = order[pos:n]
if self.repeat:
order = self._random.permutation(n)
new_order = self.order_sampler(order, pos)
if len(new_order) != n:
raise ValueError('The size of order does not match '
'the size of the previous order.')
order = new_order
indices = \
numpy.concatenate((indices, order[:new_pos]))
epoch += 1
Expand Down
53 changes: 44 additions & 9 deletions chainer/iterators/multithread_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import six

from chainer.dataset import iterator
from chainer.iterators.order_samplers import ShuffleOrderSampler


class MultithreadIterator(iterator.Iterator):
Expand All @@ -28,13 +29,21 @@ class MultithreadIterator(iterator.Iterator):
Otherwise, it stops iteration at the end of the first epoch.
shuffle (bool): If ``True``, the order of examples is shuffled at the
beginning of each epoch. Otherwise, examples are extracted in the
order of indexes.
order of indexes. If ``None`` and no ``order_sampler`` is given,
the behavior is the same as the case with ``shuffle=True``.
n_threads (int): Number of worker threads.
order_sampler (callable): A callable that generates the order
of the indices to sample in the next epoch when a epoch finishes.
This function should take two arguements: the current order
and the current position of the iterator.
This should return the next order. The size of the order
should remain constant.
This option cannot be used when ``shuffle`` is not ``None``.
"""

def __init__(self, dataset, batch_size, repeat=True, shuffle=True,
n_threads=1):
def __init__(self, dataset, batch_size, repeat=True, shuffle=None,
n_threads=1, order_sampler=None):
self.dataset = dataset
self.batch_size = batch_size
self._repeat = repeat
Expand All @@ -43,6 +52,20 @@ def __init__(self, dataset, batch_size, repeat=True, shuffle=True,
self.current_position = 0
self.epoch = 0

if self._shuffle is not None:
if order_sampler is not None:
raise ValueError('`shuffle` is not `None` and a custom '
'`order_sampler` is set. Please set '
'`shuffle` to `None` to use the custom '
'order sampler.')
else:
if self._shuffle:
order_sampler = ShuffleOrderSampler()
else:
if order_sampler is None:
order_sampler = ShuffleOrderSampler()
self.order_sampler = order_sampler

self.n_threads = n_threads
self._pool = None

Expand All @@ -52,8 +75,9 @@ def reset(self):
self.current_position = 0
self.epoch = 0
self.is_new_epoch = False
if self._shuffle:
self._order = numpy.random.permutation(len(self.dataset))
if self.order_sampler:
self._order = self.order_sampler(
numpy.arange(len(self.dataset)), 0)
else:
self._order = None

Expand Down Expand Up @@ -93,7 +117,7 @@ def __next__(self):

@property
def epoch_detail(self):
return self.epoch + self.current_position / len(self.dataset)
return self.epoch + self.current_position / self._epoch_size

@property
def previous_epoch_detail(self):
Expand All @@ -120,7 +144,7 @@ def _invoke_prefetch(self):
return
if self._pool is None:
self._pool = pool.ThreadPool(self.n_threads)
n = len(self.dataset)
n = self._epoch_size
i = self.current_position

order = self._order
Expand All @@ -143,8 +167,11 @@ def _invoke_prefetch(self):
# iterator may be serialized before the prefetched data are
# consumed by the user, in which case an inconsistency
# appears.
order = order.copy()
numpy.random.shuffle(order)
new_order = self.order_sampler(order, i)
if len(new_order) != len(order):
raise ValueError('The size of order does not match '
'the size of the previous order.')
order = new_order

self._next = self._pool.map_async(MultithreadIterator._read, args)
self._next_state = (i, epoch, is_new_epoch, order)
Expand All @@ -160,3 +187,11 @@ def _get(self):
(self.current_position, self.epoch,
self.is_new_epoch, self._order) = self._next_state
return batch

@property
def _epoch_size(self):
if self._order is None:
epoch_size = len(self.dataset)
else:
epoch_size = len(self._order)
return epoch_size
59 changes: 59 additions & 0 deletions chainer/iterators/order_samplers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import numpy


class OrderSampler(object):

"""Base class of all order samplers.
Every order sampler subclass has to provide a method
:meth:`__call__`.
This method is called by an iterator before a new epoch,
and it should return a new index order for the next epoch.
"""

def __call__(self, current_order, current_position):
"""Sample the next order.
Args:
current_order (numpy.ndarray): 1-D array of indices.
The length should be the same as the dataset to sample
data from.
current_position (int): The current position of an iterator.
Returns:
numpy.ndarray:
1-D array of indices. This is the order in which
examples are sampled from a dataset in the next epoch.
"""
raise NotImplementedError


class ShuffleOrderSampler(OrderSampler):

"""Sampler that generates random orders.
This is expected to be used together with Chainer's iterators.
An order sampler is called by an iterator every epoch.
The two initializations below create basically the same objects.
>>> dataset = [(1, 2), (3, 4)]
>>> it = chainer.iterators.MultiprocessIterator(dataset, 1, shuffle=True)
>>> it = chainer.iterators.MultiprocessIterator(
... dataset, 1, order_sampler=chainer.iterators.ShuffleOrderSampler())
Args:
random_state (numpy.random.RandomState): Pseudo-random number
generator.
"""

def __init__(self, random_state=None):
if random_state is None:
random_state = numpy.random.random.__self__
self._random = random_state

def __call__(self, current_order, current_position):
return self._random.permutation(len(current_order))
Loading

0 comments on commit 9577a88

Please sign in to comment.