Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Yoon Kim committed Jan 18, 2018
1 parent c66e515 commit aed0bf4
Show file tree
Hide file tree
Showing 7 changed files with 20,365 additions and 16 deletions.
5 changes: 4 additions & 1 deletion .gitignore
@@ -1,6 +1,8 @@
*.pt
*.amat
*.mat
*.out
*.out~
*.pyc
*.pt~
.gitignore~
Expand All @@ -13,4 +15,5 @@
*.model
*.h5
*.tar.gz
*.hdf5
*.hdf5
*.txt
73 changes: 72 additions & 1 deletion README.md
@@ -1 +1,72 @@
# savi
# Semi-Amortized Variational Inference

## Dependencies
The code was tested in `python 3.6` and `pytorch 0.2`. We also require the `h5py` package.

## Data
The raw datasets can be downloaded from [here](https://drive.google.com/file/d/1PecZKhrPkMZmvMyiOJfMS-FsHW3FE0rH/view?usp=sharing).

Text experiments use the Yahoo dataset from [Yang et al. 2017](https://arxiv.org/pdf/1702.08139.pdf), which is itself derived from [Zhang et al. 2015](https://arxiv.org/abs/1509.01626).

Image experiments use the OMNIGLOT dataset [Lake et al. 2015](https://cims.nyu.edu/~brenden/LakeEtAl2015Science.pdf) with preprocessing from [Burda et al. 2015](https://arxiv.org/pdf/1509.00519.pdf).

Please cite the original papers when using the data.

## Text
After downloading the data, run
```
python preprocess_text.py --trainfile data/yahoo/train.txt --valfile data/yahoo/val.txt --outputfile data/yahoo/test.txt --outputfile data/yahoo/yahoo
```
This will create the `hdf5` files (data tensors) to be used by the model.

The basic model command is
```
python train_text.py --train_file data/yahoo/yahoo-train.hdf5 --val_file data/yahoo/yahoo-val.hdf5 --gpu 1 --checkpoint_path model-path
```
where `model-path` is the path to save the best model and the `*.hdf5` files are obtained from running `preprocess_text.py`. You can specify which GPU to use by changing the input to the `--gpu` command.

To train the various models, add the following:
- Autoregressive (i.e. language model): `--model autoreg`
- VAE: `--model vae`
- SVI: `--model svi --svi_steps 20 --acc_param_grads 0`
- SAVI-VAE: `--model savi --svi_steps 20 --train_n2n 0 --train_kl 0 --acc_param_grads 0`
- SAVI-KL: `--model savi --svi_steps 20 --train_n2n 0 --train_kl 1 --acc_param_grads 0`
- SAVI-N2N: `--model savi --svi_steps 20 --train_n2n 1 --acc_param_grads 1`

Number of SVI steps can be changed with the `--svi_steps` command.

To evaluate, run
```
python train_text.py --train_from model-path --test_file data/yahoo/yahoo-test.hdf5 --test 1 --gpu 1
```
Make sure the append the relevant model configuration at test time too.

## Images
After downloading the data, run
```
python preprocess_img.py --raw_file data/omniglot/chardata.mat --output data/omniglot/omniglot.pt
```

To train, the basic command is
```
python train_img.py --data_file data/omniglot/omniglot.pt --gpu 1 --checkpoint_path model-path
```

To train the various models, add the following:
- Autoregressive (i.e. language model): `--model autoreg`
- VAE: `--model vae`
- SVI: `--model svi --svi_steps 20 --acc_param_grads 0`
- SAVI-VAE: `--model savi --svi_steps 20 --train_n2n 0 --train_kl 0 --acc_param_grads 0`
- SAVI-KL: `--model savi --svi_steps 20 --train_n2n 0 --train_kl 1 --acc_param_grads 0`
- SAVI-N2N: `--model savi --svi_steps 20 --train_n2n 1 --acc_param_grads 1`

To evaluate, run
```
python train_img.py --train_from model-path --test 1 --gpu 1
```

## Acknowledgements
Some of our image code is based on (VAE with a VampPrior)[https://github.com/jmtomczak/vae_vampprior] repo.

## License
MIT

0 comments on commit aed0bf4

Please sign in to comment.