Skip to content

Commit

Permalink
Add stax.repeat allowing a compiled-loop composition of layers (per #…
Browse files Browse the repository at this point in the history
…168). This allows faster / less memory hungry compilation of very deep networks.

Note that compiled loops require layer to not change shapes and other static metadata. This necessitates some warnings (see docstring), and makes it less flexible than `stax.serial`.

Co-authored-by: Jens Glaser <jens.glaser@gmail.com>
PiperOrigin-RevId: 493193125
  • Loading branch information
romanngg and Jens Glaser committed Dec 6, 2022
1 parent b25bf21 commit fe98c9c
Show file tree
Hide file tree
Showing 4 changed files with 265 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/stax.rst
Expand Up @@ -17,6 +17,7 @@ Layers to combine multiple other layers into one.
:toctree: _autosummary

parallel
repeat
serial


Expand Down
101 changes: 99 additions & 2 deletions neural_tangents/_src/stax/combinators.py
Expand Up @@ -19,7 +19,7 @@
import warnings

import frozendict
from jax import random
from jax import random, lax
import jax.example_libraries.stax as ostax
from .requirements import Diagonal, get_req, layer, requires
from ..utils.kernel import Kernel
Expand All @@ -36,6 +36,9 @@ def serial(*layers: Layer) -> InternalLayer:
*layers:
a sequence of layers, each an `(init_fn, apply_fn, kernel_fn)` triple.
See Also:
:obj:`~neural_tangents.stax.repeat` for compiled repeated composition.
Returns:
A new layer, meaning an `(init_fn, apply_fn, kernel_fn)` triple,
representing the serial composition of the given sequence of layers.
Expand All @@ -54,6 +57,101 @@ def kernel_fn(k: NTTree[Kernel], **kwargs) -> NTTree[Kernel]:
return init_fn, apply_fn, kernel_fn


@layer
def repeat(layer: Layer, n: int) -> InternalLayer:
"""Compose `layer` in a compiled loop `n` times.
Equivalent to `serial(*([layer] * n))`, but allows faster compilation time
for large `n` (but same runtime).
.. warning::
`apply_fn` of the `layer` is assumed to keep the activation (`x`) shape
unchanged.
.. warning::
`kernel_fn` of the `layer` is assumed to keep the
:class:`~neural_tangents.Kernel` metadata unchanged. This is most notably
not satisfied in :obj:`~neural_tangents.stax.Conv` and other convolutional
layers which flip the `is_reversed` attribute with each application. A
workaround is to either use `serial(*([layer] * n))`, or to use
`repeat(serial(layer, layer), n // 2)` instead of `repeat(layer, n)` for an
even `n`, i.e. to use two (or, generally, any even number of) convolutions
per `layer` instead of one (or, generally, any odd number), such that
`layer` does not alter the `is_reversed` attribute. Similar caution should
be applied to other :class:`~neural_tangents.Kernel` attributes.
See Also:
`RepeatTest` in `tests/stax/combinators_test.py` for examples and
:obj:`~neural_tangents.stax.serial` for unrolled composition.
Example:
>>> from neural_tangents import stax
>>> #
>>> layer = stax.serial(stax.Dense(128), stax.Relu())
>>> depth = 100
>>> #
>>> # Unrolled loop:
>>> nn_unrolled = stax.serial(*([layer] * depth))
>>> #
>>> # Compiled loop:
>>> nn_compiled = stax.repeat(layer, depth)
>>> # `nn_unrolled` and `nn_compiled` perform the same computation, but
>>> # `nn_compiled` compiles faster and with smaller memory footprint.
Args:
layer:
layer to be repeated. Outputs must have the same shape and other metadata
as inputs.
n:
number of times to repeat a layer (depth).
Returns:
A new layer, meaning an `(init_fn, apply_fn, kernel_fn)` triple,
representing the repeated composition of `layer` `n` times.
"""
init_fn, apply_fn, kernel_fn = layer

def init_fn_repeat(rng, input_shape):
out_shape, _ = init_fn(rng, input_shape)
if out_shape != input_shape:
raise ValueError(
f'`init_fn` produces a different output shape {out_shape} than the '
f'input shape {input_shape}. Please use the `serial(*([layer] * n)`) '
f'construction in this setting.'
)

def init_fn_scan(rng, params):
rng, layer_rng = random.split(rng)
out_shape, params = init_fn(layer_rng, input_shape)
return rng, params

_, params = lax.scan(init_fn_scan, rng, None, n)
return out_shape, params

def apply_fn_repeat(params, inputs, **kwargs):
def apply_fn_scan(x, params):
return apply_fn(params, x, **kwargs), None

outputs, _ = lax.scan(apply_fn_scan, inputs, params, n)
return outputs

@requires(**get_req(kernel_fn))
def kernel_fn_repeat(k: NTTree[Kernel], **kwargs) -> NTTree[Kernel]:
if n > 0:
k = kernel_fn(k, **kwargs)

def kernel_fn_scan(k, _):
k = kernel_fn(k, **kwargs)
return k, None

