-
Notifications
You must be signed in to change notification settings - Fork 0
/
linear.py
270 lines (221 loc) · 8.08 KB
/
linear.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
r"""Linear layers"""
__all__ = [
'Linear',
'Conv',
'ConvTransposed',
]
import jax
import math
from jax import Array
from typing import *
from .module import Module, Parameter
from ..numpy import flatten, unflatten
class Linear(Module):
r"""Creates a linear layer.
.. math:: y = W x + b
Arguments:
key: A PRNG key for initialization.
in_features: The number of input features :math:`C`.
out_features: The number of output features :math:`C'`.
bias: Whether the layer learns an additive bias :math:`b` or not.
"""
def __init__(
self,
key: Array,
in_features: int,
out_features: int,
bias: bool = True,
):
keys = jax.random.split(key, 2)
lim = 1 / math.sqrt(in_features)
self.weight = Parameter(
jax.random.uniform(
key=keys[0],
shape=(in_features, out_features),
minval=-lim,
maxval=lim,
)
)
if bias:
self.bias = Parameter(
jax.random.uniform(
key=keys[1],
shape=(out_features,),
minval=-lim,
maxval=lim,
)
)
else:
self.bias = None
def __call__(self, x: Array) -> Array:
r"""
Arguments:
x: The input vector :math:`x`, with shape :math:`(*, C)`.
Returns:
The output vector :math:`y`, with shape :math:`(*, C')`.
"""
if self.bias is None:
return x @ self.weight()
else:
return x @ self.weight() + self.bias()
class Conv(Module):
r"""Creates a convolution layer.
.. math:: y = W * x + b
References:
| A guide to convolution arithmetic for deep learning (Dumoulin et al., 2016)
| https://arxiv.org/abs/1603.07285
Arguments:
key: A PRNG key for initialization.
in_channels: The number of input channels :math:`C`.
out_channels: The number of output channels :math:`C'`.
kernel_size: The size of the kernel :math:`W` in each spatial axis.
bias: Whether the layer learns an additive bias :math:`b` or not.
stride: The stride coefficient in each spatial axis.
dilation: The dilation coefficient in each spatial axis.
padding: The padding applied to each end of each spatial axis.
groups: The number of channel groups :math:`G`.
Both :math:`C` and :math:`C'` must be divisible by :math:`G`.
"""
def __init__(
self,
key: Array,
in_channels: int,
out_channels: int,
kernel_size: Sequence[int],
bias: bool = True,
stride: Union[int, Sequence[int]] = 1,
dilation: Union[int, Sequence[int]] = 1,
padding: Union[int, Sequence[Tuple[int, int]]] = 0,
groups: int = 1,
):
in_channels = in_channels // groups
if isinstance(stride, int):
stride = [stride] * len(kernel_size)
if isinstance(dilation, int):
dilation = [dilation] * len(kernel_size)
if isinstance(padding, int):
padding = [(padding, padding)] * len(kernel_size)
keys = jax.random.split(key, 2)
lim = 1 / math.sqrt(math.prod(kernel_size) * in_channels)
self.kernel = Parameter(
jax.random.uniform(
key=keys[0],
shape=(*kernel_size, in_channels, out_channels),
minval=-lim,
maxval=lim,
)
)
if bias:
self.bias = Parameter(
jax.random.uniform(
key=keys[1],
shape=(out_channels,),
minval=-lim,
maxval=lim,
)
)
else:
self.bias = None
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.padding = padding
self.groups = groups
def __call__(self, x: Array) -> Array:
r"""
Arguments:
x: The input tensor :math:`x`, with shape :math:`(*, H_1, \dots, H_n, C)`.
Returns:
The output tensor :math:`y`, with shape :math:`(*, H_1', \dots, H_n', C')`,
such that
.. math:: H_i' =
\left\lfloor \frac{H_i - d_i \times (k_i - 1) + p_i}{s_i} + 1 \right\rfloor
where :math:`k_i`, :math:`s_i`, :math:`d_i` and :math:`p_i` are respectively
the kernel size, the stride coefficient, the dilation coefficient and the
total padding of the :math:`i`-th spatial axis.
"""
batch = x.shape[: -self.ndim]
x = flatten(x, 0, -self.ndim)
x = jax.lax.conv_general_dilated(
lhs=x,
rhs=self.kernel(),
dimension_numbers=self.dimensions,
window_strides=self.stride,
rhs_dilation=self.dilation,
padding=self.padding,
feature_group_count=self.groups,
)
x = unflatten(x, 0, batch)
if self.bias is None:
return x
else:
return x + self.bias()
@property
def ndim(self) -> int:
return len(self.kernel_size) + 1
@property
def dimensions(self) -> jax.lax.ConvDimensionNumbers:
return jax.lax.ConvDimensionNumbers(
(0, self.ndim, *range(1, self.ndim)),
(self.ndim, self.ndim - 1, *range(0, self.ndim - 1)),
(0, self.ndim, *range(1, self.ndim)),
)
class ConvTransposed(Conv):
r"""Creates a transposed convolution layer.
This layer can be seen as the gradient of :class:`Conv` with respect to its input.
It is also known as a "deconvolution", altough it does not actually compute the
inverse of a convolution.
References:
| A guide to convolution arithmetic for deep learning (Dumoulin et al., 2016)
| https://arxiv.org/abs/1603.07285
Arguments:
key: A PRNG key for initialization.
in_channels: The number of input channels :math:`C`.
out_channels: The number of output channels :math:`C'`.
kernel_size: The size of the kernel :math:`W` in each spatial axis.
bias: Whether the layer learns an additive bias :math:`b` or not.
stride: The stride coefficient in each spatial axis.
dilation: The dilation coefficient in each spatial axis.
padding: The padding applied to each end of each spatial axis.
groups: The number of channel groups :math:`G`.
Both :math:`C` and :math:`C'` must be divisible by :math:`G`.
"""
def __call__(self, x: Array) -> Array:
r"""
Arguments:
x: The input tensor :math:`x`, with shape :math:`(*, H_1, \dots, H_n, C)`.
Returns:
The output tensor :math:`y`, with shape :math:`(*, H_1', \dots, H_n', C')`,
such that
.. math:: H_i' = (H_i - 1) \times s_i + d_i \times (k_i - 1) - p_i + 1
where :math:`k_i`, :math:`s_i`, :math:`d_i` and :math:`p_i` are respectively
the kernel size, the stride coefficient, the dilation coefficient and the
total padding of the :math:`i`-th spatial axis.
"""
batch = x.shape[: -self.ndim]
x = flatten(x, 0, -self.ndim)
x = jax.lax.conv_general_dilated(
lhs=x,
rhs=self.kernel(),
dimension_numbers=self.dimensions,
window_strides=[1] * (self.ndim - 1),
padding=self.transposed_padding,
lhs_dilation=self.stride,
rhs_dilation=self.dilation,
feature_group_count=self.groups,
)
x = unflatten(x, 0, batch)
if self.bias is None:
return x
else:
return x + self.bias()
@property
def transposed_padding(self) -> Sequence[Tuple[int, int]]:
return [
(d * (k - 1) - p[0], d * (k - 1) - p[1])
for k, d, p in zip(
self.kernel_size,
self.dilation,
self.padding,
)
]