# 🤪 WGAN - CelebA Faces

In [None]:
working_dir = "/home/mary/work/repos/generative_deep_Learning_2nd_edition_pytorch"
exp_dir = working_dir + "/notebooks/04_gan/02_wgan_gp/"

In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os

# Add the path to the notebooks folder
notebooks_path = os.path.abspath(working_dir)
if notebooks_path not in sys.path:
    sys.path.append(notebooks_path)

utils_path = os.path.abspath(exp_dir)
if utils_path not in sys.path:
    sys.path.append(utils_path)

In [None]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F
from torchinfo import summary
from torch import optim
from torch import autograd
import torch
from torch.utils.tensorboard import SummaryWriter
from matplotlib import pyplot as plt

from notebooks.utils import display

import math

In [None]:
IMAGE_SIZE = 64
CHANNELS = 3
BATCH_SIZE = 512
Z_DIM = 128
LEARNING_RATE = 0.0002
EPOCHS = 200
CRITIC_STEPS = 3
GP_WEIGHT = 10.0
LOAD_MODEL = False
ADAM_BETA_1 = 0.5
ADAM_BETA_2 = 0.9

## 1. Prepare the data <a name="prepare"></a>

In [None]:
data_dir = working_dir + "/data"
dataset_dir = data_dir + "/celeba-dataset"

In [None]:
import torch.utils


transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=0.5, std=0.5)
])

train_data = datasets.ImageFolder(dataset_dir, transform=transform)

train_data, _ = torch.utils.data.random_split(train_data, [0.1, 0.9])

train_data_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

print("full dataset size = ", len(train_data))


In [None]:
train_iter = iter(train_data_loader)
sample_images, _ = next(train_iter)

In [None]:
display(sample_images)

## 2. Build the GAN <a name="build"></a>

In [None]:
class Critic(nn.Module):
    def __init__(self, channels, image_size):
        super().__init__()
        self.channels = channels
        self.image_size = image_size

        self.dropout = nn.Dropout(0.3)
        
        # layer group 1
        p = self._get_padding_size(input_w=self.image_size, stride=2, kernal_size=4)

        self.conv_1 = nn.Conv2d(in_channels=self.channels, out_channels=64, 
                                kernel_size=4, stride=2, padding=p)
        
        # layer group 2 
        p = self._get_padding_size(input_w=self.image_size/2, stride=2, kernal_size=4)

        self.conv_2 = nn.Conv2d(in_channels=64, out_channels=128, 
                               kernel_size=4, stride=2, padding=p)
        
        # layer group 3
        p = self._get_padding_size(input_w=self.image_size/4, stride=2, kernal_size=4)

        self.conv_3 = nn.Conv2d(in_channels=128, out_channels=256, 
                               kernel_size=4, stride=2, padding=p)

        # layer group 4
        p = self._get_padding_size(input_w=self.image_size/8, stride=2, kernal_size=4)
        self.conv_4 = nn.Conv2d(in_channels=256, out_channels=512, 
                               kernel_size=4, stride=2, padding=p)

        self.conv_5 = nn.Conv2d(in_channels=512, out_channels=1, 
                               kernel_size=4, stride=1, padding=0)

    @staticmethod
    def _get_padding_size(input_w, stride, kernal_size):
        p = ((input_w /2) - 1) * stride
        p = (p - input_w) + kernal_size
        p = math.ceil(p/2)

        return p
    
    def forward(self, x):
        B = x.shape[0]
        x = self.conv_1(x)
        x = F.leaky_relu(x, 0.2)

        x = self.conv_2(x)
        x = F.leaky_relu(x, 0.2)
        x = self.dropout(x)

        x = self.conv_3(x)
        x = F.leaky_relu(x, 0.2)
        x = self.dropout(x)

        x = self.conv_4(x)
        x = F.leaky_relu(x, 0.2)
        x = self.dropout(x)

        x = self.conv_5(x)

        x = x.view((B,1))

        return x


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
critic = Critic(CHANNELS, IMAGE_SIZE)

print(critic.state_dict)

