Skip to content

BPTargetDTP/ScalableDTP

Repository files navigation

Towards Scaling Difference Target Propagation with Backprop Targets

This repository is the official implementation of "Towards Scaling Difference Target Propagation with Backprop Targets", currently under review at ICML 2022. The following code runs on Python > 3.7 with Pytorch >= 1.7.0.

Installation

pip install -e .

(Optional): We suggest you use a conda environment. The specs of our environment are stored in conda_env_specs.txt.

Naming of methods:

Name in paper Name in codebase
L-DRL DTP
Backpropagation BaselineModel
DRL meulemans_dtp (Based on the original authors' repo)
Target Propagation TargetProp
Difference Target Propagation VanillaDTP
"Parallel" L-DRL (not in the paper) ParallelDTP

Codebase structure

The main logic of our method is in target_prop/models/dtp.py

An initial PyTorch implementation of our DTP model can be found under target_prop/legacy. This model was then re-implemented using PyTorch-Lightning.

Here is how the codebase is roughly structured:

├── main_pl.py             # Training script used in the paper
├── main.py                # training script (legacy)
├── figure_4_3.py          # Script for Figure 4.3
├── data                   # Data for figure 4.3
├── final_figures          # Resulting figures 4.3
├── meulemans_dtp          # Codebase for DRL (Meulemans repo)
├── numerical_experiments  # Initial scripts for creating the figures (used for fig. 4.2) 
└── target_prop
    ├── datasets  # Datasets
    ├── legacy    # initial implementation
    ├── models    # Code for all the models except DRL
    └── networks  # Networks (SimpleVGG, LetNet, ResNet)

Running the code

  • Recreating figure 4.2:

    $ python -m numerical_experiments figure_4_2

    The figure save location will then be displayed on the console.

  • Recreating figure 4.3:

    $ pytest target_prop/networks/lenet_test.py
    $ python plot.py

To run the pytorch-lightning re-implementation of DTP on CIFAR-10, using a VGG-like architecture, use the following command:

python main_pl.py run dtp simple_vgg

To see a list of available command-line options, use the "--help" command.

python main_pl.py --help
python main_pl.py run --help

To use the modified version of the above DTP model, with "parallel" feedback weight training on CIFAR-10, use the following command:

python main_pl.py run parallel_dtp simple_vgg

ImageNet

To train with DTP on downsampled ImageNet 32x32 dataset, do:

python main_pl.py run dtp <architecture> --dataset imagenet32

Legacy Implementation

To check training on CIFAR-10, use the following command:

python main.py --batch-size 128 \
    --C 128 128 256 256 512 \
    --iter 20 30 35 55 20 \
    --epochs 90 \
    --lr_b 1e-4 3.5e-4 8e-3 8e-3 0.18 \
    --noise 0.4 0.4 0.2 0.2 0.08 \
    --lr_f 0.08 \
    --beta 0.7 \
    --path CIFAR-10 \
    --scheduler --wdecay 1e-4

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published