-
Notifications
You must be signed in to change notification settings - Fork 228
/
stax.py
165 lines (145 loc) · 4.36 KB
/
stax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Closed-form NNGP and NTK library.
This library contains layers mimicking those in
:obj:`jax.example_libraries.stax` with similar API apart from:
1) Instead of `(init_fn, apply_fn)` tuple, layers return a triple
`(init_fn, apply_fn, kernel_fn)`, where the added `kernel_fn` maps a
:class:`~neural_tangents.Kernel` to a new :class:`~neural_tangents.Kernel`, and
represents the change in the analytic NTK and NNGP kernels
(:attr:`~neural_tangents.Kernel.nngp`, :attr:`~neural_tangents.Kernel.ntk`).
These functions are chained / stacked together within the :obj:`serial` or
:obj:`parallel` combinators, similarly to `init_fn` and `apply_fn`.
For details, please see "`Neural Tangents: Fast and Easy Infinite Neural
Networks in Python <https://arxiv.org/abs/1912.02803>`_".
2) In layers with random weights, NTK parameterization is used by default
(see page 3 in
"`Neural Tangent Kernel: Convergence and Generalization in Neural Networks
<https://arxiv.org/abs/1806.07572>`_"). Standard parameterization can be
specified for :obj:`Conv` and :obj:`Dense` layers by a keyword argument
`parameterization`. For details, please see "`On the infinite width limit of
neural networks with a standard parameterization
<https://arxiv.org/abs/2001.07301>`_".
3) Some functionality may be missing (e.g.
:obj:`jax.example_libraries.stax.BatchNorm`), and some may be
present only in our library (e.g. :attr:`~Padding.CIRCULAR` padding,
:obj:`LayerNorm`, :obj:`GlobalAvgPool`, :obj:`GlobalSelfAttention`, flexible
batch and channel axes etc.).
Example:
>>> from jax import random
>>> import neural_tangents as nt
>>> from neural_tangents import stax
>>> #
>>> key1, key2 = random.split(random.PRNGKey(1), 2)
>>> x_train = random.normal(key1, (20, 32, 32, 3))
>>> y_train = random.uniform(key1, (20, 10))
>>> x_test = random.normal(key2, (5, 32, 32, 3))
>>> #
>>> init_fn, apply_fn, kernel_fn = stax.serial(
>>> stax.Conv(128, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(256, (3, 3)),
>>> stax.Relu(),
>>> stax.Conv(512, (3, 3)),
>>> stax.Flatten(),
>>> stax.Dense(10)
>>> )
>>> #
>>> predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
>>> y_train)
>>> #
>>> # (5, 10) jnp.ndarray NNGP test prediction
>>> y_test_nngp = predict_fn(x_test=x_test, get='nngp')
>>> #
>>> # (5, 10) jnp.ndarray NTK prediction
>>> y_test_ntk = predict_fn(x_test=x_test, get='ntk')
"""
# Layer combinators, combining multiple layers into a single layer.
from ._src.stax.combinators import (
parallel,
serial,
repeat
)
# Elementwise nonlinearities.
from ._src.stax.elementwise import (
ABRelu,
Abs,
Cos,
Elementwise,
ElementwiseNumerical,
Erf,
Exp,
ExpNormalized,
Gabor,
Gaussian,
Gelu,
Hermite,
LeakyRelu,
Monomial,
Polynomial,
Rbf,
RectifiedMonomial,
Relu,
Sigmoid_like,
Sign,
Sin,
)
# Linear layers.
from ._src.stax.linear import (
Aggregate,
AvgPool,
Conv,
ConvLocal,
ConvTranspose,
Dense,
Identity,
Index,
DotGeneral,
Dropout,
Flatten,
GlobalAvgPool,
GlobalSelfAttention,
GlobalSumPool,
ImageResize,
LayerNorm,
SumPool,
)
# Helper object for the `Index` layer.
from ._src.stax.linear import (
Slice
)
# Branching layers.
from ._src.stax.branching import (
FanInConcat,
FanInProd,
FanInSum,
FanOut,
)
# Enums to specify layer behavior.
from ._src.stax.linear import (
AggregateImplementation,
AttentionMechanism,
Padding,
PositionalEmbedding,
)
# Decorators and classes for constructing your own layers.
from ._src.stax.requirements import (
Bool,
Diagonal,
MaskedArray,
layer,
requires,
supports_masking,
unmask_fn,
)