In [None]:
import numpy as np # linear algebra
import os
import glob
import random
import PIL.Image as Image
import matplotlib.pyplot as plt
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils

## Architecture

In [None]:
class Down(nn.Module):
    def __init__(self,
                 in_planes : int,
                 out_planes: int,
                 kernel_size : int = 4,
                 stride : int = 2,
                 padding : int = 1,
                 bias : bool = False,
                 normalize : bool = True,
                 dropout :bool = False,
                ) -> None:
        super(Down, self).__init__()
        
        block = [nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, bias = bias)]
        if normalize:
            block.append(nn.BatchNorm2d(out_planes))
        block.append(nn.LeakyReLU(0.2))
        if dropout:
            block.append(nn.Dropout(0.5))
            
        self.block = nn.Sequential(*block)
            
    def forward(self,
                x : Tensor,
               ) -> nn.Sequential:
        
        return self.block(x)

In [None]:
class Up(nn.Module):
    def __init__(self,
                 in_planes : int,
                 out_planes : int,
                 kernel_size : int = 4,
                 stride : int = 2,
                 padding : int = 1,
                 bias : bool = False,
                 dropout : bool = False,
                ) -> None:
        super(Up, self).__init__()
        
        block = [nn.ConvTranspose2d(in_planes,
                                    out_planes,
                                    kernel_size,
                                    stride,
                                    padding,
                                    bias = bias)]
        block.append(nn.BatchNorm2d(out_planes))
        block.append(nn.LeakyReLU(0.2))
        
        if dropout:
            block.append(nn.Dropout(0.5))
            
        self.block = nn.Sequential(*block)
    
    def forward(self,
                x : Tensor,
                skip : Tensor,
               ) -> nn.Sequential:
        
        x = self.block(x)
        x = torch.cat((x, skip), 1)
        return x

In [None]:
class Generator(nn.Module):
    def __init__(self, 
                 in_planes : int = 3,
                 planes : int = 64,
                 normalize : bool = True,
                ) -> None:
        super(Generator, self).__init__()
        ##########################################
        # Downward layers
        #########################################
        
        self.down1 = Down(in_planes, planes, normalize = False)
        
        self.down2 = Down(planes, planes * 2)
        self.down3 = Down(planes * 2, planes * 4)
        
        self.down4 = Down(planes * 4, planes * 8, dropout = True)
        self.down5 = Down(planes * 8, planes * 8, dropout = True)
        self.down6 = Down(planes * 8, planes * 8, dropout = True)
        self.down7 = Down(planes * 8, planes * 8, dropout = True)
        
        self.down8 = Down(planes * 8,
                             planes * 8,
                             normalize = False, 
                             dropout = True)
        
        #########################################
        # Upward layers
        ########################################
        self.up1 = Up(planes * 8, planes * 8, dropout = True)
        self.up2 = Up(planes * 16, planes * 8, dropout = True)
        self.up3 = Up(planes * 16, planes * 8, dropout = True)
        self.up4 = Up(planes * 16, planes * 8, dropout = True)
        self.up5 = Up(planes * 16, planes * 4)
        self.up6 = Up(planes * 8, planes * 2)
        self.up7 = Up(planes * 4, planes)
        
        # Final Layer
        self.final = nn.Sequential(
            nn.ConvTranspose2d(planes * 2, in_planes, 4, 2, 1, bias = True),
            nn.Tanh()
        )
        
    def forward(self, x):
        down1 = self.down1(x)
        down2 = self.down2(down1)
        down3 = self.down3(down2)
        down4 = self.down4(down3)
        down5 = self.down5(down4)
        down6 = self.down6(down5)
        down7 = self.down7(down6)
        down8 = self.down8(down7)
        
        up1 = self.up1(down8, down7)
        up2 = self.up2(up1, down6)
        up3 = self.up3(up2, down5)
        up4 = self.up4(up3, down4)
        up5 = self.up5(up4, down3)
        up6 = self.up6(up5, down2)
        up7 = self.up7(up6, down1)
        
        final = self.final(up7)
        return final

In [None]:
class Discriminator(nn.Module):
    def __init__(self, 
                 in_planes : int = 3,
                 planes : int = 64,
                ) -> None:
        super(Discriminator, self).__init__()
        
        def blocks(in_planes, planes, normalize = True):
            block = [nn.Conv2d(in_planes, planes, 4, 2, 1, bias = False)]
            if normalize:
                block.append(nn.BatchNorm2d(planes))
            block.append(nn.ReLU(True))
            
            return block
        
        self.block = nn.Sequential(
            *blocks(in_planes * 2, planes, normalize = False),
            *blocks(planes, planes * 2),
            *blocks(planes * 2, planes * 4),
            *blocks(planes * 4, planes * 8),
            nn.Conv2d(512, 1, 4, 1, 1),
        )
        
    def forward(self, x, y):
        x = torch.cat([x, y], dim = 1)
        x = self.block(x)
        return x

## Dataset

In [None]:
class Pix2PixDataset(Dataset):
    """
    Dataset is taken from kaggle.
    Here is the link : https://www.kaggle.com/datasets/vikramtiwari/pix2pix-dataset
    
    This custom dataset is particular make for citycapes and facades data. If your want to
    train on other dataset like maps and edges to shoes from the link, you have to make
    appropriate changes. Also don't forget to check the shape of image data.
        
    """
    def __init__(self,
                root : str,
                dataset_name_list : list,
                 mode : str = "train",
                 shape : int = 256,
                ):
        super(Pix2PixDataset, self).__init__()
        
        def image_list(root, dataset_name_list, mode):
            image_list = []
            for i in range(len(dataset_name_list)):
                path = os.path.join(root, dataset_name_list[i], dataset_name_list[i], mode)
                for j in glob.glob(path+"/"+"*.jpg"):
                    image_list.append(j)
            return image_list
        
        self.image_list = image_list(root, dataset_name_list, mode)
        
        self.transform = transforms.Compose([
            transforms.Resize((shape, shape), Image.BICUBIC),
            transforms.ToTensor()
        ])
        
    def image_separate(self, image):
        image = np.array(image, dtype = np.uint8)
        height, width, channel = image.shape
        width = int(width / 2)
        Input =  Image.fromarray(image[:, width: , :])
        Target = Image.fromarray(image[:, :width , :])
        
        return Input, Target
    
    def __len__(self):
        return len(self.image_list)
    
    def __getitem__(self, index : int):
        image = Image.open(self.image_list[index])
        input_image, target_image = self.image_separate(image)
        input_image = self.transform(input_image)
        target_image = self.transform(target_image)
        return input_image, target_image
        

## Utils

In [None]:
def build_model(args):
    generator = Generator().to(args.device)
    discriminator = Discriminator().to(args.device)
    
    return generator, discriminator
    
    
def load_dataset(args, mode):
    
    dataloader = DataLoader(Pix2PixDataset(args.root, 
                                           args.dataset_name_list, mode, args.shape),
                           batch_size = args.batch_size,
                            shuffle = True,
                           num_workers = args.num_workers,)
    return dataloader

def loss_function(args):
    bce_loss = nn.BCEWithLogitsLoss().to(args.device)
    l1_loss = nn.L1Loss().to(args.device)
    return bce_loss, l1_loss

def optimizer(generator, discriminator, args):
    
    optimizer_g = optim.Adam(generator.parameters(),
                             lr = args.lr, betas = (args.b1, args.b2))
    optimizer_d = optim.Adam(discriminator.parameters(),
                             lr = args.lr, betas = (args.b1, args.b2))
    
    return optimizer_g, optimizer_d

## Train

