Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
Tests for PickableSequentialChain.copy()
Browse files Browse the repository at this point in the history
  • Loading branch information
ktns committed Feb 2, 2019
1 parent a825040 commit 30b5cd6
Showing 1 changed file with 94 additions and 11 deletions.
105 changes: 94 additions & 11 deletions tests/links_tests/model_tests/test_pickable_sequential_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,9 @@ def forward(self, inputs):
return inputs[0] * 2,


@testing.parameterize(
{'pick': None},
{'pick': 'f2'},
{'pick': ('f2',)},
{'pick': ('l2', 'l1', 'f2')},
{'pick': ('l2', 'l2')},
)
class TestPickableSequentialChain(unittest.TestCase):
class PickableSequentialChainTestBase(object):

def setUp(self):
def setUpBase(self):
self.l1 = ConstantStubLink(np.random.uniform(size=(1, 3, 24, 24)))
self.f1 = DummyFunc()
self.f2 = DummyFunc()
Expand Down Expand Up @@ -94,8 +87,8 @@ def check_deletion(self):
x = self.link.xp.asarray(self.x)

if self.pick == 'l1' or \
(isinstance(self.pick, tuple) and
'l1' in self.pick):
(isinstance(self.pick, tuple)
and 'l1' in self.pick):
with self.assertRaises(AttributeError):
del self.link.l1
return
Expand All @@ -118,6 +111,96 @@ def test_deletion_gpu(self):
self.check_deletion()


@testing.parameterize(
{'pick': None},
{'pick': 'f2'},
{'pick': ('f2',)},
{'pick': ('l2', 'l1', 'f2')},
{'pick': ('l2', 'l2')},
)
class TestPickableSequentialChain(
unittest.TestCase, PickableSequentialChainTestBase):
def setUp(self):
self.setUpBase()


@testing.parameterize(
*testing.product({
'mode': ['init', 'share', 'copy'],
'pick': [None, 'f1', ('f1', 'f2'), ('l2', 'l2'), ('l2', 'l1', 'f2')]
})
)
class TestCopiedPickableSequentialChain(
unittest.TestCase, PickableSequentialChainTestBase):

def setUp(self):
self.setUpBase()

self.f100 = DummyFunc()
self.l100 = ConstantStubLink(np.random.uniform(size=(1, 3, 24, 24)))

self.link, self.original_link = \
self.link.copy(mode=self.mode), self.link

def check_unchanged(self, link, x):
class Checker(object):
def __init__(self, tester, link, x):
self.tester = tester
self.link = link
self.x = x

def __enter__(self):
self.expected = self.link(self.x)

def __exit__(self, exc_type, exc_value, traceback):
if exc_type is not None:
return None

self.actual = self.link(self.x)

if isinstance(self.expected, tuple):
self.tester.assertEqual(
len(self.expected), len(self.actual))
for e, a in zip(self.expected, self.actual):
self.tester.assertEqual(type(e.array), type(a.array))
np.testing.assert_equal(
to_cpu(e.array), to_cpu(a.array))
else:
self.tester.assertEqual(type(self.expected.array),
type(self.actual.array))
np.testing.assert_equal(
to_cpu(self.expected.array),
to_cpu(self.actual.array))

return Checker(self, link, x)

def test_original_unaffected_by_setting_pick(self):
with self.check_unchanged(self.original_link, self.x):
self.link.pick = 'f2'

def test_original_unaffected_by_function_addition(self):
with self.check_unchanged(self.original_link, self.x):
with self.link.init_scope():
self.link.f100 = self.f100

def test_original_unaffected_by_link_addition(self):
with self.check_unchanged(self.original_link, self.x):
with self.link.init_scope():
self.link.l100 = self.l100

def test_original_unaffected_by_function_deletion(self):
with self.check_unchanged(self.original_link, self.x):
with self.link.init_scope():
self.link.pick = None
del self.link.f1

def test_original_unaffected_by_link_deletion(self):
with self.check_unchanged(self.original_link, self.x):
with self.link.init_scope():
self.link.pick = None
del self.link.l1


@testing.parameterize(
{'pick': 'l1', 'layer_names': ['l1']},
{'pick': 'f1', 'layer_names': ['l1', 'f1']},
Expand Down

0 comments on commit 30b5cd6

Please sign in to comment.