In [None]:
summary(critic, (1,CHANNELS, IMAGE_SIZE, IMAGE_SIZE))

In [None]:
class Generator(nn.Module):
    def __init__(self, num_dim, channels):
        super().__init__()

        self.num_dim = num_dim
        self.channels = channels

        self.conv_trans_1 = nn.ConvTranspose2d(in_channels=self.num_dim, out_channels=512, 
                                               kernel_size=4, stride=1, padding=0, bias=False, output_padding=0)
        self.bn_1 = nn.BatchNorm2d(num_features=512, momentum=0.9)

        p = self._get_padding_size(input_w=4, stride=2, kernal_size=4)
        self.conv_trans_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, 
                                               kernel_size=4, stride=2, padding=p, output_padding=0, bias=False)
        
        self.bn_2 = nn.BatchNorm2d(num_features=256, momentum=0.9)

        p = self._get_padding_size(input_w=4*2, stride=2, kernal_size=4)
        self.conv_trans_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, 
                                               kernel_size=4, stride=2, padding=p, output_padding=0, bias=False)
        self.bn_3 = nn.BatchNorm2d(num_features=128, momentum=0.9)

        p = self._get_padding_size(input_w=4*4, stride=2, kernal_size=4) 
        self.conv_trans_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, 
                                               kernel_size=4, stride=2, padding=p, output_padding=0, bias=False)
        self.bn_4 = nn.BatchNorm2d(num_features=64, momentum=0.9)

        p = self._get_padding_size(input_w=4*8, stride=2, kernal_size=4)
        self.conv_trans_5 = nn.ConvTranspose2d(in_channels=64, out_channels=self.channels, 
                                               kernel_size=4, stride=2, padding=p, output_padding=0, bias=False)

    
    @staticmethod
    def _get_padding_size(input_w, stride, kernal_size):
        p = ((input_w - 1) * stride) / 2
        p = p - input_w
        p = p + (kernal_size / 2)
        p = p + 1/2
        return math.floor(p)
    
    def forward(self, x):
        B = x.shape[0]
        x = x.view((B, self.num_dim, 1, 1))
        x = self.conv_trans_1(x)
        x = self.bn_1(x)
        x = F.leaky_relu(x, 0.2)

        x = self.conv_trans_2(x)
        x = self.bn_2(x)
        x = F.leaky_relu(x, 0.2)

        x = self.conv_trans_3(x)
        x = self.bn_3(x)
        x = F.leaky_relu(x, 0.2)

        x = self.conv_trans_4(x)
        x = self.bn_4(x)
        x = F.leaky_relu(x, 0.2)

        x = self.conv_trans_5(x)
        x = F.tanh(x)

        return (x)


In [None]:
generator = Generator(Z_DIM, CHANNELS)
print(generator.state_dict)

In [None]:
summary(generator, (1, Z_DIM))

Loss functions

In [None]:
def wasserstein_critic_loss(fake_pred, real_pred):
    w_loss = torch.mean(fake_pred) - torch.mean(real_pred)
    return w_loss

def wasserstein_generator_loss(fake_pred):
    w_loss = -1 * torch.mean(fake_pred)
    return w_loss

def gradient_penalty(critic, real_images, fake_images):
    B = real_images.shape[0]
    alpha = torch.randn((B, 1, 1, 1)).to(real_images.device)
    interpolated = (alpha * fake_images) + ((1 - alpha) * real_images)
    interpolated.requires_grad_(True)
    pred = critic(interpolated)

    # calculate the gradient of the output with respect to the input
    gradients = autograd.grad(outputs=pred, inputs=interpolated, grad_outputs=torch.ones_like(pred),
                  create_graph=True, only_inputs=True)[0]
    # flaten the gradients for each image
    gradients = gradients.view(B, -1)
    # L2 norm
    # grad_norm = torch.sqrt(torch.sum(torch.square(gradients)))
    grad_norm = gradients.norm(2, dim=1)
    gp = torch.mean((grad_norm - 1)**2)

    return gp


WGAN class

