# Data-Efficient GANs with DiffAugment

In this tutorial, we will demonstrate 
- How to visualize and evaluate the pretrained models for 100-shot generation; and
- How to train a new network with only 100 images.

**[Differentiable Augmentation (DiffAugment)](https://github.com/mit-han-lab/data-efficient-gans)** is a simple, general method that enables **[Data-Efficient GAN Training](https://github.com/mit-han-lab/data-efficient-gans)** by imposing various types of differentiable augmentations on both real and fake samples for both generator and discriminator training.

In [None]:
#@title Check GPU
#@markdown P100/V100 recommended for StyleGAN2 training!

gpu = !nvidia-smi --query-gpu=gpu_name --format=csv
print("GPU: " + gpu[1])

## Setup Environment

1. Clone our repo:
2. Go to the DiffAugment-stylegan2 folder:

In [None]:
%cd /content/
!git clone https://github.com/n00mkrad/data-efficient-gans
%cd data-efficient-gans/DiffAugment-stylegan2

3. Prepare the preliminaries and define some functions for later uses:

In [None]:
!pip uninstall -y tensorflow tensorflow-probability
!pip install tensorflow-gpu==1.15.0

import tensorflow as tf
import os
import numpy as np
import PIL
import IPython
from multiprocessing import Pool
import matplotlib.pyplot as plt

from dnnlib import tflib, EasyDict
from training import misc, dataset_tool
from metrics import metric_base
from metrics.metric_defaults import metric_defaults

def _generate(network_name, num_rows, num_cols, seed, resolution):
  if seed is not None:
    np.random.seed(seed)
  with tf.Session():
    _, _, Gs = misc.load_pkl(network_name)
    z = np.random.randn(num_rows * num_cols, Gs.input_shape[1])
    outputs = Gs.run(z, None, output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True))
    outputs = np.reshape(outputs, [num_rows, num_cols, *outputs.shape[1:]])
    outputs = np.concatenate(outputs, axis=1)
    outputs = np.concatenate(outputs, axis=1)
    img = PIL.Image.fromarray(outputs)
    img = img.resize((resolution * num_cols, resolution * num_rows), PIL.Image.ANTIALIAS)
  return img

def generate(network_name, num_rows, num_cols, seed=None, resolution=128):
  with Pool(1) as pool:
    return pool.apply(_generate, (network_name, num_rows, num_cols, seed, resolution))

def _evaluate(network_name, dataset, resolution, metric):
  dataset = dataset_tool.create_dataset(dataset, resolution)
  dataset_args = EasyDict(tfrecord_dir=dataset, resolution=resolution, from_tfrecords=True)
  metric_group = metric_base.MetricGroup([metric_defaults[metric]])
  metric_group.run(network_name, dataset_args=dataset_args, log_results=False)
  return metric_group.metrics[0]._results[0].value

def evaluate(network_name, dataset, resolution=256, metric='fid5k-train'):
  with Pool(1) as pool:
    return pool.apply(_evaluate, (network_name, dataset, resolution, metric))

## Using the Pre-Trained Models

### 100-Shot Generation Datasets

Let's first visualize the 100-shot Obama dataset. Such a small-scale dataset can be easily collected from the Internet.

In [None]:
data_dir = dataset_tool.create_dataset('100-shot-obama')
training_images = []
for fname in os.listdir(data_dir):
  if fname.endswith('.jpg'):
    training_images.append(np.array(PIL.Image.open(os.path.join(data_dir, fname))))
imgs = np.reshape(training_images, [5, 20, *training_images[0].shape])
imgs = np.concatenate(imgs, axis=1)
imgs = np.concatenate(imgs, axis=1)
PIL.Image.fromarray(imgs).resize((1000, 250), PIL.Image.ANTIALIAS)

### StyleGAN2 (baseline)

How do vanilla GANs perform given such a small dataset? Let's visualize the (pre-trained) StyleGAN2 baseline model (this will take a minute):

