Skip to content

Commit

Permalink
Merge 34355b8 into e6a1f8a
Browse files Browse the repository at this point in the history
  • Loading branch information
beam2d committed Aug 22, 2017
2 parents e6a1f8a + 34355b8 commit 391808f
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 32 deletions.
54 changes: 30 additions & 24 deletions chainer/functions/connection/linear.py
@@ -1,14 +1,9 @@
from chainer import function
from chainer import function_node
import chainer.functions
from chainer.utils import type_check


def _as_mat(x):
if x.ndim == 2:
return x
return x.reshape(len(x), -1)


class LinearFunction(function.Function):
class LinearFunction(function_node.FunctionNode):

def check_type_forward(self, in_types):
n_in = in_types.size()
Expand All @@ -18,9 +13,9 @@ def check_type_forward(self, in_types):
type_check.expect(
x_type.dtype.kind == 'f',
w_type.dtype.kind == 'f',
x_type.ndim >= 2,
x_type.ndim == 2,
w_type.ndim == 2,
type_check.prod(x_type.shape[1:]) == w_type.shape[1],
x_type.shape[1] == w_type.shape[1],
)
if type_check.eval(n_in) == 3:
b_type = in_types[2]
Expand All @@ -31,7 +26,7 @@ def check_type_forward(self, in_types):
)

def forward(self, inputs):
x = _as_mat(inputs[0])
x = inputs[0]
W = inputs[1]

if not type_check.same_types(*inputs):
Expand All @@ -43,20 +38,25 @@ def forward(self, inputs):
if len(inputs) == 3:
b = inputs[2]
y += b
self.retain_inputs((0, 1)) # b is not retained
return y,

def backward(self, inputs, grad_outputs):
x = _as_mat(inputs[0])
W = inputs[1]
gy = grad_outputs[0]
def backward(self, indexes, grad_outputs):
x, W = self.get_retained_inputs()
gy, = grad_outputs

gx = gy.dot(W).astype(x.dtype, copy=False).reshape(inputs[0].shape)
gW = gy.T.dot(x).astype(W.dtype, copy=False)
if len(inputs) == 3:
gb = gy.sum(0)
return gx, gW, gb
else:
return gx, gW
ret = []
if 0 in indexes:
gx = linear(gy, W.T)
ret.append(chainer.functions.cast(gx, x.dtype))
if 1 in indexes:
gW = linear(gy.T, x.T)
ret.append(chainer.functions.cast(gW, W.dtype))
if 2 in indexes:
gb = chainer.functions.sum(gy, axis=0)
ret.append(gb)

return ret


