Fast and Easy Infinite Neural Networks in Python
Neural Tangents is a high-level neural network API for specifying complex, hierarchical, neural networks of both finite and infinite width. Neural Tangents allows researchers to define, train, and evaluate infinite networks as easily as finite ones.
Infinite (in width or channel count) neural networks are Gaussian Processes (GPs) with a kernel function determined by their architecture (see References for details and nuances of this correspondence).
Neural Tangents allows you to construct a neural network model with the usual building blocks like convolutions, pooling, residual connections, nonlinearities etc. and obtain not only the finite model, but also the kernel function of the respective GP.
The library is written in python using JAX and leveraging XLA to run out-of-the-box on CPU, GPU, or TPU. Kernel computation is highly optimized for speed and memory efficiency, and can be automatically distributed over multiple accelerators with near-perfect scaling.
Neural Tangents is a work in progress. We happily welcome contributions!
- 5-Minute intro
- Package description
- Technical gotchas
- Training dynamics of wide but finite networks
To install Neural Tangents, first follow JAX's installation instructions. With JAX installed, using Neural Tangents should be as easy as:
git clone https://github.com/google/neural-tangents pip install -e neural-tangents
You can then run the examples (using
tensorflow_datasets) by calling:
pip install tensorflow tensorflow-datasets python neural-tangents/examples/weight_space.py python neural-tangents/examples/function_space.py
Finally, you can run tests by calling:
for f in neural-tangents/neural_tangents/tests/*.py; do python $f; done
If you would prefer, you can get started without installing by checking out our colab examples:
See this Colab for a detailed tutorial. Below is a very quick introduction.
Our library closely follows JAX's API for specifying neural networks,
stax a network is defined by a pair of functions
(init_fn, apply_fn) initializing the trainable parameters and computing the outputs of the network respectively. Below is an example of defining a 3-layer network and computing it's outputs
y given inputs
from jax import random from jax.experimental import stax init_fn, apply_fn = stax.serial( stax.Dense(512), stax.Relu, stax.Dense(512), stax.Relu, stax.Dense(1) ) key = random.PRNGKey(1) x = random.normal(key, (10, 100)) _, params = init_fn(key, input_shape=x.shape) y = apply_fn(params, x) # (10, 1) np.ndarray outputs of the neural network
Neural Tangents is designed to serve as a drop-in replacement for
stax, extending the
(init_fn, apply_fn) tuple to a triple
(init_fn, apply_fn, kernel_fn), where
kernel_fn is the kernel function of the infinite network (GP) of the given architecture. Below is an example of computing the covariances of the GP between two batches of inputs
from jax import random from neural_tangents import stax init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(512), stax.Relu(), stax.Dense(512), stax.Relu(), stax.Dense(1) ) key1, key2 = random.split(random.PRNGKey(1)) x1 = random.normal(key1, (10, 100)) x2 = random.normal(key2, (20, 100)) kernel = kernel_fn(x1, x2, 'nngp')
kernel_fn can compute two covariance matrices corresponding to the Neural Network Gaussian Process (NNGP) and Neural Tangent (NT) kernels respectively. The NNGP kernel corresponds to the Bayesian infinite neural network . The NTK corresponds to the (continuous) gradient descent trained infinite network . In the above example, we compute the NNGP kernel but we could compute the NTK or both as follows:
# Get kernel of a single type nngp = kernel_fn(x1, x2, 'nngp') # (10, 20) np.ndarray ntk = kernel_fn(x1, x2, 'ntk') # (10, 20) np.ndarray # Get kernels as a namedtuple both = kernel_fn(x1, x2, ('nngp', 'ntk')) both.nngp == nngp # True both.ntk == ntk # True # Unpack the kernels namedtuple nngp, ntk = kernel_fn(x1, x2, ('nngp', 'ntk')) # Default is to return ('nngp', 'ntk') nngp, ntk = kernel_fn(x1, x2)
Doing inference with infinite networks trained on MSE loss reduces to classical GP inference, for which we also provide convenient tools:
import neural_tangents as nt x_train, x_test = x1, x2 y_train = random.uniform(key1, shape=(10, 1)) # training targets y_test_nngp = nt.predict.gp_inference(kernel_fn, x_train, y_train, x_test, get='nngp') # (20, 1) np.ndarray test predictions of an infinite Bayesian network y_test_ntk = nt.predict.gp_inference(kernel_fn, x_train, y_train, x_test, get='ntk') # (20, 1) np.ndarray test predictions of an infinite continuous # gradient descent trained network at convergence (t = inf)
We can define a more compex, (infinitely) Wide Residual Network  using the same
nt.stax building blocks:
from neural_tangents import stax def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False): Main = stax.serial( stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'), stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME')) Shortcut = stax.Identity() if not channel_mismatch else stax.Conv( channels, (3, 3), strides, padding='SAME') return stax.serial(stax.FanOut(2), stax.parallel(Main, Shortcut), stax.FanInSum()) def WideResnetGroup(n, channels, strides=(1, 1)): blocks =  blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)] for _ in range(n - 1): blocks += [WideResnetBlock(channels, (1, 1))] return stax.serial(*blocks) def WideResnet(block_size, k, num_classes): return stax.serial( stax.Conv(16, (3, 3), padding='SAME'), WideResnetGroup(block_size, int(16 * k)), WideResnetGroup(block_size, int(32 * k), (2, 2)), WideResnetGroup(block_size, int(64 * k), (2, 2)), stax.AvgPool((8, 8)), stax.Flatten(), stax.Dense(num_classes, 1., 0.)) init_fn, apply_fn, kernel_fn = WideResnet(block_size=4, k=1, num_classes=10)
nt) package contains the following modules and methods:
stax- primitives to construct neural networks like
predict- predictions with infinite networks:
predict.gp_inference- either fully Bayesian inference (
get='nngp') or inference with a network trained to full convergence (infinite time) on MSE loss using continuous gradient descent (
predict.gradient_descent_mse- inference with a network trained on MSE loss with continuous gradient descent for an arbitrary finite time.
predict.gradient_descent- inference with a network trained on arbitrary loss with continuous gradient descent for an arbitrary finite time (using an ODE solver).
predict.momentum- inference with a network trained on arbitrary loss with continuous momentum gradient descent for an arbitrary finite time (using an ODE solver).
monte_carlo_kernel_fn- compute a Monte Carlo kernel estimate of any
(init_fn, apply_fn), not necessarily specified
nt.stax, enabling the kernel computation of infinite networks without closed-form expressions.
Tools to investigate training dynamics of wide but finite neural networks, like
empirical_kernel_fnand more. See Training dynamics of wide but finite networks for details.
To enable 64-bit precision, set the respective JAX flag before importing
neural_tangents (see the JAX guide), for example:
from jax.config import config config.update("jax_enable_x64", True) import neural_tangents as nt # 64-bit precision enabled
We remark the following differences between our library and the JAX one.
nt.staxlayers are instantiated with a function call, i.e.
- All layers with trainable parameters use the NTK parameterization (see , Remark 1).
jax.experimental.staxmay have different layers and options available (for example
CIRCULARpadding, but only
We will be dropping python 2 support before 2020.
Training dynamics of wide but finite networks
The kernel of an infinite network
kernel_fn(x1, x2).ntk combined with
nt.predict.gradient_descent_mse together allow to analytically track the outputs of an infinitely wide neural network trained on MSE loss througout training. Here we discuss the implications for wide but finite neural networks and present tools to study their evolution in weight space (trainable parameters of the network) and function space (outputs of the network).
Continuous gradient descent in an infinite network has been shown in  to correspond to training a linear (in trainable parameters) model, which makes linearized neural networks an important subject of study for understanding the behavior of parameters in wide models.
For this, we provide two convenient methods:
which allow to linearize or get an arbitrary-order Taylor expansion of any function
apply_fn(params, x) around some initial parameters
apply_fn_lin = nt.linearize(apply_fn, params_0).
One can use
apply_fn_lin(params, x) exactly as you would any other function
(including as an input to JAX optimizers). This makes it easy to compare the
training trajectory of neural networks with that of its linearization.
Previous theory and experiments have examined the linearization of neural
networks from inputs to logits or pre-activations, rather than from inputs to
post-activations which are substantially more nonlinear.
import jax.numpy as np import neural_tangents as nt def apply_fn(params, x): W, b = params return np.dot(x, W) + b W_0 = np.array([[1., 0.], [0., 1.]]) b_0 = np.zeros((2,)) apply_fn_lin = nt.linearize(apply_fn, (W_0, b_0)) W = np.array([[1.5, 0.2], [0.1, 0.9]]) b = b_0 + 0.2 x = np.array([[0.3, 0.2], [0.4, 0.5], [1.2, 0.2]]) logits = apply_fn_lin((W, b), x) # (3, 2) np.ndarray
Outputs of a linearized model evolve identically to those of an infinite one  but with a different kernel - specifically, the Neural Tangent Kernel  evaluated on the specific
apply_fn of the finite network given specific
params_0 that the network is initialized with. For this we provide the
nt.empirical_kernel_fn function that accepts any
apply_fn and returns a
kernel_fn(x1, x2, params) that allows to compute the empirical NTK and NNGP kernels on specific
import jax.random as random import jax.numpy as np import neural_tangents as nt def apply_fn(params, x): W, b = params return np.dot(x, W) + b W_0 = np.array([[1., 0.], [0., 1.]]) b_0 = np.zeros((2,)) params = (W_0, b_0) key1, key2 = random.split(random.PRNGKey(1), 2) x_train = random.normal(key1, (3, 2)) x_test = random.normal(key2, (4, 2)) y_train = random.uniform(key1, shape=(3, 2)) kernel_fn = nt.empirical_kernel_fn(apply_fn) ntk_train_train = kernel_fn(x_train, x_train, params, 'ntk') ntk_test_train = kernel_fn(x_test, x_train, params, 'ntk') mse_predictor = nt.predict.gradient_descent_mse( ntk_train_train, y_train, ntk_test_train) t = 5. y_train_0 = apply_fn(params, x_train) y_test_0 = apply_fn(params, x_test) y_train_t, y_test_t = mse_predictor(t, y_train_0, y_test_0) # (3, 2) and (4, 2) np.ndarray train and test outputs after `t` units of time # training with continuous gradient descent
What to Expect
The success or failure of the linear approximation is highly architecture dependent. However, some rules of thumb that we've observed are:
Convergence as the network size increases.
For fully-connected networks one generally observes very strong agreement by the time the layer-width is 512 (RMSE of about 0.05 at the end of training).
For convolutional networks one generally observes reasonable agreement agreement by the time the number of channels is 512.
Convergence at small learning rates.
With a new model it is therefore adviseable to start with a very large model on a small dataset using a small learning rate.
Neural tangents has been used in the following papers:
Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent.
Jaehoon Lee*, Lechao Xiao*, Samuel S. Schoenholz, Yasaman Bahri, Roman Novak, Jascha Sohl-Dickstein, Jeffrey Pennington
Training Dynamics of Deep Networks using Stochastic Gradient Descent via Neural Tangent Kernel.
Soufiane Hayou, Arnaud Doucet, Judith Rousseau
Please let us know if you make use of the code in a publication and we'll add it to the list!
If you use the code in a publication, please cite the repo using the .bib,