In [None]:
generate('mit-han-lab:stylegan2-100-shot-obama.pkl', num_rows=2, num_cols=5, seed=1000)

As you can see, most of the images generated by the baseline model are heavily distorted. This is mainly because the discriminator is memorizing the exact training images. Let's resolve this problem with DiffAugment.

###  + DiffAugment (ours)

DiffAugment can dramatically improve the image quality even with only 100 training images of Obama portraits. Here visualizes the (pre-trained) StyleGAN2 + DiffAugment model (this will take a minute):

In [None]:
generate('mit-han-lab:DiffAugment-stylegan2-100-shot-obama.pkl', num_rows=2, num_cols=5, seed=1000)

Besides `100-shot-obama`, you may also try out `100-shot-grumpy_cat`, `100-shot-panda`, `AnimalFace-cat`, or `AnimalFace-dog` in the code above, to compare DiffAugment with the baseline models on other datasets.

### Calculating FID

FrÃ©chet Inception Distance (FID) quatitatively measures the visual fidelity for GANs. Lower FID indicates better performance. Let's do the FID calculation for both the baseline model and ours. This will take about 15 minutes.

In [None]:
print('Evaluating StyleGAN2 (baseline)...')
fid_baseline = evaluate('mit-han-lab:stylegan2-100-shot-obama.pkl', dataset='100-shot-obama')
print('Baseline FID:', fid_baseline, '\n')

print('Evaluating StyleGAN2 + DiffAugment (ours)...')
fid_ours = evaluate('mit-han-lab:DiffAugment-stylegan2-100-shot-obama.pkl', dataset='100-shot-obama')
print('Ours FID:', fid_ours, '\n')

plt.figure(figsize=(2, 3))
plt.bar([0, 1], [fid_baseline, fid_ours], color=['gray', 'darkred'])
plt.xticks([0, 1], ['Baseline', 'Ours'])
plt.ylabel("FID")
plt.show()

With DiffAugment, our model is **1.7x** better than the baseline model in terms of FID!

### Generating an Interpolation Video

Finally, let's generate an interpolation video using the pre-trained DiffAugment models. Besides `obama`, you may also try out `grumpy_cat`, `panda`, `bridge_of_sighs`, `medici_fountain`, `temple_of_heaven` in the code below. The smooth interpolation results suggest little overfitting of our method even given small datasets. This will take a minute.

In [None]:
!python3 generate_gif.py -r mit-han-lab:DiffAugment-stylegan2-100-shot-obama.pkl -o interp.gif --num-rows=2 --num-cols=3 --seed=1
IPython.display.Image(open('interp.gif', 'rb').read())

## Training

In [None]:
#@title Mount Google Drive at /content/drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!cp "YOUR_GOOGLE_DRIVE_DATASET_ZIP_PATH_HERE" "/content/data-efficient-gans/DiffAugment-stylegan2/datasets/dataset1.zip"
!unzip -q "/content/data-efficient-gans/DiffAugment-stylegan2/datasets/dataset1.zip" -d "/content/data-efficient-gans/DiffAugment-stylegan2/datasets/"

Start training with this cell.

Add this if you want to resume:

`--resume "PATH_TO_YOUR_LATEST_CHECKPOINT.pkl" --resume-kimg KIMG_NUMBER_HERE`

Replace the path with your latest checkpoint file and replace `KIMG_NUMBER_HERE` with that model's kimg count.
This is optional, but if you don't do this, you will lose track of the actual amount as it's not stored within the model.

In [None]:
# Train the model at 256px
!python3 run_few_shot.py --dataset="/content/data-efficient-gans/DiffAugment-stylegan2/datasets/YOUR_DATASET_FOLDER_NAME" --resolution=256 --batch-size 4 --result-dir "YOUR_GOOGLE_DRIVE_CHECKPOINT_PATH"