This is the companion code for the paper Unscented Autoencoder by Faris Janjoš, Lars Rosenbaum, Maxim Dolgov, and J. Marius Zöllner. The paper can be found here. The code allows the users to reproduce and extend the results reported in the paper. Please cite the above paper when reporting, reproducing or extending the results.
The code is adapted from the following open-source VAE project PyTorchVAE.
Python >= 3.5
PyTorch >= 1.3
PyTorch Lightning >= 0.6.0
$ git clone https://github.com/boschresearch/unscented-autoencoder
$ cd unscented-autoencoder
$ pip install -r requirements.txt
For setting up the CelebA dataset, please refer to PyTorchVAE. FashionMNIST and CIFAR10 should be downloaded and set up automatically upon first training.
The models included in the repository are implemented in the following files:
- Vanilla VAE / UAE:
models/vanilla_vae.py
<- this base model is used for realizing both the VAE and UAE by setting different config parameters - RAE:
models/rae.py
- IWAE:
models/iwae.py
- WAE-MMD:
models/wae_mmd.py
- Full covariance models for the VAE / UAE:
models/fullcov_uae.py
All models training config files are tied to a dataset. Each dataset config set is stored in:
- FashionMNIST:
configs_fashion_mnist/
- CIFAR10:
configs_cifar10/
- CelebA:
configs_celeba/
Each dataset config folder contains the models above.
python run.py -c configs_cifar10/uae.yaml
: Runs the CIFAR10 training of the full UAE model.- additional flags:
hsearch=True
,ray
-based hyperparameter search with config defined inhsearch_config
inrun.py
del_eval=True
, remove generated image folders (can get pretty large)
Different models can be realized by setting the following config parameters:
ut_sampling
: sets sampling sigma points or reparameterization tricksigma_recon
: sets averaging of outputs in a single reconstruction loss (Unscented Transform) or multiple per-sample reconstruction losses (vanilla VAE)multi_sample
: sets the number of samples and the heuristicfrob_norm
: sets Wasserstein metric or KL divergencedec_reg
: sets the decoder regularization
For example, in order to train the simplified UT-VAE model (with only the Unscented Transform), set frob_norm
to False
and dec_reg
to None
in uae.yaml
.
This software is a research prototype, solely developed for and published as part of the publication cited above. It will neither be maintained nor monitored in any way.
Unscented Autoencoder is open-sourced under the AGPL-3.0 license. See the LICENSE file for details.