Progressive Growing of GANs in Flax
Flax (JAX) implementation of Progressive Growing of GANs for Improved Quality, Stability, and Variation. This code is meant to a starting point you can fork for your own needs rather than a full re-implementation.
Some curated samples below from the generator trained on the CelebA (not HQ) dataset.
All hyperparameters are in
They're not as good as the original paper due to the significantly lower training time.
- Download and extract CelebA.
- Install JAX (instructions vary by system).
- Install the other dependencies (ideally in a pip environment) via
pip install -r requirements.txt. Requires Python >=3.6.
- Run the code via
python src/main.py data_dir=<celeba directory>.
The code was run on a TPUv3-8.
You will need to adjust the hyperparameters in
src/conf/config.yaml for your local system (e.g. set
distributed: False, decrease batch size, etc).
Differences from Original Paper
- Different learning rates and batch sizes.
- Transition (interpolation between previous and current stage) only lasts 80% of each stage instead of entire stage.
- Slightly smaller model (with same architecture) since this implementation is for CelebA up to 128x128.
- Trained with
bfloat16without loss scaling (as opposed to
float16with loss scaling).
- tanh activation for the Generator outputs.
- Gain used for equalized learning rate adjusted for each activation instead of using
sqrt(2)throughout (gains computed based on PyTorch).
Below are training curves with the configuration in
Each vertical grey line indicates going from one stage to the next.
The spikes in time per step at the beginning of each stage correspond to compilation.
The total training time was 11 hours 6 minutes on a TPUv3-8.
Training for longer would likely give better results.
Full training logs and checkpoints can be found here.