In [101]:
'''
    The following is a PyTorch Implementation of a Conditional GAN
    This GAN leverages DCGAN architecture in the generator and the 
    discriminator
'''


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torchvision
import torchvision.transforms as transforms
import torch.utils.data.dataloader as DataLoader
import torch.utils.data.dataset as Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.autograd import Variable

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [193]:
'''
    We use DCGAN discriminator and generator in this network. The post related to 
    DCGAN covers the DCGAN specifics in more detail. Here I will only highlight the 
    parts which are relevant to the Conditional GAN. 
'''

class DCGAN_Discriminator(nn.Module):

    def __init__(self, featmap_dim=64, n_channel=1,cond_size = 10,batch_size = 128):
        super(DCGAN_Discriminator, self).__init__()
        self.featmap_dim = featmap_dim
        self.batch_size = batch_size
        self.main = nn.Sequential(
                    #second pass, in dim 1 x 32 x 32
                    nn.Conv2d(2,featmap_dim,4,2,1,bias = False), 
                    nn.LeakyReLU(0.2,inplace = True), 
                    #third pass, ft.dim*2 x 16 x 16 
                    nn.Conv2d(featmap_dim,featmap_dim*2,4,2,1,bias = False), 
                    nn.BatchNorm2d(featmap_dim*2),
                    nn.LeakyReLU(0.2,inplace = True), 
                    #fourth pass, ft.dim*4 x 8 x 8 
                    nn.Conv2d(featmap_dim*2,featmap_dim*4,4,2,1,bias = False), 
                    nn.BatchNorm2d(featmap_dim*4),
                    nn.LeakyReLU(0.2,inplace = True),
                    #fifth pass, ft.dim*2 x 4 x 4 
                    nn.Conv2d(featmap_dim*4,1,4,1,0,bias = False), 
                    nn.Sigmoid() 
                    )
        
        
        '''
            We need to define the following embedding to embed the MNIST class label into 
            a feature map layer. We do this by passing the label through an embedding to
            bring it up to the noise dimension. Then we pass it through dense linear layer
        '''
        self.embed = nn.Sequential(nn.Embedding(10,noise_dim),nn.Linear(noise_dim,32*32))

    def forward(self, x,label):
        """
        Strided convulation layers,
        Batch Normalization after convulation but not at input layer,
        LeakyReLU activation function with slope 0.2.
        """
        
        '''
            Embed the label and concatenate the feature maps 
        '''
        
        embedded_label = torch.reshape(self.embed(label), x.shape)
        x=torch.cat((x,embedded_label),dim=1)
        
        return self.main(x)


class DCGAN_Generator(nn.Module):

    def __init__(self, featmap_dim=128, n_channel=1, noise_dim=100 ,cond_size = 10, batch_size = 128):
        super(DCGAN_Generator, self).__init__()
        self.featmap_dim = featmap_dim
        self.batch_size = batch_size
        self.noise_dim  = noise_dim
        self.main = nn.Sequential(
                    ##input batch_size x 2 x 1 x noise_dim tensor
                    ##first pass, state size 4x4 
                   # nn.ConvTranspose2d(512+1, featmap_dim*8,4,1,0,bias = False), 
                    #nn.BatchNorm2d(featmap_dim*8), 
                   # nn.ReLU(True), 
                    ##second pass, state size 8x8
                    nn.ConvTranspose2d(512+1,featmap_dim*4,4,2,1,bias = False), 
                    nn.BatchNorm2d(featmap_dim*4), 
                    nn.ReLU(True), 
                    ##third pass, state size 16,16
                    nn.ConvTranspose2d(featmap_dim*4,featmap_dim*2,4,2,1,bias = False), 
                    nn.BatchNorm2d(featmap_dim*2), 
                    nn.ReLU(True), 
                    ##fourth pass, state size 32,32 
                    nn.ConvTranspose2d(featmap_dim*2,1,4,2,1,bias = False),
                    nn.Tanh() 
                    )
        
        '''
            Here we need to embed the label for the image as well as generate feature 
            maps from both the embedding and the latent code. These feature maps are 
            then concatenated. 
        '''
        
        self.embed = nn.Sequential(nn.Embedding(10,noise_dim),nn.Linear(noise_dim,4*4))
        self.latent_encode = nn.Sequential(nn.Linear(noise_dim, 512*4*4),nn.ReLU(True))
        
    def forward(self, x,label):
        """
        Project noise to featureMap * width * height,
        Batch Normalization after convulation but not at output layer,
        ReLU activation function.
        """
        
        '''
            Embed, pass into feature maps, then concatenate
        '''
        embedded_label= torch.reshape(self.embed(label),(x.shape[0],1,4,4))
        z = torch.reshape(self.latent_encode(x),(x.shape[0], 512,4,4))                           
        x =torch.cat((z,embedded_label),dim=1)
        return self.main(x) 