In [None]:
def trainer(loader,
            generator,
            discriminator,
            bce_loss,
            l1_loss,
            optimizer_g,
            optimizer_d,
            args):
    
    ######################
    # Traning Pix2Pix gan
    ######################
    
    # List for ploting generative losses and discriminator losses
    g_losses = []
    d_losses = []
    
    start = datetime.now()
    
    # Set models in training mode
    generator.train()
    discriminator.train()
    
    for epoch in range(1, args.epochs):
        gepoch_loss=0.0
        depoch_loss=0.0
        
        epoch_time = datetime.now()
        for index, (inputs, targets) in enumerate(loader):
            b_size = inputs.size(0)
            
            inputs = inputs.to(args.device)
            targets = targets.to(args.device)
            
            # Generate images
            generated_image = generator(inputs)
            
            

            ############################
            # Train discriminator model
            ############################
            
            # Initialize discriminator model with zero gradient
            discriminator.zero_grad()
            # Calculate loss of discriminator model on input images
            disc_real = discriminator(inputs, targets)
            d_real_loss = bce_loss(disc_real, torch.ones_like(disc_real))
            # Calculate gradient of discriminator on input images
            d_real_loss.backward()
            
            # Calculate loss of discriminator on generated images
            disc_fake = discriminator(inputs, generated_image.detach())
            d_fake_loss = bce_loss(disc_fake, torch.zeros_like(disc_fake))
            # Calculate gradient of discriminator on generated images
            d_fake_loss.backward()
            # Sum of calculated losses on both inputs and generated images
            d_loss = (d_real_loss + d_fake_loss) * 0.5
            # Update weight of discriminator model
            optimizer_d.step()
            
            ########################
            # Train Generator model
            ########################
            
            # Initialize generator model with zero gradient
            generator.zero_grad()
            # Calculate loss of generator model based on discriminator model
            gen_output = discriminator(inputs, generated_image)
            g_bce_loss = bce_loss(gen_output,  torch.ones_like(gen_output))
            # Calculate pixel-wise loss
            g_pixel_wise_loss = args.lambda_pixel * l1_loss(generated_image, targets)
            # Calculate gradient of discriminator
            g_loss = g_bce_loss + g_pixel_wise_loss
            g_loss.backward()
            # Update weight of generator model
            optimizer_g.step()
            
            # batch losses
            gepoch_loss += g_loss.item()
            depoch_loss += d_loss.item()
            
            # save images to see training stability
            if index == 0:
                vutils.save_image(vutils.make_grid(generated_image, normalize = False),
                                  os.path.join(args.outputs_dir, f"gen_images_{epoch}.jpg"))
                vutils.save_image(vutils.make_grid(targets, normalize = False),
                                  os.path.join(args.outputs_dir, f"target_images_{epoch}.jpg"))
                
            # Print the loss function every ten iterations and the last iteration in this epoch.
            if index % 10 == 0 or index == len(loader):
                print(f"Epoch[{epoch:04d}/{args.epochs:04d}]({index:05d}/{len(loader):05d}) "
                      f"D Loss: {d_loss.item():.6f} G Loss: {g_loss.item():.6f} ")
        # Obtain per epoch losses
        gepoch_loss = gepoch_loss/len(loader)
        depoch_loss = depoch_loss/len(loader)
        g_losses.append(gepoch_loss)
        d_losses.append(depoch_loss)
        
        # Print epoch losses
        print(f"D Epoch Loss: {depoch_loss:.6f} G Loss: {gepoch_loss:.6f} "
              f"Epoch training completed in: {datetime.now() - epoch_time}")
    
    print("Training completed in: " + str(datetime.now() - start))
    
    return g_losses, d_losses

## Main

In [None]:
class Args():
    outputs_dir = 'result'
    if not os.path.exists(outputs_dir):
        os.makedirs(outputs_dir)
        
    dataset_name_list = ["cityscapes", "facedes"]
    root = "../input/pix2pix-dataset"
    mode_train = "train"
    most_eval = "val"
    shape = 256
    batch_size = 64
    num_workers = 2
    epochs = 200
    lr = 0.0002
    b1 = 0.5
    b2 = 0.999
    lambda_pixel = 100
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    seed = 42
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    
args = Args()

In [None]:
def main(args):
    loader = load_dataset(args, args.mode_train)
    print('Loaded dataset successfully')
    generator, discriminator = build_model(args)
    print('Built model successfully')
    bce_loss, l1_loss = loss_function(args)
    print('Define loss function succesfully')
    optimizer_g, optimizer_d = optimizer(generator, discriminator, args)
    print('Define all optimization function successfully')
    
    # Train 
    g_losses, d_losses = trainer(loader, generator, discriminator,
                                 bce_loss, l1_loss,
                                 optimizer_g, optimizer_d, args)
    
    plt.figure(figsize=(10,5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(g_losses,label="G")
    plt.plot(d_losses,label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

In [None]:
if __name__ == "__main__":
    main(args)