<a href="https://colab.research.google.com/github/martinpius/GANS/blob/main/Friday_night_WGANS_GP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [54]:
from google.colab import drive
drive.mount("/content/drive", force_remount = True)
try:
  COLAB = True
  import torch
  print(f">>>> You are on CoLaB with torch version: {torch.__version__}")
except Exception as e:
  print(f">>>> {type(e)}: {e}\n>>>> please correct {type(e)} and reload")
  COLAB = False
if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")
def time_fmt(t: float = 128.98)->float:
  h = int(t /(60 * 60))
  m = int(t % (60 * 60) / 60)
  s = int(t % 60)
  return f"{h} hrs: {m:>02} min: {s:>05.2f} sec"
print(f">>>> time formating\tprinting formated time for the demo\n>>>> time elapsed\t{time_fmt()}")

Mounted at /content/drive
>>>> You are on CoLaB with torch version: 1.8.1+cu101
>>>> time formating	printing formated time for the demo
>>>> time elapsed	0 hrs: 02 min: 08.00 sec


In [55]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from albumentations.pytorch import ToTensor
import albumentations as A
import numpy as np
import PIL
import random
import math
import os
import sys
import time
import copy

In [56]:
seed = 1234
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [57]:
class Discriminator(nn.Module):
  def __init__(self, img_channels, d_features):
    super(Discriminator, self).__init__()
    self.discriminator = nn.Sequential(
        nn.Conv2d(img_channels, d_features, kernel_size = 4, stride = 2, padding = 1),
        nn.LeakyReLU(0.2),
        self.__dblock__(d_features, 2*d_features, 4, 2, 1),
        self.__dblock__(2*d_features, 4*d_features, 4, 2, 1),
        self.__dblock__(4*d_features, 8*d_features, 4, 2, 1),
        nn.Conv2d(8*d_features, 1, kernel_size = 4, stride = 2, padding = 0))
    
  def __dblock__(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding, bias = False
        ),
        nn.InstanceNorm2d(out_channels, affine = True),
        nn.LeakyReLU(0.2))
  
  def forward(self, input_tensor):
    return self.discriminator(input_tensor)

In [58]:
class Generator(nn.Module):
  def __init__(self, img_channels, z_dim, g_features):
    super(Generator, self).__init__()
    self.generator = nn.Sequential(
        self.__gblock__(z_dim, 16*g_features, kernel_size = 4, stride = 2, padding = 0),
        self.__gblock__(16*g_features, 8*g_features, 4, 2, 1),
        self.__gblock__(8*g_features, 4*g_features, 4, 2, 1),
        self.__gblock__(4*g_features, 2*g_features, 4, 2, 1),
        nn.ConvTranspose2d(2*g_features, img_channels, kernel_size = 4, stride = 2, padding = 1),
        nn.Tanh())
    
  def __gblock__(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias = False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU())
  
  def forward(self, input_tensor):
    return self.generator(input_tensor)


In [59]:
def __test__():
  img_channels = 3
  H,W,batch_size = 64,64,64
  z_dim = 100
  rand_img = torch.randn(batch_size, img_channels, W,H)#for discriminator
  noise = torch.randn(batch_size, z_dim, 1, 1)#for generator
  gen = Generator(img_channels, z_dim, 8)
  disc = Discriminator(img_channels, 8)
  gen_out = gen(noise)# expected shape [batch_size, img_channels, W, H]
  disc_out = disc(rand_img) #expected shape [batch_size, 1, 1, 1]
  return f"gen_out_shape: {gen_out.shape}\tdisc_out_shape: {disc_out.shape}"

In [60]:
__test__()

'gen_out_shape: torch.Size([64, 3, 64, 64])\tdisc_out_shape: torch.Size([64, 1, 1, 1])'

In [61]:
def __initializer__(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, mean = 0.00, std = 0.02)
      

In [62]:
batch_size = 32
img_size = 64
img_channels =1
g_features = 64
d_features = 64
z_dim = 100
EPOCHS = 10
lambda_gp = 10
disc_iter = 5
learning_rate = 1e-4
fixed_noise = torch.randn(batch_size, z_dim, 1, 1).to(device = device)
discriminator = Discriminator(img_channels, d_features).to(device)
generator = Generator(img_channels, z_dim, g_features).to(device)
__initializer__(generator)
__initializer__(discriminator)
transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(img_size),
            transforms.Normalize([0.5 for _ in range(img_channels)],
                                [0.5 for _ in range(img_channels)])])
dfm = datasets.MNIST(root = "friday_wgans/", transform = transforms, download = True)
loader = DataLoader(dataset = dfm, batch_size = batch_size, shuffle = True)
x_loader, y_loader = next(iter(loader))
print(f">>>> x_loader_shape: {x_loader.shape}\ty_loader_shape: {y_loader.shape}")
print(f"\n\n>>>> discriminator graph:\n{discriminator}\n\ngenerator graph:\n{generator}")

