<a href="https://colab.research.google.com/github/martinpius/GANS/blob/main/Wasserstein_GAN_Implementantion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount("/content/drive", force_remount = True)
try:
  COLAB = True
  import torch
  print(f">>>> You are on CoLaB with torch version: {torch.__version__}")
except Exception as e:
  print(f">>>> {type(e)}: {e}\n>>>> please correct {type(e)} and reload")
  COLAB = False
if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")
def time_fmt(t: float = 123.187)->float:
  h = int(t / (60 * 60))
  m = int(t % (60 * 60) /60)
  s = int(t % 60)
  return f" {h} hrs: {m:>02} min: {s:>05.2f} sec"
print(f">>>> time formating\t.......................\n>>>> time elapsed\t{time_fmt()}")

Mounted at /content/drive
>>>> You are on CoLaB with torch version: 1.8.1+cu101
>>>> time formating	.......................
>>>> time elapsed	 0 hrs: 02 min: 03.00 sec


In [2]:
#In this notebook we are going to train a DCGAN using modified technique to penalize the loss function.
#Insteady of gradient clipping which may bring convergence problems we apply Wasserstein loss technique.
#This technique results ito meaningfull loss function. To compersate for BCELoss and gradient clipping 
#fake and real images are intersected and the penalty loss improves the model's performance:


In [3]:
#Import necessary modules 
import torch, time, math, random, sys, os
import numpy as np
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import PIL
from tqdm import tqdm

In [4]:
#Set the seed values for reproducability and the device to deterministic to avoid errors during training
seed = 9182
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [5]:
#This model is a complete DCGAN architectured with the only difference in the loss function
#Here we use the Wasserstein Loss hence the name WGAN.

In [6]:
class Discriminator(nn.Module):
  def __init__(self, img_channels, d_features):
    super(Discriminator, self).__init__()
    self.discriminator = nn.Sequential(
        nn.Conv2d(img_channels, d_features, kernel_size = 4, stride = 2, padding = 1),
        nn.LeakyReLU(0.2),
        self.disc_block(d_features, 2*d_features, 4, 2, 1),
        self.disc_block(2*d_features, 4*d_features, 4, 2, 1),
        self.disc_block(4*d_features, 8*d_features, 4, 2, 1),
        nn.Conv2d(8*d_features, 1, kernel_size = 4, stride = 2, padding = 0))

  def disc_block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias = False
        ),
        nn.InstanceNorm2d(out_channels, affine = True),
        nn.LeakyReLU(0.2))
  
  def forward(self, input_tensor):
    return self.discriminator(input_tensor)
  

In [7]:
class Generator(nn.Module):
  def __init__(self, img_channels, z_dim, g_features):
    super(Generator, self).__init__()
    self.generator = nn.Sequential(
        self.gen_block(z_dim, 16*g_features, kernel_size = 4, stride = 2, padding = 0),
        self.gen_block(16*g_features, 8*g_features, 4, 2, 1),
        self.gen_block(8*g_features, 4*g_features, 4, 2, 1),
        self.gen_block(4*g_features, 2*g_features, 4, 2, 1),
        nn.ConvTranspose2d(2*g_features, img_channels, kernel_size = 4, stride = 2, padding = 1),
        nn.Tanh())
  def gen_block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels,
                           out_channels,
                           kernel_size,
                           stride,
                           padding,
                           bias = False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU())
  
  def forward(self, input_tensor):
    return self.generator(input_tensor)

In [8]:
#Testing the network to see if they gives desired outputs:
def __test__():
  img_channels = 3
  H,W,batch_size = 64,64,64
  z_dim = 100
  rand_img = torch.randn(batch_size, img_channels, W,H)#for discriminator
  noise = torch.randn(batch_size, z_dim, 1, 1)#for generator
  gen = Generator(img_channels, z_dim, 8)
  disc = Discriminator(img_channels, 8)
  gen_out = gen(noise)# expected shape [batch_size, img_channels, W, H]
  disc_out = disc(rand_img) #expected shape [batch_size, 1, 1, 1]
  return f"gen_out_shape: {gen_out.shape}\tdisc_out_shape: {disc_out.shape}"

In [9]:
#Initialize the model parameters to random normal distribution with mean 0 and std 0.01
def __initializer__(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, mean = 0.0, std = 0.01)

In [10]:
__test__()

'gen_out_shape: torch.Size([64, 3, 64, 64])\tdisc_out_shape: torch.Size([64, 1, 1, 1])'

In [11]:
#Hyperparameters
batch_size = 128
lambda_gp = 10 #parameter for the gradient penalty
img_channels = 1
img_size = 64
d_features = 64
g_features = 64
learning_rate = 1e-4
EPOCHS = 2
z_dim = 100
disc_iter = 5
noise = torch.randn(batch_size, z_dim, 1, 1)
fixed_noise = torch.randn(batch_size, z_dim, 1, 1).to(device = device)
rand_img = torch.randn(batch_size, img_channels, d_features, g_features)
discriminator = Discriminator(img_channels, d_features).to(device = device)
generator = Generator(img_channels, z_dim, g_features).to(device)
__initializer__(generator)# wt initialization for the generator
__initializer__(discriminator) #wt initialization for the discriminator
print(f">>>> generator graph:\n\n{generator}\n\n>>>> discriminator graph:\n{discriminator}")

>>>> generator graph:

Generator(
  (generator): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(100, 1024, kernel_size=(4, 4), stride=(2, 2), bias=False)
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    

In [12]:
#Loading and preprocessing data from torchvision
transforms = transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Resize(img_size),
                                 transforms.Normalize([0.5 for _ in range(img_channels)],[0.5 for _ in range(img_channels)])])
