Skip to content

Commit

Permalink
Fix name and test
Browse files Browse the repository at this point in the history
  • Loading branch information
unnonouno committed Oct 18, 2016
1 parent 2d1126c commit d8afb15
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 14 deletions.
2 changes: 1 addition & 1 deletion chainer/links/__init__.py
Expand Up @@ -49,8 +49,8 @@
Linear = linear.Linear
LSTM = lstm.LSTM
StatelessLSTM = lstm.StatelessLSTM
MGU = mgu.MGU
StatefulMGU = mgu.StatefulMGU
StatelessMGU = mgu.StatelessMGU
MLPConvolution2D = mlp_convolution_2d.MLPConvolution2D
NStepLSTM = n_step_lstm.NStepLSTM
Parameter = parameter.Parameter
Expand Down
21 changes: 13 additions & 8 deletions chainer/links/connection/mgu.py
Expand Up @@ -9,26 +9,31 @@
from chainer.links.connection import linear


class MGU(link.Chain):
class MGUBase(link.Chain):

def __init__(self, n_inputs, n_units):
super(MGU, self).__init__(
super(MGUBase, self).__init__(
W_f=linear.Linear(n_inputs + n_units, n_units),
W_h=linear.Linear(n_inputs + n_units, n_units)
)

def __call__(self, h, x):
def _call_mgu(self, h, x):
f = sigmoid.sigmoid(self.W_f(concat.concat([h, x])))
h_bar = tanh.tanh(self.W_h(concat.concat([f * h, x])))
h_new = linear_interpolate.linear_interpolate(f, h_bar, h)
return h_new


class StatefulMGU(MGU):
class StatelessMGU(MGUBase):

__call__ = MGUBase._call_mgu


class StatefulMGU(MGUBase):

def __init__(self, in_size, out_size):
super(StatefulMGU, self).__init__(in_size, out_size)
self.state_size = out_size
self._state_size = out_size
self.reset_state()

def to_cpu(self):
Expand All @@ -55,12 +60,12 @@ def reset_state(self):

def __call__(self, x):
if self.h is None:
n_batch = len(x.data)
n_batch = x.shape[0]
h_data = self.xp.zeros(
(n_batch, self.state_size), dtype=numpy.float32)
(n_batch, self._state_size), dtype=numpy.float32)
h = chainer.Variable(h_data)
else:
h = self.h

self.h = MGU.__call__(self, h, x)
self.h = self._call_mgu(h, x)
return self.h
38 changes: 33 additions & 5 deletions tests/chainer_tests/links_tests/connection_tests/test_mgu.py
@@ -1,6 +1,7 @@
import unittest

import numpy
import six

import chainer
from chainer import cuda
Expand All @@ -9,7 +10,19 @@
from chainer.testing import attr


class TestMGU(unittest.TestCase):
def sigmoid(x):
return 1 / (1 + numpy.exp(-x))


def mgu(W_f, W_h, h, x):
f = sigmoid(numpy.concatenate([h, x]).dot(W_f.T))
hx = numpy.concatenate([f * h, x])
h_bar = numpy.tanh(hx.dot(W_h.T))
h_new = f * h_bar + (1 - f) * h
return h_new


class TestStatelessMGU(unittest.TestCase):

in_size = 4
out_size = 5
Expand All @@ -22,12 +35,18 @@ def setUp(self):
self.gy = numpy.random.uniform(
-1, 1, (3, self.out_size)).astype(numpy.float32)

self.mgu = links.MGU(self.in_size, self.out_size)
self.mgu = links.StatelessMGU(self.in_size, self.out_size)

def check_forward(self, h_data, x_data):
h = chainer.Variable(h_data)
x = chainer.Variable(x_data)
self.mgu(h, x)
y = self.mgu(h, x)

W_f = cuda.to_cpu(self.mgu.W_f.W.data)
W_h = cuda.to_cpu(self.mgu.W_h.W.data)
for i in six.moves.range(3):
h_new = mgu(W_f, W_h, self.h[i], self.x[i])
testing.assert_allclose(h_new, y.data[i])

def test_forward_cpu(self):
self.check_forward(self.h, self.x)
Expand All @@ -53,8 +72,17 @@ def setUp(self):

def check_forward(self, x_data):
x = chainer.Variable(x_data)
self.mgu(x)
self.mgu(x)
W_f = cuda.to_cpu(self.mgu.W_f.W.data)
W_h = cuda.to_cpu(self.mgu.W_h.W.data)
y1 = self.mgu(x)
y2 = self.mgu(x)

h = numpy.zeros(self.out_size, dtype='f')
for i in six.moves.range(3):
h1 = mgu(W_f, W_h, h, self.x[i])
testing.assert_allclose(h1, y1.data[i])
h2 = mgu(W_f, W_h, h1, self.x[i])
testing.assert_allclose(h2, y2.data[i])

def test_forward_cpu(self):
self.check_forward(self.x)
Expand Down

0 comments on commit d8afb15

Please sign in to comment.