Python code used in Understanding Approximate Fisher Information for Fast Convergence of Natural Gradient Descent in Wide Neural Networks (To appear in NeurIPS 2020 as an oral presentation). This repository provides simple JAX-/NumPy-based implementations of NGD with exact/approximate Fisher Information Matrix (FIM) both in parameter-space and function-space (by empirical/analytical NTK).
Code | NTK | Loss | Exact | BD | BTD | K-FAC | Unit-wise |
---|---|---|---|---|---|---|---|
jax-based | empirical | MSE, cross-entropy | - | ||||
numpy-based | empirical, analytical | MSE | - |
NOTE: The NumPy-based code supports only three-layered MLP. The JAX-based code implements NGD with empirical NTK (for finite-width DNNs) on top of Neural Tangents. It supports more general DNN architectures and (multi) GPU acceleration, but it does not support NGD with analytical NTK (for infinite-width DNNs).
$ git clone git@github.com:kazukiosawa/ngd_in_wide_nn.git
$ cd ngd_in_wide_nn
$ pip install -r requirements.txt
To use GPU, follow JAX's installation guide.
Visit jax-based or numpy-based for information.
@misc{karakida2020understanding,
title={Understanding Approximate Fisher Information for Fast Convergence of Natural Gradient Descent in Wide Neural Networks},
author={Ryo Karakida and Kazuki Osawa},
year={2020},
eprint={2010.00879},
archivePrefix={arXiv},
primaryClass={stat.ML}
}
(To appear in NeurIPS 2020 as an oral presentation)