#**[A4-004] 딥러닝 코딩 실습**
#**Lecture 02: Generative Adversarial Networks (GANs)**


# **[Preliminary]** Load Packages for PyTorch & GPU Setup

In [1]:
# Load Packages
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import imageio #### install with "pip install imageio"
from IPython.display import HTML

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils

from torch.utils.data import DataLoader
from torchvision.utils import make_grid

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
import random
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [3]:
# Create Folders
if not os.path.exists('./checkpoint'):
    os.mkdir('./checkpoint')
    
if not os.path.exists('./dataset'):
    os.mkdir('./dataset')
    
if not os.path.exists('./img'):
    os.mkdir('./img')
    
if not os.path.exists('./img/real'):
    os.mkdir('./img/real')

if not os.path.exists('./img/fake'):
    os.mkdir('./img/fake')

# **[Preliminary]** Define Visualization & Image Save

In [4]:
# First Image Visualization from the Torch Tensor
def vis_img(image):
    plt.imshow(image[0].detach().cpu().numpy(),cmap='gray')
    plt.show()

In [5]:
# GIF File Save
def save_gif(train_prog_imgs, images):
    '''
        train_prog_imgs: list of training images generated each iteration
        images: image generated in this iteration
    '''
    img_grid = make_grid(images.data)
    img_grid = np.transpose(img_grid.detach().cpu().numpy(), (1, 2, 0))
    img_grid = 255. * img_grid 
    img_grid = img_grid.astype(np.uint8)
    
    train_prog_imgs.append(img_grid)
    
    imageio.mimsave('./img/training_progress.gif', train_prog_imgs)
    
    return train_prog_imgs

In [6]:
# GIF File Visualization
def vis_gif(train_prog_imgs):
    fig = plt.figure()
    
    imgs = []
    for i in range(len(train_prog_imgs)):
        img = plt.imshow(train_prog_imgs[i], animated=True)
        imgs.append([img])

    ani = animation.ArtistAnimation(fig, imgs, interval=50, blit=True, repeat_delay=1000)
    
    html = ani.to_html5_video()
    HTML(html)

In [7]:
# GIF File Plotting
def plot_gif(train_prog_imgs, plot_length=10):
    plt.close()
    fig = plt.figure()
    
    total_len = len(train_prog_imgs)
    for i in range(plot_length):
        img = plt.imshow(train_prog_imgs[int(total_len/plot_length)*i])
        plt.show()

In [8]:
# Image List Save
def save_image_list(dataset, real):
    if real:
        base_path = './img/real'
    else:
        base_path = './img/fake'
    
    dataset_path = []
    
    for i in range(len(dataset)):
        save_path =  f'{base_path}/image_{i}.png'
        dataset_path.append(save_path)
        vutils.save_image(dataset[i], save_path)
    
    return base_path

# **[Preliminary]** Load Dataset: MNIST

In [9]:
# MNIST Dataset
mnist_train = torchvision.datasets.MNIST(root='./', train=True, transform=transforms.ToTensor(), target_transform=None, download=True)
mnist_test  = torchvision.datasets.MNIST(root='./', train=False, transform=transforms.ToTensor(), target_transform=None, download=True)
mnist_train, mnist_val = torch.utils.data.random_split(mnist_train, [50000, 10000])

# Data Loader for MNIST
mnist_train_loader = DataLoader(mnist_train, batch_size=128, shuffle=True)
mnist_val_loader   = DataLoader(mnist_val, batch_size=128, shuffle=False)
mnist_test_loader  = DataLoader(mnist_test, batch_size=128, shuffle=False)

# **Part I:** Define Generator ($G$) and Discriminator ($D$)

## 1-1) Define the Generator ($G$)

In [10]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(            
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid(),
        )

    def forward(self, input):
        
        output = self.main(input)                
        out_g = output.view(-1, 1, 28, 28)
        
        return out_g

## 1-2) Define the Discriminator ($D$)

In [11]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, input):
        
        input = input.view(-1, 28*28)        
        output = self.main(input)        
        out_d = output.squeeze(dim=1)        
        
        return out_d

## 1-3) Define Optimizer

In [12]:
netG = Generator().cuda()
netD = Discriminator().cuda()

optimizerG = optim.Adam(netG.parameters(), lr=0.0002)
optimizerD = optim.Adam(netD.parameters(), lr=0.0002)

# **Part II:** Train the Generator and the Discriminator

## 2-1) Train the GAN Model with Adversarial Learning

In [13]:
fixed_noise = torch.randn(128, 100).cuda()

criterion = nn.BCELoss()

