/
zoneoutlstm.py
119 lines (97 loc) · 3.93 KB
/
zoneoutlstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import numpy
from chainer.backends import cuda
from chainer.functions.activation import sigmoid
from chainer.functions.activation import tanh
from chainer.functions.array import reshape
from chainer.functions.array import split_axis
from chainer.functions.noise import zoneout
from chainer import link
from chainer.links.connection import linear
from chainer.utils import argument
from chainer import variable
class StatefulZoneoutLSTM(link.Chain):
def __init__(self, in_size, out_size, c_ratio=0.5, h_ratio=0.5, **kwargs):
argument.check_unexpected_kwargs(
kwargs, train='train argument is not supported anymore. '
'Use chainer.using_config')
argument.assert_kwargs_empty(kwargs)
super(StatefulZoneoutLSTM, self).__init__()
self.state_size = out_size
self.c_ratio = c_ratio
self.h_ratio = h_ratio
self.reset_state()
with self.init_scope():
self.upward = linear.Linear(in_size, 4 * out_size)
self.lateral = linear.Linear(out_size, 4 * out_size, nobias=True)
def to_cpu(self):
super(StatefulZoneoutLSTM, self).to_cpu()
if self.c is not None:
self.c.to_cpu()
if self.h is not None:
self.h.to_cpu()
def to_gpu(self, device=None):
super(StatefulZoneoutLSTM, self).to_gpu(device)
if self.c is not None:
self.c.to_gpu(device)
if self.h is not None:
self.h.to_gpu(device)
def set_state(self, c, h):
"""Sets the internal state.
It sets the :attr:`c` and :attr:`h` attributes.
Args:
c (~chainer.Variable): A new cell states of LSTM units.
h (~chainer.Variable): A new output at the previous time step.
"""
assert isinstance(c, variable.Variable)
assert isinstance(h, variable.Variable)
c_ = c
h_ = h
if self.xp is numpy:
c_.to_cpu()
h_.to_cpu()
else:
c_.to_gpu(self._device_id)
h_.to_gpu(self._device_id)
self.c = c_
self.h = h_
def reset_state(self):
"""Resets the internal state.
It sets ``None`` to the :attr:`c` and :attr:`h` attributes.
"""
self.c = self.h = None
def __call__(self, x):
"""Updates the internal state and returns the LSTM outputs.
Args:
x (~chainer.Variable): A new batch from the input sequence.
Returns:
~chainer.Variable: Outputs of updated LSTM units.
"""
lstm_in = self.upward(x)
if self.h is not None:
lstm_in += self.lateral(self.h)
else:
xp = self.xp
with cuda.get_device_from_id(self._device_id):
self.h = variable.Variable(
xp.zeros((len(x.data), self.state_size),
dtype=x.data.dtype))
if self.c is None:
xp = self.xp
with cuda.get_device_from_id(self._device_id):
self.c = variable.Variable(
xp.zeros((len(x.data), self.state_size),
dtype=x.data.dtype))
lstm_in = reshape.reshape(lstm_in, (len(lstm_in.data),
lstm_in.data.shape[1] // 4,
4))
a, i, f, o = split_axis.split_axis(lstm_in, 4, 2)
a = reshape.reshape(a, (len(a.data), self.state_size))
i = reshape.reshape(i, (len(i.data), self.state_size))
f = reshape.reshape(f, (len(f.data), self.state_size))
o = reshape.reshape(o, (len(o.data), self.state_size))
c_tmp = tanh.tanh(a) * sigmoid.sigmoid(i) + sigmoid.sigmoid(f) * self.c
self.c = zoneout.zoneout(self.c, c_tmp, self.c_ratio)
self.h = zoneout.zoneout(self.h,
sigmoid.sigmoid(o) * tanh.tanh(c_tmp),
self.h_ratio)
return self.h