Skip to content

Commit

Permalink
Improvements to the empirical and predict modules.
Browse files Browse the repository at this point in the history
Since `predict` depends on the layout of the empirical kernel, changes are combined in one CL.

1) Empirical:
1.1) Unify all kernel shapes: now all kernels have the `(N, N, X, X, Y, Y, ...)` layout (subject to next change 1.2), empirical and analytical, NTK and NNGP. This allows to remove all the tedious checks and bugs related to different shapes of different kernels, but potentially at a small compute expense. However, IMO the simplification here is worth the cost (haven't checked yet how big it is, if any).
1.2) Support `trace_axes` and `diagonal_axes` to indicate which output axes to consider i.i.d. (e.g. logits) and compute the mean trace over, or on which axes to compute only the diagonal (a more flexible version of `diagonal_spatial/diagonal_batch`). Note that for the implicit kernel it currently does not improve performance, but for direct it must. Note also that I have now set `trace_axes=(-1,)` for consistency everywhere by default, which may break someone, as before it was `()` for NTK and `(-1,)` for NNGP. The new shape layout may also break someone.
1.3) More type annotations and some doc fixes.

2) Predict:
2.0) Support `trace_axes` as in `empirical`.
2.1) Use JAX ODE solver instead of scipy.
2.2) Make all predict methods efficiently support array-valued time inputs.
2.3) Support the `learning_rate` in all methods for IMO a simpler correspondence between finite- and infinite-width optimizers.
2.4) Make all predictors return a single `predict_fn` method instead of init/predict/get etc.
2.5) Make all predictors work for outputs and kernels of arbitrary (but consistent) number of dimensions.
2.6) Fuse gradient_descent / momentum into one method.
2.7) Massive refactoring and purge of repetitive code.
2.8) More type annotations and doc fixes.
2.9) Fix all github reported bugs (and TPU test failures).
2.10) Remove device-placement boilerplate and pre-jitting.
2.11) Fuse `gp_inference` and `gradient_descent_mse_gp` into a single `gradient_descent_mse_ensemble`, that treats `t=None` as infinity, for simpler API and consistency with other methods.
2.12) Stop supporting iterator diagonal regularizers for simplicity, but make `gp_inference_mat` public as `gp_inference`. In follow-up CLs, methods may likely support np.array-valued diagonal regularizers.
2.12) Make `gp_inference` (former `gp_inference_mat`) and `gradient_descent_mse_ensemble` (former `gradient_descent_mse_gp`) return, like others, predictors, that cache intermediary information like kernel matrix / its eigendecomposition / cholesky factorization, allowing efficient invocation of these methods on different times / test points / get values / with and without covariances.
2.13) Add support for efficient train-set predictions (x_test=None) in finite and infinite time.

3) stax:
3.1) fix a bug when channel axis preceded the batch axis and add tests.
3.2) tighten / add some typing annotations

4) batch: remove device placement boilerplate / pre-jitting, simplify code.

5) General: add or fix type annotations / docs throughout.

PiperOrigin-RevId: 317946977
  • Loading branch information
