diff --git a/chainer/functions/__init__.py b/chainer/functions/__init__.py index 91fc26f2d2e6..86b6fe71f21f 100644 --- a/chainer/functions/__init__.py +++ b/chainer/functions/__init__.py @@ -206,8 +206,7 @@ from chainer.functions.math.logarithm_1p import log1p # NOQA from chainer.functions.math.logsumexp import logsumexp # NOQA from chainer.functions.math.logsumexp import LogSumExp # NOQA -from chainer.functions.math.matmul import matmul # NOQA -from chainer.functions.math.matmul import MatMul # NOQA +from chainer.functions.math.matmul import batch_matmul # NOQA from chainer.functions.math.matmul import matmul # NOQA from chainer.functions.math.matmul import MatMul # NOQA from chainer.functions.math.maximum import maximum # NOQA diff --git a/chainer/functions/math/matmul.py b/chainer/functions/math/matmul.py index 5db4f1fc3f65..71680ee7c029 100644 --- a/chainer/functions/math/matmul.py +++ b/chainer/functions/math/matmul.py @@ -1,3 +1,5 @@ +import warnings + import numpy from chainer import cuda @@ -154,3 +156,82 @@ def matmul(a, b, transa=False, transb=False): """ return MatMul(transa=transa, transb=transb)(a, b) + + +def _get_size(typ, index): + if index == 2 and type_check.eval(typ.ndim) == 2: + return 1 + else: + return typ.shape[index] + + +def _batch_matmul(a, b, transa, transb, transout): + a = a.reshape(a.shape[:2] + (-1,)) + b = b.reshape(b.shape[:2] + (-1,)) + return _matmul(a, b, transa, transb, transout) + + +class BatchMatMul(function.Function): + + def __init__(self, transa=False, transb=False): + self.transa = transa + self.transb = transb + + def check_type_forward(self, in_types): + type_check.expect(in_types.size() == 2) + a_type, b_type = in_types + + type_check.expect( + a_type.dtype == numpy.float32, + b_type.dtype == numpy.float32 + ) + + _check_ndim(a_type, lower=2, upper=3) + _check_ndim(b_type, lower=2, upper=3) + + a_idx = _get_check_index(self.transa, False, row_idx=1, col_idx=2) + b_idx = _get_check_index(self.transb, True, row_idx=1, col_idx=2) + a_size = _get_size(a_type, a_idx) + b_size = _get_size(b_type, b_idx) + type_check.expect( + a_size == b_size + ) + + def forward(self, x): + a, b = x + return _batch_matmul(a, b, self.transa, self.transb, False), + + def backward(self, x, gy): + a, b = x + ga = _batch_matmul(gy[0], b, False, not self.transb, + self.transa).reshape(a.shape) + gb = _batch_matmul(a, gy[0], not self.transa, False, + self.transb).reshape(b.shape) + return ga, gb + + +def batch_matmul(a, b, transa=False, transb=False): + """Computes the batch matrix multiplications of two sets of arrays. + + Args: + a (Variable): The left operand of the batch matrix multiplications. + A 2-D array of shape ``(B, N)`` is considered as B + :math:`N \\times 1` matrices. + A 3-D array of shape ``(B, M, N)`` is considered as B + :math:`M \\times N` matrices. + b (Variable): The right operand of the batch matrix multiplications. + Its array is treated as matrices in the same way as ``a``'s array. + transa (bool): If ``True``, transpose each matrix in ``a``. + transb (bool): If ``True``, transpose each matrix in ``b``. + + Returns: + ~chainer.Variable: The result of the batch matrix multiplications as a + 3-D array. + + .. deprecated:: v3.0.0 + batch_matmul is deprecated. Use ``matmul`` instead. + + """ + warnings.warn('batch_matmul is deprecated. Use matmul instead.', + DeprecationWarning) + return BatchMatMul(transa=transa, transb=transb)(a, b) diff --git a/docs/source/reference/functions.rst b/docs/source/reference/functions.rst index 17f079b18794..dfd6e9f9e299 100644 --- a/docs/source/reference/functions.rst +++ b/docs/source/reference/functions.rst @@ -169,6 +169,7 @@ Mathematical functions chainer.functions.average chainer.functions.batch_inv chainer.functions.batch_l2_norm_squared + chainer.functions.batch_matmul chainer.functions.bias chainer.functions.ceil chainer.functions.clip diff --git a/tests/chainer_tests/functions_tests/math_tests/test_matmul.py b/tests/chainer_tests/functions_tests/math_tests/test_matmul.py index b7f7d8e47d51..5d3fc1070d04 100644 --- a/tests/chainer_tests/functions_tests/math_tests/test_matmul.py +++ b/tests/chainer_tests/functions_tests/math_tests/test_matmul.py @@ -126,6 +126,119 @@ def test_matmul_backward_gpu(self): cuda.to_gpu(self.gy), atol=1e-2, rtol=1e-2) +@testing.parameterize(*testing.product_dict( + [ + # matmul + {'x1_shape': (2, 3), 'x2_shape': (2, 3), 'gy_shape': (2, 1, 1), + 'transa': True, 'transb': False}, + {'x1_shape': (2, 3), 'x2_shape': (2, 3), 'gy_shape': (2, 3, 3), + 'transa': False, 'transb': True}, + + # batched matmul + {'x1_shape': (3, 2, 5), 'x2_shape': (3, 5, 4), 'gy_shape': (3, 2, 4), + 'transa': False, 'transb': False}, + {'x1_shape': (3, 5, 2), 'x2_shape': (3, 5, 4), 'gy_shape': (3, 2, 4), + 'transa': True, 'transb': False}, + {'x1_shape': (3, 2, 5), 'x2_shape': (3, 4, 5), 'gy_shape': (3, 2, 4), + 'transa': False, 'transb': True}, + {'x1_shape': (3, 5, 2), 'x2_shape': (3, 4, 5), 'gy_shape': (3, 2, 4), + 'transa': True, 'transb': True}, + + # batched matmul 2d x 3d + {'x1_shape': (3, 5), 'x2_shape': (3, 1, 4), 'gy_shape': (3, 5, 4), + 'transa': False, 'transb': False}, + {'x1_shape': (3, 5), 'x2_shape': (3, 5, 4), 'gy_shape': (3, 1, 4), + 'transa': True, 'transb': False}, + {'x1_shape': (3, 5), 'x2_shape': (3, 4, 1), 'gy_shape': (3, 5, 4), + 'transa': False, 'transb': True}, + {'x1_shape': (3, 5), 'x2_shape': (3, 4, 5), 'gy_shape': (3, 1, 4), + 'transa': True, 'transb': True}, + + # batched matmul 3d x 2d + {'x1_shape': (3, 2, 5), 'x2_shape': (3, 5), 'gy_shape': (3, 2, 1), + 'transa': False, 'transb': False}, + {'x1_shape': (3, 5, 2), 'x2_shape': (3, 5), 'gy_shape': (3, 2, 1), + 'transa': True, 'transb': False}, + {'x1_shape': (3, 2, 1), 'x2_shape': (3, 5), 'gy_shape': (3, 2, 5), + 'transa': False, 'transb': True}, + {'x1_shape': (3, 1, 2), 'x2_shape': (3, 5), 'gy_shape': (3, 2, 5), + 'transa': True, 'transb': True}, + + # batchsize = 1 + {'x1_shape': (1, 2, 5), 'x2_shape': (1, 5, 4), 'gy_shape': (1, 2, 4), + 'transa': False, 'transb': False}, + ] +)) +class TestBatchMatMul(unittest.TestCase): + x1_dtype = numpy.float32 + x2_dtype = numpy.float32 + + def setUp(self): + self.x1 = numpy.random.uniform(.5, 1, self.x1_shape) + self.x1 = self.x1.astype(self.x1_dtype) + self.x2 = numpy.random.uniform(.5, 1, self.x2_shape) + self.x2 = self.x2.astype(self.x2_dtype) + ret_dtype = numpy.result_type(self.x1_dtype, self.x2_dtype) + self.gy = numpy.random.uniform(-1, 1, self.gy_shape).astype(ret_dtype) + + self.op = lambda x, y: F.batch_matmul( + x, y, transa=self.transa, transb=self.transb) + self.forward_answer = self._get_forward_answer( + self.x1, self.x2, self.transa, self.transb) + + def _get_forward_answer(self, x1, x2, transa, transb): + x1 = x1.reshape(x1.shape[:2] + (-1,)) + if transa and x1.ndim >= 2: + x1 = x1.swapaxes(-1, -2) + + x2 = x2.reshape(x2.shape[:2] + (-1,)) + if transb and x2.ndim >= 2: + x2 = x2.swapaxes(-1, -2) + + if x1.ndim <= 2: + return numpy.dot(x1, x2) + else: + return numpy.einsum('...ij,...jk->...ik', x1, x2) + + def check_forward(self, x1_data, x2_data, atol=1e-4, rtol=1e-5): + x1 = chainer.Variable(x1_data) + x2 = chainer.Variable(x2_data) + y = self.op(x1, x2) + testing.assert_allclose(self.forward_answer, y.data, atol, rtol) + + @condition.retry(3) + def test_matmul_forward_cpu(self): + if self.x1.dtype == numpy.float16 or self.x2.dtype == numpy.float16: + self.check_forward(self.x1, self.x2, atol=1e-3, rtol=1e-3) + else: + self.check_forward(self.x1, self.x2) + + @attr.gpu + @condition.retry(3) + def test_matmul_forward_gpu(self): + if self.x1.dtype == numpy.float16 or self.x2.dtype == numpy.float16: + self.check_forward(cuda.to_gpu(self.x1), cuda.to_gpu(self.x2), + atol=1e-3, rtol=1e-3) + else: + self.check_forward(cuda.to_gpu(self.x1), cuda.to_gpu(self.x2)) + + def check_backward(self, x1_data, x2_data, y_grad, atol, rtol): + gradient_check.check_backward( + self.op, (x1_data, x2_data), y_grad, atol=atol, rtol=rtol, + dtype=numpy.float32) + + @condition.retry(3) + def test_matmul_backward_cpu(self): + self.check_backward(self.x1, self.x2, self.gy, atol=1e-2, rtol=5e-2) + + @attr.gpu + @condition.retry(3) + def test_matmul_backward_gpu(self): + self.check_backward( + cuda.to_gpu(self.x1), cuda.to_gpu(self.x2), + cuda.to_gpu(self.gy), atol=1e-2, rtol=1e-2) + + class TestMatMulInvalid(unittest.TestCase): def test_invalid_shape(self):