A PyTorch simplified implementation of Progressive Growing of GANs for Improved Quality, Stability, and Variation.
It's based on the reference implementation in Theano: https://github.com/tkarras/progressive_growing_of_gans and a few other implementations named in acknowledgements below. It's more modularized and specifically designed for training the Progressively Growing GAN networks with WGAN-GP loss. Main focus for this implementation in terms of domain of generation were spectrograms and other visual representations of audio and thus it contains a few helpers for these tasks (but also works for any other kind of image). For more details on results in audio generation, see our blog post: http://deepsound.io/samplernn_pytorch.html.
The model and implementation proved itself to be succesful - and as for now, the authors are not aware of any other succesful GAN application to audio signal generation.
This code requires Python 3.5+ and was tested under PyTorch 0.2.0+. Other dependencies depend on the usage, most important for sound generation being Librosa 0.4.3+ and SoundFile 0.9.0+ (latter for efficient sound loading) and Pillow 4.0.0+ with scipy 0.19.0+ (possibly also works with earlier versions, these are the ones tested). Lastly, if you wish to use datasets created with original implementation's tool, you will need h5py 2.5.0+. Installation instructions for PyTorch are available on their website: http://pytorch.org/. You can install the rest of the dependencies for sound and image generation by running
pip install -r requirements.txt.
Training on new data is as simple as implementing your own DepthDataset (see dataset.py) with all required methods. A few useful Datasets were already implemented:
- OldH5Dataset - allows use of (fitting in memory!) datasets from original implementation
- DefaultImageFolderDataset - takes any folder with images and creates a proper dataset
- SoundImageDataset - takes any folder with .wav files of length at least (n_fft*hop_length/2 + n_fft) (where n_fft and hop_length are desired parameters of stft). For example, a good start is to use 16000 hz, 5 second audio created by script from our SampleRNN PyTorch implementation. This class can be a dataset of either magnitude spectrograms, "raw" data or real-part stfts, with best results obtained with magnitude spectrograms.
To train the model you need to run
train.py. All model and training hyperparameters are settable in the command line in a very generic way - you can essentially pass ClassName.init_param_name for every class that will be instantiated during the training process. Sometimes you need to specify which classes you may want to instantiate, e.g. which kind of dataset or which output postprocessors. Most hyperparameters have sensible default values, so you don't need to provide all of them. Run
python train.py -h for details. With piano dataset from our script mentioned before, you can run:
python train.py --exp_name TEST \ --dataset_class SoundImageDataset --SoundImageDataset.dir_path PATH_TO_DATASET_FOLDER \ --SoundImageDataset.preload True --SoundImageDataset.img_mode abslog \ --postprocessors ['ImageSaver', 'SoundSaver'] \ --ImageSaver.samples_path images --SoundSaver.samples_path sound
If you would like to keep lod information from original implementation in logging, add --DepthManager.max_lod and --DepthManager.depth_offset parameters.
The results - training log, model checkpoints and postprocessed generated samples will be saved in
results_dir/NUMEXP-exp_name/. NUMEXP is number one greater than maximal number of experiments in results directory.
We also have an option to monitor the metrics using CometML. To use it, just pass your API key as
--comet_key parameter to
train.py. For newer versions of CometML, you may also pass sensible
Main differences compared to original implementation
Main difference is substitution of LOD parameter in exchange for two, more related to the paper parameters:
alpha. The inspiration for this is also the Chainer implementation (see acknowledgements). Also, this implementation is completely agnostic of a term of
resolution and everything can be specified just by proper
depth_offset parameters. In this implementation, parameter
DepthManager.max_lod corresponds to
model_dataset_depth_offset (in the selected dataset) corresponds to
log2(initial_resolution). This is motivated by the following logic: neural models should always have their parameter
depth corresponding to the index of deepest layer trainable, while datasets can also provide versions of data for other depths (it works this way also in original implementation, where dataset always creates inputs downsampled down to 2^0 x 2^0). If you provide
--DepthManager.depth_offset parameters, you can have old-style lod displayed in lod, equal to
max_lod - depth_offset - depth + 1 - alpha.
There is a lot of unimplemented features of the original paper, such as conditional training, a few network parameters etc. which we didn't find useful (for now!) during spectrogram generation which is main purpose of this implementation. All the default parameters for generating the 1024x1024 faces from the original paper are implemented.
The training process in Trainer class resembles the original implementation notion of "nimg", "kimg" and "ticks" instead of "iterations" or "epochs". For the ease of use of several built-in plugins and to satisfy both PyTorch and the original implementation, "tick" is equivalent to "epoch".
Take note, that with simple proper initializations of depth / plugin parameters, you can train your own progressively growing networks, which only need to specify their max_depth and make use of depth and alpha parameters. Or, you may wish to train regular (non-progressive) GANs with this framework - for ease of use, you just need to turn pass
--progressive_growing False param.
Original implementation: https://github.com/tkarras/progressive_growing_of_gans Few other inspirations:
- chainer-PGGAN: https://github.com/joisino/chainer-PGGAN
- PyTorch - progressive growing of gans: https://github.com/github-pengge/PyTorch-progressive_growing_of_gans
- wgan-gp pytorch - https://github.com/caogang/wgan-gp
We want to highlight, that these were strong, but still just an inspiration for rewriting the code from scratch, even though a clear resemblence of the code from mentioned implementation can be seen.
The code is in constant development. :) We publish it now, because you can always improve something. It is possible, that much more detailed implementations may arise for both PGGAN-general framework and Signal-Generating Progressively Growing GANs (SGPGGAN acronym isn't hopefully taken yet).