Skip to content

Commit

Permalink
add minibatch_slices_iterator and ArrayMiniBatch
Browse files Browse the repository at this point in the history
  • Loading branch information
korepwx committed Apr 15, 2018
1 parent 4c35cde commit 07be715
Show file tree
Hide file tree
Showing 9 changed files with 429 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import tensorflow as tf

from tfsnippet.scaffold import early_stopping
from tfsnippet.scaffold import EarlyStopping
from tfsnippet.utils import TemporaryDirectory, get_default_session_or_error


Expand Down Expand Up @@ -38,21 +38,21 @@ def test_param_vars_must_not_be_empty(self):
with self.test_session():
with pytest.raises(
ValueError, match='`param_vars` must not be empty'):
with early_stopping([]):
with EarlyStopping([]):
pass

def test_early_stopping_context_without_updating_loss(self):
with self.test_session():
a, b, c = _populate_variables()
with early_stopping([a, b]) as es:
with EarlyStopping([a, b]) as es:
set_variable_values([a], [10])
self.assertFalse(es.ever_updated)
self.assertEqual(get_variable_values([a, b, c]), [10, 2, 3])

def test_the_first_loss_will_always_cause_saving(self):
with self.test_session():
a, b, c = _populate_variables()
with early_stopping([a, b]) as es:
with EarlyStopping([a, b]) as es:
set_variable_values([a], [10])
self.assertTrue(es.update(1.))
set_variable_values([a, b], [100, 20])
Expand All @@ -63,7 +63,7 @@ def test_the_first_loss_will_always_cause_saving(self):
def test_memorize_the_best_loss(self):
with self.test_session():
a, b, c = _populate_variables()
with early_stopping([a, b]) as es:
with EarlyStopping([a, b]) as es:
set_variable_values([a], [10])
self.assertTrue(es.update(1.))
self.assertAlmostEqual(es.best_metric, 1.)
Expand All @@ -80,7 +80,7 @@ def test_memorize_the_best_loss(self):
def test_initial_loss(self):
with self.test_session():
a, b, c = _populate_variables()
with early_stopping([a, b], initial_metric=.6) as es:
with EarlyStopping([a, b], initial_metric=.6) as es:
set_variable_values([a], [10])
self.assertFalse(es.update(1.))
self.assertAlmostEqual(es.best_metric, .6)
Expand All @@ -92,14 +92,14 @@ def test_initial_loss(self):
def test_initial_loss_is_tensor(self):
with self.test_session():
a, b, c = _populate_variables()
with early_stopping([a, b], initial_metric=tf.constant(.5)) as es:
with EarlyStopping([a, b], initial_metric=tf.constant(.5)) as es:
np.testing.assert_equal(es.best_metric, .5)

def test_do_not_restore_on_error(self):
with self.test_session():
a, b, c = _populate_variables()
with pytest.raises(ValueError, match='value error'):
with early_stopping([a, b], restore_on_error=False) as es:
with EarlyStopping([a, b], restore_on_error=False) as es:
self.assertTrue(es.update(1.))
set_variable_values([a, b], [10, 20])
raise ValueError('value error')
Expand All @@ -110,7 +110,7 @@ def test_restore_on_error(self):
with self.test_session():
a, b, c = _populate_variables()
with pytest.raises(ValueError, match='value error'):
with early_stopping([a, b], restore_on_error=True) as es:
with EarlyStopping([a, b], restore_on_error=True) as es:
self.assertTrue(es.update(1.))
set_variable_values([a, b], [10, 20])
raise ValueError('value error')
Expand All @@ -120,7 +120,7 @@ def test_restore_on_error(self):
def test_bigger_is_better(self):
with self.test_session():
a, b, c = _populate_variables()
with early_stopping([a, b], smaller_is_better=False) as es:
with EarlyStopping([a, b], smaller_is_better=False) as es:
set_variable_values([a], [10])
self.assertTrue(es.update(.5))
self.assertAlmostEqual(es.best_metric, .5)
Expand All @@ -138,7 +138,7 @@ def test_cleanup_checkpoint_dir(self):
a, b, c = _populate_variables()
with TemporaryDirectory() as tempdir:
checkpoint_dir = os.path.join(tempdir, '1')
with early_stopping([a, b], checkpoint_dir=checkpoint_dir) as es:
with EarlyStopping([a, b], checkpoint_dir=checkpoint_dir) as es:
self.assertTrue(es.update(1.))
self.assertTrue(
os.path.exists(os.path.join(checkpoint_dir, 'latest')))
Expand All @@ -149,8 +149,8 @@ def test_not_cleanup_checkpoint_dir(self):
a, b, c = _populate_variables()
with TemporaryDirectory() as tempdir:
checkpoint_dir = os.path.join(tempdir, '2')
with early_stopping([a, b], checkpoint_dir=checkpoint_dir,
cleanup=False) as es:
with EarlyStopping([a, b], checkpoint_dir=checkpoint_dir,
cleanup=False) as es:
self.assertTrue(es.update(1.))
self.assertTrue(
os.path.exists(os.path.join(checkpoint_dir, 'latest')))
Expand Down
2 changes: 1 addition & 1 deletion tests/scaffold/test_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tfsnippet.utils import TemporaryDirectory


