Skip to content

simple JAX-/NumPy-based implementations of NGD with exact/approximate Fisher Information Matrix both in parameter-space and function-space (by empirical/analytical NTK).

License

kazukiosawa/ngd_in_wide_nn

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Natural Gradient Descent in Wide Neural Networks

Python code used in Understanding Approximate Fisher Information for Fast Convergence of Natural Gradient Descent in Wide Neural Networks (To appear in NeurIPS 2020 as an oral presentation). This repository provides simple JAX-/NumPy-based implementations of NGD with exact/approximate Fisher Information Matrix (FIM) both in parameter-space and function-space (by empirical/analytical NTK).

image

Code NTK Loss Exact BD BTD K-FAC Unit-wise
jax-based empirical MSE, cross-entropy -
numpy-based empirical, analytical MSE -

NOTE: The NumPy-based code supports only three-layered MLP. The JAX-based code implements NGD with empirical NTK (for finite-width DNNs) on top of Neural Tangents. It supports more general DNN architectures and (multi) GPU acceleration, but it does not support NGD with analytical NTK (for infinite-width DNNs).

Setup

$ git clone git@github.com:kazukiosawa/ngd_in_wide_nn.git
$ cd ngd_in_wide_nn
$ pip install -r requirements.txt

To use GPU, follow JAX's installation guide.

How to run

Visit jax-based or numpy-based for information.

Citation

@misc{karakida2020understanding,
      title={Understanding Approximate Fisher Information for Fast Convergence of Natural Gradient Descent in Wide Neural Networks}, 
      author={Ryo Karakida and Kazuki Osawa},
      year={2020},
      eprint={2010.00879},
      archivePrefix={arXiv},
      primaryClass={stat.ML}
}

(To appear in NeurIPS 2020 as an oral presentation)

About

simple JAX-/NumPy-based implementations of NGD with exact/approximate Fisher Information Matrix both in parameter-space and function-space (by empirical/analytical NTK).

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages