Skip to content

Commit

Permalink
Merge e34a0ae into 126c432
Browse files Browse the repository at this point in the history
  • Loading branch information
delta2323 committed Aug 15, 2017
2 parents 126c432 + e34a0ae commit 8d5a2fb
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 87 deletions.
10 changes: 7 additions & 3 deletions chainer/testing/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from chainer import training


def get_trainer_with_mock_updater(stop_trigger=(10, 'iteration')):
def get_trainer_with_mock_updater(
stop_trigger=(10, 'iteration'), iter_per_epoch=10):
"""Returns a :class:`~chainer.training.Trainer` object with mock updater.
The returned trainer can be used for testing the trainer itself and the
Expand All @@ -26,14 +27,17 @@ def get_trainer_with_mock_updater(stop_trigger=(10, 'iteration')):
updater.epoch = 0
updater.epoch_detail = 0
updater.is_new_epoch = True
iter_per_epoch = 10
updater.previous_epoch_detail = None

def update():
updater.update_core()
updater.iteration += 1
updater.epoch = updater.iteration // iter_per_epoch
updater.epoch_detail = updater.iteration / iter_per_epoch
updater.is_new_epoch = updater.epoch == updater.epoch_detail
updater.is_new_epoch = (updater.iteration - 1) // \
iter_per_epoch != updater.epoch
updater.previous_epoch_detail = (updater.iteration - 1) \
/ iter_per_epoch

updater.update = update
trainer = training.Trainer(updater, stop_trigger)
Expand Down
67 changes: 63 additions & 4 deletions chainer/training/triggers/interval.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import warnings


class IntervalTrigger(object):

"""Trigger based on a fixed interval.
Expand All @@ -23,6 +26,11 @@ def __init__(self, period, unit):
self.period = period
assert unit == 'epoch' or unit == 'iteration'
self.unit = unit

self._previous_iteration = 0
self._previous_epoch_detail = 0.

# count is kept for backward compatibility
self.count = 0

