This repository implements VQVAE for mnist and colored version of mnist and follows up with a simple LSTM for generating numbers.
- Create a new conda environment with python 3.8 then run below commands
git clone https://github.com/explainingai-code/VQVAE-Pytorch.git
cd VQVAE-Pytorch
pip install -r requirements.txt
- For running a simple VQVAE with minimal code to understand the basics
python run_simple_vqvae.py
- For playing around with VQVAE and training/inferencing the LSTM use the below commands passing the desired configuration file as the config argument
python -m tools.train_vqvae
for training vqvaepython -m tools.infer_vqvae
for generating reconstructions and encoder outputs for LSTM trainingpython -m tools.train_lstm
for training minimal LSTMpython -m tools.generate_images
for using the trained LSTM to generate some numbers
config/vqvae_mnist.yaml
- VQVAE for training on black and white mnist imagesconfig/vqvae_colored_mnist.yaml
- VQVAE with more embedding vectors for training colored mnist images
For setting up the dataset: Follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation
Verify the data directory has the following structure:
VQVAE-Pytorch/data/train/images/{0/1/.../9}
*.png
VQVAE-Pytorch/data/test/images/{0/1/.../9}
*.png
Outputs will be saved according to the configuration present in yaml files.
For every run a folder of task_name
key in config will be created and output_train_dir
will be created inside it.
During training of VQVAE the following output will be saved
- Best Model checkpoints(VQVAE and LSTM) in
task_name
directory
During inference the following output will be saved
- Reconstructions for sample of test set in
task_name/output_train_dir/reconstruction.png
- Encoder outputs on train set for LSTM training in
task_name/output_train_dir/mnist_encodings.pkl
- LSTM generation output in
task_name/output_train_dir/generation_results.png
Running run_simple_vqvae
should be very quick (as its very simple model) and give you below reconstructions (input in black black background and reconstruction in white background)
Running default config VQVAE for mnist should give you below reconstructions for both versions
Sample Generation Output after just 10 epochs Training the vqvae and lstm longer and more parameters(codebook size, codebook dimension, channels , lstm hidden dimension e.t.c) will give better results
@misc{oord2018neural,
title={Neural Discrete Representation Learning},
author={Aaron van den Oord and Oriol Vinyals and Koray Kavukcuoglu},
year={2018},
eprint={1711.00937},
archivePrefix={arXiv},
primaryClass={cs.LG}
}