This is a Python3 / Tensorflow implementation of PixelSNAIL.
This code base is based on OpenAI's PixelCNN++ code.
To run this code you need the following:
- a machine with multiple GPUs
- Python3
- Numpy, TensorFlow
Use the train.py
script to train the model.
You can download our pretrained (TensorFlow) CIFAR10 model and ImageNet model
python train.py \
--data_set=cifar \
--model=h12_pool2_smallkey \
--nr_logistic_mix=10 \
--nr_filters=256 \
--batch_size=8 \
--init_batch_size=8 \
--dropout_p=0.5 \
--polyak_decay=0.9995 \
--save_interval=10
python train.py \
--data_set=imagenet \
--model=h12_noup_smallkey \
--nr_logistic_mix=32 \
--nr_filters=256 \
--batch_size=8 \
--init_batch_size=8 \
--learning_rate=0.0001 \
--dropout_p=0.0 \
--polyak_decay=0.9997 \
--save_interval=1