Skip to content

Commit

Permalink
fix test_timer test for new-style-exp
Browse files Browse the repository at this point in the history
  • Loading branch information
aonotas committed Aug 25, 2017
1 parent 665d571 commit c8dfa97
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions tests/chainer_tests/function_hooks_tests/test_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def check_history(self, t, function_type, return_type):


class SimpleLink(chainer.Link):

def __init__(self):
super(SimpleLink, self).__init__()
with self.init_scope():
Expand Down Expand Up @@ -88,7 +89,7 @@ def setUp(self):
self.gy = numpy.random.uniform(-0.1, 0.1, (3, 5)).astype(numpy.float32)

def check_forward(self, x):
self.f(chainer.Variable(x))
self.f.apply((chainer.Variable(x),))
self.assertEqual(1, len(self.h.call_history))
check_history(self, self.h.call_history[0], functions.Exp, float)

Expand All @@ -101,7 +102,7 @@ def test_forward_gpu(self):

def check_backward(self, x, gy):
x = chainer.Variable(x)
y = self.f(x)
y = self.f.apply((x,))[0]
y.grad = gy
y.backward()
self.assertEqual(2, len(self.h.call_history))
Expand Down Expand Up @@ -163,15 +164,15 @@ def setUp(self):

def test_summary(self):
x = self.x
self.f(chainer.Variable(x))
self.f(chainer.Variable(x))
self.f.apply((chainer.Variable(x),))
self.f.apply((chainer.Variable(x),))
self.assertEqual(2, len(self.h.call_history))
self.assertEqual(1, len(self.h.summary()))

def test_print_report(self):
x = self.x
self.f(chainer.Variable(x))
self.f(chainer.Variable(x))
self.f.apply((chainer.Variable(x),))
self.f.apply((chainer.Variable(x),))
io = six.StringIO()
self.h.print_report(file=io)
expect = r'''\AFunctionName ElapsedTime Occurrence
Expand Down

0 comments on commit c8dfa97

Please sign in to comment.