In [None]:
class WGAN (nn.Module):
    def __init__(self, num_dim, channels, image_size, log_dir="./log/"):
        super().__init__()
        self.num_dim = num_dim
        self.channels = channels
        self.image_size = image_size
        
        self.generator = Generator(num_dim, channels)
        self.critic = Critic(channels, image_size)

        self.writer_train = SummaryWriter(log_dir + "/train")
    
    # this function will not be used, we just implment it to be able to use
    # the torchinfo summary function
    def forward(self, x):
        x = self.generator(x)
        x = self.critic(x)
        return x

    def train_step(self, real_images):
        # set the dicremenator and generator to training mode
        self.generator.train()
        self.critic.train()

        # generate fake images
        B = real_images.shape[0]

        # We train the critic more times per step
        for i in range(self.critic_steps):
            # zero the grads
            self.c_optimizer.zero_grad()
            input_noise = torch.randn((B, self.num_dim)).to(self.device)

            real_images = real_images.to(device)

            fake_images = self.generator(input_noise)

            # train discremeantor
            fake_pred = self.critic(fake_images)
            real_pred = self.critic(real_images)

            #calculate dicremenator loss 
            c_w_loss = self.c_w_loss_fn(fake_pred=fake_pred, real_pred=real_pred)
            c_gp_loss = self.c_gp_loss_fn(critic=self.critic, real_images=real_images, fake_images=fake_images)
            c_total_loss = c_w_loss + (self.gp_lambda * c_gp_loss)


            # calculate gradiants
            c_total_loss.backward()
            # update waits
            self.c_optimizer.step()

        # train generator

        # zero the grads
        self.g_optimizer.zero_grad()

        fake_images = self.generator(input_noise)
        # train discremeantor
        fake_pred = self.critic(fake_images)

        g_w_loss = self.g_w_loss_fn(fake_pred)
        #cal gradients
        g_w_loss.backward()
        #update waits
        self.g_optimizer.step()

        loss_dict = {"c_w_loss":c_w_loss.item(), "c_gp_loss":c_gp_loss.item(), 
                     "c_total_loss": c_total_loss.item(), "g_w_loss":g_w_loss.item()}
        
        return loss_dict


    def fit(self, training_dataloader, epochs, g_optimizer, c_optimizer, 
            c_w_loss_fn, c_gp_loss, g_w_loss_fn, gp_lambda, device, 
            critic_steps=3, callbacks=None):
        
        self.g_optimizer = g_optimizer
        self.c_optimizer = c_optimizer
        self.c_w_loss_fn = c_w_loss_fn
        self.g_w_loss_fn = g_w_loss_fn
        self.c_gp_loss_fn = c_gp_loss
        self.gp_lambda = gp_lambda
        self.critic_steps = critic_steps
        self.device = device


        for i in range(1, epochs+1):

            losses = {"c_w_loss_acc": 0,
                      "c_gp_loss_acc": 0,
                      "c_total_loss_acc": 0,
                      "g_w_loss_acc": 0}

            # loop over all data in the training set
            for images, _ in training_dataloader:

                # run training_step
                loss_dict = self.train_step(images)
                losses["c_w_loss_acc"] += loss_dict["c_w_loss"]
                losses["c_gp_loss_acc"] += loss_dict["c_gp_loss"]
                losses["c_total_loss_acc"] += loss_dict["c_total_loss"]
                losses["g_w_loss_acc"] += loss_dict["g_w_loss"]
            

            losses["c_w_loss_acc"] /= len(train_data)
            losses["c_gp_loss_acc"] /= len(train_data)
            losses["c_total_loss_acc"] /= len(train_data)
            losses["g_w_loss_acc"] /= len(train_data)
            
            # print epoch progress
            print(
                f"Epoch {i}/{epochs}: Training: c_w_loss: {losses['c_w_loss_acc'] :.4f} "
                f" c_gp_loss: {losses['c_gp_loss_acc']:.4f} "
                f" c_total_loss: {losses['c_total_loss_acc']:.4f}"
                f" g_w_loss: {losses['g_w_loss_acc']:.4f}"
            )
            # log loss to tensorboard
            self.writer_train.add_scalar("c_w_loss", losses["c_w_loss_acc"], global_step=i)
            self.writer_train.add_scalar("c_gp_loss", losses["c_gp_loss_acc"], global_step=i)
            self.writer_train.add_scalar("c_total_loss",losses["c_total_loss_acc"], global_step=i)
            self.writer_train.add_scalar("g_w_loss", losses["g_w_loss_acc"], global_step=i)
            
            # run call back functions
            if callbacks is not None:
                logs = {"device":self.device,
                        "generator":self.generator,
                        "model_state_dict": self.state_dict(),
                        "loss": losses
                }

                for callback in callbacks:
                    callback.on_epoch_end(i, logs=logs)


