A production-ready implementation of Variational Autoencoders with advanced features including beta-VAE, KL annealing, comprehensive evaluation metrics, and interactive demos.
- Modern Architecture: Beta-VAE and Convolutional VAE implementations
- Advanced Training: KL annealing, gradient clipping, mixed precision training
- Comprehensive Evaluation: FID, reconstruction metrics, latent space analysis
- Interactive Demo: Streamlit web interface for sample generation and exploration
- Production Ready: Type hints, comprehensive testing, CI/CD pipeline
- Flexible Configuration: YAML-based configuration with OmegaConf
- Multiple Datasets: MNIST, Fashion-MNIST, CelebA support
- Clone the repository:
git clone https://github.com/kryptologyst/Variational-Autoencoder-Implementation.git
cd Variational-Autoencoder-Implementation- Install dependencies:
pip install -r requirements.txt- Train a model:
python scripts/train.py --config configs/config.yaml- Launch the interactive demo:
streamlit run demo/streamlit_app.py├── src/vae/ # Core VAE implementation
│ ├── models.py # VAE model architectures
│ ├── data.py # Data loading and preprocessing
│ ├── training.py # Training utilities and Lightning module
│ ├── evaluation.py # Evaluation metrics and utilities
│ └── visualization.py # Sampling and visualization tools
├── configs/ # Configuration files
│ ├── config.yaml # Default configuration
│ ├── conv_vae.yaml # Convolutional VAE config
│ └── celeba.yaml # CelebA dataset config
├── scripts/ # Training and sampling scripts
│ ├── train.py # Main training script
│ └── sample.py # Sampling and evaluation script
├── demo/ # Interactive demos
│ └── streamlit_app.py # Streamlit web interface
├── tests/ # Unit tests
│ └── test_vae.py # Test cases
├── data/ # Dataset storage
├── checkpoints/ # Model checkpoints
├── logs/ # Training logs
├── assets/ # Generated samples and visualizations
└── results/ # Evaluation results
Train a VAE model with default configuration:
python scripts/train.py --config configs/config.yamlTrain with custom parameters:
python scripts/train.py --config configs/config.yaml --override training.max_epochs=50 training.learning_rate=0.001Train different model types:
# Convolutional VAE
python scripts/train.py --config configs/conv_vae.yaml
# CelebA dataset
python scripts/train.py --config configs/celeba.yamlEvaluate a trained model:
python scripts/sample.py --checkpoint checkpoints/best_model.ckpt --mode evaluateGenerate samples from a trained model:
python scripts/sample.py --checkpoint checkpoints/best_model.ckpt --mode samples --num-samples 100Create interpolations:
python scripts/sample.py --checkpoint checkpoints/best_model.ckpt --mode interpolations --num-pairs 10Launch the Streamlit demo:
streamlit run demo/streamlit_app.pyThe demo provides:
- Random sample generation
- Latent space interpolation
- Dimensional traversal
- Model information and statistics
The project uses YAML configuration files with OmegaConf. Key configuration sections:
model:
type: "beta_vae" # or "conv_vae"
params:
input_dim: 784
hidden_dims: [400, 200]
latent_dim: 20
beta: 1.0
use_batch_norm: true
dropout: 0.1training:
max_epochs: 100
learning_rate: 1e-3
batch_size: 128
beta_start: 0.0
beta_end: 1.0
beta_schedule: "linear" # linear, cyclical, sigmoid
scheduler:
type: "cosine"
eta_min: 1e-6data:
dataset: "mnist" # mnist, fashionmnist, celeba
data_dir: "./data"
image_size: 28
augment: false
val_split: 0.1
test_split: 0.1- Fully connected encoder/decoder
- Configurable beta parameter for KL divergence weight
- KL annealing for stable training
- Batch normalization and dropout support
- Convolutional encoder/decoder
- Suitable for image data
- Configurable channel dimensions
- Transpose convolutions for upsampling
The implementation includes comprehensive evaluation metrics:
- Reconstruction Quality: MSE, MAE, PSNR, SSIM
- Latent Space: KL divergence, latent space statistics
- Sample Quality: FID (Fréchet Inception Distance), diversity metrics
- Interpolation: Smoothness and quality of latent interpolations
- Traversal: Dimensional analysis of latent space
Supports multiple annealing schedules:
- Linear: Gradual increase from beta_start to beta_end
- Cyclical: Cosine-based cyclical annealing
- Sigmoid: S-shaped annealing curve
Automatic mixed precision training when CUDA is available for faster training and reduced memory usage.
- WandB integration for experiment tracking
- TensorBoard support for local logging
- Automatic sample generation and logging
- Latent space visualization
- Deterministic seeding for all random operations
- Reproducible data loading
- Consistent model initialization
Run the test suite:
pytest tests/Run specific test categories:
# Test model architectures
pytest tests/test_vae.py::TestBetaVAE
# Test loss functions
pytest tests/test_vae.py::TestVAELoss
# Test evaluation metrics
pytest tests/test_vae.py::TestVAEEvaluator- MNIST: ~95% reconstruction accuracy, FID < 50
- Fashion-MNIST: ~90% reconstruction accuracy, FID < 100
- CelebA: FID < 200 (64x64 images)
- MNIST: ~5 minutes on GPU, ~30 minutes on CPU
- Fashion-MNIST: ~10 minutes on GPU, ~45 minutes on CPU
- CelebA: ~2 hours on GPU (200 epochs)
- Fork the repository
- Create a feature branch
- Make your changes
- Add tests for new functionality
- Run the test suite
- Submit a pull request
This project is licensed under the MIT License - see the LICENSE file for details.
If you use this implementation in your research, please cite:
@software{vae_implementation,
title={Variational Autoencoder Implementation},
author={Kryptologyst},
year={2025},
url={https://github.com/kryptologyst/Variational-Autoencoder-Implementation}
}- Original VAE paper: Kingma & Welling (2014)
- Beta-VAE: Higgins et al. (2017)
- PyTorch Lightning team for the excellent training framework
- Streamlit team for the interactive demo framework
- CUDA out of memory: Reduce batch size or use gradient accumulation
- Slow training: Enable mixed precision training or use smaller models
- Poor reconstruction quality: Increase model capacity or adjust beta parameter
- Mode collapse: Use beta-VAE with higher beta values or add regularization
- Check the issues section for common problems
- Review the configuration files for parameter tuning
- Run the test suite to verify installation
- Use the interactive demo to explore model behavior
- Support for more datasets (CIFAR-10, LSUN)
- Additional VAE variants (WAE, VQ-VAE)
- Improved evaluation metrics (IS, LPIPS)
- Model compression and quantization
- Distributed training support
- Web API for model serving