Skip to content

Implementing Bayesian neural networks to close the amortization gap in VAEs in pytorch.

License

Notifications You must be signed in to change notification settings

jordandeklerk/Amortized-Bayes

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

27 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Closing the Amortization Gap in Bayesian Deep Generative Models

This project explores the integration of amortized variational inference (A-VI) with deep Bayesian Variational Autoencoders (VAEs). We delve into the theoretical foundations and practical applications to close the amortization gap in Bayesian modeling. Through comprehensive numerical experiments, we demonstrate how A-VI enhances computational efficiency and accuracy in modeling on benchmark imaging datasets like MNIST and FashionMNIST.

Overview

Amortized variational inference (A-VI) has emerged as a promising approach to enhance the efficiency of Bayesian deep generative models. This project investigates the effectiveness of A-VI in closing the amortization gap between A-VI and traditional variational inference methods, such as factorized variational inference (F-VI), or mean-field variational inference. We conduct numerical experiments to compare the performance of A-VI with varying neural network architectures against F-VI and constant-VI.

Key Findings

  • A-VI, when implemented with sufficiently deep neural networks, can achieve the same evidence lower bound (ELBO) and reconstruction mean squared error (MSE) as F-VI while being 2 to 3 times computationally faster.
  • These results highlight the potential of A-VI in addressing the amortization interpolation problem and suggest that a deep encoder-decoder linear neural network with full Bayesian inference over the latent variables can effectively approximate an ideal inference function.

Getting Started

To get started with this project, clone the repository and install the required dependencies:

git clone https://github.com/jordandeklerk/Amortized-Bayes.git
cd Amortized-Bayes
pip install -r requirements.txt

Then run the main.py script:

python main.py

Project Structure

├── experiment.py
├── images
│   ├── fmnist_comp.png
│   ├── fmnist_elbo.png
│   ├── fmnist_mse.png
│   ├── fmnist_mse_test.png
│   ├── index.md
│   ├── mnist_comp.png
│   ├── mnist_elbo.png
│   ├── mnist_mse.png
│   ├── mnist_mse_test.png
│   ├── re1.png
│   ├── re2.png
│   ├── reparm.png
│   ├── reparm4.png
│   ├── vae.png
│   └── variational.png
├── main.py
├── src
│   ├── model
│   │   └── model.py
│   └── utils
│       ├── config.py
│       ├── optimizer.py
│       └── parser.py
└── train.py

Main Results

MNIST

Our results, presented in Figure 3 for the MNIST dataset, examine the effects of different network widths and configurations. After 5,000 epochs, our amortized variational inference (A-VI) achieved comparable ELBO values to fixed variational inference (F-VI) with sufficiently deep networks (k ≥ 64). We also evaluated the mean squared error (MSE) for image reconstruction on both the training and testing sets and noted that A-VI effectively bridged the performance gap here too.

Image 1 Image 2 Image 3

Figure 3: Results for the MNIST dataset

Moreover, A-VI proved to be 2 to 3 times faster computationally than F-VI, as seen in Figure 4, underscoring its efficiency in leveraging shared inference computations across data, thus negating the need to estimate unique latent factors $q_n$ for each $z_n$.

Computation Time MNIST Figure 4: Computational efficiency of A-VI on MNIST

FashionMNIST

Our results for the FashionMNIST experiments are presented in Figure 4 and show the same conclusions as the MNIST experiments.

Image 1 Image 2 Image 3

Figure 4: Results for the FashionMNIST dataset

We also see a similar increase in computational speed on the FashionMNIST dataset as shown in Figure 5.

Computation Time FashionMNIST Figure 5: Computational efficiency of A-VI on FashionMNIST

In Figure 6, we present reconstructed images for a sample of five original images from the MNIST and FashionMNIST datasets. It’s important to note that these reconstructions, produced using a linear neural network, exhibit lower visual quality. This outcome, while noticeable, was not the primary focus of our project. Implementing a convolutional neural network for both the encoder and decoder could significantly enhance the aesthetic quality of these images.

Image 1 Image 2

Figure 6: Reconstructed images for MNIST and FashionMNIST

Releases

No releases published

Packages

 
 
 

Languages