In [None]:
import numpy as np
import os
import random
import PIL.Image as Image
from datetime import datetime
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch import Tensor

import torchvision
from torchvision import datasets, transforms
import torchvision.utils as vutils

## Architecture

In [None]:
# Class : Generator
class Generator(nn.Module):
    def __init__(self,
                 latent_space : int = 100,
                 image_size : int = 28,
                 channels : int = 1,
                 num_classes : int = 10,
                 ) -> None:
        super(Generator, self).__init__()
        self.image_size = image_size
        self.channels = channels
        self.num_classes = num_classes
        
        self.label_embedding = nn.Embedding(num_classes, num_classes)
        
        def block(in_planes, out_planes, normalize = True):
            layers = [nn.Linear(in_planes, out_planes)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_planes))
            layers.append(nn.LeakyReLU(0.2, inplace = True))
            
            return layers
        
        self.blocks = nn.Sequential(
            *block(latent_space + num_classes, 128, normalize = False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, channels * image_size * image_size),
            nn.Tanh()
        )
        
        
    def forward(self, x : Tensor,
                labels : list = None,
               ) -> Tensor:
        
        x = torch.cat([x, self.label_embedding(labels)], dim = 1)
        out = self.blocks(x)
        out = out.view(out.size(0), *(1, 28, 28))
        return out
        

