Skip to content

Commit

Permalink
Merge c58507b into 391e2c0
Browse files Browse the repository at this point in the history
  • Loading branch information
okuta committed Aug 15, 2017
2 parents 391e2c0 + c58507b commit 999b28e
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 2 deletions.
3 changes: 1 addition & 2 deletions chainer/functions/__init__.py
Expand Up @@ -206,8 +206,7 @@
from chainer.functions.math.logarithm_1p import log1p # NOQA
from chainer.functions.math.logsumexp import logsumexp # NOQA
from chainer.functions.math.logsumexp import LogSumExp # NOQA
from chainer.functions.math.matmul import matmul # NOQA
from chainer.functions.math.matmul import MatMul # NOQA
from chainer.functions.math.matmul import batch_matmul # NOQA
from chainer.functions.math.matmul import matmul # NOQA
from chainer.functions.math.matmul import MatMul # NOQA
from chainer.functions.math.maximum import maximum # NOQA
Expand Down
81 changes: 81 additions & 0 deletions chainer/functions/math/matmul.py
@@ -1,3 +1,5 @@
import warnings

import numpy

from chainer import cuda
Expand Down Expand Up @@ -154,3 +156,82 @@ def matmul(a, b, transa=False, transb=False):
"""
return MatMul(transa=transa, transb=transb)(a, b)


def _get_size(typ, index):
if index == 2 and type_check.eval(typ.ndim) == 2:
return 1
else:
return typ.shape[index]


def _batch_matmul(a, b, transa, transb, transout):
a = a.reshape(a.shape[:2] + (-1,))
b = b.reshape(b.shape[:2] + (-1,))
return _matmul(a, b, transa, transb, transout)


class BatchMatMul(function.Function):

def __init__(self, transa=False, transb=False):
self.transa = transa
self.transb = transb

def check_type_forward(self, in_types):
type_check.expect(in_types.size() == 2)
a_type, b_type = in_types

type_check.expect(
a_type.dtype == numpy.float32,
b_type.dtype == numpy.float32
)

_check_ndim(a_type, lower=2, upper=3)
_check_ndim(b_type, lower=2, upper=3)

a_idx = _get_check_index(self.transa, False, row_idx=1, col_idx=2)
b_idx = _get_check_index(self.transb, True, row_idx=1, col_idx=2)
a_size = _get_size(a_type, a_idx)
b_size = _get_size(b_type, b_idx)
type_check.expect(
a_size == b_size
)

def forward(self, x):
a, b = x
return _batch_matmul(a, b, self.transa, self.transb, False),

def backward(self, x, gy):
a, b = x
ga = _batch_matmul(gy[0], b, False, not self.transb,
self.transa).reshape(a.shape)
gb = _batch_matmul(a, gy[0], not self.transa, False,
self.transb).reshape(b.shape)
return ga, gb


def batch_matmul(a, b, transa=False, transb=False):
"""Computes the batch matrix multiplications of two sets of arrays.
Args:
a (Variable): The left operand of the batch matrix multiplications.
A 2-D array of shape ``(B, N)`` is considered as B
:math:`N \\times 1` matrices.
A 3-D array of shape ``(B, M, N)`` is considered as B
:math:`M \\times N` matrices.
b (Variable): The right operand of the batch matrix multiplications.
Its array is treated as matrices in the same way as ``a``'s array.
transa (bool): If ``True``, transpose each matrix in ``a``.
transb (bool): If ``True``, transpose each matrix in ``b``.
Returns:
~chainer.Variable: The result of the batch matrix multiplications as a
3-D array.
.. deprecated:: v3.0.0
batch_matmul is deprecated. Use ``matmul`` instead.
"""
warnings.warn('batch_matmul is deprecated. Use matmul instead.',
DeprecationWarning)
return BatchMatMul(transa=transa, transb=transb)(a, b)
1 change: 1 addition & 0 deletions docs/source/reference/functions.rst
Expand Up @@ -169,6 +169,7 @@ Mathematical functions
chainer.functions.average
chainer.functions.batch_inv
chainer.functions.batch_l2_norm_squared
chainer.functions.batch_matmul
chainer.functions.bias
chainer.functions.ceil
chainer.functions.clip
Expand Down
113 changes: 113 additions & 0 deletions tests/chainer_tests/functions_tests/math_tests/test_matmul.py
Expand Up @@ -126,6 +126,119 @@ def test_matmul_backward_gpu(self):
cuda.to_gpu(self.gy), atol=1e-2, rtol=1e-2)


@testing.parameterize(*testing.product_dict(
[
# matmul
{'x1_shape': (2, 3), 'x2_shape': (2, 3), 'gy_shape': (2, 1, 1),
'transa': True, 'transb': False},
{'x1_shape': (2, 3), 'x2_shape': (2, 3), 'gy_shape': (2, 3, 3),
'transa': False, 'transb': True},
# batched matmul
{'x1_shape': (3, 2, 5), 'x2_shape': (3, 5, 4), 'gy_shape': (3, 2, 4),
'transa': False, 'transb': False},
{'x1_shape': (3, 5, 2), 'x2_shape': (3, 5, 4), 'gy_shape': (3, 2, 4),
'transa': True, 'transb': False},
{'x1_shape': (3, 2, 5), 'x2_shape': (3, 4, 5), 'gy_shape': (3, 2, 4),
'transa': False, 'transb': True},
{'x1_shape': (3, 5, 2), 'x2_shape': (3, 4, 5), 'gy_shape': (3, 2, 4),
'transa': True, 'transb': True},
# batched matmul 2d x 3d
{'x1_shape': (3, 5), 'x2_shape': (3, 1, 4), 'gy_shape': (3, 5, 4),
'transa': False, 'transb': False},
{'x1_shape': (3, 5), 'x2_shape': (3, 5, 4), 'gy_shape': (3, 1, 4),
'transa': True, 'transb': False},
{'x1_shape': (3, 5), 'x2_shape': (3, 4, 1), 'gy_shape': (3, 5, 4),
'transa': False, 'transb': True},
{'x1_shape': (3, 5), 'x2_shape': (3, 4, 5), 'gy_shape': (3, 1, 4),
'transa': True, 'transb': True},
# batched matmul 3d x 2d
{'x1_shape': (3, 2, 5), 'x2_shape': (3, 5), 'gy_shape': (3, 2, 1),
'transa': False, 'transb': False},
{'x1_shape': (3, 5, 2), 'x2_shape': (3, 5), 'gy_shape': (3, 2, 1),
'transa': True, 'transb': False},
{'x1_shape': (3, 2, 1), 'x2_shape': (3, 5), 'gy_shape': (3, 2, 5),
'transa': False, 'transb': True},
{'x1_shape': (3, 1, 2), 'x2_shape': (3, 5), 'gy_shape': (3, 2, 5),
'transa': True, 'transb': True},
# batchsize = 1
{'x1_shape': (1, 2, 5), 'x2_shape': (1, 5, 4), 'gy_shape': (1, 2, 4),
'transa': False, 'transb': False},
]
))
class TestBatchMatMul(unittest.TestCase):
x1_dtype = numpy.float32
x2_dtype = numpy.float32

def setUp(self):
self.x1 = numpy.random.uniform(.5, 1, self.x1_shape)
self.x1 = self.x1.astype(self.x1_dtype)
self.x2 = numpy.random.uniform(.5, 1, self.x2_shape)
self.x2 = self.x2.astype(self.x2_dtype)
ret_dtype = numpy.result_type(self.x1_dtype, self.x2_dtype)
self.gy = numpy.random.uniform(-1, 1, self.gy_shape).astype(ret_dtype)

self.op = lambda x, y: F.batch_matmul(
x, y, transa=self.transa, transb=self.transb)
self.forward_answer = self._get_forward_answer(
self.x1, self.x2, self.transa, self.transb)

def _get_forward_answer(self, x1, x2, transa, transb):
x1 = x1.reshape(x1.shape[:2] + (-1,))
if transa and x1.ndim >= 2:
x1 = x1.swapaxes(-1, -2)

x2 = x2.reshape(x2.shape[:2] + (-1,))
if transb and x2.ndim >= 2:
x2 = x2.swapaxes(-1, -2)

if x1.ndim <= 2:
return numpy.dot(x1, x2)
else:
return numpy.einsum('...ij,...jk->...ik', x1, x2)

def check_forward(self, x1_data, x2_data, atol=1e-4, rtol=1e-5):
x1 = chainer.Variable(x1_data)
x2 = chainer.Variable(x2_data)
y = self.op(x1, x2)
testing.assert_allclose(self.forward_answer, y.data, atol, rtol)

@condition.retry(3)
def test_matmul_forward_cpu(self):
if self.x1.dtype == numpy.float16 or self.x2.dtype == numpy.float16:
self.check_forward(self.x1, self.x2, atol=1e-3, rtol=1e-3)
else:
self.check_forward(self.x1, self.x2)

@attr.gpu
@condition.retry(3)
def test_matmul_forward_gpu(self):
if self.x1.dtype == numpy.float16 or self.x2.dtype == numpy.float16:
self.check_forward(cuda.to_gpu(self.x1), cuda.to_gpu(self.x2),
atol=1e-3, rtol=1e-3)
else:
self.check_forward(cuda.to_gpu(self.x1), cuda.to_gpu(self.x2))

def check_backward(self, x1_data, x2_data, y_grad, atol, rtol):
gradient_check.check_backward(
self.op, (x1_data, x2_data), y_grad, atol=atol, rtol=rtol,
dtype=numpy.float32)

@condition.retry(3)
def test_matmul_backward_cpu(self):
self.check_backward(self.x1, self.x2, self.gy, atol=1e-2, rtol=5e-2)

@attr.gpu
@condition.retry(3)
def test_matmul_backward_gpu(self):
self.check_backward(
cuda.to_gpu(self.x1), cuda.to_gpu(self.x2),
cuda.to_gpu(self.gy), atol=1e-2, rtol=1e-2)


class TestMatMulInvalid(unittest.TestCase):

def test_invalid_shape(self):
Expand Down

0 comments on commit 999b28e

Please sign in to comment.