Skip to content

Commit

Permalink
Merge pull request #7533 from toslunar/ptb-example-iter-reset
Browse files Browse the repository at this point in the history
Implement `reset` method in the PTB example
  • Loading branch information
niboshi committed Jun 17, 2019
2 parents 4b0daf5 + e9a5c11 commit d8c34ad
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions examples/ptb/train_ptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,22 @@ def forward(self, x):
class ParallelSequentialIterator(chainer.dataset.Iterator):

def __init__(self, dataset, batch_size, repeat=True):
super(ParallelSequentialIterator, self).__init__()
self.dataset = dataset
self.batch_size = batch_size # batch size
self.repeat = repeat
length = len(dataset)
# Offsets maintain the position of each sequence in the mini-batch.
self.offsets = [i * length // batch_size for i in range(batch_size)]
self.reset()

def reset(self):
# Number of completed sweeps over the dataset. In this case, it is
# incremented if every word is visited at least once after the last
# increment.
self.epoch = 0
# True if the epoch is incremented at the last iteration.
self.is_new_epoch = False
self.repeat = repeat
length = len(dataset)
# Offsets maintain the position of each sequence in the mini-batch.
self.offsets = [i * length // batch_size for i in range(batch_size)]
# NOTE: this is not a count of parameter updates. It is just a count of
# calls of ``__next__``.
self.iteration = 0
Expand Down

0 comments on commit d8c34ad

Please sign in to comment.