<a href="https://colab.research.google.com/github/martinpius/GANS/blob/main/Wasserstein_GAN_WGAN_with_Gradient_Penalty_Application_with_Flickr30k_images.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [16]:
from google.colab import drive
drive.mount("/content/drive", force_remount = True)
try:
  COLAB = True
  import torch
  print(f">>>> You are on CoLaB with torch version {torch.__version__}")
except Exception as e:
  print(f">>>> {type(e)} {e}\n>>>> please correct {type(e)} and re-load")
  COLAB = False
def time_fmt(t: float = 123.198)->float:
  h = int(t / (60 * 60))
  m = int(t % (60 * 60) / 60)
  s = int(t % 60)
  return f">>>> hrs: {h} min: {m:>02} sec: {s:>05.2f}"
print(f">>>> testing the time formating function............\n>>>> time elapsed\t{time_fmt()}")
if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")


Mounted at /content/drive
>>>> You are on CoLaB with torch version 1.9.0+cu102
>>>> testing the time formating function............
>>>> time elapsed	>>>> hrs: 0 min: 02 sec: 03.00


In [17]:
# In this manuscript we are going to implement the Wasserstein GAN with gradient Penalty.
# This is a typically WGAN with minor changes in the discriminator normalization layer. Here
# we use layer-normalization (suggested in the paper). During training we do not employ 
# gradient cliping since its delay training. The penalt factor introduced to the cost function
# is bassically a loss accounted by a merged imageges (the generated and the real images) sfter
# passing through the discriminator network. The learning rate used is 1e-5 with a lambda value of 10
# as hyper-parameter for the gradient penalty.

In [18]:
import torch, torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import numpy as np
import pandas as pd
import os, time, datetime, random
from tqdm import tqdm
from PIL import Image
from tensorflow import summary
%load_ext tensorboard


The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [19]:
# setup the seed value for reproducability and the GPU device to deterministic.
seed = 1234
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True


In [20]:
# Setup the directories for tensorboard:
current_time = datetime.datetime.now().timestamp()
fake_path = "logs/tensorboard/wgan_gp_fake/" + str(current_time)
real_path = "logs/tensorboard/wgan_gp_real/" + str(current_time)
fake_writer = summary.create_file_writer(fake_path)
real_writer = summary.create_file_writer(real_path)

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-20-af34e50a6e77>", line 5, in <module>
    fake_writer = summary.create_file_writer(fake_path)
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/summary_ops_v2.py", line 516, in create_file_writer_v2
    metadata={"logdir": logdir})
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/summary_ops_v2.py", line 285, in __init__
    self._init_op = init_op_fn(self._resource)
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/gen_summary_ops.py", line 146, in create_summary_file_writer
    _ops.raise_from_not_ok_status(e, name)
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py", line 6897, in raise_from_not_ok_status
    six.raise_from(core._status_to_exception(e.code, message), None)
  File "

FailedPreconditionError: ignored

In [None]:
# The discriminator class:
class Discriminator(nn.Module):
  def __init__(self, img_channels, d_features):
    super(Discriminator, self).__init__()
    self.discriminator = nn.Sequential(
        nn.Conv2d(img_channels, d_features, kernel_size = 4, stride = 2, padding = 1),
        nn.LeakyReLU(0.2),
        self.__dblock__(d_features, 2*d_features, 4, 2, 1),
        self.__dblock__(2*d_features, 4*d_features, 4, 2, 1),
        self.__dblock__(4*d_features, 8*d_features, 4, 2, 1),
        nn.Conv2d(8*d_features, 1, kernel_size = 4, stride = 2, padding = 0)
        )
  def __dblock__(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels,
                  out_channels,
                  kernel_size,
                  stride,
                  padding,
                  bias = False),
          nn.InstanceNorm2d(out_channels, affine = True),
          nn.LeakyReLU(0.2))
  def forward(self, input_tensor):
    return self.discriminator(input_tensor)


In [None]:
# The generator class -> same as in WGAN-grad-clipping
class Generator(nn.Module):
  def __init__(self, z_dim, img_channels, g_features):
    super(Generator, self).__init__()
    self.generator = nn.Sequential(
        self.__gblock__(z_dim, 16*g_features, kernel_size = 4, stride = 2, padding = 0),
        self.__gblock__(16*g_features, 8*g_features, 4, 2, 1),
        self.__gblock__(8*g_features, 4*g_features, 4, 2, 1),
        self.__gblock__(4*g_features, 2*g_features, 4, 2, 1),
        nn.ConvTranspose2d(2*g_features, img_channels, kernel_size = 4, stride = 2, padding = 1),
        nn.Tanh())
    
  def __gblock__(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels,
                           out_channels,
                           kernel_size,
                           stride,
                           padding,
                           bias = False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU())
  
  def forward(self, input_tensor):
    return self.generator(input_tensor)

In [None]:
# parameter initializer -> we initialize the parameters to random normal dst
def __initializer__(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, mean = 0.0, std = 0.02)

In [None]:
# testing the model if it yield the desired outputs:
def __test__():
  img_channels = 3
  g_features = 64
  d_features = 64
  z_dim = 100
  batch_size = 128
  noise = torch.randn(batch_size,z_dim, 1, 1)
  img = torch.randn(batch_size, img_channels, g_features, d_features)
  disc = Discriminator(img_channels, d_features)
  gen = Generator(z_dim, img_channels, g_features)
  print(f">>>> generator's output shape: {gen(noise).shape}\n>>>> discriminator's output shape: {disc(img).shape}")

