/
functions.py
403 lines (308 loc) · 11 KB
/
functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared neural network activations and other functions."""
import operator
import numpy as np
from typing import Any, Optional, Tuple, Union
from jax import custom_jvp
from jax._src import dtypes
from jax import lax
from jax import core
from jax.core import AxisName
from .. import util
from jax.scipy.special import expit
from jax.scipy.special import logsumexp as _logsumexp
import jax.numpy as jnp
Array = Any
# activations
@custom_jvp
def relu(x: Array) -> Array:
r"""Rectified linear unit activation function.
Computes the element-wise function:
.. math::
\mathrm{relu}(x) = \max(x, 0)
Args:
x : input array
"""
return jnp.maximum(x, 0)
relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))
def softplus(x: Array) -> Array:
r"""Softplus activation function.
Computes the element-wise function
.. math::
\mathrm{softplus}(x) = \log(1 + e^x)
Args:
x : input array
"""
return jnp.logaddexp(x, 0)
def soft_sign(x: Array) -> Array:
r"""Soft-sign activation function.
Computes the element-wise function
.. math::
\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}
Args:
x : input array
"""
return x / (jnp.abs(x) + 1)
def sigmoid(x: Array) -> Array:
r"""Sigmoid activation function.
Computes the element-wise function:
.. math::
\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}
Args:
x : input array
"""
return expit(x)
def silu(x: Array) -> Array:
r"""SiLU activation function.
Computes the element-wise function:
.. math::
\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}
Args:
x : input array
"""
return x * sigmoid(x)
swish = silu
def log_sigmoid(x: Array) -> Array:
r"""Log-sigmoid activation function.
Computes the element-wise function:
.. math::
\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})
Args:
x : input array
"""
return -softplus(-x)
def elu(x: Array, alpha: Array = 1.0) -> Array:
r"""Exponential linear unit activation function.
Computes the element-wise function:
.. math::
\mathrm{elu}(x) = \begin{cases}
x, & x > 0\\
\alpha \left(\exp(x) - 1\right), & x \le 0
\end{cases}
Args:
x : input array
alpha : scalar or array of alpha values (default: 1.0)
"""
safe_x = jnp.where(x > 0, 0., x)
return jnp.where(x > 0, x, alpha * jnp.expm1(safe_x))
def leaky_relu(x: Array, negative_slope: Array = 1e-2) -> Array:
r"""Leaky rectified linear unit activation function.
Computes the element-wise function:
.. math::
\mathrm{leaky\_relu}(x) = \begin{cases}
x, & x \ge 0\\
\alpha x, & x < 0
\end{cases}
where :math:`\alpha` = :code:`negative_slope`.
Args:
x : input array
negative_slope : array or scalar specifying the negative slope (default: 0.01)
"""
return jnp.where(x >= 0, x, negative_slope * x)
def hard_tanh(x: Array) -> Array:
r"""Hard :math:`\mathrm{tanh}` activation function.
Computes the element-wise function:
.. math::
\mathrm{hard\_tanh}(x) = \begin{cases}
-1, & x < -1\\
x, & -1 \le x \le 1\\
1, & 1 < x
\end{cases}
Args:
x : input array
"""
return jnp.where(x > 1, 1, jnp.where(x < -1, -1, x))
def celu(x: Array, alpha: Array = 1.0) -> Array:
r"""Continuously-differentiable exponential linear unit activation.
Computes the element-wise function:
.. math::
\mathrm{celu}(x) = \begin{cases}
x, & x > 0\\
\alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0
\end{cases}
For more information, see
`Continuously Differentiable Exponential Linear Units
<https://arxiv.org/pdf/1704.07483.pdf>`_.
Args:
x : input array
alpha : array or scalar (default: 1.0)
"""
return jnp.where(x > 0, x, alpha * jnp.expm1(x / alpha))
def selu(x: Array) -> Array:
r"""Scaled exponential linear unit activation.
Computes the element-wise function:
.. math::
\mathrm{selu}(x) = \lambda \begin{cases}
x, & x > 0\\
\alpha e^x - \alpha, & x \le 0
\end{cases}
where :math:`\lambda = 1.0507009873554804934193349852946` and
:math:`\alpha = 1.6732632423543772848170429916717`.
For more information, see
`Self-Normalizing Neural Networks
<https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf>`_.
Args:
x : input array
"""
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
return scale * elu(x, alpha)
def gelu(x: Array, approximate: bool = True) -> Array:
r"""Gaussian error linear unit activation function.
If ``approximate=False``, computes the element-wise function:
.. math::
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left(
\frac{x}{\sqrt{2}} \right) \right)
If ``approximate=True``, uses the approximate formulation of GELU:
.. math::
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
For more information, see `Gaussian Error Linear Units (GELUs)
<https://arxiv.org/abs/1606.08415>`_, section 2.
Args:
x : input array
approximate: whether to use the approximate or exact formulation.
"""
if approximate:
sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype)
cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x + 0.044715 * (x ** 3))))
return x * cdf
else:
return jnp.array(x * (lax.erf(x / np.sqrt(2)) + 1) / 2, dtype=x.dtype)
def glu(x: Array, axis: int = -1) -> Array:
"""Gated linear unit activation function.
Args:
x : input array
axis: the axis along which the split should be computed (default: -1)
"""
size = x.shape[axis]
assert size % 2 == 0, "axis size must be divisible by 2"
x1, x2 = jnp.split(x, 2, axis)
return x1 * sigmoid(x2)
# other functions
logsumexp = _logsumexp
def log_softmax(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = -1) -> Array:
r"""Log-Softmax function.
Computes the logarithm of the :code:`softmax` function, which rescales
elements to the range :math:`[-\infty, 0)`.
.. math ::
\mathrm{log\_softmax}(x) = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
\right)
Args:
x : input array
axis: the axis or axes along which the :code:`log_softmax` should be
computed. Either an integer or a tuple of integers.
"""
shifted = x - lax.stop_gradient(x.max(axis, keepdims=True))
return shifted - jnp.log(jnp.sum(jnp.exp(shifted), axis, keepdims=True))
def softmax(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = -1) -> Array:
r"""Softmax function.
Computes the function which rescales elements to the range :math:`[0, 1]`
such that the elements along :code:`axis` sum to :math:`1`.
.. math ::
\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
Args:
x : input array
axis: the axis or axes along which the softmax should be computed. The
softmax output summed across these dimensions should sum to :math:`1`.
Either an integer or a tuple of integers.
"""
unnormalized = jnp.exp(x - lax.stop_gradient(x.max(axis, keepdims=True)))
return unnormalized / unnormalized.sum(axis, keepdims=True)
def normalize(x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
mean: Optional[Array] = None,
variance: Optional[Array] = None,
epsilon: Array = 1e-5) -> Array:
"""Normalizes an array by subtracting mean and dividing by sqrt(var)."""
if mean is None:
mean = jnp.mean(x, axis, keepdims=True)
if variance is None:
# this definition is traditionally seen as less accurate than jnp.var's
# mean((x - mean(x))**2) but may be faster and even, given typical
# activation distributions and low-precision arithmetic, more accurate
# when used in neural network normalization layers
variance = jnp.mean(jnp.square(x), axis, keepdims=True) - jnp.square(mean)
return (x - mean) * lax.rsqrt(variance + epsilon)
def one_hot(x: Array, num_classes: int, *,
dtype: Any = jnp.float64, axis: Union[int, AxisName] = -1) -> Array:
"""One-hot encodes the given indicies.
Each index in the input ``x`` is encoded as a vector of zeros of length
``num_classes`` with the element at ``index`` set to one::
>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
DeviceArray([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]], dtype=float32)
Indicies outside the range [0, num_classes) will be encoded as zeros::
>>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
DeviceArray([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)
Args:
x: A tensor of indices.
num_classes: Number of classes in the one-hot dimension.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
axis: the axis or axes along which the function should be
computed.
"""
num_classes = core.concrete_or_error(
int, num_classes,
"The error arose in jax.nn.one_hot argument `num_classes`.")
dtype = dtypes.canonicalize_dtype(dtype)
x = jnp.asarray(x)
try:
output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
except TypeError:
axis_size = lax.psum(1, axis)
if num_classes != axis_size:
raise ValueError(f"Expected num_classes to match the size of axis {axis}, "
f"but {num_classes} != {axis_size}") from None
axis_idx = lax.axis_index(axis)
return jnp.asarray(x == axis_idx, dtype=dtype)
axis = operator.index(axis)
lhs = lax.expand_dims(x, (axis,))
rhs_shape = [1] * x.ndim
rhs_shape.insert(output_pos_axis, num_classes)
rhs = lax.broadcast_in_dim(jnp.arange(num_classes, dtype=x.dtype),
rhs_shape, (output_pos_axis,))
return jnp.asarray(lhs == rhs, dtype=dtype)
def relu6(x: Array) -> Array:
r"""Rectified Linear Unit 6 activation function.
Computes the element-wise function
.. math::
\mathrm{relu6}(x) = \min(\max(x, 0), 6)
Args:
x : input array
"""
return jnp.minimum(jnp.maximum(x, 0), 6.)
def hard_sigmoid(x: Array) -> Array:
r"""Hard Sigmoid activation function.
Computes the element-wise function
.. math::
\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}
Args:
x : input array
"""
return relu6(x + 3.) / 6.
def hard_silu(x: Array) -> Array:
r"""Hard SiLU activation function
Computes the element-wise function
.. math::
\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)
Args:
x : input array
"""
return x * hard_sigmoid(x)
hard_swish = hard_silu