# Wasserstein GAN with gradient penalty implementation

[Original video](https://youtu.be/pG0QZ7OddX4)

[Read-through: Wasserstein GAN](https://www.alexirpan.com/2017/02/22/wasserstein-gan.html)

[Wasserstein GAN paper](https://arxiv.org/abs/1701.07875)

[Improved Training of Wasserstein GANs](https://arxiv.org/abs/1704.00028) paper

## Download and prepare dataset. Import libraries

In [None]:
# Get dataset from Kaggle

# Colab's file access feature
from google.colab import files

# Upload `kaggle.json` file
uploaded = files.upload()

# Retrieve uploaded file and print results
for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))


# Then copy kaggle.json into the folder where the API expects to find it.
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!ls ~/.kaggle

# Download the dataset
# !kaggle datasets list -s celeba
!kaggle datasets download -d jessicali9530/celeba-dataset

Saving kaggle.json to kaggle.json
User uploaded file "kaggle.json" with length 65 bytes
kaggle.json
Downloading celeba-dataset.zip to /content
 99% 1.32G/1.33G [00:25<00:00, 77.5MB/s]
100% 1.33G/1.33G [00:26<00:00, 54.8MB/s]


In [None]:
# Unzip
import zipfile

with zipfile.ZipFile('celeba-dataset.zip', 'r') as zip_ref:
    zip_ref.extractall('.')

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import multiprocessing

from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter

In [None]:
class MyDataset(Dataset):
    def __init__(self, root, transform):
        self.root = root
        self.transform = transform
        
        self.img_name = {}
        for idx, name in enumerate(os.listdir(self.root)):
            self.img_name[idx] = name

    def __len__(self):
        return len(self.img_name)  # 202 599 images

    def __getitem__(self, index):
        filepath = os.path.join(self.root, self.img_name[index])
        image = Image.open(filepath)
        image = self.transform(image)
        return image


image_folder = './img_align_celeba/img_align_celeba'
IMAGE_SIZE = 64

transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),  # resize proportionally to rectangular image
    transforms.RandomCrop(IMAGE_SIZE),  # crop rectangular image to square
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),
                         (0.5, 0.5, 0.5)),
])

dataset = MyDataset(root=image_folder, transform=transform)


def save_checkpoint(state, filename):
    print("=> Saving checkpoint")
    torch.save(state, filename)


def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    step = checkpoint["step"]
    return step

## Set and test the model

In [None]:
class Critic(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Critic, self).__init__()
        self.critic = nn.Sequential(  # Input: N x 3 x 64 x64
            self._block(channels_img, features_d,   kernel_size=4, stride=2, padding=1),  # N x 64 x 32 x 32
            self._block(features_d,   features_d*2, kernel_size=4, stride=2, padding=1),  # N x 128 x 16 x 16
            self._block(features_d*2, features_d*4, kernel_size=4, stride=2, padding=1),  # N x 256 x 8 x 8
            self._block(features_d*4, features_d*8, kernel_size=4, stride=2, padding=1),  # N x 512 x 4 x 4
            nn.Conv2d(features_d*8, 1, kernel_size=4, stride=1, padding=0),  # N x 1 x 1 x 1
            nn.Flatten(),  # no nn.Sigmoid() anymore, this is why it is Critic, but not Discriminator
        )  # Output: N x 1

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            # bias=False for BatchNorm
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            # Do not normalize across the batches. Normalize only across the layer (instance).
            nn.InstanceNorm2d(out_channels, affine=True),  # LayerNorm <--> InstanceNorm
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        return self.critic(x)


class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(  # Input: N x z_dim x 1 x 1
            self._block(z_dim,        features_g*8, kernel_size=4, stride=1, padding=0),  # N x 512 x 4 x 4
            self._block(features_g*8, features_g*4, kernel_size=4, stride=2, padding=1),  # N x 256 x 8 x 8
            self._block(features_g*4, features_g*2, kernel_size=4, stride=2, padding=1),  # N x 128 x 16 x 16
            self._block(features_g*2, features_g,   kernel_size=4, stride=2, padding=1),  # N x 64 x 32 x 32
            nn.ConvTranspose2d(features_g, channels_img, kernel_size=4, stride=2, padding=1),
            # Output: N x 3 x 64 x 64
            nn.Tanh(),  # between (-1, 1)
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),  # like in DCGAN paper
        )
    
    def forward(self, x):
        return self.gen(x)


