Skip to content

mit-han-lab/data-efficient-gans-dynamic

Repository files navigation

Data-Efficient GANs with DiffAugment

Generated using only 100 images of Obama, grumpy cats, pandas, the Bridge of Sighs, the Medici Fountain, the Temple of Heaven, without pre-training.

This repository contains our implementation of Differentiable Augmentation (DiffAugment) in both PyTorch and TensorFlow. It can be used to significantly improve the data efficiency for GAN training. We have provided DiffAugment-stylegan2 (TensorFlow) and DiffAugment-stylegan2-pytorch, DiffAugment-biggan-cifar (PyTorch) for GPU training, and DiffAugment-biggan-imagenet (TensorFlow) for TPU training.

Differentiable Augmentation for Data-Efficient GAN Training
Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
MIT, Tsinghua University, Adobe Research, CMU
NeurIPS'20

Usage

3rd party code is dynamically linked in this repo. Users are free to obtain the 3rd party code on their own.

First, clone NVIDIA's stylegan2 repos and apply our patch files:

git clone https://github.com/NVlabs/stylegan2 DiffAugment-stylegan2
cd DiffAugment-stylegan2
git apply ../DiffAugment-stylegan2.patch
cd ..

git clone https://github.com/NVlabs/stylegan2-ada-pytorch DiffAugment-stylegan2-pytorch
cd DiffAugment-stylegan2-pytorch
git apply ../DiffAugment-stylegan2-pytorch.patch
cd ..

Overview

Overview of DiffAugment for updating D (left) and G (right). DiffAugment applies the augmentation T to both the real sample x and the generated output G(z). When we update G, gradients need to be back-propagated through T (iii), which requires T to be differentiable w.r.t. the input.

Training and Generation with 100 Images

To generate an interpolation video using our pre-trained models:

cd DiffAugment-stylegan2
python generate_gif.py -r mit-han-lab:DiffAugment-stylegan2-100-shot-obama.pkl -o obama.gif

or to train a new model:

python run_low_shot.py --dataset=100-shot-obama --num-gpus=4

You may also try out 100-shot-grumpy_cat, 100-shot-panda, 100-shot-bridge_of_sighs, 100-shot-medici_fountain, 100-shot-temple_of_heaven, 100-shot-wuzhen, or the folder containing your own training images. Please refer to the DiffAugment-stylegan2 README for the dependencies and details.

[NEW!] PyTorch training is now available:

cd DiffAugment-stylegan2-pytorch
python train.py --outdir=training-runs --data=https://hanlab.mit.edu/projects/data-efficient-gans/datasets/100-shot-obama.zip --gpus=1

DiffAugment for StyleGAN2

To run StyleGAN2 + DiffAugment for unconditional generation on the 100-shot datasets, CIFAR, FFHQ, or LSUN, please refer to the DiffAugment-stylegan2 README or DiffAugment-stylegan2-pytorch for the PyTorch version.

DiffAugment for BigGAN

Please refer to the DiffAugment-biggan-cifar README to run BigGAN + DiffAugment for conditional generation on CIFAR (using GPUs), and the DiffAugment-biggan-imagenet README to run on ImageNet (using TPUs).

Using DiffAugment for Your Own Training

To help you use DiffAugment in your own codebase, we provide portable DiffAugment operations of both TensorFlow and PyTorch versions in DiffAugment_tf.py and DiffAugment_pytorch.py. Generally, DiffAugment can be easily adopted in any model by substituting every D(x) with D(T(x)), where x can be real images or fake images, D is the discriminator, and T is the DiffAugment operation. For example,

from DiffAugment_pytorch import DiffAugment
# from DiffAugment_tf import DiffAugment
policy = 'color,translation,cutout' # If your dataset is as small as ours (e.g.,
# hundreds of images), we recommend using the strongest Color + Translation + Cutout.
# For large datasets, try using a subset of transformations in ['color', 'translation', 'cutout'].
# Welcome to discover more DiffAugment transformations!

...
# Training loop: update D
reals = sample_real_images() # a batch of real images
z = sample_latent_vectors()
fakes = Generator(z) # a batch of fake images
real_scores = Discriminator(DiffAugment(reals, policy=policy))
fake_scores = Discriminator(DiffAugment(fakes, policy=policy))
# Calculating D's loss based on real_scores and fake_scores...
...

...
# Training loop: update G
z = sample_latent_vectors()
fakes = Generator(z) # a batch of fake images
fake_scores = Discriminator(DiffAugment(fakes, policy=policy))
# Calculating G's loss based on fake_scores...
...

We have implemented Color, Translation, and Cutout DiffAugment as visualized below:

Citation

If you find this code helpful, please cite our paper:

@inproceedings{zhao2020diffaugment,
  title={Differentiable Augmentation for Data-Efficient GAN Training},
  author={Zhao, Shengyu and Liu, Zhijian and Lin, Ji and Zhu, Jun-Yan and Han, Song},
  booktitle={Conference on Neural Information Processing Systems (NeurIPS)},
  year={2020}
}

Acknowledgements

We thank NSF Career Award #1943349, MIT-IBM Watson AI Lab, Google, Adobe, and Sony for supporting this research. Research supported with Cloud TPUs from Google's TensorFlow Research Cloud (TFRC). We thank William S. Peebles and Yijun Li for helpful comments.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published