Skip to content

Commit

Permalink
Add an experimental prototype for using empirical NTK in Tensorflow.
Browse files Browse the repository at this point in the history
Minor other changes:
- add missing NT pip-installs in Colabs;
- avoid using `jax.example_libraries.stax` to avoid rank promotion, since JAX's flag to raise errors on rank promotion can leak into test targets where it isn't necessarily set.
- fix a bug in `nt.NtkImplementation.AUTO` that would cause failures on nested dictionaries as inputs/parameters.

PiperOrigin-RevId: 455853984
  • Loading branch information
romanngg committed Jun 19, 2022
1 parent 99c002c commit e28971c
Show file tree
Hide file tree
Showing 20 changed files with 938 additions and 191 deletions.
163 changes: 82 additions & 81 deletions README.md

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions examples/empirical_ntk.py
Expand Up @@ -22,8 +22,8 @@
import jax
from jax import numpy as np
from jax import random
from jax.example_libraries import stax
import neural_tangents as nt
from neural_tangents import stax


def main(unused_argv):
Expand All @@ -32,13 +32,13 @@ def main(unused_argv):
x2 = random.normal(key2, (3, 8, 8, 3))

# A vanilla CNN.
init_fn, f = stax.serial(
init_fn, f, _ = stax.serial(
stax.Conv(8, (3, 3)),
stax.Relu,
stax.Relu(),
stax.Conv(8, (3, 3)),
stax.Relu,
stax.Relu(),
stax.Conv(8, (3, 3)),
stax.Flatten,
stax.Flatten(),
stax.Dense(10)
)

Expand Down
13 changes: 13 additions & 0 deletions examples/experimental/__init__.py
@@ -0,0 +1,13 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
111 changes: 111 additions & 0 deletions examples/experimental/empirical_ntk_tf.py
@@ -0,0 +1,111 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Minimal highly-experimental Tensorflow NTK example."""

from absl import app
import neural_tangents as nt
import tensorflow as tf


tf.random.set_seed(1)


def _get_ntks(f, x1, x2, params, vmap_axes):
"""Returns a list of NTKs computed using different implementations."""
kwargs = dict(
f=f,
trace_axes=(),
vmap_axes=vmap_axes,
)

# Default, baseline Jacobian contraction.
jacobian_contraction = nt.experimental.empirical_ntk_fn_tf(
**kwargs,
implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION)
# (6, 3, 10, 10) full `np.ndarray` test-train NTK
ntk_jc = jacobian_contraction(x2, x1, params)

# NTK-vector products-based implementation.
ntk_vector_products = nt.experimental.empirical_ntk_fn_tf(
**kwargs,
implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS)
ntk_vp = ntk_vector_products(x2, x1, params)

# Structured derivatives-based implementation.
structured_derivatives = nt.experimental.empirical_ntk_fn_tf(
**kwargs,
implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES)
ntk_sd = structured_derivatives(x2, x1, params)

# Auto-FLOPs-selecting implementation. Doesn't work correctly on CPU/GPU.
auto = nt.experimental.empirical_ntk_fn_tf(
**kwargs,
implementation=nt.NtkImplementation.AUTO)
ntk_auto = auto(x2, x1, params)

return [ntk_jc, ntk_vp, ntk_sd, ntk_auto]


def _check_ntks(ntks):
# Check that implementations match
for ntk1 in ntks:
for ntk2 in ntks:
diff = tf.reduce_max(tf.abs(ntk1 - ntk2))
print(f'NTK implementation diff {diff}.')
assert diff < 1e-4, diff

print('All NTK implementations match.')


def _compute_and_check_ntks(f, x1, x2, params):
ntks = _get_ntks(f, x1, x2, params, vmap_axes=None)
ntks_vmap = _get_ntks(f, x1, x2, params, vmap_axes=0)
_check_ntks(ntks + ntks_vmap)


def main(unused_argv):
x1 = tf.random.normal((6, 8, 8, 3), seed=1)
x2 = tf.random.normal((3, 8, 8, 3), seed=2)

# A vanilla CNN `tf.keras.Model` example.
print('A Keras CNN example.')

f = tf.keras.Sequential()
f.add(tf.keras.layers.Conv2D(16, (3, 3), activation='relu'))
f.add(tf.keras.layers.Conv2D(16, (3, 3), activation='relu'))
f.add(tf.keras.layers.Conv2D(16, (3, 3)))
f.add(tf.keras.layers.Flatten())
f.add(tf.keras.layers.Dense(10))

f.build((None, *x1.shape[1:]))

_, params = nt.experimental.get_apply_fn_and_params(f)
_compute_and_check_ntks(f, x1, x2, params)

# A `tf.function` example.
print('A `tf.function` example.')

params_tf = tf.random.normal((1, 2, 3, 4), seed=3)

@tf.function(input_signature=[tf.TensorSpec(None),
tf.TensorSpec((None, *x1.shape[1:]))])
def f_tf(params, x):
return tf.transpose(x, (0, 3, 1, 2)) * tf.reduce_mean(params**2) + 1.

_compute_and_check_ntks(f_tf, x1, x2, params_tf)


if __name__ == '__main__':
app.run(main)
4 changes: 2 additions & 2 deletions examples/weight_space.py
Expand Up @@ -27,7 +27,7 @@
from jax import jit
from jax import random
from jax.example_libraries import optimizers
from jax.example_libraries.stax import logsoftmax
from jax.nn import log_softmax
import jax.numpy as np
import neural_tangents as nt
from neural_tangents import stax
Expand Down Expand Up @@ -66,7 +66,7 @@ def main(unused_argv):
state_lin = opt_init(params)

# Create a cross-entropy loss function.
loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat)
loss = lambda fx, y_hat: -np.mean(log_softmax(fx) * y_hat)

# Specialize the loss function to compute gradients for both linearized and
# full networks.
Expand Down
2 changes: 2 additions & 0 deletions neural_tangents/__init__.py
Expand Up @@ -32,3 +32,5 @@
from ._src.empirical import taylor_expand

from ._src.monte_carlo import monte_carlo_kernel_fn

from . import experimental
2 changes: 1 addition & 1 deletion neural_tangents/_src/empirical.py
Expand Up @@ -1500,7 +1500,7 @@ def _to_tuple_tree(x: PyTree) -> Tuple:
return tuple(_to_tuple_tree(x_i) for x_i in x)

if isinstance(x, dict):
return tuple((k, v) for k, v in sorted(x.items()))
return tuple((k, _to_tuple_tree(v)) for k, v in sorted(x.items()))

return x

Expand Down
6 changes: 3 additions & 3 deletions neural_tangents/_src/predict.py
Expand Up @@ -31,7 +31,6 @@
from functools import lru_cache
from typing import Callable, Dict, Generator, Iterable, NamedTuple, Optional, Tuple, Union

from typing_extensions import Protocol
import jax
from jax import grad
from jax.experimental import ode
Expand All @@ -40,6 +39,7 @@
from jax.tree_util import tree_all, tree_map
import numpy as onp
import scipy as osp
from typing_extensions import Protocol
from .utils import dataclasses, utils
from .utils.typing import Axes, Get, KernelFn

Expand Down Expand Up @@ -317,8 +317,8 @@ def gradient_descent(
>>> kernel_fn = nt.empirical_ntk_fn(f)
>>> k_test_train = kernel_fn(x_test, x_train, params)
>>>
>>> from jax.example_libraries import stax
>>> cross_entropy = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat)
>>> from jax.nn import log_softmax
>>> cross_entropy = lambda fx, y_hat: -np.mean(log_softmax(fx) * y_hat)
>>> predict_fn = nt.redict.gradient_descent(
>>> cross_entropy, k_train_train, y_train, learning_rate, momentum)
>>>
Expand Down
16 changes: 16 additions & 0 deletions neural_tangents/experimental/__init__.py
@@ -0,0 +1,16 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .empirical_tf.empirical import empirical_ntk_fn_tf
from .empirical_tf.empirical import get_apply_fn_and_params
13 changes: 13 additions & 0 deletions neural_tangents/experimental/empirical_tf/__init__.py
@@ -0,0 +1,13 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

0 comments on commit e28971c

Please sign in to comment.