Code to run the simulations of the paper:
Giorgia Dellaferrera, Gabriel Kreiman
Presented at ICML 2022: https://proceedings.mlr.press/v162/dellaferrera22a.html
We run the experiments with the following:
Numpy framework (fully connected models): Python 3.9.5, Numpy 1.19.5, Keras 2.5.0
Pytorch framework (convolutional models): Python 3.7.10, Numpy 1.19.2, Pytorch 1.6.0
The notebook Tutorial_PEPITA_FullyConnectedNets_CIFAR-10.ipynb
provides a simple tutorial on how to implement and run the PEPITA training scheme for fully connected models. The entire framework is pytorch-based. The settings and results are the same as reported in the paper.
The training for 100 epochs takes approximately 1.5 hours on CPU.
The experiments are run through main.py
, which uses functions in functions.py
and utils.py
.
The entire framework is numpy-based and relies on the keras library to load the datasets.
For example, to run PEPITA with the standard settings on the MNIST dataset:
python main.py --exp_name Experiment1 \
--learn_type ERIN --n_runs 1 --train_epochs 100 \
--sample_passes 2 --n_samples all --eta 0.1 --dropout 0.9 \
--eta_decay --mnist --validation --batch_size 64 \
--update_type mom --w_init he_uniform \
--build auto --struct uniform --start_size 1024 --n_hlayers 1 --act_hidden relu --act_out softmax
Note that the training scheme for PEPITA is denoted as ERIN (ERror-INput).
If you train with PEPITA (ERIN), make sure to use the setting --sample_passes 2
, to have for each input two forward passes.
Substitute --learn_type ERIN
with --learn_type BP
to train the network with backpropagation. Remember to set --sample_passes 1
.
The experiments are run through main_pytorch.py
, which uses functions in models.py
. The entire framework is pytorch-based.
For example, to run PEPITA with the standard settings on the MNIST dataset:
python main_pytorch.py --exp_name Experiment2 \
--learn_type ERIN --n_runs 1 --train_epochs 100 \
--eta 0.01 --dropout 0.9 --Bstd 0.05 \
--eta_decay --dataset mn --batch_size 50 \
--update_type mom --w_init he_uniform \
--model Net1conv1fcXL
The argument Bstd
defines the standard deviation of the projection matrix.
Here we use B
instead of F
(paper) to denote the projection matrix to avoid confusion with torch.nn.functional.
The notebook plot_compute_slowness.ipynb
contains the function to extract the convergence rate as "slowness" parameter.
@InProceedings{pmlr-v162-dellaferrera22a,
title = {Error-driven Input Modulation: Solving the Credit Assignment Problem without a Backward Pass},
author = {Dellaferrera, Giorgia and Kreiman, Gabriel},
booktitle = {Proceedings of the 39th International Conference on Machine Learning},
pages = {4937--4955},
year = {2022},
editor = {Chaudhuri, Kamalika and Jegelka, Stefanie and Song, Le and Szepesvari, Csaba and Niu, Gang and Sabato, Sivan},
volume = {162},
series = {Proceedings of Machine Learning Research},
month = {17--23 Jul},
publisher = {PMLR},
pdf = {https://proceedings.mlr.press/v162/dellaferrera22a/dellaferrera22a.pdf},
url = {https://proceedings.mlr.press/v162/dellaferrera22a.html},
}