In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import random
from datetime import datetime
import matplotlib.pyplot as plt
import glob
import PIL.Image as Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch import Tensor
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.utils as vutils
from torchvision import datasets, transforms

## Architecture

In [None]:
class Generator(nn.Module):
    def __init__(self,
                latent_space : int = 100,
                ) -> None:
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # Input : 100
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
    
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=True),
            nn.Tanh()
        )

    def forward(self, x: Tensor) -> Tensor:
        out = self.main(x)
        return out

In [None]:
# Class : ConvBlock
class ConvBlock(nn.Module):
    def __init__(self,
                in_channels : int = 3,
                out_channels : int = 128,
                kernel_size : int = 4,
                stride : int = 2, 
                padding : int = 1,
                bias : bool = False,
                ) -> None:
        super(ConvBlock, self).__init__()
        
        self.block = nn.Sequential(
            nn.Conv2d(in_channels = in_channels,
                     out_channels = out_channels,
                     kernel_size = kernel_size,
                     stride = stride,
                     padding = padding,
                     bias = bias),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace = True),
        )
    def forward(self,
                x : Tensor,
               ) -> Tensor :
        out = self.block(x)
        return out

In [None]:
# Class : Discriminator
class Discriminator(nn.Module):
    def __init__(self,
                channels : list = [128, 256, 512],
                ) -> None:
        super(Discriminator, self).__init__()
        
        # Top layer
        self.top = nn.Sequential(
            nn.Conv2d(in_channels = 3,
                      out_channels = 64,
                      kernel_size = 4,
                      stride = 2,
                     padding = 1,),
            nn.LeakyReLU(0.2, inplace = True),
        
        )
        
        # Block of layers
        blocks = []
        in_channels = 64
        for i in range(len(channels)):
            blocks.append(ConvBlock(in_channels = in_channels, out_channels = channels[i]))
            in_channels = channels[i]
        self.blocks = nn.Sequential(*blocks)
        
        # Bottom layer
        self.bottom = nn.Sequential(
            nn.Conv2d(in_channels = 512,
                      out_channels = 1,
                      kernel_size = 4,
                      stride = 1,
                     padding = 0,),
            nn.Sigmoid(),
        )
    def forward(self, x : Tensor) -> Tensor:
        out = self.top(x)
        out = self.blocks(out)
        out = self.bottom(out)
        out = torch.flatten(out, 1)
        return out

In [None]:
# Initialize weights of layers
def weights_init(m):
    classname = m.__class__.__name__
    if (classname.find("Conv2d") or classname.find('ConvTranspose2d')) != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

## Custom Dataset