class LoggingUtilsTestCase(unittest.TestCase):
class LoggingUtilsTestCase(tf.test.TestCase):

def test_summarize_variables(self):
# test variable summaries
Expand Down
121 changes: 121 additions & 0 deletions tests/scaffold/test_mini_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import unittest

import numpy as np
import pytest

from tfsnippet.scaffold import (minibatch_slices_iterator, ArrayMiniBatch,
MiniBatch)


class MiniBatchSlicesIteratorTestCase(unittest.TestCase):

def test_minibatch_slices_iterator(self):
self.assertEqual(
list(minibatch_slices_iterator(0, 10, False)),
[]
)
self.assertEqual(
list(minibatch_slices_iterator(9, 10, False)),
[slice(0, 9, 1)]
)
self.assertEqual(
list(minibatch_slices_iterator(10, 10, False)),
[slice(0, 10, 1)]
)
self.assertEqual(
list(minibatch_slices_iterator(10, 9, False)),
[slice(0, 9, 1), slice(9, 10, 1)]
)
self.assertEqual(
list(minibatch_slices_iterator(10, 9, True)),
[slice(0, 9, 1)]
)


class MiniBatchTestCase(unittest.TestCase):

def test_get_iterator_reentrant(self):
class MyMiniBatch(MiniBatch):
def _get_iterator(self):
return iter([123])

m = MyMiniBatch()
for b in m:
self.assertEquals(123, b)
for b in m:
self.assertEquals(123, b)
with pytest.raises(
RuntimeError,
match='get_iterator of MiniBatch is not re-entrant'):
for b2 in m:
pass


class ArrayMiniBatchTestCase(unittest.TestCase):

def test_property(self):
m = ArrayMiniBatch(
arrays=[np.arange(12).reshape([4, 3]), np.arange(4)],
batch_size=5,
shuffle=True,
ignore_incomplete_batch=True
)
self.assertEquals(2, m.array_count)
self.assertEquals(4, m.data_length)
self.assertEquals(((3,), ()), m.data_shapes)
self.assertEquals(5, m.batch_size)
self.assertTrue(m.ignore_incomplete_batch)
self.assertTrue(m.is_shuffled)

# test default options
m = ArrayMiniBatch([np.arange(12)], 5)
self.assertFalse(m.ignore_incomplete_batch)
self.assertFalse(m.is_shuffled)

def test_errors(self):
with pytest.raises(
ValueError, match='`arrays` must not be empty'):
_ = ArrayMiniBatch([], 3)
with pytest.raises(
ValueError, match='`arrays` must be numpy-like arrays'):
_ = ArrayMiniBatch([np.arange(3).tolist()], 3)
with pytest.raises(
ValueError, match='`arrays` must be at least 1-d arrays'):
_ = ArrayMiniBatch([np.array(0)], 3)
with pytest.raises(
ValueError, match='`arrays` must have the same data length'):
_ = ArrayMiniBatch([np.arange(3), np.arange(4)], 3)

def test_get_iterator(self):
# test single array, without shuffle, no ignore
b = [a[0] for a in ArrayMiniBatch([np.arange(12)], 5)]
self.assertEquals(3, len(b))
np.testing.assert_array_equal(np.arange(0, 5), b[0])
np.testing.assert_array_equal(np.arange(5, 10), b[1])
np.testing.assert_array_equal(np.arange(10, 12), b[2])

# test single array, without shuffle, ignore
b = [a[0] for a in ArrayMiniBatch(
[np.arange(12)], 5, ignore_incomplete_batch=True)]
self.assertEquals(2, len(b))
np.testing.assert_array_equal(np.arange(0, 5), b[0])
np.testing.assert_array_equal(np.arange(5, 10), b[1])

# test dual arrays, without shuffle, no ignore
b = list(ArrayMiniBatch([np.arange(6), np.arange(12).reshape([6, 2])],
5))
self.assertEquals(2, len(b))
np.testing.assert_array_equal(np.arange(0, 5), b[0][0])
np.testing.assert_array_equal(np.arange(5, 6), b[1][0])
np.testing.assert_array_equal(np.arange(0, 10).reshape([5, 2]), b[0][1])
np.testing.assert_array_equal(
np.arange(10, 12).reshape([1, 2]), b[1][1])

# test single array, with shuffle, no ignore
b = [a[0] for a in ArrayMiniBatch([np.arange(12)], 5, shuffle=True)]
self.assertEquals(3, len(b))
np.testing.assert_array_equal(np.arange(12), sorted(np.concatenate(b)))


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 07be715

Please sign in to comment.