k, _ = lax.scan(kernel_fn_scan, k, None, n - 1)

return k

return init_fn_repeat, apply_fn_repeat, kernel_fn_repeat


@layer
def parallel(*layers: Layer) -> InternalLayer:
"""Combinator for composing layers in parallel.
Expand Down Expand Up @@ -155,4 +253,3 @@ def _get_input_req_attr(
raise NotImplementedError(k)

return req

1 change: 1 addition & 0 deletions neural_tangents/stax.py
Expand Up @@ -77,6 +77,7 @@
from ._src.stax.combinators import (
parallel,
serial,
repeat
)


Expand Down
164 changes: 164 additions & 0 deletions tests/stax/combinators_test.py
@@ -0,0 +1,164 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for `neural_tangents/_src/stax/combinators.py`."""

import random as prandom

from absl.testing import absltest
from jax import random
from jax.config import config
import jax.numpy as np
from neural_tangents import stax
from tests import test_utils


config.parse_flags_with_absl()
config.update('jax_numpy_rank_promotion', 'raise')


test_utils.update_test_tolerance()

prandom.seed(1)


class RepeatTest(test_utils.NeuralTangentsTestCase):

def _test_repeat(self, x1, x2, layer, n, rng_params, **kwargs):
init_fn, apply_fn, kernel_fn = (stax.Identity() if n == 0 else
stax.serial(*([layer] * n)))
init_fn_repeat, apply_fn_repeat, kernel_fn_repeat = stax.repeat(layer, n)

out_shape, params = init_fn(rng_params, x1.shape)
out_shape_repeat, params_repeat = init_fn_repeat(rng_params, x1.shape)

self.assertEqual(out_shape, out_shape_repeat)

kwargs1 = {k: kwargs[k][0] for k in kwargs}
out = apply_fn(params, x1, **kwargs1)
out_repeat = apply_fn_repeat(params_repeat, x1, **kwargs1)

self.assertAllClose(out, out_repeat)

for get in [None, 'ntk', 'nngp', 'cov1', ('nngp', 'cov1'), ('cov1', 'ntk')]:
with self.subTest(get=get):
k = kernel_fn(x1, x2, get, **kwargs)
k_repeat = kernel_fn_repeat(x1, x2, get, **kwargs)
self.assertAllClose(k, k_repeat)

@test_utils.product(
same_inputs=[
False,
True
],
n=[
0,
1,
2,
3,
],
layer=[
stax.Identity(),
stax.Dense(3),
stax.serial(stax.Identity()),
stax.serial(stax.Dense(3)),
stax.GlobalAvgPool(),
stax.serial(stax.Dense(3), stax.Relu()),
stax.serial(stax.Dense(3), stax.Relu(), stax.Dense(3))
]
)
def test_repeat(
self,
same_inputs,
n,
layer
):
rng_input, rng_params = random.split(random.PRNGKey(1), 2)
x1 = np.cos(random.normal(rng_input, (2, 3)))
x2 = None if same_inputs else random.normal(rng_input, (4, 3))

self._test_repeat(x1, x2, layer, n, rng_params)

@test_utils.product(
same_inputs=[
False,
True
],
n=[
0,
1,
2,
3,
],
layer=[
stax.serial(stax.Conv(3, (2, 2), padding='SAME'),
stax.Relu(),
stax.Conv(3, (2, 2), padding='SAME'),
stax.Gelu()
),
]
)
def test_repeat_conv(
self,
same_inputs,
n,
layer
):
rng_input, rng_params = random.split(random.PRNGKey(1), 2)
x1 = np.cos(random.normal(rng_input, (2, 4, 4, 3)))
x2 = None if same_inputs else random.normal(rng_input, (4, 4, 4, 3))

self._test_repeat(x1, x2, layer, n, rng_params)

@test_utils.product(
same_inputs=[
False,
True
],
n=[
0,
1,
2,
3,
],
layer=[
stax.Aggregate(),
stax.serial(stax.Dense(3), stax.Aggregate(), stax.Abs()),
stax.serial(stax.Conv(3, (2, 2), padding='SAME'),
stax.Aggregate(),
stax.Abs(),
stax.Conv(3, (1, 2), padding='SAME'),
)
]
)
def test_repeat_agg(
self,
same_inputs,
n,
layer
):
rng_input, rng_params, rng_p1, rng_p2 = random.split(random.PRNGKey(1), 4)
x1 = np.cos(random.normal(rng_input, (2, 4, 3, 3)))
x2 = None if same_inputs else random.normal(rng_input, (4, 4, 3, 3))

p1 = random.normal(rng_p1, x1.shape[:-1] + x1.shape[1:-1])
p2 = p1 if x2 is None else random.normal(rng_p2,
x2.shape[:-1] + x2.shape[1:-1])

self._test_repeat(x1, x2, layer, n, rng_params, pattern=(p1, p2))
self._test_repeat(x1, x2, layer, n, rng_params, pattern=(None, None))


if __name__ == '__main__':
absltest.main()

0 comments on commit fe98c9c

Please sign in to comment.