-
Notifications
You must be signed in to change notification settings - Fork 0
/
activation.py
232 lines (155 loc) · 5.51 KB
/
activation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
r"""Activation functions"""
__all__ = [
'Identity',
'Tanh',
'Sigmoid',
'SiLU',
'Softplus',
'Softmax',
'ReLU',
'LeakyReLU',
'ELU',
'CELU',
'GELU',
'SELU',
'SiLU',
]
import jax
import jax.numpy as jnp
from jax import Array
from typing import Sequence, Union
# isort: local
from .module import Module
class Activation(Module):
r"""Abstract activation class."""
def __init__(self):
pass
class Identity(Activation):
r"""Creates an identity activation function.
.. math:: y = x
"""
def __call__(self, x: Array) -> Array:
return x
class Tanh(Activation):
r"""Creates an identity activation function.
.. math:: y = \tanh(x)
"""
def __call__(self, x: Array) -> Array:
return jnp.tanh(x)
class Sigmoid(Activation):
r"""Creates a sigmoid activation function.
.. math:: y = \sigma(x) = \frac{1}{1 + \exp(-x)}
"""
def __call__(self, x: Array) -> Array:
return jax.nn.sigmoid(x)
class Softplus(Activation):
r"""Creates a softplus activation function.
.. math:: y = \log(1 + \exp(x))
"""
def __call__(self, x: Array) -> Array:
return jax.nn.softplus(x)
class Softmax(Activation):
r"""Creates a softmax activation function.
.. math:: y_i = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
Arguments:
axis: The axis(es) over which the sum is performed.
"""
def __init__(self, axis: Union[int, Sequence[int]] = -1):
self.axis = axis
def __call__(self, x: Array) -> Array:
return jax.nn.softmax(x, axis=self.axis)
def tree_repr(self, **kwargs) -> str:
return f'{self.__class__.__name__}(axis={self.axis})'
class ReLU(Activation):
r"""Creates a rectified linear unit (ReLU) activation function.
.. math:: y = \max(x, 0)
"""
def __call__(self, x: Array) -> Array:
return jax.nn.relu(x)
class LeakyReLU(Activation):
r"""Creates a leaky-ReLU activation function.
.. math:: y = \begin{cases}
\alpha x & \text{if } x \leq 0 \\
x & \text{otherwise}
\end{cases}
Arguments:
alpha: The negative slope :math:`\alpha`.
"""
def __init__(self, alpha: Union[float, Array] = 0.01):
self.alpha = jnp.asarray(alpha)
def __call__(self, x: Array) -> Array:
return jax.nn.leaky_relu(x, self.alpha)
def tree_repr(self, **kwargs) -> str:
return f'{self.__class__.__name__}(alpha={self.alpha})'
class ELU(Activation):
r"""Creates an exponential linear unit (ELU) activation function.
.. math:: y = \begin{cases}
\alpha (\exp(x) - 1) & \text{if } x \leq 0 \\
x & \text{otherwise}
\end{cases}
References:
| Fast and Accurate Deep Network Learning by Exponential Linear Units (Clevert et al., 2015)
| https://arxiv.org/abs/1511.07289
Arguments:
alpha: The coefficient :math:`\alpha`.
"""
def __init__(self, alpha: Union[float, Array] = 1.0):
self.alpha = jnp.asarray(alpha)
def __call__(self, x: Array) -> Array:
return jax.nn.elu(x)
def tree_repr(self, **kwargs) -> str:
return f'{self.__class__.__name__}(alpha={self.alpha})'
class CELU(ELU):
r"""Creates a continuously-differentiable ELU (CELU) activation function.
.. math:: y = \max(x, 0) + \alpha \min(0, \exp(x / \alpha) - 1)
References:
| Continuously Differentiable Exponential Linear Units (Barron, 2017)
| https://arxiv.org/abs/1704.07483
Arguments:
alpha: The coefficient :math:`\alpha`.
"""
def __call__(self, x: Array) -> Array:
return jax.nn.celu(x)
class GELU(Activation):
r"""Creates a Gaussian error linear unit (GELU) activation function.
.. math:: y = \frac{x}{2}
\left(1 + \mathrm{erf}\left(\frac{x}{\sqrt{2}}\right)\right)
When :py:`approximate=True`, it is approximated as
.. math:: y = \frac{x}{2}
\left(1 + \tanh\left(\sqrt{\frac{2}{\pi}}(x + 0.044715 x^3)\right)\right)
References:
| Gaussian Error Linear Units (Hendrycks et al., 2017)
| https://arxiv.org/abs/1606.08415v4
Arguments:
approximate: Whether to use the approximate or exact formulation.
"""
def __init__(self, approximate: bool = True):
self.approximate = approximate
def __call__(self, x: Array) -> Array:
return jax.nn.gelu(x, self.approximate)
def tree_repr(self, **kwargs) -> str:
return f'{self.__class__.__name__}(approximate={self.approximate})'
class SELU(Activation):
r"""Creates a self-normalizing ELU (SELU) activation function.
.. math:: y = \lambda \begin{cases}
\alpha (\exp(x) - 1) & \text{if } x \leq 0 \\
x & \text{otherwise}
\end{cases}
where :math:`\lambda \approx 1.0507` and :math:`\alpha \approx 1.6732`.
References:
| Self-Normalizing Neural Networks (Klambauer et al., 2017)
| https://arxiv.org/abs/1706.02515
"""
def __call__(self, x: Array) -> Array:
return jax.nn.selu(x)
class SiLU(Activation):
r"""Creates a sigmoid linear unit (SiLU) activation function.
.. math:: y = x \sigma(x)
References:
| Gaussian Error Linear Units (Hendrycks et al., 2017)
| https://arxiv.org/abs/1606.08415v4
| Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning (Elfwing et al., 2017)
| https://arxiv.org/abs/1702.03118
"""
def __call__(self, x: Array) -> Array:
return jax.nn.silu(x)