Skip to content

Commit

Permalink
Merge a1561cc into ba18b83
Browse files Browse the repository at this point in the history
  • Loading branch information
haowen-xu committed May 12, 2018
2 parents ba18b83 + a1561cc commit adee1be
Show file tree
Hide file tree
Showing 27 changed files with 1,671 additions and 42 deletions.
22 changes: 6 additions & 16 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -1,28 +1,18 @@
[run]
branch = True
source = tfsnippet

[report]
# Regexes for lines to exclude from consideration
exclude_lines =
# Have to re-enable the standard pragma
pragma: no cover

# Don't complain about missing debug-only code:
if self\.debug

# Don't complain if tests don't hit defensive assertion code:
raise AssertionError
raise NotImplementedError

# Don't complain if non-runnable code isn't run:
if 0:
if __name__ == .__main__.:


[run]
ignore_errors = True
omit =
# test code need not coverage statistics
tests/*

# maintenance scripts need not coverage statistics
scripts/*

# imported functions are just thin wrappers around libraries, skip tests
tfsnippet/utils/imported.py
setup.py
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ env:
matrix:
- PYTHON_VERSION=2 TENSORFLOW_VERSION=1.5
- PYTHON_VERSION=3 TENSORFLOW_VERSION=1.5
- PYTHON_VERSION=2 TENSORFLOW_VERSION=1.6
- PYTHON_VERSION=3 TENSORFLOW_VERSION=1.6
- PYTHON_VERSION=2 TENSORFLOW_VERSION=1.8
- PYTHON_VERSION=3 TENSORFLOW_VERSION=1.8
install:
- docker pull "ipwx/travis-tensorflow-docker:py${PYTHON_VERSION}tf${TENSORFLOW_VERSION}"
script:
Expand Down
7 changes: 7 additions & 0 deletions docs/api/tfsnippet.trainer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
tfsnippet\.trainer
==================

.. automodule:: tfsnippet.trainer
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ API Docs
api/tfsnippet.modules
api/tfsnippet.scaffold
api/tfsnippet.stochastic
api/tfsnippet.trainer
api/tfsnippet.utils
api/tfsnippet.variational

Expand Down
2 changes: 1 addition & 1 deletion tests/dataflows/test_array_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_arrays(self):
arrays, 4, shuffle=False, skip_incomplete=False)
self.assertIsInstance(df, ArrayFlow)
for i, arr in enumerate(arrays):
self.assertIs(arr, df.arrays[i])
self.assertIs(arr, df.the_arrays[i])
self.assertEquals(2, df.array_count)
self.assertEquals(5, df.data_length)
self.assertEquals(((), (2,)), df.data_shapes)
Expand Down
55 changes: 55 additions & 0 deletions tests/dataflows/test_data_mappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import unittest

import numpy as np
import pytest
from mock import Mock

from tfsnippet.dataflow import DataMapper, SlidingWindow


class DataMapperTestCase(unittest.TestCase):

def test_error(self):
dm = DataMapper()
dm._transform = Mock(return_value=np.array([1, 2, 3]))
with pytest.raises(TypeError, match='The output of .* is not a tuple'):
dm(np.array([1, 2, 3]))


class SlidingWindowTestCase(unittest.TestCase):

def test_props(self):
arr = np.arange(13)
sw = SlidingWindow(arr, window_size=3)
self.assertIs(arr, sw.data_array)
self.assertEquals(3, sw.window_size)

def test_transform(self):
arr = np.arange(13)
sw = SlidingWindow(arr, window_size=3)
np.testing.assert_equal(
[[0, 1, 2], [5, 6, 7], [3, 4, 5]],
sw(np.asarray([0, 5, 3]))[0]
)

def test_as_flow(self):
arr = np.arange(13)
sw = SlidingWindow(arr, window_size=3)
batches = list(sw.as_flow(batch_size=4))
self.assertEquals(3, len(batches))
np.testing.assert_equal(
[[0, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 5]],
batches[0][0]
)
np.testing.assert_equal(
[[4, 5, 6], [5, 6, 7], [6, 7, 8], [7, 8, 9]],
batches[1][0]
)
np.testing.assert_equal(
[[8, 9, 10], [9, 10, 11], [10, 11, 12]],
batches[2][0]
)


if __name__ == '__main__':
unittest.main()
23 changes: 13 additions & 10 deletions tests/scaffold/test_train_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,9 @@ def test_logs(self):
def test_valid_metric_default_settings(self):
logs = []
with TrainLoop([], print_func=logs.append) as loop:
self.assertEqual(loop._valid_metric_name, 'valid_loss')
self.assertTrue(loop._valid_metric_smaller_is_better)
self.assertEqual(loop.valid_metric_name, 'valid_loss')
self.assertTrue(loop.valid_metric_smaller_is_better)
self.assertFalse(loop.use_early_stopping)
for _ in loop.iter_epochs():
best_metric = 1.
for _, valid_loss in loop.iter_steps([0.8, 0.6, 0.7]):
Expand All @@ -184,11 +185,12 @@ def test_valid_metric_default_settings(self):

def test_valid_metric_with_custom_settings(self):
logs = []
with TrainLoop([], print_func=logs.append,
v = tf.get_variable('a', shape=[1], dtype=tf.int32)
with TrainLoop([v], print_func=logs.append,
valid_metric_name='y',
valid_metric_smaller_is_better=False) as loop:
self.assertEqual(loop._valid_metric_name, 'y')
self.assertFalse(loop._valid_metric_smaller_is_better)
self.assertEqual(loop.valid_metric_name, 'y')
self.assertFalse(loop.valid_metric_smaller_is_better)
for _ in loop.iter_epochs():
best_metric = 0.
for _, y in loop.iter_steps([0.7, 0.6, 0.8]):
Expand All @@ -214,13 +216,13 @@ def test_valid_metric_with_custom_settings(self):

def test_valid_metric_with_valid_acc(self):
with TrainLoop([], valid_metric_name='valid_acc') as loop:
self.assertEqual(loop._valid_metric_name, 'valid_acc')
self.assertFalse(loop._valid_metric_smaller_is_better)
self.assertEqual(loop.valid_metric_name, 'valid_acc')
self.assertFalse(loop.valid_metric_smaller_is_better)

def test_valid_metric_with_y_as_name(self):
with TrainLoop([], valid_metric_name='y') as loop:
self.assertEqual(loop._valid_metric_name, 'y')
self.assertTrue(loop._valid_metric_smaller_is_better)
self.assertEqual(loop.valid_metric_name, 'y')
self.assertTrue(loop.valid_metric_smaller_is_better)

def test_training_summary(self):
a = tf.get_variable('a', dtype=tf.float32, shape=(2, 3))
Expand Down Expand Up @@ -380,7 +382,8 @@ def test_early_stopping(self):
# test early-stopping with no valid metric committed
set_variable_values([a, b], [1, 2])
self.assertEqual(get_variable_values([a, b]), [1, 2])
with TrainLoop([a], early_stopping=True):
with TrainLoop([a], early_stopping=True) as loop:
self.assertTrue(loop.use_early_stopping)
set_variable_values([a, b], [10, 20])
self.assertEqual(get_variable_values([a, b]), [10, 20])

Expand Down
Empty file added tests/trainer/__init__.py
Empty file.
157 changes: 157 additions & 0 deletions tests/trainer/test_base_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import functools
import unittest

import numpy as np
import pytest
import tensorflow as tf
from mock import Mock

from tfsnippet.dataflow import DataFlow
from tfsnippet.scaffold import TrainLoop
from tfsnippet.trainer import *


class BaseTrainerTestCase(tf.test.TestCase):

def test_props(self):
loop = Mock(valid_metric_name='valid_loss')
df = Mock()

t = BaseTrainer(loop, [12, 34], df, feed_dict={'a': 56})
self.assertIs(loop, t.loop)
self.assertEquals([12, 34], t.inputs)
self.assertIs(df, t.data_flow)
self.assertEquals({'a': 56}, t.feed_dict)
self.assertEquals(
(t.before_epochs, t.before_steps, t.after_steps, t.after_epochs),
t.hook_lists
)
for hkl in t.hook_lists:
self.assertIsInstance(hkl, HookList)

def test_add_and_remove_hooks(self):
loop = Mock(
valid_metric_name='valid_loss',
print_logs=Mock(return_value=None, __repr__=lambda o: 'print_logs')
)
df = Mock()
val1 = Validator(loop, 1., [], df)
val2 = Validator(loop, 2., [], df)
anneal1 = AnnealingDynamicValue(1., .5)
anneal2 = AnnealingDynamicValue(2., .5)

# test add
t = BaseTrainer(loop, [12, 34], df)
t.log_after_epochs(3)
t.log_after_steps(4)
t.validate_after_steps(
Mock(return_value=None, __repr__=lambda o: 'val_step'), 5)
t.validate_after_epochs(
Mock(return_value=None, __repr__=lambda o: 'val_epoch'), 6)
t.anneal_after_steps(
Mock(return_value=None, __repr__=lambda o: 'anneal_step'), 7)
t.anneal_after_epochs(
Mock(return_value=None, __repr__=lambda o: 'anneal_epoch'), 8)
t.validate_after_steps(val1, 9)
t.validate_after_epochs(val2, 10)
t.anneal_after_steps(anneal1, 11)
t.anneal_after_epochs(anneal2, 12)

self.assertEquals('HookList()', repr(t.before_steps))
self.assertEquals('HookList()', repr(t.before_epochs))
steps_ans = 'HookList(val_step:5,{!r}:9,anneal_step:7,' \
'{!r}:11,print_logs:4)'.format(val1.run, anneal1.anneal)
self.assertEquals(steps_ans, repr(t.after_steps))
epochs_ans = 'HookList(val_epoch:6,{!r}:10,anneal_epoch:8,' \
'{!r}:12,print_logs:3)'.format(val2.run, anneal2.anneal)
self.assertEquals(epochs_ans, repr(t.after_epochs))

# test remove
t.remove_log_hooks()
steps_ans = 'HookList(val_step:5,{!r}:9,anneal_step:7,' \
'{!r}:11)'.format(val1.run, anneal1.anneal)
self.assertEquals(steps_ans, repr(t.after_steps))
epochs_ans = 'HookList(val_epoch:6,{!r}:10,anneal_epoch:8,' \
'{!r}:12)'.format(val2.run, anneal2.anneal)
self.assertEquals(epochs_ans, repr(t.after_epochs))

t.remove_validation_hooks()
steps_ans = 'HookList(anneal_step:7,{!r}:11)'.format(anneal1.anneal)
self.assertEquals(steps_ans, repr(t.after_steps))
epochs_ans = 'HookList(anneal_epoch:8,{!r}:12)'.format(anneal2.anneal)
self.assertEquals(epochs_ans, repr(t.after_epochs))

t.remove_annealing_hooks()
self.assertEquals('HookList()', repr(t.after_steps))
self.assertEquals('HookList()', repr(t.after_epochs))

def test_run(self):
with tf.Session().as_default() as session:
df = DataFlow.arrays([np.arange(6, dtype=np.float32)], batch_size=4)
ph = tf.placeholder(tf.float32, shape=[None])
ph2 = tf.placeholder(tf.float32, shape=[])
ph3 = tf.placeholder(tf.float32, shape=[])

def log_message(m):
logged_messages.append(m)
logged_messages = []

# test default loss weight and merged feed dict
with TrainLoop([], max_epoch=2) as loop:
t = BaseTrainer(loop, [ph], df, feed_dict={ph2: 34})
t._fit_step = Mock(return_value=None)
t.before_epochs.add_hook(
functools.partial(log_message, 'before_epoch'))
t.before_steps.add_hook(
functools.partial(log_message, 'before_step'))
t.after_steps.add_hook(
functools.partial(log_message, 'after_step'))
t.after_epochs.add_hook(
functools.partial(log_message, 'after_epoch'))

t.run({ph3: 56})
self.assertEquals(4, len(t._fit_step.call_args_list))
for i, call_args in enumerate(t._fit_step.call_args_list[:-2]):
call_session, call_feed_dict = call_args[0]
self.assertIs(session, call_session)
np.testing.assert_equal(
np.arange(6, dtype=np.float32)[i * 4: (i + 1) * 4],
call_feed_dict[ph]
)
self.assertEquals(34, call_feed_dict[ph2])
self.assertEquals(56, call_feed_dict[ph3])

self.assertEquals(
['before_epoch', 'before_step', 'after_step',
'before_step', 'after_step', 'after_epoch'] * 2,
logged_messages
)

# test override feed dict
with TrainLoop([], max_epoch=1) as loop:
t = BaseTrainer(loop, [ph], df, feed_dict={ph2: 34})
t._fit_step = Mock(return_value=None)
t.run(feed_dict={ph2: 56})

for i, call_args in enumerate(t._fit_step.call_args_list):
call_session, call_feed_dict = call_args[0]
self.assertEquals(56, call_feed_dict[ph2])
self.assertNotIn(ph3, call_feed_dict)

# test re-entrant error
with TrainLoop([], max_epoch=1) as loop:
t = BaseTrainer(loop, [ph], df)
t._fit_step = Mock(return_value=None)

def reentrant_error():
with pytest.raises(
RuntimeError, match=r'`run\(\)` is not re-entrant'):
t.run()
reentrant_error = Mock(wraps=reentrant_error)
t.after_steps.add_hook(reentrant_error)
t.run()
self.assertTrue(reentrant_error.called)


if __name__ == '__main__':
unittest.main()
40 changes: 40 additions & 0 deletions tests/trainer/test_dynamic_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import unittest

import pytest

from tfsnippet.trainer import *


class DynamicValuesTestCase(unittest.TestCase):

def test_SimpleDynamicValue(self):
v = SimpleDynamicValue(123)
self.assertEquals(123, v.get())
v.set(456)
self.assertEquals(456, v.get())
v.set(SimpleDynamicValue(789))
self.assertEquals(789, v.get())

with pytest.raises(ValueError, match='Cannot set the value to `self`.'):
v.set(v)

def test_AnnealingDynamicValue(self):
v = AnnealingDynamicValue(1, .5)
self.assertEquals(1, v.get())
self.assertEquals(.5, v.ratio)
v.ratio = .25
self.assertEquals(.25, v.ratio)

v.anneal()
self.assertEquals(.25, v.get())
v.anneal()
self.assertEquals(.0625, v.get())

v.set(2.)
self.assertEquals(2., v.get())
v.anneal()
self.assertEquals(.5, v.get())


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit adee1be

Please sign in to comment.