The goal of this repository is to provide users with a collection of benchmark to evaluate Flax optimizers. We aim to both give fast benchmarks, to quickly evaluate new optimizers, and slow classic benchmarks to help authors of deep-learning papers that want to publish work on an optimizer.
To do so, we built an infrastructure to:
- run optimizers on a panel of datasets (adapted from TensorFlow Datasets),
- save all meaningful training metrics and parameters to disk (as human-readable json files),
- plot the results later.
This is a work in progress, you should expect the elements in the TODO list to be done within a few weeks.
You can install this librarie with:
pip install git+https://github.com/nestordemeure/flaxOptimizersBenchmark.git
TODO
- MNIST
- Imagenette (V2)
- Imagewoof (V2)
-
add some datasets
- imagenet
- COCO
- wikitext
-
code to load experiments by search criteria or load all of them and apply the criteria afterward with a simple filter
-
plotting functions that take
Experiment
as input (or (dataset,architecture) to compare optimizers)- bar plot of jit/run time
- bar plot of final train/test loss
- plot of metrics accross time (train and/or test) (one or all optimizers)
- final loss as a function of starting lr
- data augmentation (as it is very problem specific and we are focussing on the optimizers rather than the individual problems)
- learning rate and weight decay scheduler (might be added later)
You can find optimizers compatible with Flax in the following repositories:
- flaxOptimizers contains implementations of a large number of optimizers in Flax.
- AdahessianJax contains my implementation of the Adahessian second order optimizer in Flax.
- Flax.optim contains the official Flax optimizers.