/
mgu.py
71 lines (54 loc) · 1.88 KB
/
mgu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import numpy
import chainer
from chainer.functions.activation import sigmoid
from chainer.functions.activation import tanh
from chainer.functions.array import concat
from chainer.functions.math import linear_interpolate
from chainer import link
from chainer.links.connection import linear
class MGUBase(link.Chain):
def __init__(self, n_inputs, n_units):
super(MGUBase, self).__init__()
with self.init_scope():
self.W_f = linear.Linear(n_inputs + n_units, n_units)
self.W_h = linear.Linear(n_inputs + n_units, n_units)
def _call_mgu(self, h, x):
f = sigmoid.sigmoid(self.W_f(concat.concat([h, x])))
h_bar = tanh.tanh(self.W_h(concat.concat([f * h, x])))
h_new = linear_interpolate.linear_interpolate(f, h_bar, h)
return h_new
class StatelessMGU(MGUBase):
__call__ = MGUBase._call_mgu
class StatefulMGU(MGUBase):
def __init__(self, in_size, out_size):
super(StatefulMGU, self).__init__(in_size, out_size)
self._state_size = out_size
self.reset_state()
def to_cpu(self):
super(StatefulMGU, self).to_cpu()
if self.h is not None:
self.h.to_cpu()
def to_gpu(self, device=None):
super(StatefulMGU, self).to_gpu(device)
if self.h is not None:
self.h.to_gpu(device)
def set_state(self, h):
assert isinstance(h, chainer.Variable)
h_ = h
if self.xp is numpy:
h_.to_cpu()
else:
h_.to_gpu()
self.h = h_
def reset_state(self):
self.h = None
def __call__(self, x):
if self.h is None:
n_batch = x.shape[0]
h_data = self.xp.zeros(
(n_batch, self._state_size), dtype=numpy.float32)
h = chainer.Variable(h_data)
else:
h = self.h
self.h = self._call_mgu(h, x)
return self.h