n_epoch = 200
train_prog_imgs_list = []
for epoch in range(n_epoch):
    for i, (real_img, _) in enumerate(mnist_train_loader):
        ###########################################################
        # Update D network: Maximize log(D(x)) + log(1 - D(G(z))) #
        ###########################################################
        # Train the Discriminator with Real Images
        netD.zero_grad()
        real_img = real_img.cuda()  # Real image
        batch_size = real_img.size(0)
        label = torch.ones((batch_size,)).cuda()  # Real label = 1
        output = netD(real_img) # D(x)
        errD_real = criterion(output, label)
        D_x = output.mean().item()

        # Train the Discriminator with Fake Images
        noise = torch.randn(batch_size, 100).cuda() # Latent vector, z
        fake_img = netG(noise)  # Fake image (= the generated image by G)
        label = torch.zeros((batch_size,)).cuda() # Fake label = 0
        output = netD(fake_img.detach())  # No gradient backpropagation to the generator
        errD_fake = criterion(output, label)
        D_G_z1 = output.mean().item()
        
        # Loss backward
        errD = errD_real + errD_fake
        errD.backward()
        optimizerD.step()

        ###########################################
        # Update G network: Maximize log(D(G(z))) #
        ###########################################
        # Train the Generator with Fake Images
        netG.zero_grad()
        label = torch.ones((batch_size,)).cuda()  # Fake labels are real label (= 1) for the generator loss
        output = netD(fake_img)
        errG = criterion(output, label)
        D_G_z2 = output.mean().item()

        errG.backward()
        optimizerG.step()
        
    print('[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' 
              % (epoch+1, n_epoch, errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
    
    # Save the output
    fake = netG(fixed_noise)
    train_prog_imgs_list = save_gif(train_prog_imgs_list, fake)  # Save fake image while training!
    
    # Check pointing for every epoch
    torch.save(netG.state_dict(), './checkpoint/netG_epoch_%d.pth' % (epoch))
    torch.save(netD.state_dict(), './checkpoint/netD_epoch_%d.pth' % (epoch))

[1/200] Loss_D: 0.0241 Loss_G: 8.8593 D(x): 0.9871 D(G(z)): 0.0063 / 0.0003
[2/200] Loss_D: 0.0038 Loss_G: 7.7983 D(x): 0.9979 D(G(z)): 0.0017 / 0.0004
[3/200] Loss_D: 0.0032 Loss_G: 9.4515 D(x): 0.9974 D(G(z)): 0.0003 / 0.0001
[4/200] Loss_D: 0.0003 Loss_G: 9.1278 D(x): 0.9999 D(G(z)): 0.0002 / 0.0001
[5/200] Loss_D: 0.0108 Loss_G: 11.0628 D(x): 0.9918 D(G(z)): 0.0016 / 0.0000
[6/200] Loss_D: 0.0035 Loss_G: 6.0623 D(x): 0.9999 D(G(z)): 0.0034 / 0.0025
[7/200] Loss_D: 0.0002 Loss_G: 11.0517 D(x): 0.9999 D(G(z)): 0.0001 / 0.0000
[8/200] Loss_D: 0.0005 Loss_G: 13.5468 D(x): 0.9996 D(G(z)): 0.0001 / 0.0000
[9/200] Loss_D: 0.0046 Loss_G: 9.0462 D(x): 0.9990 D(G(z)): 0.0035 / 0.0004
[10/200] Loss_D: 0.0057 Loss_G: 7.3132 D(x): 1.0000 D(G(z)): 0.0056 / 0.0022
[11/200] Loss_D: 0.0035 Loss_G: 9.9997 D(x): 0.9972 D(G(z)): 0.0004 / 0.0001
[12/200] Loss_D: 0.0033 Loss_G: 8.1474 D(x): 1.0000 D(G(z)): 0.0033 / 0.0004
[13/200] Loss_D: 0.0030 Loss_G: 8.2620 D(x): 1.0000 D(G(z)): 0.0029 / 0.0005
[14/2

KeyboardInterrupt: 

## 2-2) Visualize or Plot the Generated Samples

In [None]:
# Visualization of the Generated Images

# vis_gif(train_prog_imgs_list)
plot_gif(train_prog_imgs_list)

# **Part III:** Test the Trained GANs


In [None]:
for i, (data, _) in enumerate(mnist_test_loader):
    real_dataset = data
    break
    
noise = torch.randn(128, 100).cuda()
fake_dataset = netG(noise)

In [None]:
# Save Real Images and Fake Images
real_image_path_list = save_image_list(real_dataset, True)
fake_image_path_list = save_image_list(fake_dataset, False)