<a href="https://colab.research.google.com/github/jinkyukim-me/Summary-Seocho-Pytorch/blob/master/Simple_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import auth
auth.authenticate_user()
print('Authenticated!')

from google.colab import drive
drive.mount('/content/gdrive')
print('Mounted!')

Authenticated!
Mounted at /content/gdrive
Mounted!


In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import Adam
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
import pickle

In [0]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

In [0]:
mnist = datasets.MNIST(root='data', download=True, transform=transform)

dataloader = DataLoader(mnist, batch_size=60, shuffle=True)

In [0]:
import os
import imageio

if torch.cuda.is_available():
    use_gpu = True
    
leave_log = True

if leave_log:
    result_dir = 'GAN_generated_images'
    if not os.path.isdir(result_dir):
        os.mkdir(result_dir)

In [0]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(in_features = 100, out_features = 256),
            
            nn.LeakyReLU(0.2),
            nn.Linear(in_features = 256, out_features = 512),
            
            nn.LeakyReLU(0.2),
            nn.Linear(in_features = 512, out_features = 1024),
            
            nn.LeakyReLU(0.2),
            nn.Linear(in_features = 1024, out_features = 28*28),
            nn.Tanh()
        )


    def forward(self, inputs):
        return self.main(inputs).view(-1, 1, 28, 28)

In [0]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(in_features = 28*28, out_features = 1024),
            
            nn.LeakyReLU(0.2),
            nn.Dropout(inplace=True),
            nn.Linear(in_features = 1024, out_features = 512),
            
            nn.LeakyReLU(0.2),
            nn.Dropout(inplace=True),
            nn.Linear(in_features = 512, out_features = 256),
            
            nn.LeakyReLU(0.2),
            nn.Dropout(inplace=True),
            nn.Linear(in_features = 256, out_features = 1),
            nn.Sigmoid()
        )


    def forward(self, inputs):
        inputs = inputs.view(-1, 28 * 28)
        return self.main(inputs)

In [0]:
D = Discriminator()
G = Generator()

if use_gpu:
    G.cuda()
    D.cuda()
    
criterion = nn.BCELoss()

D_optimizer = optim.Adam(D.parameters(), lr=0.0002)
G_optimizer = optim.Adam(G.parameters(), lr=0.0002)

In [0]:
from matplotlib import pyplot as plt
import numpy as np
def square_plot(data, path):
    """Take an array of shape (n, height, width) or (n, height, width , 3)
       and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)"""
    
    if type(data) == list:
	    data = np.concatenate(data)
    # normalize data for display
    data = (data - data.min()) / (data.max() - data.min())
    
    # force the number of filters to be square
    n = int(np.ceil(np.sqrt(data.shape[0])))
    
    padding = (((0, n ** 2 - data.shape[0]) ,
                (0, 1), (0, 1))  # add some space between filters
               + ((0, 0),) * (data.ndim - 3))  # don't pad the last dimension (if there is one)
    
    data = np.pad(data , padding, mode='constant' , constant_values=1)  # pad with ones (white)
    
    # tilethe filters into an image
    data = data.reshape((n , n) + data.shape[1:]).transpose((0 , 2 , 1 , 3) + tuple(range(4 , data.ndim + 1)))
    data = data.reshape((n * data.shape[1] , n * data.shape[3]) + data.shape[4:])
    plt.imsave(path, data, cmap='gray')

In [10]:
if leave_log:
    train_hist = {}
    train_hist['D_losses'] = []
    train_hist['G_losses'] = []
    generated_images = []

z_fixed = Variable(torch.randn(5*5, 100), volatile=True)

if use_gpu:
    z_fixed = z_fixed.cuda()
   

  import sys


In [11]:
for epoch in range(100):
    
    if leave_log:
        D_losses = []
        G_losses = []
    
    for real_data, _ in dataloader:
        batch_size = real_data.size(0)
        

        real_data = Variable(real_data)

        target_real = Variable(torch.ones(batch_size, 1))
        target_fake = Variable(torch.zeros(batch_size, 1))
         
        if use_gpu:
            real_data, target_real, target_fake = real_data.cuda(), target_real.cuda(), target_fake.cuda()
            

        D_result_from_real = D(real_data)
        D_loss_real = criterion(D_result_from_real, target_real)

        z = Variable(torch.randn((batch_size, 100)))
        
        if use_gpu:
            z = z.cuda()
            

        fake_data = G(z)
        

        D_result_from_fake = D(fake_data)
        D_loss_fake = criterion(D_result_from_fake, target_fake)
        

        D_loss = D_loss_real + D_loss_fake
        

        D.zero_grad()

        D_loss.backward()

        D_optimizer.step()
        
        if leave_log:
            D_losses.append(D_loss.item())

        

        z = Variable(torch.randn((batch_size, 100)))
        
        if use_gpu:
            z = z.cuda()
        

        fake_data = G(z)

        D_result_from_fake = D(fake_data)

        G_loss = criterion(D_result_from_fake, target_real)
        

        G.zero_grad()

        G_loss.backward()

        G_optimizer.step()
        
        if leave_log:
            G_losses.append(G_loss.item())
            
    if leave_log:
        true_positive_rate = (D_result_from_real > 0.5).float().mean().item()
        true_negative_rate = (D_result_from_fake < 0.5).float().mean().item()
        base_message = ("Epoch: {epoch:<3d} D Loss: {d_loss:<8.6} G Loss: {g_loss:<8.6} "
                        "True Positive Rate: {tpr:<5.1%} True Negative Rate: {tnr:<5.1%}"
                       )
        message = base_message.format(
                    epoch=epoch,
                    d_loss=sum(D_losses)/len(D_losses),
                    g_loss=sum(G_losses)/len(G_losses),
                    tpr=true_positive_rate,
                    tnr=true_negative_rate
        )
        print(message)
    
    if leave_log:
        fake_data_fixed = G(z_fixed)
        image_path = result_dir + '/epoch{}.png'.format(epoch)
        square_plot(fake_data_fixed.view(25, 28, 28).cpu().data.numpy(), path=image_path)
        generated_images.append(image_path)
    
    if leave_log:
        train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)))
        train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))

Epoch: 0   D Loss: 0.946694 G Loss: 2.77823  True Positive Rate: 91.7% True Negative Rate: 83.3%
Epoch: 1   D Loss: 0.739339 G Loss: 2.3576   True Positive Rate: 90.0% True Negative Rate: 96.7%
Epoch: 2   D Loss: 0.539076 G Loss: 2.84856  True Positive Rate: 76.7% True Negative Rate: 88.3%
Epoch: 3   D Loss: 0.678398 G Loss: 2.23925  True Positive Rate: 93.3% True Negative Rate: 95.0%
Epoch: 4   D Loss: 0.705851 G Loss: 2.05034  True Positive Rate: 70.0% True Negative Rate: 95.0%
Epoch: 5   D Loss: 0.824256 G Loss: 1.79897  True Positive Rate: 88.3% True Negative Rate: 95.0%
Epoch: 6   D Loss: 0.881752 G Loss: 1.62285  True Positive Rate: 71.7% True Negative Rate: 83.3%
Epoch: 7   D Loss: 0.911446 G Loss: 1.58236  True Positive Rate: 70.0% True Negative Rate: 91.7%
Epoch: 8   D Loss: 0.945653 G Loss: 1.52441  True Positive Rate: 53.3% True Negative Rate: 81.7%
Epoch: 9   D Loss: 0.978359 G Loss: 1.44012  True Positive Rate: 63.3% True Negative Rate: 80.0%
Epoch: 10  D Loss: 0.982334 G 