dfm = datasets.MNIST(root = "wgans_mnist/", transform = transforms, download = True)
loader = DataLoader(dataset = dfm, batch_size = batch_size, shuffle = True)
x_loader, y_loader = next(iter(loader))
print(f"x_loader_shape: {x_loader.shape}\ty_loader_shape: {y_loader.shape}")


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to wgans_mnist/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting wgans_mnist/MNIST/raw/train-images-idx3-ubyte.gz to wgans_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to wgans_mnist/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting wgans_mnist/MNIST/raw/train-labels-idx1-ubyte.gz to wgans_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to wgans_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to wgans_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting wgans_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to wgans_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to wgans_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting wgans_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to wgans_mnist/MNIST/raw

Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


x_loader_shape: torch.Size([128, 1, 64, 64])	y_loader_shape: torch.Size([128])


In [13]:
#The gradient penalty: This function will compute the penalt factor to penalize the slopes of the wight to improve 
#the discriminator performance
def __gp__(discriminator, real_img, fake_img, device = device):
  batch_size, channels, H, W = real_img.shape
  #compute the epsilon factor to enable images interpolation
  e = torch.randn(batch_size, 1,1,1).repeat(1, channels, H, W).to(device = device)
  ip_imgs = e * real_img + (1-e) * fake_img #results into a hybrid images
  ip_scores = discriminator(ip_imgs) #compute score if we apply the hybrid image to the discriminator
  #computing the gradient penalty. (see the paper (WGAN-paper for more details))
  grads = torch.autograd.grad(
      inputs = ip_imgs,
      outputs = ip_scores,
      grad_outputs = torch.ones_like(ip_scores),
      create_graph = True, 
      retain_graph = True)[0]
  grads = grads.view(grads.shape[0], -1)
  grad_norm = grads.norm(2, dim = 1)
  gp = torch.mean((grad_norm - 1)**2)
  return gp


In [14]:
#Get the optimizer, tensorboard roots
gen_opt = optim.Adam(params = generator.parameters(), lr = learning_rate, betas = (0.0, 0.99))
disc_opt = optim.Adam(params = discriminator.parameters(), lr = learning_rate, betas = (0.0, 0.99))

In [None]:
#The training loop
step = 0
real_writer = SummaryWriter(f"runs/real_images")
fake_writer = SummaryWriter(f"runs/fake_images")
global_tic = time.time()
for epoch in range(EPOCHS):
  tic = time.time()
  print(f"\n>>>> training starts for epoch {epoch + 1}\n>>>> please wait while model is training........")
  for idx, (real, _) in enumerate(tqdm(loader)):
    real = real.to(device = device)
    #training the discriminator
    noise = torch.randn(batch_size, z_dim, 1, 1).to(device = device)
    fake_image = generator(noise)
    for _ in range(disc_iter):
      noise = torch.randn(batch_size, z_dim, 1, 1).to(device = device)
      fake_image = generator(noise)
      disc_real_out = discriminator(real).reshape(-1)#flattening to 1d
      disc_fake_out = discriminator(fake_image).reshape(-1)#flattening to 1d
      real_loss = torch.mean(disc_real_out)
      fake_loss = torch.mean(disc_fake_out)
      gp = __gp__(discriminator,real, fake_image,device)
      disc_loss = (-(real_loss - fake_loss) + lambda_gp *gp)
      discriminator.zero_grad()
      disc_loss.backward(retain_graph = True)
      disc_opt.step()
    #training the generator
    gen_out = discriminator(fake_image).reshape(-1)
    gen_loss = -torch.mean(gen_out)
    generator.zero_grad()
    gen_loss.backward()
    gen_opt.step()
    toc = time.time()
    if idx % 200 == 0:
      print(f"\n>>>> time elapsed at the end of epoch {epoch + 1} for batch {idx} is {time_fmt(toc - tic)}")
      print(f">>>> generator loss: {gen_loss:.4f} | generatot PPL: {math.exp(gen_loss):7.4f}")
      print(f">>>> discriminator loss: {disc_loss:.4f} | discriminator PPL: {math.exp(disc_loss):7.4f}")
      #Recorning to tensorboard
      with torch.no_grad():
        fake_image = generator(fixed_noise)
        real_img_grid = torchvision.utils.make_grid(real[:32], normalize = True)
        fake_img_grid = torchvision.utils.make_grid(fake_image[:32], normalize = True)
        real_writer.add_image("real_images", real_img_grid, global_step = step)
        fake_writer.add_image("fake_images", fake_img_grid, global_step = step)
      step += 1
global_toc = time.time()
print(f"\n>>>> time elapsed after 10 epochs training is: {time_fmt(global_toc - global_tic)}")


  0%|          | 0/469 [00:00<?, ?it/s][A


>>>> training starts for epoch 1
>>>> please wait while model is training........

>>>> time elapsed at the end of epoch 1 for batch 0 is  0 hrs: 00 min: 04.00 sec
>>>> generator loss: 59.5622 | generatot PPL: 73711602955997443725983744.0000
>>>> discriminator loss: -32.3376 | discriminator PPL:  0.0000



  0%|          | 1/469 [00:05<43:34,  5.59s/it][A
  0%|          | 2/469 [00:10<41:20,  5.31s/it][A
  1%|          | 3/469 [00:15<40:36,  5.23s/it][A
  1%|          | 4/469 [00:20<40:07,  5.18s/it][A
  1%|          | 5/469 [00:25<39:46,  5.14s/it][A
  1%|▏         | 6/469 [00:30<39:32,  5.12s/it][A
  1%|▏         | 7/469 [00:35<39:21,  5.11s/it][A
