This repository provides a reference implementation for boosted generative models as described in the paper:
Boosted Generative Models
Aditya Grover and Stefano Ermon.
AAAI Conference on Artificial Intelligence (AAAI), 2018.
https://arxiv.org/pdf/1702.08484.pdf
The codebase is implemented in Python 3.6. To install the necessary requirements, run the following commands:
pip install -r requirements.txt
bash install.sh
The code takes an input dataset in csv file. Every row indicates one datapoint with comma-separated features. A sample train, validation, and test file for the nltcs
dataset is included in the data/
directory.
Learning and inference of boosted generative models is handled by the main.py
script which provides the following command line arguments.
--seed INT Random seed for numpy, tensorflow
--datadir STR Directory containing dataset files
--dataset STR Name of dataset
--resultdir STR Directory for saving tf checkpoints
--run-addbgm BOOL Runs additive boosting if True
--addbgm-alpha FLOAT LIST Space-separated list of model weights for additive boosting
--run-genbgm BOOL Runs multiplicative generative boosting if True
--genbgm-alpha FLOAT LIST Space-separated list of model weights for multiplicative generative boosting
--genbgm-beta FLOAT LIST Space-separated list of reweighting exponents for multiplicative generative boosting
--run-discbgm BOOL Runs multiplicative discriminative boosting if True
--discbgm-alpha FLOAT LIST Space-separated list of model weights for multiplicative generative boosting
--discbgm-epochs INT Number of epochs of training for each discriminator
--discbgm-burn-in INT Number of discarded burn in samples for Markov chains
--run-classifier BOOL Uses generative model for classification if True
The following commands learns boosted ensembles with two models and evaluates the ensemble for density estimation and classification.
Meta-algorithm: multiplicative generative boosting
python src/main.py --dataset nltcs --run-genbgm --genbgm-alpha 0.5 0.5 --genbgm-beta 0.25 0.125 --run-classifier
Meta-algorithm: multiplicative discriminative boosting
python src/main.py --dataset nltcs --run-discbgm --discbgm-alpha 1. 1. --run-classifier
Meta-algorithm: additive boosting
python src/main.py --dataset nltcs --run-addbgm --addbgm-alpha 0.5 0.25 --run-classifier
You can also run any combination of the meta-algorithms together as shown below.
python src/main.py --dataset nltcs --run-genbgm --genbgm-alpha 0.5 0.5 --genbgm-beta 0.25 0.125 --run-discbgm --discbgm-alpha 1. 1. --run-addbgm --addbgm-alpha 0.5 0.25 --run-classifier
If you find boosted generative models useful in your research, please consider citing the following paper:
@inproceedings{grover2018boosted,
title={Boosted Generative Models},
author={Grover, Aditya and Ermon, Stefano},
booktitle={AAAI Conference on Artificial Intelligence},
year={2018}}