Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
68 lines (58 sloc) 2.85 KB
"""
This code is copied from official chainer examples
- https://github.com/chainer/chainer/blob/e2fe6f8023e635f8c1fc9c89e85d075ebd50c529/examples/ptb/train_ptb.py
"""
import chainer
# Dataset iterator to create a batch of sequences at different positions.
# This iterator returns a pair of current words and the next words. Each
# example is a part of sequences starting from the different offsets
# equally spaced within the whole sequence.
class ParallelSequentialIterator(chainer.dataset.Iterator):
def __init__(self, dataset, batch_size, repeat=True):
self.dataset = dataset
self.batch_size = batch_size # batch size
# Number of completed sweeps over the dataset. In this case, it is
# incremented if every word is visited at least once after the last
# increment.
self.epoch = 0
# True if the epoch is incremented at the last iteration.
self.is_new_epoch = False
self.repeat = repeat
length = len(dataset)
# Offsets maintain the position of each sequence in the mini-batch.
self.offsets = [i * length // batch_size for i in range(batch_size)]
# NOTE: this is not a count of parameter updates. It is just a count of
# calls of ``__next__``.
self.iteration = 0
def __next__(self):
# This iterator returns a list representing a mini-batch. Each item
# indicates a different position in the original sequence. Each item is
# represented by a pair of two word IDs. The first word is at the
# "current" position, while the second word at the next position.
# At each iteration, the iteration count is incremented, which pushes
# forward the "current" position.
length = len(self.dataset)
if not self.repeat and self.iteration * self.batch_size >= length:
# If not self.repeat, this iterator stops at the end of the first
# epoch (i.e., when all words are visited once).
raise StopIteration
cur_words = self.get_words()
self.iteration += 1
next_words = self.get_words()
epoch = self.iteration * self.batch_size // length
self.is_new_epoch = self.epoch < epoch
if self.is_new_epoch:
self.epoch = epoch
return list(zip(cur_words, next_words))
@property
def epoch_detail(self):
# Floating point version of epoch.
return self.iteration * self.batch_size / len(self.dataset)
def get_words(self):
# It returns a list of current words.
return [self.dataset[(offset + self.iteration) % len(self.dataset)]
for offset in self.offsets]
def serialize(self, serializer):
# It is important to serialize the state to be recovered on resume.
self.iteration = serializer('iteration', self.iteration)
self.epoch = serializer('epoch', self.epoch)