def weights_init(m): 
    classname = m.__class__.__name__
    if classname.find('Conv') != -1: 
        nn.init.normal_(m.weight.data,0.0,0.02)
    elif classname.find('BatchNorm') != -1: 
        nn.init.normal_(m.weight.data,1.0,0.02)
        nn.init.constant_(m.bias.data,0)      


In [201]:
#import data for MNIST 
transform = transforms.Compose([transforms.Resize(32),transforms.ToTensor(),transforms.Normalize(0,1)])
train_data = datasets.MNIST(
    root = 'data', 
    train = True, 
    transform =   transform,
    download = True
)

test_data = datasets.MNIST(
    root = 'data', 
    train = False, 
    transform =  transform
)
                                
batch_size = 256


loaders = {'train': torch.utils.data.DataLoader(train_data, batch_size = batch_size, shuffle = True,num_workers = 1, pin_memory = True), 
          'test' : torch.utils.data.DataLoader(test_data, batch_size = batch_size, shuffle = True, num_workers = 1, pin_memory = True)}


#define model 
torch.cuda.empty_cache()


In [203]:
import time
'''
Training Loop
'''


def train_GAN(epochs, discriminator, generator, loaders,dOptim, gOptim, dLoss, gLoss, disc_scheduler, gen_scheduler,noise_dim=100,batch_size = 500): 
    torch.cuda.empty_cache()
    total_step = len(loaders['train']) 
    
    #here we define the fixed noise vector
    dim_fixed_noise = 5
    fixed_noise = torch.randn(dim_fixed_noise,noise_dim,1,1,device = device)
    fixed_noise = torch.flatten(fixed_noise,start_dim =1)
    
    time_diff = [] 
    
    for epoch in range(epochs): 
        TS  = time.time() 
        for i , data in enumerate(loaders['train'],0): 
            #loads true inputs and labels into cuda 
            true_inputs,to_encode = data
            true_inputs = true_inputs.cuda()
            to_encode = to_encode.cuda()
            
            #get actual batch size (dataset might not be cut evenly) 
            act_batch_size = len(true_inputs)
            '''
                create labels for the real and fake targets
            '''
            real_label = torch.from_numpy(np.ones(act_batch_size).astype(np.float32)).to(device)
            fake_label = torch.from_numpy(np.zeros(act_batch_size).astype(np.float32)).to(device)
            
            
            t1 = time.time() 
            
            '''
                Discriminator Training
            '''
            
            #training Discriminator network on real data 
            discriminator.zero_grad()
            output =  discriminator(true_inputs,to_encode).view(-1)
            errD_real = dLoss(output, real_label) 
            errD_real.backward() 
            D_x = output.mean().item() 
            
            #train D with fake batch 
            
            #first need to generate false outputs
            noise = torch.randn(act_batch_size,noise_dim,device = device)
            raw_labels = torch.randint(0,10,(1,act_batch_size),device =device)[0]
            
            false_inputs = generator(noise,raw_labels)
            
            #training on false inputs
            output = discriminator(false_inputs,raw_labels).view(-1) 
            errD_fake = dLoss(output,fake_label) 
            errD_fake.backward(retain_graph = True) 
            D_G_z1 = output.mean().item() 
            #total D error 
            errD = (errD_real + errD_fake)/2 
            
            '''
                Generator Training
            '''
            
            #now train G 
            generator.zero_grad()
            output = discriminator(false_inputs,raw_labels).view(-1) 
            #output = discriminator(false_inputs).view(-1) 
            errG = gLoss(output,real_label) 
            errG.backward()
            D_G_z2 = output.mean().item() 
            #update G and D
            
            dOptim.step() 
            gOptim.step() 
            
            '''
               Sample output statements 
            '''
            t2 = time.time()    
            time_diff.append((t2-t1)/total_step) 
            if (i+1) % (total_step/2) == 0: 
                print('Epoch [{}/{}], Step [{}/{}], Discriminator Loss: {:.8f}, Generator Loss: {:.8f}, Pass Rates (real/fake): {:.8f} {:.8f}' 
                       .format(epoch + 1, epochs, i + 1, total_step, errD.item(),errG.item(), D_x,D_G_z2))
                resize_transform = transforms.Resize(64)
                
                full_mat = np.zeros([64*dim_fixed_noise,640])
              

                for j in range(10): 
                    fixed_labels = (torch.ones((1,dim_fixed_noise),device = device)*j).type(torch.int64)
                    fixed_encoded_labels = torch.squeeze(F.one_hot(fixed_labels,num_classes= 10)).type(torch.FloatTensor).cuda()
                    
                    test_output = generator(fixed_noise, fixed_labels)
                    test_output = resize_transform(torch.reshape(test_output,(dim_fixed_noise, 32,32)))#resize output 
                    full_mat[:,j*64:(j+1)*64] = test_output.view(dim_fixed_noise*64,64).cpu().detach().numpy()
                    
                full_mat = np.where(full_mat>0, full_mat, 0)                
                img = Image.fromarray(np.uint8(full_mat * 255) , 'L')
                print('test Image: \n')
                display(img)
        TE = time.time() 
        
        print('average step time: {:.8f}, total time for epoch: {:2.4f}'.format(sum(time_diff), (TE-TS)/total_step ))
        time_diff = []
        disc_scheduler.step() 
        gen_scheduler.step()
    return generator, discriminator 

