Skip to content

Commit

Permalink
Revert "[backport] New style forget"
Browse files Browse the repository at this point in the history
  • Loading branch information
hvy committed Mar 27, 2018
1 parent 99ca6b8 commit 65fe145
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 70 deletions.
76 changes: 41 additions & 35 deletions chainer/functions/util/forget.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,63 @@
import chainer
from chainer.backends import cuda
from chainer import function
from chainer import function_node
from chainer import variable


def _call_func(func, xs):
outs = func(*xs)
class _DummyFunction(function.Function):

if isinstance(outs, tuple):
for i, out in enumerate(outs):
if isinstance(out, variable.Variable):
continue
n = i + 1
suffix = {1: 'st', 2: 'nd', 3: 'rd'}.get(
n if n < 20 else n % 10, 'th')
msg = ('{}{} element of a returned tuple is not Variable, '
'but is {}').format(n, suffix, type(out))
raise RuntimeError(msg)
elif isinstance(outs, variable.Variable):
outs = (outs,)
else:
msg = ('A tuple of Variables or a Variable are expected, but {} '
'is returned.'.format(type(outs)))
raise RuntimeError(msg)
def __init__(self, grads):
self.grads = grads

def forward(self, inputs):
xp = cuda.get_array_module(*inputs)
return xp.array(0),

return outs
def backward(self, inputs, outputs):
return self.grads


class Forget(function_node.FunctionNode):
class Forget(function.Function):

def __init__(self, func):
if not callable(func):
raise TypeError('func must be callable')

self.func = func

def _call_func(self, xs):
outs = self.func(*xs)

if isinstance(outs, tuple):
for i, out in enumerate(outs):
if isinstance(out, variable.Variable):
continue
n = i + 1
suffix = {1: 'st', 2: 'nd', 3: 'rd'}.get(
n if n < 20 else n % 10, 'th')
msg = ('{}{} element of a returned tuple is not Variable, '
'but is {}').format(n, suffix, type(out))
raise RuntimeError(msg)
elif isinstance(outs, variable.Variable):
outs = (outs,)
else:
msg = ('A tuple of Variables or a Variable are expected, but {} '
'is returned.'.format(type(outs)))
raise RuntimeError(msg)

return outs

def forward(self, inputs):
self.retain_inputs(tuple(range(len(inputs))))
with function.no_backprop_mode():
xs = [variable.Variable(x) for x in inputs]
outs = _call_func(self.func, xs)
outs = self._call_func(xs)
return tuple(out.data for out in outs)

def backward(self, indexes, grad_outputs):
inputs = self.get_retained_inputs()
def backward(self, inputs, grads):
with function.force_backprop_mode():
outs = _call_func(self.func, inputs)
# Return gradients that are further backproable
return chainer.grad(
outs, inputs, grad_outputs=grad_outputs,
enable_double_backprop=True)
xs = [variable.Variable(x) for x in inputs]
outs = self._call_func(xs)
_DummyFunction(grads)(*outs).backward()
return tuple(x.grad for x in xs)


def forget(func, *xs):
Expand Down Expand Up @@ -119,7 +128,4 @@ def forget(func, *xs):
"""
xs = tuple(x if isinstance(x, variable.Variable) else
variable.Variable(x, requires_grad=True) for x in xs)
y = Forget(func).apply(xs)
if len(y) == 1:
y, = y
return y
return Forget(func)(*xs)
42 changes: 7 additions & 35 deletions tests/chainer_tests/functions_tests/util_tests/test_forget.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@
import six

import chainer
from chainer import cuda
from chainer import functions
from chainer import gradient_check
from chainer import testing
from chainer.testing import attr


class TestForget(unittest.TestCase):
Expand All @@ -17,11 +14,6 @@ def setUp(self):
self.x = numpy.random.uniform(-1, 1, (3, 4)).astype(numpy.float32)
self.y = numpy.random.uniform(-1, 1, (3, 4)).astype(numpy.float32)
self.gz = numpy.random.uniform(-1, 1, (3, 4)).astype(numpy.float32)
self.ggx = numpy.random.uniform(-1, 1, (3, 4)).astype(numpy.float32)
self.ggy = numpy.random.uniform(-1, 1, (3, 4)).astype(numpy.float32)

self.check_backward_options = {'atol': 5e-4, 'rtol': 5e-3}
self.check_double_backward_options = {'atol': 5e-3, 'rtol': 5e-2}

def check_forward(self, x_data, y_data):
x = chainer.Variable(x_data)
Expand All @@ -33,38 +25,18 @@ def test_forward_cpu(self):
self.check_forward(self.x, self.y)

def check_backward(self, x_data, y_data, gz_data):
def f(x, y):
return functions.forget(lambda x, y: (x + y + x), x, y)
x = chainer.Variable(x_data)
y = chainer.Variable(y_data)
z = functions.forget(lambda x, y: (x + y + x,), x, y)
z.grad = gz_data
z.backward()

gradient_check.check_backward(
f, (x_data, y_data), gz_data, **self.check_backward_options)
testing.assert_allclose(x.grad, gz_data * 2)
testing.assert_allclose(y.grad, gz_data)

def test_backward_cpu(self):
self.check_backward(self.x, self.y, self.gz)

@attr.gpu
def test_backward_gpu(self):
self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.y),
cuda.to_gpu(self.gz))

def check_double_backward(self, x_data, y_data, gz_data, ggx_data,
ggy_data):
def f(x, y):
return functions.forget(lambda x, y: (x * x * 3 + y * x,), x, y)

gradient_check.check_double_backward(
f, (x_data, y_data), gz_data, (ggx_data, ggy_data),
**self.check_double_backward_options)

def test_double_backward_cpu(self):
self.check_double_backward(self.x, self.y, self.gz, self.ggx, self.ggy)

@attr.gpu
def test_double_backward_gpu(self):
self.check_double_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.y),
cuda.to_gpu(self.gz), cuda.to_gpu(self.ggx),
cuda.to_gpu(self.ggy))


class TestForgetError(unittest.TestCase):

Expand Down

0 comments on commit 65fe145

Please sign in to comment.