In [None]:
# Class : Discriminator
class Discriminator(nn.Module):
    def __init__(self,
                 image_size : int = 28,
                 channels : int = 1,
                 num_classes : int = 10,
                ) -> None:
        super(Discriminator, self).__init__()
        
        self.label_embedding = nn.Embedding(num_classes, num_classes)
        
        self.block = nn.Sequential(
            nn.Linear(channels * image_size * image_size + num_classes, 512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(128, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x : Tensor,
                labels : list = None,
               ) -> Tensor:
        
        x = torch.flatten(x, 1)
        out = torch.cat([x, self.label_embedding(labels)], dim = 1)
        out = self.block(out)
        return out

## Utils

In [None]:
def build_model(args):
    generator = Generator().to(args.device)
    discriminator = Discriminator().to(args.device)
    
    return generator, discriminator
    
    
def load_dataset(args):
    """
    Make a directory : ./data/mnist
    Download dataset into the directory and //
    use dataloader to load the dataset with //
    given batch size
    ---------------
    return dataloader
    """
    os.makedirs("./data/mnist", exist_ok = True)
    dataloader = DataLoader(
        datasets.MNIST("./data/mnist",
                       train = True,
                       transform = transforms.Compose([
                           transforms.Resize(28),
                           transforms.ToTensor(),
                           transforms.Normalize([0.5,],[0.5,]),
                       ]),
                      download = True,),\
        batch_size = args.batch_size,
        shuffle = True,
    )
    
    return dataloader
        
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias.data is not None:
            m.bias.data.fill_(0.01)
    elif isinstance(m, nn.BatchNorm1d):
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        if m.bias.data is not None:
            m.bias.data.fill_(0.01)

def loss_function(args):
    
    loss = nn.BCELoss().to(args.device)
    return loss

def optimizer(generator, discriminator, args):
    
    optimizer_g = optim.Adam(generator.parameters(),
                             lr = args.lr, betas = (args.b1, args.b2))
    optimizer_d = optim.Adam(discriminator.parameters(),
                             lr = args.lr, betas = (args.b1, args.b2))
    
    return optimizer_g, optimizer_d

## Train

In [None]:
def trainer(loader,
            generator,
            discriminator,
            optimizer_g,
            optimizer_d, 
            loss_fn,
            args):
    
    ######################################################################
    #       Training Conditional Generative Adversarial Networks
    ######################################################################
    
    # List to plot losses
    g_losses = []
    d_losses = []
    
    start = datetime.now()
    
    # Set model in training mode
    generator.train()
    discriminator.train()
    
    for epoch in range(1, args.epochs):
        for index, (image, label) in enumerate(loader):
            image = image.view(-1, 784)
            b_size = image.shape[0]
            
            image = image.to(args.device)
            label = label.to(args.device)
            
            # create ground truth : set real sample level to 1 and fake sample level to 0
            real_sample = torch.full((b_size, 1), 1, dtype= image.dtype, device = args.device)
            fake_sample = torch.full((b_size, 1), 0, dtype= image.dtype, device = args.device)
            
            # noise sample and label for generator input
            noise = torch.randn([b_size, 100], device = args.device)
            gen_label = torch.randint(1, 10, (b_size,), device = args.device)
            ######################################################################
            # Train Discriminator max E(x) log(D(x)) + E(z) [log(1 - D(z))]
            ######################################################################
            
            # Initialize discriminator model with zero gradient
            discriminator.zero_grad()
    
            # Calculate loss of discriminator on real data
            output = discriminator(image, label)
            loss_real_d = loss_fn(output, real_sample)
            # Calculate gradiant of discriminator in backward pass
            loss_real_d.backward()
            d_x = output.mean().item()
            
            # Generate fake images and fake label
            fake = generator(noise, gen_label)
            
            # Calculate loss of discriminator on fake data
            fake_output = discriminator(fake.detach(), gen_label)
            loss_fake_d = loss_fn(fake_output, fake_sample)
            # Calculate gradient of discriminator in backward pass
            loss_fake_d.backward()
            d_g_z1 = loss_fake_d.mean().item()
            # Calculate loss of discriminator on both real and fake data
            loss_d = loss_real_d + loss_fake_d / 2
            # Update dicriminator weights
            optimizer_d.step()
            
            ######################################################################
            # Train Generator min E(z) [log(1 - D(z))]
            ######################################################################
            
            # Initialize generator model with zero gradient
            generator.zero_grad()
            
            # Calculate loss of generator based on discriminator output
            fake_output = discriminator(fake, gen_label)
            loss_g = loss_fn(fake_output, real_sample)
            # Calculate gradient of generator in backward pass
            loss_g.backward()
            d_g_z2 = output.mean().item() 
            # Update weights of generator model
            optimizer_g.step()
            # save images to see training stability
            if index == 0:
                vutils.save_image(vutils.make_grid(fake.detach().cpu().view(-1, *(1, 28, 28)),
                                                   normalize = True),
                                  os.path.join(args.outputs_dir, f"fake_image_{epoch}.jpg"))
                vutils.save_image(vutils.make_grid(image.detach().cpu().view(-1, *(1, 28, 28)),
                                                   normalize = True),
                                  os.path.join(args.outputs_dir, f"real_image_{epoch}.jpg"))
            
            if index % 10 == 0 or index == len(loader):
                print(f"Train stage: adversarial "
                      f"Epoch[{epoch:04d}/{args.epochs:04d}]({index:05d}/{len(loader):05d})"
                      f"D Loss: {loss_d.item():.6f} G Loss: {loss_g.item():.6f}"
                      f"D_X: {d_x:.6f} D_G_Z1/ D_G_Z2: {d_g_z1:.6f}/{d_g_z2:.6f}.")
            
            # save losses for plotting later
            g_losses.append(loss_g.item())
            d_losses.append(loss_d.item())
        
        print("Training complete in: " + str(datetime.now() - start))
    return g_losses, d_losses

In [None]:
class ARGS():
    outputs_dir = 'output'
    latent_space = 100
    batch_size = 64
    epochs = 50
    b1 = 0.5
    b2 = 0.9
    lr = 0.002
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    outputs_dir = os.path.join(outputs_dir,  "cgan")
    if not os.path.exists(outputs_dir):
        os.makedirs(outputs_dir)
    
    seed = 42
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    
args = ARGS()

## Main

In [None]:
def main(args):
    loader = load_dataset(args)
    print('Loaded dataset successfully')
    generator, discriminator = build_model(args)
    print('Built model successfully')
    generator.apply(init_weights)
    discriminator.apply(init_weights)
    print('Weights initialize successfully')
    loss_fn = loss_function(args)
    print('Define loss function succesfully')
    optimizer_g, optimizer_d = optimizer(generator, discriminator,args)
    print('Define all optimization function successfully')
    
    # training
    g_losses, d_losses = trainer(loader, generator, discriminator, optimizer_g, optimizer_d, loss_fn, args)
    
    plt.figure(figsize=(10,5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(g_losses,label="G")
    plt.plot(d_losses,label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

In [None]:
if __name__ == "__main__":
    main(args)