<a href="https://colab.research.google.com/github/martinpius/GANS/blob/main/DCGAN_with_Gradient_clipping_Implementation_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount("/content/drive", force_remount = True)
try:
  COLAB = True
  import torch
  print(f">>>> You are in Google CoLaB with torch version: {torch.__version__}")
except Exception as e:
  print(f">>>> {type(e)}: {e}\n>>>> please correct {type(e)} and reload")
  COLAB = False
if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")
def time_fmt(t:float = 123.98)->float:
  h = int(t / (60 * 60))
  m = int(t % (60 * 60) / 60)
  s = int(t % 60)
  return f"{h} hrs: {m:>02} min: {s:>05.2f} sec"
print(f">>>> time formating\tplease wait\n>>>> time elapsed\t{time_fmt()}")

Mounted at /content/drive
>>>> You are in Google CoLaB with torch version: 1.8.1+cu101
>>>> time formating	please wait
>>>> time elapsed	0 hrs: 02 min: 03.00 sec


In [2]:
#In this notebook we are going to train a GAN network with the gradient clipping technique from scratch.
# This is a modification of DCGAN where the 
#Loss function is modified (BCELoss is replaced with the Wasserstein loss). Also the discriminator will be trained
#several times in a loop before the generator at every epoch. The architecture of the network is completely similar 
#to DCGAN but without a sigmoid layer in the discriminator. 

In [3]:
#Importing necessary modules:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import math, random, time, os, sys
import numpy as np


In [4]:
#setup the seeds for reproducability

In [5]:
seed = 1234
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [6]:
#The discriminator:
class Discriminator(nn.Module):
  def __init__(self, img_channels, d_features):
    super(Discriminator, self).__init__()
    self.discriminator = nn.Sequential(
        nn.Conv2d(img_channels, d_features, kernel_size = 4, stride = 2, padding = 1),
        nn.LeakyReLU(0.2),
        self.__dblock__(d_features, 2*d_features, 4, 2, 1),
        self.__dblock__(2*d_features, 4*d_features, 4, 2, 1),
        self.__dblock__(4*d_features, 8*d_features, 4, 2, 1),
        nn.Conv2d(8*d_features, 1, kernel_size = 4, stride = 1, padding = 0))
    
  def __dblock__(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels,
                  out_channels,
                  kernel_size,
                  stride,
                  padding,
                  bias = False),
                  nn.BatchNorm2d(out_channels),
                  nn.LeakyReLU(0.2))
  
  def forward(self, input_tensor):
    return self.discriminator(input_tensor)

In [7]:
#The generator:
class Generator(nn.Module):
  def __init__(self, img_channels, z_dim, g_features):
    super(Generator, self).__init__()
    self.generator = nn.Sequential(
        self.__gblock__(z_dim, 16*g_features, kernel_size = 4, stride = 2, padding = 0),
        self.__gblock__(16*g_features, 8*g_features, kernel_size = 4, stride = 2,padding = 1),
        self.__gblock__(8*g_features, 4*g_features, kernel_size = 4, stride = 2, padding = 1),
        self.__gblock__(4*g_features, 2*g_features, kernel_size = 4, stride = 2, padding = 1),
        nn.ConvTranspose2d(2*g_features, img_channels, kernel_size = 4, stride = 2, padding = 1),
        nn.Tanh())

  def __gblock__(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels,
                           out_channels,
                           kernel_size,
                           stride,
                           padding, bias = False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU())
  
  def forward(self, input_tensor):
    return self.generator(input_tensor)

In [8]:
#Testing the classes if they deliver the desired outputs
def __test__():
  H, W,img_channels = 64, 64, 3
  z_dim = 100
  batch_size = 64
  fake_img = torch.randn(batch_size,img_channels,W,H)# for the discriminator
  noise_img = torch.randn(batch_size, z_dim, 1, 1) #for the generator
  disc = Discriminator(img_channels, 8)#instantiate the discriminator network
  gen = Generator(img_channels, z_dim, 8)#instantiating the generator network
  disc_out = disc(fake_img)#shape expected [batch_size, 1, 1, 1]
  gen_out = gen(noise_img)# shape expected [batch_size, img_channels, W,H]
  return f"gen_out_shape: {gen_out.shape}\tdisc_out_shape: {disc_out.shape}"

In [9]:
__test__()

'gen_out_shape: torch.Size([64, 3, 64, 64])\tdisc_out_shape: torch.Size([64, 1, 1, 1])'

In [10]:
#weight initializer(we initialize the weight to random normal with mean 0 and std 0.01)
def __initializer__(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, mean = 0.00, std = 0.01)