romanngg committed Jun 24, 2020
1 parent a936d87 commit a76bbb4
Show file tree
Hide file tree
Showing 24 changed files with 4,168 additions and 3,659 deletions.
38 changes: 23 additions & 15 deletions README.md
Expand Up @@ -78,7 +78,7 @@ You can now run the examples (using [`tensorflow_datasets`](https://github.com/t
and tests by calling:

```
pip install "tensorflow>=2.2.0rc3" "tensorflow-datasets>=3.0.0"
pip install tensorflow tensorflow-datasets --upgrade
python neural-tangents/examples/infinite_fcn.py
python neural-tangents/examples/weight_space.py
Expand Down Expand Up @@ -161,14 +161,23 @@ import neural_tangents as nt
x_train, x_test = x1, x2
y_train = random.uniform(key1, shape=(10, 1)) # training targets

y_test_nngp = nt.predict.gp_inference(kernel_fn, x_train, y_train, x_test,
get='nngp')
predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
y_train)

y_test_nngp = predict_fn(x_test=x_test, get='nngp')
# (20, 1) np.ndarray test predictions of an infinite Bayesian network

y_test_ntk = nt.predict.gp_inference(kernel_fn, x_train, y_train, x_test,
get='ntk')
y_test_ntk = predict_fn(x_test=x_test, get='nngp')
# (20, 1) np.ndarray test predictions of an infinite continuous
# gradient descent trained network at convergence (t = inf)

# Get predictions as a namedtuple
both = predict_fn(x_test=x_test, get=('nngp', 'ntk'))
both.nngp == y_test_nngp # True
both.ntk == y_test_ntk # True

# Unpack the predictions namedtuple
y_test_nngp, y_test_ntk = predict_fn(x_test=x_test, get=('nngp', 'ntk'))
```


Expand Down Expand Up @@ -212,19 +221,19 @@ init_fn, apply_fn, kernel_fn = WideResnet(block_size=4, k=1, num_classes=10)

## Package description

The `neural_tangents` (`nt`) package contains the following modules and methods:
The `neural_tangents` (`nt`) package contains the following modules and functions:

* `stax` - primitives to construct neural networks like `Conv`, `Relu`, `serial`, `parallel` etc.

* `predict` - predictions with infinite networks:

* `predict.gp_inference` - either fully Bayesian inference (`get='nngp'`) or inference with a network trained to full convergence (infinite time) on MSE loss using continuous gradient descent (`get='ntk'`).
* `predict.gradient_descent_mse` - inference with a single infinite width / linearized network trained on MSE loss with continuous gradient descent for an arbitrary finite or infinite (`t=None`) time. Computed in closed form.

* `predict.gradient_descent_mse` - inference with a network trained on MSE loss with continuous gradient descent for an arbitrary finite time.
* `predict.gradient_descent` - inference with a single infinite width / linearized network trained on arbitrary loss with continuous (momentum) gradient descent for an arbitrary finite time. Computed using an ODE solver.

* `predict.gradient_descent` - inference with a network trained on arbitrary loss with continuous gradient descent for an arbitrary finite time (using an ODE solver).
* `predict.gradient_descent_mse_ensemble` - inference with an infinite ensemble of infinite width networks, either fully Bayesian (`get='nngp'`) or inference with MSE loss using continuous gradient descent (`get='ntk'`). Finite-time Bayesian inference (e.g. `t=1., get='nngp'`) is interpreted as gradient descent on the top layer only (see [[11]](#11-wide-neural-networks-of-any-depth-evolve-as-linear-models-under-gradient-descent-neurips-2019-jaehoon-lee-lechao-xiao-samuel-s-schoenholz-yasaman-bahri-roman-novak-jascha-sohl-dickstein-jeffrey-pennington))), since it converges to exact Gaussian process inference with NNGP (`t=None, get='nngp'`). Computed in closed form.

* `predict.momentum` - inference with a network trained on arbitrary loss with continuous momentum gradient descent for an arbitrary finite time (using an ODE solver).
* `predict.gp_inference` - exact closed form Gaussian process inference using NNGP (`get='nngp'`), NTK (`get='ntk'`), or both (`get=('nngp', 'ntk')`). Equivalent to `predict.gradient_descent_mse_ensemble` with `t=None` (infinite training time), but has a slightly different API (accepting precomputed kernel matrix `k_train_train` instead of `kernel_fn` and `x_train`).

* `monte_carlo_kernel_fn` - compute a Monte Carlo kernel estimate of _any_ `(init_fn, apply_fn)`, not necessarily specified `nt.stax`, enabling the kernel computation of infinite networks without closed-form expressions.

Expand Down Expand Up @@ -268,7 +277,7 @@ The kernel of an infinite network `kernel_fn(x1, x2).ntk` combined with `nt.pre

Continuous gradient descent in an infinite network has been shown in [[11]](#11-wide-neural-networks-of-any-depth-evolve-as-linear-models-under-gradient-descent-neurips-2019-jaehoon-lee-lechao-xiao-samuel-s-schoenholz-yasaman-bahri-roman-novak-jascha-sohl-dickstein-jeffrey-pennington) to correspond to training a _linear_ (in trainable parameters) model, which makes linearized neural networks an important subject of study for understanding the behavior of parameters in wide models.

For this, we provide two convenient methods:
For this, we provide two convenient functions:

* `nt.linearize`, and
* `nt.taylor_expand`,
Expand Down Expand Up @@ -305,7 +314,7 @@ logits = apply_fn_lin((W, b), x) # (3, 2) np.ndarray

### Function space:

Outputs of a linearized model evolve identically to those of an infinite one [[11]](#11-wide-neural-networks-of-any-depth-evolve-as-linear-models-under-gradient-descent-neurips-2019-jaehoon-lee-lechao-xiao-samuel-s-schoenholz-yasaman-bahri-roman-novak-jascha-sohl-dickstein-jeffrey-pennington) but with a different kernel - specifically, the Neural Tangent Kernel [[10]](#10-neural-tangent-kernel-convergence-and-generalization-in-neural-networks-neurips-2018-arthur-jacot-franck-gabriel-clément-hongler) evaluated on the specific `apply_fn` of the finite network given specific `params_0` that the network is initialized with. For this we provide the `nt.empirical_kernel_fn` function that accepts any `apply_fn` and returns a `kernel_fn(x1, x2, params)` that allows to compute the empirical NTK and NNGP kernels on specific `params`.
Outputs of a linearized model evolve identically to those of an infinite one [[11]](#11-wide-neural-networks-of-any-depth-evolve-as-linear-models-under-gradient-descent-neurips-2019-jaehoon-lee-lechao-xiao-samuel-s-schoenholz-yasaman-bahri-roman-novak-jascha-sohl-dickstein-jeffrey-pennington) but with a different kernel - specifically, the Neural Tangent Kernel [[10]](#10-neural-tangent-kernel-convergence-and-generalization-in-neural-networks-neurips-2018-arthur-jacot-franck-gabriel-clément-hongler) evaluated on the specific `apply_fn` of the finite network given specific `params_0` that the network is initialized with. For this we provide the `nt.empirical_kernel_fn` function that accepts any `apply_fn` and returns a `kernel_fn(x1, x2, params, get)` that allows to compute the empirical NTK and/or NNGP (based on `get`) kernels on specific `params`.

#### Example:

Expand All @@ -330,13 +339,12 @@ y_train = random.uniform(key1, shape=(3, 2))
kernel_fn = nt.empirical_kernel_fn(apply_fn)
ntk_train_train = kernel_fn(x_train, x_train, params, 'ntk')
ntk_test_train = kernel_fn(x_test, x_train, params, 'ntk')
mse_predictor = nt.predict.gradient_descent_mse(
ntk_train_train, y_train, ntk_test_train)
mse_predictor = nt.predict.gradient_descent_mse(ntk_train_train, y_train)

t = 5.
y_train_0 = apply_fn(params, x_train)
y_test_0 = apply_fn(params, x_test)
y_train_t, y_test_t = mse_predictor(t, y_train_0, y_test_0)
y_train_t, y_test_t = mse_predictor(t, y_train_0, y_test_0, ntk_test_train)
# (3, 2) and (4, 2) np.ndarray train and test outputs after `t` units of time
# training with continuous gradient descent
```
Expand Down
12 changes: 8 additions & 4 deletions examples/function_space.py
@@ -1,3 +1,5 @@
# Lint as: python3

# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -72,7 +74,7 @@ def main(unused_argv):
ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0)
g_dd = ntk(x_train, None, params)
g_td = ntk(x_test, x_train, params)
predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td)
predictor = nt.predict.gradient_descent_mse(g_dd, y_train)

# Get initial values of the network in function space.
fx_train = apply_fn(params, x_train)
Expand All @@ -88,11 +90,13 @@ def main(unused_argv):

# Get predictions from analytic computation.
print('Computing analytic prediction.')
fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)
fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test, g_td)

