/
stax.py
354 lines (299 loc) · 13.4 KB
/
stax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stax is a small but flexible neural net specification library from scratch.
For an example of its use, see examples/resnet50.py.
"""
import functools
import itertools
import operator as op
from jax import lax
from jax import random
import jax.numpy as jnp
from jax.nn import (relu, log_softmax, softmax, softplus, sigmoid, elu,
leaky_relu, selu, gelu, normalize)
from jax.nn.initializers import glorot_normal, normal, ones, zeros
# aliases for backwards compatibility
glorot = glorot_normal
randn = normal
logsoftmax = log_softmax
# Following the convention used in Keras and tf.layers, we use CamelCase for the
# names of layer constructors, like Conv and Relu, while using snake_case for
# other functions, like lax.conv and relu.
# Each layer constructor function returns an (init_fun, apply_fun) pair, where
# init_fun: takes an rng key and an input shape and returns an
# (output_shape, params) pair,
# apply_fun: takes params, inputs, and an rng key and applies the layer.
def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):
"""Layer constructor function for a dense (fully-connected) layer."""
def init_fun(rng, input_shape):
output_shape = input_shape[:-1] + (out_dim,)
k1, k2 = random.split(rng)
W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,))
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
return jnp.dot(inputs, W) + b
return init_fun, apply_fun
def GeneralConv(dimension_numbers, out_chan, filter_shape,
strides=None, padding='VALID', W_init=None,
b_init=normal(1e-6)):
"""Layer construction function for a general convolution layer."""
lhs_spec, rhs_spec, out_spec = dimension_numbers
one = (1,) * len(filter_shape)
strides = strides or one
W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O'))
def init_fun(rng, input_shape):
filter_shape_iter = iter(filter_shape)
kernel_shape = [out_chan if c == 'O' else
input_shape[lhs_spec.index('C')] if c == 'I' else
next(filter_shape_iter) for c in rhs_spec]
output_shape = lax.conv_general_shape_tuple(
input_shape, kernel_shape, strides, padding, dimension_numbers)
bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))
k1, k2 = random.split(rng)
W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
return lax.conv_general_dilated(inputs, W, strides, padding, one, one,
dimension_numbers=dimension_numbers) + b
return init_fun, apply_fun
Conv = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC'))
def GeneralConvTranspose(dimension_numbers, out_chan, filter_shape,
strides=None, padding='VALID', W_init=None,
b_init=normal(1e-6)):
"""Layer construction function for a general transposed-convolution layer."""
lhs_spec, rhs_spec, out_spec = dimension_numbers
one = (1,) * len(filter_shape)
strides = strides or one
W_init = W_init or glorot_normal(rhs_spec.index('I'), rhs_spec.index('O'))
def init_fun(rng, input_shape):
filter_shape_iter = iter(filter_shape)
kernel_shape = [out_chan if c == 'O' else
input_shape[lhs_spec.index('C')] if c == 'I' else
next(filter_shape_iter) for c in rhs_spec]
output_shape = lax.conv_transpose_shape_tuple(
input_shape, kernel_shape, strides, padding, dimension_numbers)
bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))
k1, k2 = random.split(rng)
W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
return lax.conv_transpose(inputs, W, strides, padding,
dimension_numbers=dimension_numbers) + b
return init_fun, apply_fun
Conv1DTranspose = functools.partial(GeneralConvTranspose, ('NHC', 'HIO', 'NHC'))
ConvTranspose = functools.partial(GeneralConvTranspose,
('NHWC', 'HWIO', 'NHWC'))
def BatchNorm(axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True,
beta_init=zeros, gamma_init=ones):
"""Layer construction function for a batch normalization layer."""
_beta_init = lambda rng, shape: beta_init(rng, shape) if center else ()
_gamma_init = lambda rng, shape: gamma_init(rng, shape) if scale else ()
axis = (axis,) if jnp.isscalar(axis) else axis
def init_fun(rng, input_shape):
shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
k1, k2 = random.split(rng)
beta, gamma = _beta_init(k1, shape), _gamma_init(k2, shape)
return input_shape, (beta, gamma)
def apply_fun(params, x, **kwargs):
beta, gamma = params
# TODO(phawkins): jnp.expand_dims should accept an axis tuple.
# (https://github.com/numpy/numpy/issues/12290)
ed = tuple(None if i in axis else slice(None) for i in range(jnp.ndim(x)))
z = normalize(x, axis, epsilon=epsilon)
if center and scale: return gamma[ed] * z + beta[ed]
if center: return z + beta[ed]
if scale: return gamma[ed] * z
return z
return init_fun, apply_fun
def elementwise(fun, **fun_kwargs):
"""Layer that applies a scalar function elementwise on its inputs."""
init_fun = lambda rng, input_shape: (input_shape, ())
apply_fun = lambda params, inputs, **kwargs: fun(inputs, **fun_kwargs)
return init_fun, apply_fun
Tanh = elementwise(jnp.tanh)
Relu = elementwise(relu)
Exp = elementwise(jnp.exp)
LogSoftmax = elementwise(log_softmax, axis=-1)
Softmax = elementwise(softmax, axis=-1)
Softplus = elementwise(softplus)
Sigmoid = elementwise(sigmoid)
Elu = elementwise(elu)
LeakyRelu = elementwise(leaky_relu)
Selu = elementwise(selu)
Gelu = elementwise(gelu)
def _pooling_layer(reducer, init_val, rescaler=None):
def PoolingLayer(window_shape, strides=None, padding='VALID', spec=None):
"""Layer construction function for a pooling layer."""
strides = strides or (1,) * len(window_shape)
rescale = rescaler(window_shape, strides, padding) if rescaler else None
if spec is None:
non_spatial_axes = 0, len(window_shape) + 1
else:
non_spatial_axes = spec.index('N'), spec.index('C')
for i in sorted(non_spatial_axes):
window_shape = window_shape[:i] + (1,) + window_shape[i:]
strides = strides[:i] + (1,) + strides[i:]
def init_fun(rng, input_shape):
padding_vals = lax.padtype_to_pads(input_shape, window_shape,
strides, padding)
ones = (1,) * len(window_shape)
out_shape = lax.reduce_window_shape_tuple(
input_shape, window_shape, strides, padding_vals, ones, ones)
return out_shape, ()
def apply_fun(params, inputs, **kwargs):
out = lax.reduce_window(inputs, init_val, reducer, window_shape,
strides, padding)
return rescale(out, inputs, spec) if rescale else out
return init_fun, apply_fun
return PoolingLayer
MaxPool = _pooling_layer(lax.max, -jnp.inf)
SumPool = _pooling_layer(lax.add, 0.)
def _normalize_by_window_size(dims, strides, padding):
def rescale(outputs, inputs, spec):
if spec is None:
non_spatial_axes = 0, inputs.ndim - 1
else:
non_spatial_axes = spec.index('N'), spec.index('C')
spatial_shape = tuple(inputs.shape[i]
for i in range(inputs.ndim)
if i not in non_spatial_axes)
one = jnp.ones(spatial_shape, dtype=inputs.dtype)
window_sizes = lax.reduce_window(one, 0., lax.add, dims, strides, padding)
for i in sorted(non_spatial_axes):
window_sizes = jnp.expand_dims(window_sizes, i)
return outputs / window_sizes
return rescale
AvgPool = _pooling_layer(lax.add, 0., _normalize_by_window_size)
def Flatten():
"""Layer construction function for flattening all but the leading dim."""
def init_fun(rng, input_shape):
output_shape = input_shape[0], functools.reduce(op.mul, input_shape[1:], 1)
return output_shape, ()
def apply_fun(params, inputs, **kwargs):
return jnp.reshape(inputs, (inputs.shape[0], -1))
return init_fun, apply_fun
Flatten = Flatten()
def Identity():
"""Layer construction function for an identity layer."""
init_fun = lambda rng, input_shape: (input_shape, ())
apply_fun = lambda params, inputs, **kwargs: inputs
return init_fun, apply_fun
Identity = Identity()
def FanOut(num):
"""Layer construction function for a fan-out layer."""
init_fun = lambda rng, input_shape: ([input_shape] * num, ())
apply_fun = lambda params, inputs, **kwargs: [inputs] * num
return init_fun, apply_fun
def FanInSum():
"""Layer construction function for a fan-in sum layer."""
init_fun = lambda rng, input_shape: (input_shape[0], ())
apply_fun = lambda params, inputs, **kwargs: sum(inputs)
return init_fun, apply_fun
FanInSum = FanInSum()
def FanInConcat(axis=-1):
"""Layer construction function for a fan-in concatenation layer."""
def init_fun(rng, input_shape):
ax = axis % len(input_shape[0])
concat_size = sum(shape[ax] for shape in input_shape)
out_shape = input_shape[0][:ax] + (concat_size,) + input_shape[0][ax+1:]
return out_shape, ()
def apply_fun(params, inputs, **kwargs):
return jnp.concatenate(inputs, axis)
return init_fun, apply_fun
def Dropout(rate, mode='train'):
"""Layer construction function for a dropout layer with given rate."""
def init_fun(rng, input_shape):
return input_shape, ()
def apply_fun(params, inputs, **kwargs):
rng = kwargs.get('rng', None)
if rng is None:
msg = ("Dropout layer requires apply_fun to be called with a PRNG key "
"argument. That is, instead of `apply_fun(params, inputs)`, call "
"it like `apply_fun(params, inputs, rng)` where `rng` is a "
"jax.random.PRNGKey value.")
raise ValueError(msg)
if mode == 'train':
keep = random.bernoulli(rng, rate, inputs.shape)
return jnp.where(keep, inputs / rate, 0)
else:
return inputs
return init_fun, apply_fun
# Composing layers via combinators
def serial(*layers):
"""Combinator for composing layers in serial.
Args:
*layers: a sequence of layers, each an (init_fun, apply_fun) pair.
Returns:
A new layer, meaning an (init_fun, apply_fun) pair, representing the serial
composition of the given sequence of layers.
"""
nlayers = len(layers)
init_funs, apply_funs = zip(*layers)
def init_fun(rng, input_shape):
params = []
for init_fun in init_funs:
rng, layer_rng = random.split(rng)
input_shape, param = init_fun(layer_rng, input_shape)
params.append(param)
return input_shape, params
def apply_fun(params, inputs, **kwargs):
rng = kwargs.pop('rng', None)
rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers
for fun, param, rng in zip(apply_funs, params, rngs):
inputs = fun(param, inputs, rng=rng, **kwargs)
return inputs
return init_fun, apply_fun
def parallel(*layers):
"""Combinator for composing layers in parallel.
The layer resulting from this combinator is often used with the FanOut and
FanInSum layers.
Args:
*layers: a sequence of layers, each an (init_fun, apply_fun) pair.
Returns:
A new layer, meaning an (init_fun, apply_fun) pair, representing the
parallel composition of the given sequence of layers. In particular, the
returned layer takes a sequence of inputs and returns a sequence of outputs
with the same length as the argument `layers`.
"""
nlayers = len(layers)
init_funs, apply_funs = zip(*layers)
def init_fun(rng, input_shape):
rngs = random.split(rng, nlayers)
return zip(*[init(rng, shape) for init, rng, shape
in zip(init_funs, rngs, input_shape)])
def apply_fun(params, inputs, **kwargs):
rng = kwargs.pop('rng', None)
rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers
return [f(p, x, rng=r, **kwargs) for f, p, x, r in zip(apply_funs, params, inputs, rngs)]
return init_fun, apply_fun
def shape_dependent(make_layer):
"""Combinator to delay layer constructor pair until input shapes are known.
Args:
make_layer: a one-argument function that takes an input shape as an argument
(a tuple of positive integers) and returns an (init_fun, apply_fun) pair.
Returns:
A new layer, meaning an (init_fun, apply_fun) pair, representing the same
layer as returned by `make_layer` but with its construction delayed until
input shapes are known.
"""
def init_fun(rng, input_shape):
return make_layer(input_shape)[0](rng, input_shape)
def apply_fun(params, inputs, **kwargs):
return make_layer(inputs.shape)[1](params, inputs, **kwargs)
return init_fun, apply_fun