This code is copied from official chainer examples
import chainer
# Dataset iterator to create a batch of sequences at different positions.
# This iterator returns a pair of current words and the next words. Each
# example is a part of sequences starting from the different offsets
# equally spaced within the whole sequence.
class ParallelSequentialIterator(chainer.dataset.Iterator):
def __init__(self, dataset, batch_size, repeat=True):
self.dataset = dataset
self.batch_size = batch_size # batch size
# Number of completed sweeps over the dataset. In this case, it is
# incremented if every word is visited at least once after the last
# increment.
self.epoch = 0
# True if the epoch is incremented at the last iteration.
self.is_new_epoch = False
self.repeat = repeat
length = len(dataset)
# Offsets maintain the position of each sequence in the mini-batch.
self.offsets = [i * length // batch_size for i in range(batch_size)]
# NOTE: this is not a count of parameter updates. It is just a count of
# calls of ``__next__``.
self.iteration = 0
def __next__(self):
# This iterator returns a list representing a mini-batch. Each item
# indicates a different position in the original sequence. Each item is
# represented by a pair of two word IDs. The first word is at the
# "current" position, while the second word at the next position.
# At each iteration, the iteration count is incremented, which pushes
# forward the "current" position.
length = len(self.dataset)
if not self.repeat and self.iteration * self.batch_size >= length:
# If not self.repeat, this iterator stops at the end of the first
# epoch (i.e., when all words are visited once).
raise StopIteration
cur_words = self.get_words()
self.iteration += 1
next_words = self.get_words()
epoch = self.iteration * self.batch_size // length
self.is_new_epoch = self.epoch < epoch
if self.is_new_epoch:
self.epoch = epoch
return list(zip(cur_words, next_words))
def epoch_detail(self):
# Floating point version of epoch.
return self.iteration * self.batch_size / len(self.dataset)
def get_words(self):
# It returns a list of current words.
return [self.dataset[(offset + self.iteration) % len(self.dataset)]
for offset in self.offsets]
def serialize(self, serializer):
# It is important to serialize the state to be recovered on resume.
self.iteration = serializer('iteration', self.iteration)
self.epoch = serializer('epoch', self.epoch)