def init_weights(model):
    ''' Initialize weights of the model
        with mean of 0.0 and standard deviation of 0.02 '''
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)


def test():
    N, in_channels, height, width = 8, 3, 64, 64
    z_dim = 100
    features_d = features_g = 64
    
    x = torch.randn((N, in_channels, height, width))
    critic = Critic(in_channels, features_d)
    init_weights(critic)
    assert critic(x).shape == (N, 1)

    gen = Generator(z_dim, in_channels, features_g)
    init_weights(gen)
    z = torch.randn((N, z_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, height, width)
    
    print('Test is OK')

In [None]:
test()

Test is OK


## Run TensorBoard

In [None]:
# Run TensorBoard

# Delete previous logs dir
logs_dir = 'logs_dir'
if os.path.exists(logs_dir):
    !rm -rf $logs_dir

# To fix the error, because PyTorch and TensorFlow are installed both:
# AttributeError: module 'tensorflow._api.v2.io.gfile' has no attribute 'get_filesystem'
import tensorflow as tf
import tensorboard as tb
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile

# Load the TensorBoard notebook extension
%load_ext tensorboard

# Start TensorBoard before training to monitor it in progress
%tensorboard --logdir $logs_dir

# Reload TensorBoard
%reload_ext tensorboard

## Prepare the model

In [None]:
# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 1e-4  # could also use two lrs, one for gen and one for critic
BATCH_SIZE = 64  # was 64
CHANNELS_IMG = 3
NOISE_DIM = 128  # was 100
NUM_EPOCHS = 5
FEATURES_CRITIC = FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10

loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True,
                    num_workers=multiprocessing.cpu_count(), pin_memory=True)

gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Critic(CHANNELS_IMG, FEATURES_CRITIC).to(device)
init_weights(gen)
init_weights(critic)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))

fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)
writer_real = SummaryWriter(os.path.join(logs_dir, "real"))
writer_fake = SummaryWriter(os.path.join(logs_dir, "fake"))
step = 0

# Load models
gen_checkpoint_name = 'generator_celeb.pth.tar'
critic_checkpoint_name = 'critic_celeb.pth.tar'

if os.path.exists(gen_checkpoint_name) and os.path.exists(critic_checkpoint_name):
    step = load_checkpoint(torch.load(gen_checkpoint_name), gen, opt_gen)
    step = load_checkpoint(torch.load(critic_checkpoint_name), critic, opt_critic)


def gradient_penalty(critic, real, fake, device):
    BATCH_SIZE, C, H, W = real.shape
    epsilon = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = epsilon * real + (1 - epsilon) * fake

    # calculate critic scores
    mixed_scores = critic(interpolated_images)
    # calculate gradient
    gradient = torch.autograd.grad(inputs=interpolated_images,
                                   outputs=mixed_scores,
                                   grad_outputs=torch.ones_like(mixed_scores),
                                   create_graph=True,
                                   retain_graph=True,)[0]  # BATCH_SIZE x 3 x 64 x 64
    gradient = gradient.view(gradient.shape[0], -1)  # BATCH_SIZE x 12 288 or 64*64*3
    gradient_norm = gradient.norm(2, dim=1)  # L2 norm
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

## Train the model