In [11]:
#Hyperparameters:
batch_size = 128
EPOCHS = 10
disc_iter = 5
wt_clip = 0.01
learning_rate = 5e-5
d_features = 64
g_features = 64
img_size = 64
img_channels = 1
z_dim = 100
fixed_noise = torch.randn(batch_size, z_dim, 1, 1).to(device = device)
discriminator = Discriminator(img_channels, d_features).to(device = device)
generator = Generator(img_channels, z_dim, g_features).to(device = device)
__initializer__(discriminator)
__initializer__(generator)
print(f">>>> discriminator network:\n{discriminator}")
print(f"\n\n>>>> generator network:\n{generator}")

>>>> discriminator network:
Discriminator(
  (discriminator): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1))
  

In [12]:
#Loading and preprocess the data from torchvision:

In [13]:
transforms = transforms.Compose([transforms.ToTensor(), transforms.Resize(img_size),
    transforms.Normalize([0.5 for _ in range(img_channels)],[0.5 for _ in range(img_channels)])])

In [14]:
dfm = datasets.MNIST(root = "mnist/", transform = transforms, download = True,)
loader = DataLoader(dfm, batch_size = batch_size, shuffle = True)
x_dfm, y_dfm = next(iter(loader))
print(f"x_dfm_shape: {x_dfm.shape}\ty_dfm_shape: {y_dfm.shape}")

x_dfm_shape: torch.Size([128, 1, 64, 64])	y_dfm_shape: torch.Size([128])


In [None]:
step = 0
real_writer = SummaryWriter(f"runs/real_images")
fake_writer = SummaryWriter(f"runs/fake_images")
disc_opt = optim.RMSprop(params = discriminator.parameters(), lr = learning_rate)
gen_opt = optim.RMSprop(params = generator.parameters(), lr = learning_rate)
global_tic = time.time()
for epoch in range(EPOCHS):
  tic = time.time()
  print(f"\n>>>> training begins for epoch {epoch+1}\nplease wait while the model is training........")
  for idx, (real, _) in enumerate(tqdm(loader)):
    real = real.to(device = device)
    #training the discriminator (in a loop of 5 iterations)
    for _ in range(disc_iter):
      fake = torch.randn(batch_size, z_dim, 1, 1).to(device)
      noise = generator(fake)
      real_disc_out = discriminator(real).reshape(-1) #flattening to 1d
      fake_disc_out = discriminator(noise).reshape(-1)#flattening to 1d
      disc_loss = -(torch.mean(real_disc_out) - torch.mean(fake_disc_out))
      discriminator.zero_grad()
      disc_loss.backward(retain_graph = True)
      disc_opt.step()
      #clipping the gradients
      for m in discriminator.parameters():
        m.data.clamp_(-wt_clip, wt_clip)

      #training the generator (max (log(D(G(z)))))
      gen_out = discriminator(noise).reshape(-1)
      gen_loss = -torch.mean(gen_out)
      generator.zero_grad()
      gen_loss.backward()
      gen_opt.step()
      toc = time.time()
      if idx % 200 == 0:
        print(f"\n>>>> time at the end of epoch {epoch +1} of batch {idx} is {time_fmt(toc - tic)}")
        print(f">>>> generator Loss: {gen_loss:.4f} | Generator PPL: {math.exp(gen_loss):7.4f}")
        print(f">>>> discriminator loss: {disc_loss:.4f} | discriminator PPL: {math.exp(disc_loss):7.4f}")
        #fetching and printing the fake and real images to tensorboard
        with torch.no_grad():
          fake_images = generator(fixed_noise).to(device)
          fake_grid = torchvision.utils.make_grid(fake_images[:32], normalize = True)
          real_grid = torchvision.utils.make_grid(real[:32], normalize = True)
          real_writer.add_image('real_image', real_grid, global_step = step)
          fake_writer.add_image('fake_image', fake_grid, global_step = step)
        step += 1
gloabal_toc = time.time()
print(f"\n>>>> total time elapsed for 10 iterations is: {time_fmt(gloabal_toc - global_tic)}")


  0%|          | 0/469 [00:00<?, ?it/s][A


>>>> training begins for epoch 1
please wait while the model is training........

>>>> time at the end of epoch 1 of batch 0 is 0 hrs: 00 min: 19.00 sec
>>>> generator Loss: -0.0024 | Generator PPL:  0.9976
>>>> discriminator loss: -0.0098 | discriminator PPL:  0.9903

>>>> time at the end of epoch 1 of batch 0 is 0 hrs: 00 min: 42.00 sec
>>>> generator Loss: 0.0047 | Generator PPL:  1.0047
>>>> discriminator loss: -0.0528 | discriminator PPL:  0.9486
