Skip to content
The Deep Weight Prior, ICLR 2019
Branch: master
Clone or download
Fetching latest commit…
Cannot retrieve the latest commit at this time.
Permalink
Type Name Latest commit message Commit time
Failed to load latest commit information.
data code release Apr 30, 2019
experiments code release Apr 30, 2019
models code release Apr 30, 2019
myexman code release Apr 30, 2019
LICENSE code release Apr 30, 2019
README.md code release Apr 30, 2019
__init__.py code release Apr 30, 2019
logger.py code release Apr 30, 2019
shutils.py code release Apr 30, 2019
train-bnn.py code release Apr 30, 2019
train-net.py code release Apr 30, 2019
train-vae.py code release Apr 30, 2019
utils.py code release Apr 30, 2019
vae-stats.py code release Apr 30, 2019

README.md

Deep Weight Prior

This repo contains code for our ICLR2019 paper The Deep Weight Prior.

The generative model for kernels of convolutional neural networks, that acts as a prior distribution while training on new datasets.

Abstract: Bayesian inference is known to provide a general framework for incorporating prior knowledge or specific properties into machine learning models via carefully choosing a prior distribution. In this work, we propose a new type of prior distributions for convolutional neural networks, deep weight prior (DWP), that exploit generative models to encourage a specific structure of trained convolutional filters e.g., spatial correlations of weights. We define DWP in the form of an implicit distribution and propose a method for variational inference with such type of implicit priors. In experiments, we show that DWP improves the performance of Bayesian neural networks when training data are limited, and initialization of weights with samples from DWP accelerates training of conventional convolutional neural networks.

Experiments

Classification

Convergence

Random Feature Extraction

Launch experiments

Bayesian NN with Deep Weight Prior

Setup environment:

export DATA_ROOT=<path to MNIST/notMNIST>
export DWP_PATH=<path to logs to save>

Train source models (100 times):

python train-net.py --data not-mnist --bs 256 --hid_dim 256 512 --l2 0.001 --model fconv --net_cfg vanilla --seed <unique per launch>

Then we need to prune kernels with norm almost equal to zero. Selected kenels used for training priors can be found in data/fconv/conv7x7/ and data/fconv/conv5x5/.

Examples of 7x7 kernels

Examples of 7x7 kernels

Train VAE priors:

python train-vae.py --hidden_dim 32 --kernel_dim 7 --z_dim 2 --num_epochs 300 --data_dir data/fconv/conv7x7

python train-vae.py --hidden_dim 64 --kernel_dim 5 --z_dim 4 --num_epochs 300 --data_dir data/fconv/conv5x5

VAE latent space

VAE latent space for 7x7 kernels

Trained VAEs can found at data/fconv/vae7x7/ and data/fconv/vae5x5/, and can be loaded by utils.load_vae.

Train BNN with VAE prior (we vary data_split_seed from 0 to 9):

python train-bnn.py --data mnist --data_split_seed <seed> --num_examples 50 --epochs 10000 --model fconv --vae_list data/fconv/vae7x7/ data/fconv/vae5x5/ --init_list vae vae --net_cfg bayes --hid_dim 32 128 --prior_list vae vae --eval_freq 100

Citation

If you found this code useful please cite our paper

@article{atanov2019dwp,
  title={The Deep Weight Prior},
  author={Atanov, Andrei and Ashukha, Arsenii and Struminsky, Kirill and Vetrov, Dmitry and Welling, Max},
  journal={arXiv preprint arXiv:1810.06943},
  year={2019}
}
You can’t perform that action at this time.