>>>> x_loader_shape: torch.Size([32, 1, 64, 64])	y_loader_shape: torch.Size([32])


>>>> discriminator graph:
Discriminator(
  (discriminator): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): Lea

In [63]:
gen_opt = optim.Adam(params = generator.parameters(), lr = learning_rate, betas = (0.00, 0.999))
disc_opt = optim.Adam(params = discriminator.parameters(), lr = learning_rate, betas = (0.00, 0.999))

In [64]:
def __gp__(discriminator, real_img, fake_img, device = device):
  batch_size, C, H, W = real_img.shape
  e = torch.randn(batch_size, 1, 1, 1).repeat(1, C, H, W).to(device = device)
  ip_img = e*real_img + (1-e)*fake_img
  ip_scores = discriminator(ip_img)
  grads = torch.autograd.grad(
      inputs = ip_img,
      outputs = ip_scores,
      grad_outputs = torch.ones_like(ip_scores),
      create_graph = True,
      retain_graph = True)[0]
  grads = grads.view(grads.shape[0], -1)
  grads_norm = grads.norm(2, dim = 1)
  gp = torch.mean((grads_norm-1)**2)
  return gp
    

In [None]:
real_writer = SummaryWriter(f"logs/real_images")
fake_writer = SummaryWriter(f"logs/fake_images")
step = 0
global_tic = time.time()
for epoch in range(EPOCHS):
  tic = time.time()
  print(f"\n>>>> training starts for epoch: {epoch + 1}\n>>>> please wait while the model is training....................")
  for idx, (real, _) in enumerate(tqdm(loader)):
    real = real.to(device = device)
    for _ in range(disc_iter):
      noise = torch.randn(batch_size, z_dim, 1, 1).to(device = device)
      fake_image = generator(noise)
      real_disc_out = discriminator(real).reshape(-1)
      fake_disc_out = discriminator(fake_image).reshape(-1)
      disc_real_loss = torch.mean(real_disc_out)
      disc_fake_loss = torch.mean(fake_disc_out)
      gp = __gp__(discriminator, real, fake_image)
      disc_loss = (-(disc_real_loss - disc_fake_loss) + lambda_gp * gp)
      discriminator.zero_grad()
      disc_loss.backward(retain_graph = True)
      disc_opt.step()
      gen_out = discriminator(fake_image).reshape(-1)
      gen_loss = -torch.mean(gen_out)
      generator.zero_grad()
      gen_loss.backward()
      gen_opt.step()
      toc = time.time()
      if idx % 200 == 0:
        print(f"\n>>>> time elapsed at the end of epoch {epoch + 1} of batch {idx} is {time_fmt(toc - tic)}")
        print(f">>>> generator loss: {gen_loss:.4f} | generator PPL: {math.exp(gen_loss):7.4f}")
        print(f">>>> discriminator loss: {disc_loss:.4f} | discriminator PPL: {math.exp(disc_loss):7.4f}")
        with torch.no_grad():
          fake_image = generator(fixed_noise)
          real_img_grid = torchvision.utils.make_grid(real[:32], normalize = True)
          fake_img_grid = torchvision.utils.make_grid(fake_image[:32], normalize = True)
          real_writer.add_image("real_images", real_img_grid, global_step = step)
          fake_writer.add_image('fake_images', fake_img_grid, global_step = step)
        step+=1
global_toc = time.time()
print(f"\n>>>> time elapsed at the end of training: {time_fmt(global_toc - global_tic)}")

  0%|          | 0/1875 [00:00<?, ?it/s]


>>>> training starts for epoch: 1
>>>> please wait while the model is training....................

>>>> time elapsed at the end of epoch 1 of batch 0 is 0 hrs: 00 min: 07.00 sec
>>>> generator loss: 0.4419 | generator PPL:  1.5557
>>>> discriminator loss: 138.5393 | discriminator PPL: 1468373439783610395984339173235211649821434926801716217118720.0000

>>>> time elapsed at the end of epoch 1 of batch 0 is 0 hrs: 00 min: 15.00 sec
>>>> generator loss: 1.0993 | generator PPL:  3.0020
>>>> discriminator loss: 88.4659 | discriminator PPL: 263180705147594256973721438383216001024.0000

>>>> time elapsed at the end of epoch 1 of batch 0 is 0 hrs: 00 min: 24.00 sec
>>>> generator loss: 1.7239 | generator PPL:  5.6066
>>>> discriminator loss: 72.7807 | discriminator PPL: 40576376625350739227661177978880.0000

>>>> time elapsed at the end of epoch 1 of batch 0 is 0 hrs: 00 min: 32.00 sec
>>>> generator loss: 2.5063 | generator PPL: 12.2596
>>>> discriminator loss: 39.1812 | discriminator PPL: 1

  0%|          | 2/1875 [01:17<20:47:08, 39.95s/it]