In [1]:
"""
MNIST를 이용한 GAN 모델 학습시키기
"""

'\nMNIST를 이용한 GAN 모델 학습시키기\n'

In [2]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image

In [3]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# print(torch.cuda.is_available())

In [4]:
# 하이퍼파라미터 설정하기
# Latent : Generator 가 만들어내는 가상의 공간
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'

In [5]:
# 폴더를 생성하는 코드
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [6]:
# image processing
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5],   # 1 for greyscale channels
                                     std=[0.5])])

In [7]:
# MNIST dataset을 torchvision으로부터 불러오기
mnist = torchvision.datasets.MNIST(root='../../data/', # ..은 이전폴더를 의미
                                   train=True,
                                   transform=transform,
                                   download=True)

In [8]:
# dataloader를 불러오는 과정
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size, 
                                          shuffle=True)

In [9]:
# Discriminator 모델 선언
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid())

In [10]:
# Generator 모델 선언
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())

In [11]:
# 연산을 진행할 device 설정
D = D.to(device)
G = G.to(device)

In [12]:
# Loss와 Optimizer 설정
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

In [13]:
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)# clamp함수를 통해 0과 1을 반환

In [14]:
# autograd 설정
def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

In [15]:
# 훈련이 진행되는 부분
total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)
        
        # label을 만드는 과정
        # Create the labels which are later used as input for the BCE loss
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # discriminator trainig
        
        # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
        # Second term of the loss is always zero since real_labels == 1 
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        # Compute BCELoss using fake images
        # First term of the loss is always zero since fake_labels == 0
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        # Backprop and optimize
        d_loss = d_loss_real + d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()

        # Generator training
        
        # Compute loss with fake images
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        
        # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
        # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
        g_loss = criterion(outputs, real_labels)
        
        # Backprop and optimize
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
            
    # 실제 이미지를 저장하는 과정
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))

    # 샘플 이미지를 저장하는 과정
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

# Save the model checkpoints 
# 훈련 모델을 저장하는 과정
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')

