Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improvements to the
empirical
and predict
modules.
Since `predict` depends on the layout of the empirical kernel, changes are combined in one CL. 1) Empirical: 1.1) Unify all kernel shapes: now all kernels have the `(N, N, X, X, Y, Y, ...)` layout (subject to next change 1.2), empirical and analytical, NTK and NNGP. This allows to remove all the tedious checks and bugs related to different shapes of different kernels, but potentially at a small compute expense. However, IMO the simplification here is worth the cost (haven't checked yet how big it is, if any). 1.2) Support `trace_axes` and `diagonal_axes` to indicate which output axes to consider i.i.d. (e.g. logits) and compute the mean trace over, or on which axes to compute only the diagonal (a more flexible version of `diagonal_spatial/diagonal_batch`). Note that for the implicit kernel it currently does not improve performance, but for direct it must. Note also that I have now set `trace_axes=(-1,)` for consistency everywhere by default, which may break someone, as before it was `()` for NTK and `(-1,)` for NNGP. The new shape layout may also break someone. 1.3) More type annotations and some doc fixes. 2) Predict: 2.0) Support `trace_axes` as in `empirical`. 2.1) Use JAX ODE solver instead of scipy. 2.2) Make all predict methods efficiently support array-valued time inputs. 2.3) Support the `learning_rate` in all methods for IMO a simpler correspondence between finite- and infinite-width optimizers. 2.4) Make all predictors return a single `predict_fn` method instead of init/predict/get etc. 2.5) Make all predictors work for outputs and kernels of arbitrary (but consistent) number of dimensions. 2.6) Fuse gradient_descent / momentum into one method. 2.7) Massive refactoring and purge of repetitive code. 2.8) More type annotations and doc fixes. 2.9) Fix all github reported bugs (and TPU test failures). 2.10) Remove device-placement boilerplate and pre-jitting. 2.11) Fuse `gp_inference` and `gradient_descent_mse_gp` into a single `gradient_descent_mse_ensemble`, that treats `t=None` as infinity, for simpler API and consistency with other methods. 2.12) Stop supporting iterator diagonal regularizers for simplicity, but make `gp_inference_mat` public as `gp_inference`. In follow-up CLs, methods may likely support np.array-valued diagonal regularizers. 2.12) Make `gp_inference` (former `gp_inference_mat`) and `gradient_descent_mse_ensemble` (former `gradient_descent_mse_gp`) return, like others, predictors, that cache intermediary information like kernel matrix / its eigendecomposition / cholesky factorization, allowing efficient invocation of these methods on different times / test points / get values / with and without covariances. 2.13) Add support for efficient train-set predictions (x_test=None) in finite and infinite time. 3) stax: 3.1) fix a bug when channel axis preceded the batch axis and add tests. 3.2) tighten / add some typing annotations 4) batch: remove device placement boilerplate / pre-jitting, simplify code. 5) General: add or fix type annotations / docs throughout. PiperOrigin-RevId: 317946977
- Loading branch information
Showing
24 changed files
with
4,168 additions
and
3,659 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.