Skip to content

Commit

Permalink
Add references to "Fast Neural Kernel Embeddings for General Activati…
Browse files Browse the repository at this point in the history
…ons"

PiperOrigin-RevId: 473827049
  • Loading branch information
romanngg committed Sep 12, 2022
1 parent 7c8729d commit c2e8d07
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 123 deletions.
200 changes: 108 additions & 92 deletions README.md

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion docs/stax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ Linear layers without any trainable parameters.

Elementwise nonlinear
--------------------------------------
Pointwise nonlinear layers.
Pointwise nonlinear layers. For details, please see "`Fast Neural Kernel Embeddings for General Activations
<https://arxiv.org/abs/2209.04121>`_".

.. autosummary::
:toctree: _autosummary
Expand Down
3 changes: 2 additions & 1 deletion examples/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

"""Example of automatically deriving the closed-form NTK from NNGP.
For details, see :obj:`~neural_tangents.stax.Elementwise`.
For details, see :obj:`~neural_tangents.stax.Elementwise` and "`Fast Neural
Kernel Embeddings for General Activations <https://arxiv.org/abs/2209.04121>`_".
"""

from absl import app
Expand Down
4 changes: 3 additions & 1 deletion examples/elementwise_numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

"""Example of approximating the NNGP and NTK using quadrature and autodiff.
For details, see :obj:`~neural_tangents.stax.ElementwiseNumerical`.
For details, see :obj:`~neural_tangents.stax.ElementwiseNumerical` and "`Fast
Neural Kernel Embeddings for General Activations
<https://arxiv.org/abs/2209.04121>`_".
"""

from absl import app
Expand Down
7 changes: 5 additions & 2 deletions examples/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
https://github.com/google/neural-tangents/blob/main/examples/infinite_fcn.py
By default, this example does inference on a very small subset, and uses small
word embeddings for performance. A 300/300 train/test split takes 30 seconds
on a machine with 2 Titan X Pascal GPUs, please adjust settings accordingly.
word embeddings for performance. A 300/300 train/test split takes 30 seconds
on a machine with 2 Titan X Pascal GPUs, please adjust settings accordingly.
For details, please see "`Infinite attention: NNGP and NTK for deep attention
networks <https://arxiv.org/abs/2006.10540>`_".
"""

import time
Expand Down
12 changes: 11 additions & 1 deletion neural_tangents/_src/stax/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Elementwise nonlinearities / activation functions."""
"""Elementwise nonlinearities / activation functions.
For details, please see "`Fast Neural Kernel Embeddings for General Activations
<https://arxiv.org/abs/2209.04121>`_".
"""

