Skip to content

Commit

Permalink
added checkpoint_dir for TrainLoop, but not tested
Browse files Browse the repository at this point in the history
  • Loading branch information
haowen-xu committed Feb 5, 2019
1 parent 873a304 commit 2fb0faa
Show file tree
Hide file tree
Showing 20 changed files with 568 additions and 312 deletions.
14 changes: 3 additions & 11 deletions tests/scaffold/test_train_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,12 @@ def test_counter_attributes(self):
self.assertEqual(loop.step, 0)
self.assertIsNone(loop.max_epoch)
self.assertIsNone(loop.max_step)
self.assertEqual(loop.summary_metric_prefix, 'metrics/')

with TrainLoop([], initial_epoch=1, initial_step=3,
max_epoch=2, max_step=10, summary_metric_prefix='123/'
) as loop:
self.assertEqual(loop.epoch, 1)
self.assertEqual(loop.step, 3)
with TrainLoop([], max_epoch=2, max_step=10,
summary_metric_prefix='123/') as loop:
self.assertEqual(loop.max_epoch, 2)
self.assertEqual(loop.max_step, 10)
self.assertEqual(loop.summary_metric_prefix, '123/')
self.assertEqual(loop._summary_metric_prefix, '123/')
loop.max_epoch = 20
loop.max_step = 100
self.assertEqual(loop.max_epoch, 20)
Expand Down Expand Up @@ -510,13 +506,9 @@ def test_tensor_arguments(self):
with TrainLoop([a],
early_stopping=True,
initial_valid_metric=tf.constant(1.23),
initial_epoch=tf.constant(4),
initial_step=tf.constant(5),
max_epoch=tf.constant(6),
max_step=tf.constant(7)) as loop:
self.assertAlmostEqual(loop._early_stopping._best_metric, 1.23)
self.assertEqual(loop.epoch, 4)
self.assertEqual(loop.step, 5)
self.assertEqual(loop.max_epoch, 6)
self.assertEqual(loop.max_step, 7)

Expand Down
102 changes: 54 additions & 48 deletions tests/trainer/test_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,80 +35,82 @@ def test_add_and_remove_hooks(self):
t.log_after_steps(3)
t.log_after_epochs(4)
t.evaluate_after_steps(
Mock(return_value=None, __repr__=lambda o: 'eval_step'), 5)
Mock(return_value=None, __repr__=lambda o: 'eval'), 5)
t.evaluate_after_epochs(
Mock(return_value=None, __repr__=lambda o: 'eval_epoch'), 6)
Mock(return_value=None, __repr__=lambda o: 'eval'), 6)
t.anneal_after_steps(
Mock(return_value=None, __repr__=lambda o: 'anneal_step'), 7)
Mock(return_value=None, __repr__=lambda o: 'anneal'), 7)
t.anneal_after_epochs(
Mock(return_value=None, __repr__=lambda o: 'anneal_epoch'), 8)
Mock(return_value=None, __repr__=lambda o: 'anneal'), 8)
t.evaluate_after_steps(eval1, 9)
t.evaluate_after_epochs(eval2, 10)
t.anneal_after_steps(anneal1, 11)
t.anneal_after_epochs(anneal2, 12)
t.log_after(steps=13)
t.log_after(epochs=14)
t.evaluate_after(
Mock(return_value=None, __repr__=lambda o: 'eval_step2'),
Mock(return_value=None, __repr__=lambda o: 'eval2'),
steps=15
)
t.evaluate_after(
Mock(return_value=None, __repr__=lambda o: 'eval_epoch2'),
Mock(return_value=None, __repr__=lambda o: 'eval2'),
epochs=16
)
t.anneal_after(
Mock(return_value=None, __repr__=lambda o: 'anneal_step2'),
Mock(return_value=None, __repr__=lambda o: 'anneal2'),
steps=17
)
t.anneal_after(
Mock(return_value=None, __repr__=lambda o: 'anneal_epoch2'),
Mock(return_value=None, __repr__=lambda o: 'anneal2'),
epochs=18
)

self.assertEqual(
repr(t.events._event_handlers_map[EventKeys.AFTER_STEP_EVAL]),
'[eval_step:5, {!r}:9, eval_step2:15]'.format(eval1.run)
repr(t.events._event_handlers_map[EventKeys.STEP_EVALUATION]),
'[eval:step:5, {!r}:step:9, eval2:step:15]'.format(eval1.run)
)
self.assertEqual(
repr(t.events._event_handlers_map[EventKeys.AFTER_STEP_ANNEAL]),
'[anneal_step:7, {!r}:11, anneal_step2:17]'.format(anneal1.anneal)
repr(t.events._event_handlers_map[EventKeys.STEP_ANNEALING]),
'[anneal:step:7, {!r}:step:11, anneal2:step:17]'.
format(anneal1.anneal)
)
self.assertEqual(
repr(t.events._event_handlers_map[EventKeys.AFTER_STEP_LOG]),
'[print_logs:3, print_logs:13]'
repr(t.events._event_handlers_map[EventKeys.STEP_LOGGING]),
'[print_logs:step:3, print_logs:step:13]'
)