In [None]:
class CalebDataset(nn.Module):
    def __init__(self,
                 datapath : str,
                 image_size : int = 64,
                ) -> None:
        super(CalebDataset, self).__init__()
        
        self.image_size = image_size
        # List of images ends with .jpg
        self.images = sorted(glob.glob(datapath+"/*.jpg"))
        
        self.transform = transforms.Compose([
            transforms.Resize((self.image_size, self.image_size),
                              transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor(),
        ])
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        images =  self.transform(Image.open(self.images[index % len(self.images)]))
        return images, index

## Utils

In [None]:
def build_model(args):
    generator = Generator(args.latent_space).to(args.device)
    discriminator = Discriminator().to(args.device)
    
    return generator, discriminator

def load_dataset(args
                ) :
    
    loader = DataLoader(CalebDataset(args.datapath,
                                      args.image_size,
                                     ),
                        batch_size = args.batch_size,
                        shuffle = True,
                       )
    return loader

def loss_function():
    adversarial_loss = nn.BCELoss().to(args.device)
    
    return adversarial_loss
    
def optimizer(generator, discriminator, args):
    optimizer_g = optim.Adam(generator.parameters(), lr = args.lr, betas = (args.b1, args.b2))
    optimizer_d = optim.Adam(discriminator.parameters(), lr = args.lr, betas = (args.b1, args.b2))
    
    return optimizer_g, optimizer_d

## Trainer

In [None]:
def trainer(loader,
            generator,
            discriminator,
            optimizer_g,
            optimizer_d,
            loss_fn,
            args,
           ):
    """ Training Generative Adversarial Networks """
    
    # list to keep track of progress
    g_losses = []
    d_losses = []
    
    start = datetime.now()
    
    # set model in training model
    generator.train()
    discriminator.train()
    

    
    for epoch in range(1, args.epochs):
        for index, (real, _) in enumerate(loader):
            #flatten
            real = real.to(args.device)
            size = real.size(0)
            # create ground truth. set real sample level to 1 and fake sample level to 0
            real_sample = torch.full([size, 1], 1.0, dtype=real.dtype, device=args.device)
            fake_sample = torch.full([size, 1], 0.0, dtype=real.dtype, device=args.device)
            
            # create a noise sample for generator input
            noise = torch.randn([size, 100, 1, 1], device = args.device)
            
            """ Train Discriminator"""
            # initialize discriminator model gradients
            discriminator.zero_grad()
            # calculate loss of discriminator model on real images
            output = discriminator(real)
            loss_real_d = loss_fn(output, real_sample)
            # calculate gradient of discriminator in backward pass
            loss_real_d.backward()
            D_x = output.mean().item()
            
            # generate fake image
            fake = generator(noise)
            # calculate loss of discriminator model on fake images
            output = discriminator(fake.detach())
            loss_fake_d = loss_fn(output, fake_sample)
            # calculate gradient of discriminator for this batch summed with previous gradient
            loss_fake_d.backward()
            D_G_z1 = output.mean().item()
            # calculate loss of discriminator model as sum on both real images and fake images
            loss_d = loss_real_d + loss_fake_d
            # update weight of discriminator model
            optimizer_d.step()
            
            """ Train Generator """
            # initialize generator model gradient
            generator.zero_grad()
            # calculate the loss of discriminator model on fake images
            output = discriminator(fake)
            loss_g = loss_fn(output, real_sample)
            # calculate gradient of generator
            loss_g.backward()
            D_G_z2 = output.mean().item()
            # update weight of generator model
            optimizer_g.step()
            
            # save images to see training stability
            if index == 0:
                vutils.save_image(vutils.make_grid(fake, normalize = False), os.path.join(args.outputs_dir, f"fake_image_{epoch}.jpg"))
                vutils.save_image(vutils.make_grid(real, normalize = False), os.path.join(args.outputs_dir, f"real_image_{epoch}.jpg"))
                
            # Print the loss function every ten iterations and the last iteration in this epoch.
            if index % 10 == 0 or index == len(loader):
                print(f"Train stage: adversarial "
                      f"Epoch[{epoch:04d}/{args.epochs:04d}]({index:05d}/{len(loader):05d}) "
                      f"D Loss: {loss_d.item():.6f} G Loss: {loss_g.item():.6f} "
                      f"D(D_x): {D_x:.6f} D(D_G_z1)/D(D_G_z2): {D_G_z1:.6f}/{D_G_z2:.6f}.")
            
            # save losses for plotting later
            g_losses.append(loss_g.item())
            d_losses.append(loss_d.item())
        print("Training complete in: " + str(datetime.now() - start))
        
    return g_losses, d_losses

## Configuration

In [None]:
class ARGS():
    outputs_dir = 'output'
    datapath = "../input/celeba-dataset/img_align_celeba/img_align_celeba"
    latent_space = 100
    image_size = 64
    batch_size = 256
    epochs = 20
    b1 = 0.5
    b2 = 0.99
    lr = 0.002
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    outputs_dir = os.path.join(outputs_dir,  "dcgan")
    if not os.path.exists(outputs_dir):
        os.makedirs(outputs_dir)
    
    seed = 42
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    
args = ARGS()

## Main

In [None]:
def main(args):
    loader = load_dataset(args)
    print('Loaded dataset successfully')
    generator, discriminator = build_model(args)
    print('Built model successfully')
    
    # Weight Initialization
    # weight's initialization
    generator.apply(weights_init)
    discriminator.apply(weights_init)
    
    print('Weights initialize successfully')
    
    
    loss_fn = loss_function()
    print('Define loss function succesfully')
    optimizer_g, optimizer_d = optimizer(generator, discriminator,args)
    print('Define all optimization function successfully')
    
    # training
    g_losses, d_losses = trainer(loader, generator, discriminator,
                                 optimizer_g, optimizer_d, loss_fn, args)
    
    plt.figure(figsize=(10,5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(g_losses,label="G")
    plt.plot(d_losses,label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

In [None]:
if __name__ == "__main__":
    main(args)