In [None]:
from PIL import Image

torch.cuda.empty_cache()
noise_dim = 256
g_net = DCGAN_Generator(noise_dim=noise_dim) 
d_net = DCGAN_Discriminator( ) 
g_net.cuda() 
d_net.cuda() 


g_net.apply(weights_init) 
d_net.apply(weights_init)

dis_loss = nn.MSELoss() 
gen_loss = nn.MSELoss() 
#dis_loss = nn.BCELoss() 
#gen_loss = nn.BCELoss() 
dis_optim = optim.Adam(d_net.parameters(), lr = .0001,betas=(0.5, 0.99))
#dis_optim = optim.RMSprop(d_net.parameters(), lr = .0005, centered = True)
gen_optim = optim.Adam(g_net.parameters(), lr = .001,betas=(0.5, 0.99))
#gen_optim = optim.RMSprop(g_net.parameters(),lr = .0001)
#dis_optim = optim.SGD(d_net.parameters(), lr = .1,momentum = .9, weight_decay = .0001)
#gen_optim = optim.SGD(g_net.parameters(), lr = .1,momentum = 0, weight_decay = .0001)
schedulerD =torch.optim.lr_scheduler.StepLR(dis_optim, step_size = 10, gamma=.1, verbose=False)
schedulerG =torch.optim.lr_scheduler.StepLR(gen_optim, step_size = 10, gamma=.1, verbose=False)

G, D = train_GAN(40,d_net,g_net, loaders, dis_optim, gen_optim, dis_loss, gen_loss,schedulerD, schedulerG,noise_dim=noise_dim, batch_size = batch_size)

In [None]:
'''
    Save trained Model
'''
torch.save(G.state_dict(), 'cGAN_gen.model') 
torch.save(D.state_dict(), 'cGAN_dis.model')