create the required callbacks

In [None]:
class Callback:
    def on_epoch_end(self, epoch, logs=None):
        pass

In [None]:
class GenerateImages(Callback):
    def __init__(self, num_images, latent_dim, save_dir="./gen_examples"):
        super().__init__()
        self.num_images = num_images
        self.latent_dim = latent_dim
        self.save_dir = save_dir
    def on_epoch_end(self, epoch, logs=None):
        device = logs["device"]
        generator = logs["generator"]

        input_noise = torch.randn((self.num_images, self.latent_dim)).to(device)

        with torch.no_grad():
            generator.eval()
            # scale back to 0 to 255
            gen_imgs = generator(input_noise).detach() * 127.5 + 127.5
            display(gen_imgs, save_to=self.save_dir+f"/epoch_{epoch}.png")
        
        return


In [None]:
class SaveCheckpoint(Callback):
    def __init__(self, save_dir, save_every=10):
        super().__init__()
        self.save_dir = save_dir
        self.save_every = save_every
    def on_epoch_end(self, epoch, logs=None):
        
        if (epoch % self.save_every) == 0:
            checkpoint = {"epoch":epoch,
                        "model_state_dict":logs["model_state_dict"],
                        "loss":logs["loss"]
                        }
            checkpoint_file = self.save_dir + f"/checkpoint_{epoch}.pth"

            torch.save(checkpoint, checkpoint_file)

Create the WGAN object and train it

In [None]:
log_dir =  exp_dir + "/log"
os.makedirs(log_dir, exist_ok=True)

sample_dir =  exp_dir + "/sample_gen"
os.makedirs(sample_dir, exist_ok=True)

checkpoint_dir =  exp_dir + "/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

In [None]:
wgan = WGAN(Z_DIM, CHANNELS, IMAGE_SIZE, log_dir).to(device)
print(wgan.state_dict)

In [None]:
summary(wgan, (1, Z_DIM))

In [None]:
g_loss_function = wasserstein_generator_loss
c_loss_function = wasserstein_critic_loss
c_gp_loss = gradient_penalty

g_optimizer = optim.Adam(wgan.generator.parameters(), lr=LEARNING_RATE, betas=[ADAM_BETA_1, ADAM_BETA_2])
c_optimizer = optim.Adam(wgan.critic.parameters(), lr=LEARNING_RATE, betas=[ADAM_BETA_1, ADAM_BETA_2])

In [None]:
callbacks = [GenerateImages(10, Z_DIM, save_dir=sample_dir),
             SaveCheckpoint(save_dir=checkpoint_dir, save_every=30)]

In [None]:
# check if we have checkpoint to load
if LOAD_MODEL:
    checkpoint_file = checkpoint_dir + "/checkpoint_120.pth"
    checkpoint = torch.load(checkpoint_file)
    wgan.load_state_dict(checkpoint["model_state_dict"])

In [None]:
wgan.fit(train_data_loader, epochs=EPOCHS, g_optimizer=g_optimizer, c_optimizer=c_optimizer,
          c_w_loss_fn=c_loss_function, c_gp_loss=c_gp_loss, g_w_loss_fn=g_loss_function, device=device, callbacks=callbacks,
          gp_lambda=GP_WEIGHT, critic_steps=CRITIC_STEPS)

## 3. Generate new images <a name="decode"></a>

In [None]:
z_sample = torch.randn(size=(10, Z_DIM)).to(device)
imgs = wgan.generator(z_sample)
display(imgs, cmap=None)