ML Reproducibility Challenge 2020: Your Classifier is Secretly an Energy Based Model and You Should Treat it Like One
This repo contains a re-implemimentation of the 2020 ICLR paper Your Classifier is Secretly an Energy Based Model and You Should Treat it Like One. A reproducibility report, submitted to the 2020 ML reproducibility challenge, is available here. Local code was run on Mac OS X Mojave, version 10.14.6
- Install conda
- Create conda environment:
conda env create -f environment.yml
- Activate environment
conda activate ml_reprod_hybrid_energy_models
- Confirm you can run the training scripts
python train_supervised.py
andpython train_JEM_algorithm.py
- Model parameters and hyperparameters are stored in
params.json
. Modify these as you see fit; the defaults are what are described in the paper. The neural nets used are given in themodels
directory. - To train a model with the supervised training method run
python train_supervised.py
. Model artefacts (checkpoints) will be stored in./artefacts_supervised
. - To train a model with the joint energy-based model (JEM) training method, run
python train_JEM_algorithm.py
. Model artefacts (checkpoints, images) will be stored in./artefacts
. Note: The JEM training technique is unstable, and the training run will likely crash. This is discussed in Appendix H.3 of the paper. If and when it does crash, do the following: Inparams.json
, changeparams["load_from_checkpoint"]
toTrue
, and changeparams["start_epoch"]
to the epoch where it crashed. Try loading from earlier checkpoints if the most recent one crashes. If all else fails, restart training from the beginning.
- Once the model is trained, run
python calibration.py
to generate the calibration plots for both the supervised as well as JEM training methods. An example calibration plot is shown inexample_artefacts
. - To generate fresh SGLD samples (from a randomly initialized buffer), run
python generate_samples.py
. By default, it will run for 20 SGLD steps but this can be changed. Example SGLD evolutions for 1, 5, 10, 20 and 50 steps are given inexample_artefacts/fresh_sgld_samples
.