self.assertEqual(
repr(t.events._event_handlers_map[EventKeys.AFTER_EPOCH_EVAL]),
'[eval_epoch:6, {!r}:10, eval_epoch2:16]'.format(eval2.run)
repr(t.events._event_handlers_map[EventKeys.EPOCH_EVALUATION]),
'[eval:epoch:6, {!r}:epoch:10, eval2:epoch:16]'.format(eval2.run)
)
self.assertEqual(
repr(t.events._event_handlers_map[EventKeys.AFTER_EPOCH_ANNEAL]),
'[anneal_epoch:8, {!r}:12, anneal_epoch2:18]'.format(anneal2.anneal)
repr(t.events._event_handlers_map[EventKeys.EPOCH_ANNEALING]),
'[anneal:epoch:8, {!r}:epoch:12, anneal2:epoch:18]'.
format(anneal2.anneal)
)
self.assertEqual(
repr(t.events._event_handlers_map[EventKeys.AFTER_EPOCH_LOG]),
'[print_logs:4, print_logs:14]'
repr(t.events._event_handlers_map[EventKeys.EPOCH_LOGGING]),
'[print_logs:epoch:4, print_logs:epoch:14]'
)

# test remove
t.remove_log_hooks()
self.assertNotIn(
EventKeys.AFTER_STEP_LOG, t.events._event_handlers_map)
EventKeys.STEP_LOGGING, t.events._event_handlers_map)
self.assertNotIn(
EventKeys.AFTER_EPOCH_LOG, t.events._event_handlers_map)
EventKeys.EPOCH_LOGGING, t.events._event_handlers_map)

t.remove_validation_hooks()
self.assertNotIn(
EventKeys.AFTER_STEP_EVAL, t.events._event_handlers_map)
EventKeys.STEP_EVALUATION, t.events._event_handlers_map)
self.assertNotIn(
EventKeys.AFTER_EPOCH_EVAL, t.events._event_handlers_map)
EventKeys.EPOCH_EVALUATION, t.events._event_handlers_map)

t.remove_annealing_hooks()
self.assertNotIn(
EventKeys.AFTER_STEP_ANNEAL, t.events._event_handlers_map)
EventKeys.STEP_ANNEALING, t.events._event_handlers_map)
self.assertNotIn(
EventKeys.AFTER_EPOCH_ANNEAL, t.events._event_handlers_map)
EventKeys.EPOCH_ANNEALING, t.events._event_handlers_map)