In [None]:
__test__()

In [None]:
# Hyperparameters and other useful transformation objects
learning_rate = 1e-4
batch_size = 128
d_features = 64
g_features = 64
img_size = 64
lambda_value = 10
z_dim = 100
EPOCHS = 3
img_channels = 3
fixed_noise = torch.randn(batch_size, z_dim, 1, 1).to(device = device)
discriminator = Discriminator(img_channels, d_features).to(device = device)
generator = Generator(z_dim, img_channels, g_features).to(device = device)
__initializer__(discriminator)
__initializer__(generator)
transforms = transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Resize((img_size, img_size)),
                                 transforms.Normalize([0.5 for _ in range(img_channels)],
                                                      [0.5 for _ in range(img_channels)])])

In [None]:
# Loading and pre-processing images from the directory <Google drive>
os.chdir("/content/drive/MyDrive/WGAN-FLICKR30K") # save all checkpoints here
class MyLoader(Dataset):
  def __init__(self, root_dir, csv_dir, transform = None):
    super(MyLoader, self).__init__()
    self.transform = transform
    self.root_dir = root_dir
    self.csv_dir = csv_dir
    self.dfm = pd.read_csv(csv_dir, error_bad_lines = False)
  
  def __len__(self):
    return len(self.dfm)
  
  def __getitem__(self, index):
    img_path = os.path.join(self.root_dir, self.dfm.iloc[index, 0])
    image = Image.open(img_path)
    if self.transform:
      image = self.transform(image)
    return image

dataset = MyLoader(root_dir = "/content/drive/MyDrive/flickr30k_images/flickr8k/images",
                   csv_dir = "/content/drive/MyDrive/flickr30k_images/flickr8k/captions.txt",
                   transform = transforms)
loader = DataLoader(dataset = dataset, shuffle = True, batch_size = batch_size)
x_train_image_batch = next(iter(loader))
print(f">>>> the shape of pre-processed image is: {x_train_image_batch.shape}")

In [None]:
# Get the optimizers for the generator and discriminator (Adam with zero-momentum):
gen_optim = optim.Adam(params = generator.parameters(), lr = learning_rate, betas = (0, 0.999))
disc_optim = optim.Adam(params = discriminator.parameters(), lr = learning_rate, betas = (0, 0.99))


In [None]:
# We now define our gradient penalty function:
def gradient_penalty(discriminator, real_img, fake_img, device = 'cpu'):
  batch_size, C, H, W = real.shape
  # epsilon value is a tensor of numbers btn 0/1 to be used to interpolate the imgs
  eps = torch.rand(batch_size, 1, 1, 1).repeat(1, C, H, W).to(device = device)
  mixed_img = real_img* eps + fake_img * (1 - eps) # interpolate the real and fake image
  mixed_preds = discriminator(mixed_img)

  #compute the penalty factor (using l2-norm of the gradients for the mixed images)
  grads = torch.autograd.grad(
      inputs = mixed_img,
      outputs = mixed_preds,
      grad_outputs = torch.ones_like(mixed_preds),
      create_graph = True,
      retain_graph = True)[0]
      #reshape the gradient tensor by combining all other columns but batch-size
  grads = grads.view(grads.shape[0], -1)
  grad_norm = grads.norm(2, dim = 1) # compute the norm "l2" across the gradients
  GP = (torch.mean((grad_norm - 1)**2))
  return GP


In [None]:
# The training loop of WGGAN with GP:
tic = time.time()
step = 0
for epoch in range(EPOCHS):
  print(f"\n>>>> train begins for epoch [{epoch + 1}]\n>>>> please wait while the model is training..............................")
  for idx, real in enumerate(tqdm(loader)):
    real = real.to(device)
    noise = torch.randn(batch_size, z_dim, 1, 1).to(device = device)
    fake_imgs = generator(noise)

    # training the discriminator using GP:
    real_preds = discriminator(real).reshape(-1)
    fake_preds = discriminator(fake_imgs).reshape(-1)
    real_loss = torch.mean(real_preds)
    fake_loss = torch.mean(fake_preds)
    gp = gradient_penalty(discriminator, real, fake_imgs, device = device)
    disc_loss = (-(real_loss - fake_loss) + lambda_value * gp)
    discriminator.zero_grad()
    disc_loss.backward(retain_graph = True)
    disc_optim.step()

    # training the generator as in WGAN with clipping (max [log(D(G(z)))])
    gen_preds = discriminator(fake_imgs).reshape(-1)
    gen_loss = -torch.mean(gen_preds)
    generator.zero_grad()
    gen_loss.backward()
    gen_optim.step()

    # printing summary on screen and tensorboard
    if idx % 100 == 0:
      print(f"\n>>>> end of epoch {epoch + 1}, generator loss: {gen_loss:.4f}, discriminator loss: {disc_loss:.4f}")
      fake_img = generator(fixed_noise)
      fake_img_grids = torchvision.utils.make_grid(fake_img[:32], normalize = True)
      real_img_grids = torchvision.utils.make_grid(real[:32], normalize = True)
      with fake_writer.as_default():
        summary.scalar("generator loss", gen_loss.cpu().detach().numpy(), step = step)
      with real_writer.as_default():
        summary.scalar("dscriminator loss", disc_loss.cpu().detach().numpy(), step = step)
      step += 1
%tensorboard --logdir logs/tensorboard
toc = time.time()
print(f"\n>>>> time elapsed for training WGAN with GP for 10 epochs is {time_fmt(toc - tic)}")