# Print out summary data comparing the linear / nonlinear model.
util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss)
util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
util.print_summary('train', y_train, apply_fn(params, x_train), fx_train,
loss)
util.print_summary('test', y_test, apply_fn(params, x_test), fx_test,
loss)

if __name__ == '__main__':
app.run(main)
9 changes: 3 additions & 6 deletions examples/infinite_fcn.py
Expand Up @@ -58,12 +58,9 @@ def main(unused_argv):

start = time.time()
# Bayesian and infinite-time gradient descent inference with infinite network.
fx_test_nngp, fx_test_ntk = nt.predict.gp_inference(kernel_fn,
x_train,
y_train,
x_test,
get=('nngp', 'ntk'),
diag_reg=1e-3)
predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
y_train, diag_reg=1e-3)
fx_test_nngp, fx_test_ntk = predict_fn(x_test=x_test)
fx_test_nngp.block_until_ready()
fx_test_ntk.block_until_ready()

Expand Down
1 change: 1 addition & 0 deletions neural_tangents/__init__.py
Expand Up @@ -19,6 +19,7 @@
from neural_tangents.utils.empirical import empirical_kernel_fn
from neural_tangents.utils.empirical import empirical_nngp_fn
from neural_tangents.utils.empirical import empirical_ntk_fn
from neural_tangents.utils.empirical import empirical_direct_ntk_fn
from neural_tangents.utils.empirical import linearize
from neural_tangents.utils.empirical import taylor_expand
from neural_tangents.utils.monte_carlo import monte_carlo_kernel_fn

0 comments on commit a76bbb4

Please sign in to comment.