Epoch [0/200], Step [200/600], d_loss: 0.0367, g_loss: 4.3965, D(x): 0.99, D(G(z)): 0.03
Epoch [0/200], Step [400/600], d_loss: 0.2260, g_loss: 4.4613, D(x): 0.93, D(G(z)): 0.12
Epoch [0/200], Step [600/600], d_loss: 0.0903, g_loss: 4.6187, D(x): 0.96, D(G(z)): 0.05
Epoch [1/200], Step [200/600], d_loss: 0.2268, g_loss: 4.0707, D(x): 0.94, D(G(z)): 0.13
Epoch [1/200], Step [400/600], d_loss: 0.1637, g_loss: 3.7801, D(x): 0.94, D(G(z)): 0.07
Epoch [1/200], Step [600/600], d_loss: 0.5623, g_loss: 4.4578, D(x): 0.86, D(G(z)): 0.21
Epoch [2/200], Step [200/600], d_loss: 0.0975, g_loss: 4.6959, D(x): 0.95, D(G(z)): 0.03
Epoch [2/200], Step [400/600], d_loss: 0.1025, g_loss: 4.1881, D(x): 0.95, D(G(z)): 0.05
Epoch [2/200], Step [600/600], d_loss: 0.5488, g_loss: 3.2134, D(x): 0.78, D(G(z)): 0.09
Epoch [3/200], Step [200/600], d_loss: 0.8556, g_loss: 2.3169, D(x): 0.83, D(G(z)): 0.37
Epoch [3/200], Step [400/600], d_loss: 0.2034, g_loss: 4.0470, D(x): 0.95, D(G(z)): 0.12
Epoch [3/200], Step [

Epoch [30/200], Step [600/600], d_loss: 0.6339, g_loss: 2.9603, D(x): 0.84, D(G(z)): 0.19
Epoch [31/200], Step [200/600], d_loss: 0.6270, g_loss: 2.8055, D(x): 0.83, D(G(z)): 0.20
Epoch [31/200], Step [400/600], d_loss: 0.3406, g_loss: 4.5430, D(x): 0.87, D(G(z)): 0.08
Epoch [31/200], Step [600/600], d_loss: 0.4228, g_loss: 3.8033, D(x): 0.85, D(G(z)): 0.11
Epoch [32/200], Step [200/600], d_loss: 0.5602, g_loss: 3.3939, D(x): 0.85, D(G(z)): 0.20
Epoch [32/200], Step [400/600], d_loss: 0.3819, g_loss: 3.4433, D(x): 0.85, D(G(z)): 0.08
Epoch [32/200], Step [600/600], d_loss: 0.6017, g_loss: 2.9691, D(x): 0.77, D(G(z)): 0.09
Epoch [33/200], Step [200/600], d_loss: 0.5114, g_loss: 3.0566, D(x): 0.80, D(G(z)): 0.07
Epoch [33/200], Step [400/600], d_loss: 0.4298, g_loss: 3.7148, D(x): 0.89, D(G(z)): 0.17
Epoch [33/200], Step [600/600], d_loss: 0.4070, g_loss: 3.9979, D(x): 0.84, D(G(z)): 0.09
Epoch [34/200], Step [200/600], d_loss: 0.3789, g_loss: 4.1989, D(x): 0.88, D(G(z)): 0.13
Epoch [34/

Epoch [61/200], Step [400/600], d_loss: 0.8794, g_loss: 2.5103, D(x): 0.66, D(G(z)): 0.16
Epoch [61/200], Step [600/600], d_loss: 0.7396, g_loss: 2.2723, D(x): 0.79, D(G(z)): 0.25
Epoch [62/200], Step [200/600], d_loss: 0.6384, g_loss: 1.9669, D(x): 0.81, D(G(z)): 0.23
Epoch [62/200], Step [400/600], d_loss: 0.6997, g_loss: 2.3348, D(x): 0.78, D(G(z)): 0.21
Epoch [62/200], Step [600/600], d_loss: 0.4718, g_loss: 2.0963, D(x): 0.86, D(G(z)): 0.20
Epoch [63/200], Step [200/600], d_loss: 0.8048, g_loss: 2.2954, D(x): 0.76, D(G(z)): 0.27
Epoch [63/200], Step [400/600], d_loss: 0.5985, g_loss: 2.0520, D(x): 0.78, D(G(z)): 0.16
Epoch [63/200], Step [600/600], d_loss: 0.7281, g_loss: 2.6914, D(x): 0.82, D(G(z)): 0.26
Epoch [64/200], Step [200/600], d_loss: 0.6783, g_loss: 2.3776, D(x): 0.80, D(G(z)): 0.23
Epoch [64/200], Step [400/600], d_loss: 0.7061, g_loss: 1.7348, D(x): 0.77, D(G(z)): 0.25
Epoch [64/200], Step [600/600], d_loss: 0.5106, g_loss: 2.8978, D(x): 0.85, D(G(z)): 0.20
Epoch [65/

Epoch [92/200], Step [200/600], d_loss: 0.9088, g_loss: 1.9058, D(x): 0.73, D(G(z)): 0.31
Epoch [92/200], Step [400/600], d_loss: 1.0245, g_loss: 2.3199, D(x): 0.87, D(G(z)): 0.43
Epoch [92/200], Step [600/600], d_loss: 0.7813, g_loss: 1.7273, D(x): 0.77, D(G(z)): 0.29
Epoch [93/200], Step [200/600], d_loss: 1.0465, g_loss: 1.5407, D(x): 0.70, D(G(z)): 0.31
Epoch [93/200], Step [400/600], d_loss: 0.6598, g_loss: 2.1540, D(x): 0.86, D(G(z)): 0.29
Epoch [93/200], Step [600/600], d_loss: 0.7233, g_loss: 2.5487, D(x): 0.75, D(G(z)): 0.19
Epoch [94/200], Step [200/600], d_loss: 0.6879, g_loss: 2.4908, D(x): 0.79, D(G(z)): 0.25
Epoch [94/200], Step [400/600], d_loss: 0.8418, g_loss: 1.8157, D(x): 0.72, D(G(z)): 0.30
Epoch [94/200], Step [600/600], d_loss: 0.6717, g_loss: 2.1985, D(x): 0.74, D(G(z)): 0.23
Epoch [95/200], Step [200/600], d_loss: 0.6218, g_loss: 1.9220, D(x): 0.81, D(G(z)): 0.26
Epoch [95/200], Step [400/600], d_loss: 0.6337, g_loss: 2.2690, D(x): 0.79, D(G(z)): 0.19
Epoch [95/

Epoch [122/200], Step [400/600], d_loss: 0.8633, g_loss: 1.6645, D(x): 0.74, D(G(z)): 0.33
Epoch [122/200], Step [600/600], d_loss: 0.8606, g_loss: 1.6375, D(x): 0.73, D(G(z)): 0.27
Epoch [123/200], Step [200/600], d_loss: 0.7029, g_loss: 2.0930, D(x): 0.80, D(G(z)): 0.29
Epoch [123/200], Step [400/600], d_loss: 0.9469, g_loss: 1.7440, D(x): 0.73, D(G(z)): 0.31
Epoch [123/200], Step [600/600], d_loss: 0.9033, g_loss: 1.4572, D(x): 0.78, D(G(z)): 0.38
Epoch [124/200], Step [200/600], d_loss: 0.8169, g_loss: 1.4233, D(x): 0.72, D(G(z)): 0.28
Epoch [124/200], Step [400/600], d_loss: 0.8656, g_loss: 1.5553, D(x): 0.76, D(G(z)): 0.33
Epoch [124/200], Step [600/600], d_loss: 0.9827, g_loss: 1.5257, D(x): 0.72, D(G(z)): 0.32
Epoch [125/200], Step [200/600], d_loss: 1.0148, g_loss: 1.8342, D(x): 0.67, D(G(z)): 0.33
Epoch [125/200], Step [400/600], d_loss: 0.6614, g_loss: 2.0857, D(x): 0.76, D(G(z)): 0.22
Epoch [125/200], Step [600/600], d_loss: 1.0119, g_loss: 1.5749, D(x): 0.71, D(G(z)): 0.34

Epoch [152/200], Step [600/600], d_loss: 0.9003, g_loss: 1.4777, D(x): 0.72, D(G(z)): 0.32
Epoch [153/200], Step [200/600], d_loss: 0.8280, g_loss: 1.8714, D(x): 0.70, D(G(z)): 0.27
Epoch [153/200], Step [400/600], d_loss: 0.9207, g_loss: 1.4425, D(x): 0.70, D(G(z)): 0.34
Epoch [153/200], Step [600/600], d_loss: 1.1886, g_loss: 1.5774, D(x): 0.56, D(G(z)): 0.29
Epoch [154/200], Step [200/600], d_loss: 0.8660, g_loss: 1.5757, D(x): 0.69, D(G(z)): 0.27
Epoch [154/200], Step [400/600], d_loss: 1.0202, g_loss: 1.5430, D(x): 0.67, D(G(z)): 0.34
Epoch [154/200], Step [600/600], d_loss: 0.9753, g_loss: 1.2473, D(x): 0.63, D(G(z)): 0.25
Epoch [155/200], Step [200/600], d_loss: 0.8839, g_loss: 1.6349, D(x): 0.67, D(G(z)): 0.28
Epoch [155/200], Step [400/600], d_loss: 1.2290, g_loss: 1.9258, D(x): 0.60, D(G(z)): 0.33
Epoch [155/200], Step [600/600], d_loss: 0.8643, g_loss: 1.1746, D(x): 0.78, D(G(z)): 0.37
Epoch [156/200], Step [200/600], d_loss: 0.8301, g_loss: 1.3377, D(x): 0.75, D(G(z)): 0.33

Epoch [183/200], Step [200/600], d_loss: 0.8571, g_loss: 1.7567, D(x): 0.72, D(G(z)): 0.33
Epoch [183/200], Step [400/600], d_loss: 0.8483, g_loss: 1.3930, D(x): 0.73, D(G(z)): 0.31
Epoch [183/200], Step [600/600], d_loss: 1.0457, g_loss: 1.5737, D(x): 0.67, D(G(z)): 0.32
Epoch [184/200], Step [200/600], d_loss: 0.8535, g_loss: 1.7206, D(x): 0.69, D(G(z)): 0.27
Epoch [184/200], Step [400/600], d_loss: 1.0034, g_loss: 1.3741, D(x): 0.64, D(G(z)): 0.29
Epoch [184/200], Step [600/600], d_loss: 1.1397, g_loss: 1.5199, D(x): 0.75, D(G(z)): 0.43
Epoch [185/200], Step [200/600], d_loss: 0.9640, g_loss: 1.5114, D(x): 0.72, D(G(z)): 0.36
Epoch [185/200], Step [400/600], d_loss: 1.1469, g_loss: 1.5161, D(x): 0.55, D(G(z)): 0.28
Epoch [185/200], Step [600/600], d_loss: 1.0745, g_loss: 1.2729, D(x): 0.64, D(G(z)): 0.33
Epoch [186/200], Step [200/600], d_loss: 0.8850, g_loss: 1.2960, D(x): 0.64, D(G(z)): 0.22
Epoch [186/200], Step [400/600], d_loss: 1.0148, g_loss: 1.7993, D(x): 0.66, D(G(z)): 0.28

In [25]:
# 이미지 연산 처리를 이용한 이미지 합성하기
# transpose를 해야함
import torch.nn.functional as F

print(type(fake_images))
print(type(real_labels))

def _ssim(img1, img2, window, window_size, channel, size_average=True): 
    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2

    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

<class 'torch.Tensor'>
<class 'torch.Tensor'>


In [26]:
# _ssim('samples/fake_images-200.png', 'real_images.png')