In [None]:
def train(step=step):
    gen.train()
    critic.train()

    for epoch in range(NUM_EPOCHS):
        # Target labels not needed! <3 unsupervised
        for batch_idx, real in enumerate(loader):
            real = real.to(device)

            # Train Critic: max (E(critic(real)) - E(critic(fake)))
            # or min (-1) * (E(critic(real)) - E(critic(fake)))
            for _ in range(CRITIC_ITERATIONS):
                noise = torch.randn(real.shape[0], NOISE_DIM, 1, 1).to(device)
                fake = gen(noise)
                critic_real = critic(real)
                critic_fake = critic(fake)
                gp = gradient_penalty(critic, real, fake, device)

                loss_critic = torch.mean(critic_fake) - torch.mean(critic_real) + LAMBDA_GP*gp

                opt_critic.zero_grad()
                loss_critic.backward(retain_graph=True)
                opt_critic.step()

            # Train Generator: min (E(critic(real)) - E(critic(fake)))
            # which is the same as: min (-1) * E(critic(fake))
            # because generator can not influence on E(critic(real)).
            output = critic(fake)
            loss_gen = -torch.mean(output)

            opt_gen.zero_grad()
            loss_gen.backward()
            opt_gen.step()

            # Print losses occasionally and print to tensorboard
            if batch_idx % 100 == 0:
                print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] "
                    f"Batch {batch_idx}/{len(loader)} "
                    f"Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}")

                with torch.no_grad():
                    fake = gen(fixed_noise)
                    # take out (up to) 32 examples
                    img_grid_real = torchvision.utils.make_grid(
                        real[:32], normalize=True
                    )
                    img_grid_fake = torchvision.utils.make_grid(
                        fake[:32], normalize=True
                    )

                    writer_real.add_image("Real", img_grid_real, global_step=step)
                    writer_fake.add_image("Fake", img_grid_fake, global_step=step)

                step += 1

        # Save models
        gen_checkpoint = {
            'state_dict': gen.state_dict(),
            'optimizer': opt_gen.state_dict(),
            'step': step,
        }
        critic_checkpoint = {
            'state_dict': critic.state_dict(),
            'optimizer': opt_critic.state_dict(),
            'step': step,
        }
        save_checkpoint(gen_checkpoint, gen_checkpoint_name)
        save_checkpoint(critic_checkpoint, critic_checkpoint_name)

In [None]:
train(step=step)

Epoch [1/5] Batch 0/3166 Loss D: 7935.7070, loss G: -0.1983
Epoch [1/5] Batch 100/3166 Loss D: -26.8365, loss G: 32.4119
Epoch [1/5] Batch 200/3166 Loss D: -20.4779, loss G: 29.1151
Epoch [1/5] Batch 300/3166 Loss D: -28.7844, loss G: 42.3356
Epoch [1/5] Batch 400/3166 Loss D: -46.5882, loss G: 59.1718
Epoch [1/5] Batch 500/3166 Loss D: -38.3854, loss G: 66.0575
Epoch [1/5] Batch 600/3166 Loss D: -30.5604, loss G: 73.1024
Epoch [1/5] Batch 700/3166 Loss D: -28.6866, loss G: 52.6345
Epoch [1/5] Batch 800/3166 Loss D: -14.9460, loss G: 72.3076
Epoch [1/5] Batch 900/3166 Loss D: -20.1652, loss G: 69.6176
Epoch [1/5] Batch 1000/3166 Loss D: -17.4340, loss G: 69.3884
Epoch [1/5] Batch 1100/3166 Loss D: -17.2508, loss G: 68.5212
Epoch [1/5] Batch 1200/3166 Loss D: -15.9447, loss G: 72.8945
Epoch [1/5] Batch 1300/3166 Loss D: -16.1062, loss G: 68.3909
Epoch [1/5] Batch 1400/3166 Loss D: -17.9006, loss G: 70.9348
Epoch [1/5] Batch 1500/3166 Loss D: -18.3491, loss G: 72.8673
Epoch [1/5] Batch 1

## Save models if necessary

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

!ls -hal '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial'
!cp critic_celeb.pth.tar '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial'
!cp generator_celeb.pth.tar     '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial'

Mounted at /content/gdrive
total 193M
-rw------- 1 root root   32K Mar  1 12:56 '2021.02.26 basics.ipynb'
-rw------- 1 root root   33K Mar  1 12:59 '2021.03.01-1 Pytorch Neural Network example.ipynb'
-rw------- 1 root root   33K Mar  5 08:04 '2021.03.01-2 Convolutional Neural Network example.ipynb'
-rw------- 1 root root   20K Mar 10 10:20 '2021.03.01-3 Recurrent Neural Network example.ipynb'
-rw------- 1 root root   34K Mar  4 22:21 '2021.03.01-4 Bidirectional LSTM example.ipynb'
-rw------- 1 root root   11K Mar  5 14:30 '2021.03.02-1 How to save and load models in Pytorch.ipynb'
-rw------- 1 root root   16K Mar  2 14:21 '2021.03.02-2 Transfer Learning and Fine Tuning.ipynb'
-rw------- 1 root root   49K Mar  5 08:09 '2021.03.02-3 Build custom dataset.ipynb'
-rw------- 1 root root  1.5M Mar 10 08:06 '2021.03.03-1 How to build custom Datasets for Text in Pytorch.ipynb'
-rw------- 1 root root 1013K Mar  4 17:47 '2021.03.03-2 Data Augmentation using Torchvision.ipynb'
-rw------- 1 root ro