# test error add
func_list = [
Expand Down Expand Up @@ -137,18 +139,21 @@ def test_hook_freq(self):
t.evaluate_after(f, steps=5)

for i in range(1, 6):
t.events.fire(EventKeys.AFTER_STEP_EVAL, i)
t.events.fire(EventKeys.AFTER_STEP_EVAL, 7)
t.events.fire(EventKeys.AFTER_STEP_EVAL, 10)
t.loop.step = i
t.events.fire(EventKeys.STEP_EVALUATION, t)
t.loop.step = 7
t.events.fire(EventKeys.STEP_EVALUATION, t)
t.loop.step = 10
t.events.fire(EventKeys.STEP_EVALUATION, t)

self.assertEqual(f.call_count, 2)

def test_run(self):
with self.test_session() as session:
df = DataFlow.arrays([np.arange(6, dtype=np.float32)], batch_size=4)

def log_event(m, *args):
logged_events.append((m,) + args)
def log_event(m, trainer):
logged_events.append((m, trainer))
logged_events = []

# test default loss weight and merged feed dict
Expand All @@ -158,13 +163,13 @@ def log_event(m, *args):
t._iter_steps = Mock(wraps=lambda: loop.iter_steps(df))
for key in [EventKeys.BEFORE_EPOCH,
EventKeys.BEFORE_STEP,
EventKeys.AFTER_STEP_ANNEAL,
EventKeys.AFTER_STEP_EVAL,
EventKeys.AFTER_STEP_LOG,
EventKeys.STEP_ANNEALING,
EventKeys.STEP_EVALUATION,
EventKeys.STEP_LOGGING,
EventKeys.AFTER_STEP,
EventKeys.AFTER_EPOCH_ANNEAL,
EventKeys.AFTER_EPOCH_EVAL,
EventKeys.AFTER_EPOCH_LOG,
EventKeys.EPOCH_ANNEALING,
EventKeys.EPOCH_EVALUATION,
EventKeys.EPOCH_LOGGING,
EventKeys.AFTER_EPOCH]:
t.events.on(key, functools.partial(log_event, key))

Expand All @@ -184,21 +189,21 @@ def log_event(m, *args):
expected_logged_events = sum(
[
[
(EventKeys.BEFORE_EPOCH, epoch + 1),
(EventKeys.BEFORE_EPOCH, t),
] + sum([
[
(EventKeys.BEFORE_STEP, epoch * 2 + step + 1),
(EventKeys.AFTER_STEP_EVAL, epoch * 2 + step + 1),
(EventKeys.AFTER_STEP_ANNEAL, epoch * 2 + step + 1),
(EventKeys.AFTER_STEP_LOG, epoch * 2 + step + 1),
(EventKeys.AFTER_STEP, epoch * 2 + step + 1),
(EventKeys.BEFORE_STEP, t),
(EventKeys.STEP_EVALUATION, t),
(EventKeys.STEP_ANNEALING, t),
(EventKeys.STEP_LOGGING, t),
(EventKeys.AFTER_STEP, t),
]
for step in [0, 1]
], []) + [
(EventKeys.AFTER_EPOCH_EVAL, epoch + 1),
(EventKeys.AFTER_EPOCH_ANNEAL, epoch + 1),
(EventKeys.AFTER_EPOCH_LOG, epoch + 1),
(EventKeys.AFTER_EPOCH, epoch + 1)
(EventKeys.EPOCH_EVALUATION, t),
(EventKeys.EPOCH_ANNEALING, t),
(EventKeys.EPOCH_LOGGING, t),
(EventKeys.AFTER_EPOCH, t)
]
for epoch in [0, 1]
],
Expand All @@ -212,7 +217,8 @@ def log_event(m, *args):
t._run_step = Mock(return_value=None)
t._iter_steps = Mock(wraps=lambda: loop.iter_steps(df))

def reentrant_error(step):
def reentrant_error(trainer):
self.assertIs(trainer, t)
with pytest.raises(
RuntimeError, match=r'`run\(\)` is not re-entrant'):
t.run()
Expand Down
18 changes: 18 additions & 0 deletions tests/utils/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,24 @@ def test_events(self):
events.fire('ev1', 123, value=456)
self.assertEqual(f1.call_args, ((123,), {'value': 456}))

def test_order(self):
dest = []

def f(x):
dest.append(x)

events = EventSource()
events.on('ev', lambda: f(1))
events.on('ev', lambda: f(2))
events.on('ev', lambda: f(3))

events.fire('ev')
self.assertListEqual(dest, [1, 2, 3])

del dest[:]
events.reverse_fire('ev')
self.assertListEqual(dest, [3, 2, 1])

def test_clear(self):
f1 = Mock()
f2 = Mock()
Expand Down
4 changes: 2 additions & 2 deletions tfsnippet/examples/auto_encoders/bernoulli_latent_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def plot_samples(loop):
time_metric_name='test_time'
)
evaluator.events.on(
spt.EventKeys.AFTER_EVALUATION,
lambda: results.update_metrics(evaluator.last_metrics_dict)
spt.EventKeys.AFTER_EXECUTION,
lambda e: results.update_metrics(evaluator.last_metrics_dict)
)
trainer.evaluate_after_epochs(evaluator, freq=10)
trainer.evaluate_after_epochs(
Expand Down
4 changes: 2 additions & 2 deletions tfsnippet/examples/auto_encoders/conv_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ def plot_samples(loop):
time_metric_name='test_time'
)
evaluator.events.on(
spt.EventKeys.AFTER_EVALUATION,
lambda: results.update_metrics(evaluator.last_metrics_dict)
spt.EventKeys.AFTER_EXECUTION,
lambda e: results.update_metrics(evaluator.last_metrics_dict)
)
trainer.evaluate_after_epochs(evaluator, freq=10)
trainer.evaluate_after_epochs(
Expand Down
4 changes: 2 additions & 2 deletions tfsnippet/examples/auto_encoders/dense_real_nvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def plot_samples(loop):
time_metric_name='test_time'
)
evaluator.events.on(
spt.EventKeys.AFTER_EVALUATION,
lambda: results.update_metrics(evaluator.last_metrics_dict)
spt.EventKeys.AFTER_EXECUTION,
lambda e: results.update_metrics(evaluator.last_metrics_dict)
)
trainer.evaluate_after_epochs(evaluator, freq=10)
trainer.evaluate_after_epochs(
Expand Down
4 changes: 2 additions & 2 deletions tfsnippet/examples/auto_encoders/gm_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ def evaluate_classifier(loop):
time_metric_name='test_time'
)
evaluator.events.on(
spt.EventKeys.AFTER_EVALUATION,
lambda: results.update_metrics(evaluator.last_metrics_dict)
spt.EventKeys.AFTER_EXECUTION,
lambda e: results.update_metrics(evaluator.last_metrics_dict)
)
trainer.evaluate_after_epochs(evaluator, freq=10)
trainer.evaluate_after_epochs(
Expand Down
4 changes: 2 additions & 2 deletions tfsnippet/examples/auto_encoders/mixture_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ def plot_samples(loop):
time_metric_name='test_time'
)
evaluator.events.on(
spt.EventKeys.AFTER_EVALUATION,
lambda: results.update_metrics(evaluator.last_metrics_dict)
spt.EventKeys.AFTER_EXECUTION,
lambda e: results.update_metrics(evaluator.last_metrics_dict)
)
trainer.evaluate_after_epochs(evaluator, freq=10)
trainer.evaluate_after_epochs(
Expand Down
4 changes: 2 additions & 2 deletions tfsnippet/examples/auto_encoders/planar_nf.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ def plot_samples(loop):
time_metric_name='test_time'
)
evaluator.events.on(
spt.EventKeys.AFTER_EVALUATION,
lambda: results.update_metrics(evaluator.last_metrics_dict)
spt.EventKeys.AFTER_EXECUTION,
lambda e: results.update_metrics(evaluator.last_metrics_dict)
)
trainer.evaluate_after_epochs(evaluator, freq=10)
trainer.evaluate_after_epochs(
Expand Down
4 changes: 2 additions & 2 deletions tfsnippet/examples/auto_encoders/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ def plot_samples(loop):
time_metric_name='test_time'
)
evaluator.events.on(
spt.EventKeys.AFTER_EVALUATION,
lambda: results.update_metrics(evaluator.last_metrics_dict)
spt.EventKeys.AFTER_EXECUTION,
lambda e: results.update_metrics(evaluator.last_metrics_dict)
)
trainer.evaluate_after_epochs(evaluator, freq=10)
trainer.evaluate_after_epochs(
Expand Down
4 changes: 2 additions & 2 deletions tfsnippet/examples/classification/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def main():
time_metric_name='test_time'
)
evaluator.events.on(
spt.EventKeys.AFTER_EVALUATION,
lambda: results.update_metrics(evaluator.last_metrics_dict)
spt.EventKeys.AFTER_EXECUTION,
lambda e: results.update_metrics(evaluator.last_metrics_dict)
)
trainer.evaluate_after_epochs(evaluator, freq=5)
trainer.log_after_epochs(freq=1)
Expand Down
4 changes: 2 additions & 2 deletions tfsnippet/examples/classification/cifar10_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ def main():
time_metric_name='test_time'
)
evaluator.events.on(
spt.EventKeys.AFTER_EVALUATION,
lambda: results.update_metrics(evaluator.last_metrics_dict)
spt.EventKeys.AFTER_EXECUTION,
lambda e: results.update_metrics(evaluator.last_metrics_dict)
)
trainer.evaluate_after_epochs(evaluator, freq=5)
trainer.log_after_epochs(freq=1)
Expand Down
4 changes: 2 additions & 2 deletions tfsnippet/examples/classification/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def main():
time_metric_name='test_time'
)
evaluator.events.on(
spt.EventKeys.AFTER_EVALUATION,
lambda: results.update_metrics(evaluator.last_metrics_dict)
spt.EventKeys.AFTER_EXECUTION,
lambda e: results.update_metrics(evaluator.last_metrics_dict)
)
trainer.evaluate_after_epochs(evaluator, freq=5)
trainer.log_after_epochs(freq=1)
Expand Down
4 changes: 2 additions & 2 deletions tfsnippet/examples/classification/mnist_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ def main():
time_metric_name='test_time'
)
evaluator.events.on(
spt.EventKeys.AFTER_EVALUATION,
lambda: results.update_metrics(evaluator.last_metrics_dict)
spt.EventKeys.AFTER_EXECUTION,
lambda e: results.update_metrics(evaluator.last_metrics_dict)
)
trainer.evaluate_after_epochs(evaluator, freq=5)
trainer.log_after_epochs(freq=1)
Expand Down

0 comments on commit 2fb0faa

Please sign in to comment.