def __call__(self, trainer):
Expand All @@ -40,9 +48,60 @@ def __call__(self, trainer):
"""
updater = trainer.updater
if self.unit == 'epoch':
prev = self.count
self.count = updater.epoch_detail // self.period
return prev != self.count
epoch_detail = updater.epoch_detail
previous_epoch_detail = self._previous_epoch_detail

# if pvevious_epoch_detail is invalid value,
# use the value of updater.
if previous_epoch_detail < 0:
previous_epoch_detail = updater.previous_epoch_detail

# count is kept for backward compatibility
self.count = epoch_detail // self.period

fire = previous_epoch_detail // self.period != \
epoch_detail // self.period
else:
iteration = updater.iteration
return iteration > 0 and iteration % self.period == 0
previous_iteration = self._previous_iteration

# if pvevious_iteration is invalid value,
# guess it from current iteration.
if previous_iteration < 0:
previous_iteration = iteration - 1

fire = previous_iteration // self.period != \
iteration // self.period

# save current values
self._previous_iteration = updater.iteration
if hasattr(updater, 'epoch_detail'):
self._previous_epoch_detail = updater.epoch_detail

return fire

def serialize(self, serializer):
try:
self._previous_iteration = serializer(
'previous_iteration', self._previous_iteration)
except KeyError:
warnings.warn(
'The previous value of iteration is not saved. '
'IntervalTrigger guesses it using current iteration. '
'If this trigger is not called at every iteration, '
'it may not work correctly.')
# set a negative value for invalid
self._previous_iteration = -1

try:
self._previous_epoch_detail = serializer(
'previous_epoch_detail', self._previous_epoch_detail)
except KeyError:
warnings.warn(
'The previous value of epoch_detail is not saved. '
'IntervalTrigger uses the value of '
'trainer.updater.previous_epoch_detail. '
'If this trigger is not called at every iteration, '
'it may not work correctly.')
# set a negative value for invalid
self._previous_epoch_detail = -1.
21 changes: 21 additions & 0 deletions examples/ptb/train_ptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def __init__(self, dataset, batch_size, repeat=True):
# NOTE: this is not a count of parameter updates. It is just a count of
# calls of ``__next__``.
self.iteration = 0
# use -1 instead of None internally
self._previous_epoch_detail = -1.

def __next__(self):
# This iterator returns a list representing a mini-batch. Each item
Expand All @@ -80,6 +82,7 @@ def __next__(self):
# epoch (i.e., when all words are visited once).
raise StopIteration
cur_words = self.get_words()
self._previous_epoch_detail = self.epoch_detail
self.iteration += 1
next_words = self.get_words()

Expand All @@ -95,6 +98,12 @@ def epoch_detail(self):
# Floating point version of epoch.
return self.iteration * self.batch_size / len(self.dataset)

@property
def previous_epoch_detail(self):
if self._previous_epoch_detail < 0:
return None
return self._previous_epoch_detail

def get_words(self):
# It returns a list of current words.
return [self.dataset[(offset + self.iteration) % len(self.dataset)]
Expand All @@ -104,6 +113,18 @@ def serialize(self, serializer):
# It is important to serialize the state to be recovered on resume.
self.iteration = serializer('iteration', self.iteration)
self.epoch = serializer('epoch', self.epoch)
try:
self._previous_epoch_detail = serializer(
'previous_epoch_detail', self._previous_epoch_detail)
except KeyError:
# guess previous_epoch_detail for older version
self._previous_epoch_detail = self.epoch + \
(self.current_position - self.batch_size) / len(self.dataset)
if self.epoch_detail > 0:
self._previous_epoch_detail = max(
self._previous_epoch_detail, 0.)
else:
self._previous_epoch_detail = -1.


# Custom updater for truncated BackProp Through Time (BPTT)
Expand Down
49 changes: 38 additions & 11 deletions tests/chainer_tests/testing_tests/test_training.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,50 @@
from __future__ import division

import math
import unittest

from chainer import testing


@testing.parameterize(*testing.product({
'stop_trigger': [(5, 'iteration'), (5, 'epoch')],
'iter_per_epoch': [0.5, 1, 1.5, 5],
}))
class TestGetTrainerWithMockUpdater(unittest.TestCase):

def setUp(self):
self.trainer = testing.get_trainer_with_mock_updater((5, 'iteration'))

def test_update_count(self):
count = [0]

def check_count(trainer):
count[0] += 1
self.assertEqual(trainer.updater.iteration, count[0])

self.trainer.extend(check_count)
self.trainer = testing.get_trainer_with_mock_updater(
self.stop_trigger, self.iter_per_epoch)

def test_run(self):
iteration = [0]

def check(trainer):
iteration[0] += 1

self.assertEqual(trainer.updater.iteration, iteration[0])
self.assertEqual(
trainer.updater.epoch, iteration[0] // self.iter_per_epoch)
self.assertEqual(
trainer.updater.epoch_detail,
iteration[0] / self.iter_per_epoch)
self.assertEqual(
trainer.updater.is_new_epoch,
(iteration[0] - 1) // self.iter_per_epoch !=
iteration[0] // self.iter_per_epoch)
self.assertEqual(
trainer.updater.previous_epoch_detail,
(iteration[0] - 1) / self.iter_per_epoch)

self.trainer.extend(check)
self.trainer.run()
self.assertEqual(count[0], 5)

if self.stop_trigger[1] == 'iteration':
self.assertEqual(iteration[0], self.stop_trigger[0])
elif self.stop_trigger[1] == 'epoch':
self.assertEqual(
iteration[0],
math.ceil(self.stop_trigger[0] * self.iter_per_epoch))


testing.run_module(__name__, __file__)

0 comments on commit 8d5a2fb

Please sign in to comment.