def linear(x, W, b=None):
Expand Down Expand Up @@ -96,7 +96,13 @@ def linear(x, W, b=None):
(3, 5)
"""
if x.ndim > 2:
x = x.reshape(len(x), -1)

if b is None:
return LinearFunction()(x, W)
args = x, W
else:
return LinearFunction()(x, W, b)
args = x, W, b

y, = LinearFunction().apply(args)
return y
5 changes: 5 additions & 0 deletions chainer/variable.py
Expand Up @@ -638,6 +638,11 @@ def requires_grad(self):
"""It indicates that ``grad`` will be set in backward calculation."""
return self._requires_grad

@property
def T(self):
"""Transposition of this variable."""
return chainer.functions.transpose(self)

def to_cpu(self):
"""Copies the data and gradient arrays to CPU."""
if self.data is None:
Expand Down
25 changes: 18 additions & 7 deletions tests/chainer_tests/function_hooks_tests/test_timer.py
Expand Up @@ -8,22 +8,34 @@
from chainer import cuda
from chainer import function_hooks
from chainer import functions
from chainer.functions.connection import linear
from chainer import links
from chainer.functions.math import basic_math
from chainer import testing
from chainer.testing import attr


def check_history(self, t, function_type, return_type):
self.assertIsInstance(t[0].function, function_type)
func = getattr(t[0], 'function', t[0])
self.assertIsInstance(func, function_type)
self.assertIsInstance(t[1], return_type)


class SimpleLink(chainer.Link):
def __init__(self):
super(SimpleLink, self).__init__()
with self.init_scope():
init_w = numpy.random.uniform(-1, 1, (3, 5)).astype(
numpy.float32)
self.w = chainer.Parameter(init_w)

def __call__(self, x):
return self.w * x


class TestTimerHookToLink(unittest.TestCase):

def setUp(self):
self.h = function_hooks.TimerHook()
self.l = links.Linear(5, 5)
self.l = SimpleLink()
self.x = numpy.random.uniform(-0.1, 0.1, (3, 5)).astype(numpy.float32)
self.gy = numpy.random.uniform(-0.1, 0.1, (3, 5)).astype(numpy.float32)

Expand All @@ -34,8 +46,7 @@ def check_forward(self, x):
with self.h:
self.l(chainer.Variable(x))
self.assertEqual(1, len(self.h.call_history))
check_history(self, self.h.call_history[0],
linear.LinearFunction, float)
check_history(self, self.h.call_history[0], basic_math.Mul, float)

def test_forward_cpu(self):
self.check_forward(self.x)
Expand All @@ -56,7 +67,7 @@ def check_backward(self, x, gy):
for entry in self.h.call_history:
if entry[0].label == '_ + _':
continue
check_history(self, entry, linear.LinearFunction, float)
check_history(self, entry, basic_math.Mul, float)

def test_backward_cpu(self):
self.check_backward(self.x, self.gy)
Expand Down
Expand Up @@ -26,6 +26,12 @@ def setUp(self):

self.x = numpy.random.uniform(-1, 1, (4, 3)).astype(self.x_dtype)
self.gy = numpy.random.uniform(-1, 1, (4, 2)).astype(self.x_dtype)
self.ggx = numpy.random.uniform(-1, 1, self.x.shape).astype(
self.x_dtype)
self.ggW = numpy.random.uniform(-1, 1, self.W.shape).astype(
self.W_dtype)
self.ggb = numpy.random.uniform(-1, 1, self.b.shape).astype(
self.x_dtype)
self.y = self.x.dot(self.W.T) + self.b
self.check_forward_options = {}
self.check_backward_options = {}
Expand Down Expand Up @@ -78,7 +84,7 @@ def check_backward(self, x_data, W_data, b_data, y_grad):
args = args + (b_data,)

gradient_check.check_backward(
linear.LinearFunction(), args, y_grad,
linear.linear, args, y_grad,
eps=1e-2, **self.check_backward_options)

@condition.retry(3)
Expand All @@ -101,5 +107,48 @@ def test_backward_gpu_nobias(self):
self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.W),
None, cuda.to_gpu(self.gy))

def check_double_backward(self, x_data, W_data, b_data, y_grad,
x_grad_grad, W_grad_grad, b_grad_grad):
args = x_data, W_data
grad_grads = x_grad_grad, W_grad_grad
if b_data is not None:
args += b_data,
grad_grads += b_grad_grad,

# non-linear function for testing
def nonlinear(x, W, b=None):
y = linear.linear(x, W, b)
return y * y

gradient_check.check_double_backward(
nonlinear, args, (y_grad,), grad_grads,
**self.check_backward_options)

@condition.retry(3)
def test_double_backward_cpu(self):
self.check_double_backward(self.x, self.W, self.b, self.gy,
self.ggx, self.ggW, self.ggb)

@condition.retry(3)
def test_double_backward_cpu_nobias(self):
self.check_double_backward(self.x, self.W, None, self.gy,
self.ggx, self.ggW, None)

@attr.gpu
@condition.retry(3)
def test_double_backward_gpu(self):
self.check_double_backward(
cuda.to_gpu(self.x), cuda.to_gpu(self.W), cuda.to_gpu(self.b),
cuda.to_gpu(self.gy), cuda.to_gpu(self.ggx), cuda.to_gpu(self.ggW),
cuda.to_gpu(self.ggb))

@attr.gpu
@condition.retry(3)
def test_double_backward_gpu_nobias(self):
self.check_double_backward(
cuda.to_gpu(self.x), cuda.to_gpu(self.W), None,
cuda.to_gpu(self.gy), cuda.to_gpu(self.ggx), cuda.to_gpu(self.ggW),
None)


testing.run_module(__name__, __file__)

0 comments on commit 391808f

Please sign in to comment.