-
Notifications
You must be signed in to change notification settings - Fork 0
/
recurrent.py
352 lines (268 loc) · 9.3 KB
/
recurrent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
r"""Recurrent layers"""
__all__ = [
'Cell',
'Recurrent',
'BRCell',
'MGUCell',
'GRUCell',
'LSTMCell',
]
import jax
import jax.numpy as jnp
from jax import Array
from jax.random import KeyArray
from typing import *
from .linear import Linear
from .module import Module
class Cell(Module):
r"""Abstract cell class.
A cell defines a recurrence function :math:`f` of the form
.. math:: (h_i, y_i) = f(h_{i-1}, x_i)
and an initial hidden state :math:`h_0`.
Warning:
The recurrence function :math:`f` should have no side effects.
"""
def __call__(self, h: Any, x: Any) -> Tuple[Any, Any]:
r"""
Arguments:
h: The previous hidden state :math:`h_{i-1}`.
x: The input :math:`x_i`.
Returns:
The hidden state and output :math:`(h_i, y_i)`.
"""
raise NotImplementedError()
def init(self) -> Any:
r"""
Returns:
The initial hidden state :math:`h_0`.
"""
raise NotImplementedError()
class Recurrent(Module):
r"""Creates a recurrent layer.
Arguments:
cell: A recurrent cell.
reverse: Whether to apply the recurrence in reverse or not.
"""
def __init__(
self,
cell: Cell,
reverse: bool = False,
):
self.cell = cell
self.reverse = reverse
def __call__(self, xs: Any) -> Any:
r"""
Arguments:
xs: A sequence of inputs :math:`x_i`, stacked on the leading axis.
When inputs are vectors, :py:`xs` has shape :math:`(L, C)`.
Returns:
A sequence of outputs :math:`y_i`, stacked on the leading axis. When outputs
are vectors, :py:`ys` has shape :math:`(L, C')`.
"""
_, ys = jax.lax.scan(
f=self.cell,
init=self.cell.init(),
xs=xs,
reverse=self.reverse,
)
return ys
class GRUCell(Cell):
r"""Creates a gated recurrent unit (GRU) cell.
References:
| Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation (Cho et al., 2014)
| https://arxiv.org/abs/1406.1078
Arguments:
key: A PRNG key for initialization.
in_features: The number of input features :math:`C`.
hid_features: The number of hidden features :math:`H`.
bias: Whether the cell learns additive biases or not.
"""
def __init__(
self,
key: KeyArray,
in_features: int,
hid_features: int,
bias: bool = True,
):
keys = jax.random.split(key, 2)
self.lin_h = Linear(keys[0], hid_features, 3 * hid_features, bias)
self.lin_x = Linear(keys[1], in_features, 3 * hid_features, bias)
self.in_features = in_features
self.hid_features = hid_features
@jax.jit
def __call__(self, h: Array, x: Array) -> Tuple[Array, Array]:
r"""
Arguments:
h: The previous hidden state :math:`h_{i-1}`, with shape :math:`(*, H)`.
x: The input vector :math:`x_i`, with shape :math:`(*, C)`.
Returns:
The hidden state :math:`(h_i, h_i)`.
"""
rh, zh, gh = jnp.split(self.lin_h(h), 3, axis=-1)
rx, zx, gx = jnp.split(self.lin_x(x), 3, axis=-1)
r = jax.nn.sigmoid(rx + rh)
z = jax.nn.sigmoid(zx + zh)
g = jax.nn.tanh(gx + r * gh)
h = (1 - z) * g + z * h
return h, h
def init(self) -> Array:
r"""
Returns:
The initial hidden state :math:`h_0 = 0`, with shape :math:`(H)`.
"""
return jnp.zeros(self.hid_features)
class BRCell(Module):
r"""Creates a bistable recurrent cell (BRC).
References:
| A bio-inspired bistable recurrent cell allows for long-lasting memory (Vecoven et al., 2021)
| https://arxiv.org/abs/2006.05252
Arguments:
key: A PRNG key for initialization.
in_features: The number of input features :math:`C`.
hid_features: The number of hidden features :math:`H`.
bias: Whether the cell learns additive biases or not.
modulated: Whether to use neuromodulation or not.
"""
def __init__(
self,
key: KeyArray,
in_features: int,
hid_features: int,
bias: bool = True,
modulated: bool = True,
):
keys = jax.random.split(key, 3)
self.modulated = modulated
self.lin_x = Linear(keys[0], in_features, 3 * hid_features, bias)
if self.modulated:
self.lin_h = Linear(keys[1], hid_features, 2 * hid_features, bias)
else:
self.wa = jax.random.normal(keys[1], (hid_features,))
self.wc = jax.random.normal(keys[2], (hid_features,))
self.in_features = in_features
self.hid_features = hid_features
@jax.jit
def __call__(self, h: Array, x: Array) -> Tuple[Array, Array]:
r"""
Arguments:
h: The previous hidden state :math:`h_{i-1}`, with shape :math:`(*, H)`.
x: The input vector :math:`x_i`, with shape :math:`(*, C)`.
Returns:
The hidden state :math:`(h_i, h_i)`.
"""
if self.modulated:
ah, ch = jnp.split(self.lin_h(h), 2, axis=-1)
else:
ah = self.wa * h
ch = self.wc * h
ax, cx, gx = jnp.split(self.lin_x(x), 3, axis=-1)
a = 1.0 + jax.nn.tanh(ax + ah)
c = jax.nn.sigmoid(cx + ch)
g = jax.nn.tanh(gx + a * h)
h = (1 - c) * g + c * h
return h, h
def init(self) -> Array:
r"""
Returns:
The initial hidden state :math:`h_0 = 0`, with shape :math:`(H)`.
"""
return jnp.zeros(self.hid_features)
class MGUCell(Module):
r"""Creates a minimal gated unit (MGU) cell.
References:
| Minimal Gated Unit for Recurrent Neural Networks (Zhou et al., 2016)
| https://arxiv.org/pdf/1603.09420
Arguments:
key: A PRNG key for initialization.
in_features: The number of input features :math:`C`.
hid_features: The number of hidden features :math:`H`.
bias: Whether the cell learns additive biases or not.
"""
def __init__(
self,
key: KeyArray,
in_features: int,
hid_features: int,
bias: bool = True,
):
keys = jax.random.split(key, 2)
self.lin_fh = Linear(keys[0], hid_features, 1 * hid_features, bias)
self.lin_hh = Linear(keys[0], hid_features, 1 * hid_features, bias)
self.lin_x = Linear(keys[1], in_features, 2 * hid_features, bias)
self.in_features = in_features
self.hid_features = hid_features
@jax.jit
def __call__(self, h: Array, x: Array) -> Tuple[Array, Array]:
r"""
Arguments:
h: The previous hidden state :math:`h_{i-1}`, with shape :math:`(*, H)`.
x: The input vector :math:`x_i`, with shape :math:`(*, C)`.
Returns:
The hidden state :math:`(h_i, h_i)`.
"""
fh = self.lin_fh(h)
fx, gx = jnp.split(self.lin_x(x), 2, axis=-1)
f = jax.nn.sigmoid(fx + fh)
gh = self.lin_hh(f * h)
g = jax.nn.tanh(gx + gh)
h = (1 - f) * g + f * h
return h, h
def init(self) -> Array:
r"""
Returns:
The initial hidden state :math:`h_0 = 0`, with shape :math:`(H)`.
"""
return jnp.zeros(self.hid_features)
class LSTMCell(Cell):
r"""Creates a long short-term memory (LSTM) cell.
References:
| Long Short-Term Memory (Hochreiter et al., 1997)
| https://ieeexplore.ieee.org/abstract/document/6795963
Arguments:
key: A PRNG key for initialization.
in_features: The number of input features :math:`C`.
hid_features: The number of hidden features :math:`H`.
bias: Whether the cell learns additive biases or not.
"""
def __init__(
self,
key: KeyArray,
in_features: int,
hid_features: int,
bias: bool = True,
):
keys = jax.random.split(key, 2)
self.lin_h = Linear(keys[0], hid_features, 4 * hid_features, bias)
self.lin_x = Linear(keys[1], in_features, 4 * hid_features, bias)
self.in_features = in_features
self.hid_features = hid_features
@jax.jit
def __call__(
self,
hc: Tuple[Array, Array],
x: Array,
) -> Tuple[Tuple[Array, Array], Array]:
r"""
Arguments:
hc: The previous hidden and cell states :math:`(h_{i-1}, c_{i-1})`,
each with shape :math:`(*, H)`.
x: The input vector :math:`x_i`, with shape :math:`(*, C)`.
Returns:
The hidden and cell states :math:`((h_i, c_i), h_i)`.
"""
h, c = hc
i, f, g, o = jnp.split(self.lin_h(h) + self.lin_x(x), 4, axis=-1)
i = jax.nn.sigmoid(i)
f = jax.nn.sigmoid(f)
g = jax.nn.tanh(g)
o = jax.nn.sigmoid(o)
c = f * c + i * g
h = o * jax.nn.tanh(c)
return (h, c), h
def init(self) -> Tuple[Array, Array]:
r"""
Returns:
The initial hidden and cell states :math:`h_0 = c_0 = 0`,
each with shape :math:`(H)`.
"""
return jnp.zeros(self.hid_features), jnp.zeros(self.hid_features)