import functools
import operator as op
Expand Down Expand Up @@ -1028,6 +1032,9 @@ def Elementwise(
to use the custom implementation, since it uses symbolically simplified
expressions that are more precise and numerically stable.
For details, please see "`Fast Neural Kernel Embeddings for General
Activations <https://arxiv.org/abs/2209.04121>`_".
See Also:
`examples/elementwise.py`.
Expand Down Expand Up @@ -1124,6 +1131,9 @@ def ElementwiseNumerical(
Supports general activation functions using Gauss-Hermite quadrature.
For details, please see "`Fast Neural Kernel Embeddings for General
Activations <https://arxiv.org/abs/2209.04121>`_".
See Also:
`examples/elementwise_numerical.py`.
Expand Down
26 changes: 26 additions & 0 deletions neural_tangents/_src/stax/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3072,14 +3072,18 @@ def _conv_kernel_full_spatial_shared(
the number of spatial dimensions (e.g. 2 for images). Has shape
`(batch_size_1, [batch_size_2,] height, height, width, width, depth,
depth, ...)`.
filter_shape:
positive integers, the convolutional filters spatial shape
(e.g. `(3, 3)` for a 2D convolution).
strides:
positive integers, the CNN strides (e.g. `(1, 1)` for a 2D
convolution).
padding:
a `Padding` enum, e.g. `Padding.CIRCULAR`.
batch_ndim:
number of batch dimensions, 1 or 2.
Expand Down Expand Up @@ -3171,14 +3175,18 @@ def _conv_kernel_full_spatial_unshared(
the number of spatial dimensions (e.g. 2 for images). Has shape
`(batch_size_1, [batch_size_2,] height, height, width, width, depth,
depth, ...)`.
filter_shape:
positive integers, the convolutional filters spatial shape
(e.g. `(3, 3)` for a 2D convolution).
strides:
positive integers, the CNN strides (e.g. `(1, 1)` for a 2D
convolution).
padding:
a `Padding` enum, e.g. `Padding.CIRCULAR`.
batch_ndim:
number of batch dimensions, 1 or 2.
Expand Down Expand Up @@ -3222,14 +3230,18 @@ def _conv_kernel_full_spatial_transpose(
the number of spatial dimensions (e.g. 2 for images). Has shape
`(batch_size_1, [batch_size_2,] height, height, width, width, depth,
depth, ...)`.
filter_shape:
positive integers, the convolutional filters spatial shape
(e.g. `(3, 3)` for a 2D convolution).
strides:
positive integers, the CNN strides (e.g. `(1, 1)` for a 2D
convolution).
padding:
a `Padding` enum, e.g. `Padding.CIRCULAR`.
batch_ndim:
number of batch dimensions, 1 or 2.
Expand Down Expand Up @@ -3328,14 +3340,18 @@ def _conv_kernel_diagonal_spatial(
sample-sample-(same position) covariances of CNN inputs. Has `batch_ndim`
batch and `S` spatial dimensions with the shape of `(batch_size_1,
[batch_size_2,] height, width, depth, ...)`.
filter_shape:
tuple of positive integers, the convolutional filters spatial shape
(e.g. `(3, 3)` for a 2D convolution).
strides:
tuple of positive integers, the CNN strides (e.g. `(1, 1)` for a 2D
convolution).
padding:
a `Padding` enum, e.g. `Padding.CIRCULAR`.
batch_ndim:
number of leading batch dimensions, 1 or 2.
Expand Down Expand Up @@ -3383,14 +3399,18 @@ def _conv_kernel_diagonal_spatial_transpose(
sample-sample-(same position) covariances of CNN inputs. Has `batch_ndim`
batch and `S` spatial dimensions with the shape of `(batch_size_1,
[batch_size_2,] height, width, depth, ...)`.
filter_shape:
tuple of positive integers, the convolutional filters spatial shape
(e.g. `(3, 3)` for a 2D convolution).
strides:
tuple of positive integers, the CNN strides (e.g. `(1, 1)` for a 2D
convolution).
padding:
a `Padding` enum, e.g. `Padding.CIRCULAR`.
batch_ndim:
number of leading batch dimensions, 1 or 2.
Expand Down Expand Up @@ -3439,19 +3459,25 @@ def _pool_kernel(
is the number of spatial dimensions (e.g. 2 for images). Has shape
`(batch_size_1, [batch_size_2,]
height, height, width, width, depth, depth, ...)`.
pool_type:
a `Pooling` enum, e.g. `Pooling.AVG`.
window_shape:
tuple of positive integers, the pooling spatial shape (e.g. `(3, 3)`).
strides:
tuple of positive integers, the pooling strides, e.g. `(1, 1)`.
padding:
a `Padding` enum, e.g. `Padding.CIRCULAR`.
normalize_edges:
`True` to normalize output by the effective receptive field, `False` to
normalize by the window size. Only has effect at the edges when `SAME`
padding is used. Set to `True` to retain correspondence to
`ostax.AvgPool`.
batch_ndim:
number of leading batch dimensions, 1 or 2.
Expand Down
39 changes: 14 additions & 25 deletions notebooks/elementwise.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
"# Examples of automatic nonlinearity NNGP/NTK computation using `stax.Elementwise` and `stax.ElementwiseNumerical`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "e76rfN10CKHn"
},
"source": [
"For details, please see \"[Fast Neural Kernel Embeddings for General Activations](https://arxiv.org/abs/2209.04121)\"."
]
},
{
"cell_type": "markdown",
"metadata": {
Expand All @@ -29,7 +38,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"executionInfo": {
"elapsed": 56,
Expand Down Expand Up @@ -63,18 +72,8 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {
"executionInfo": {
"elapsed": 6736,
"status": "ok",
"timestamp": 1660802709761,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": 420
},
"id": "2-Y93-C7lPOC"
},
"outputs": [],
Expand All @@ -86,18 +85,8 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {
"executionInfo": {
"elapsed": 3632,
"status": "ok",
"timestamp": 1660802713526,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": 420
},
"id": "8o90BE__iJVS"
},
"outputs": [],
Expand Down Expand Up @@ -127,7 +116,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {
"executionInfo": {
"elapsed": 3231,
Expand Down Expand Up @@ -203,7 +192,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {
"executionInfo": {
"elapsed": 6559,
Expand Down

0 comments on commit c2e8d07

Please sign in to comment.