# 9.1 GAN으로 새로운 패션아이템 생성하기
*GAN을 이용하여 새로운 패션 아이템을 만들어봅니다*

GAN을 구현하기 위해 그 구조를 더 자세히 알아보겠습니다.

GAN은 생성자(Generator)와 판별자(Discriminator) 2개의 신경망으로
이루어져 있습니다.



## GAN 구현하기

In [1]:
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torchvision.utils import save_image

In [2]:
torch.manual_seed(1)    # reproducible

<torch._C.Generator at 0x7f448fd11bf0>

In [3]:
# Hyper Parameters
EPOCHS = 100
BATCH_SIZE = 100
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
print("Using Device:", DEVICE)

Using Device: cuda


In [4]:
# Fashion MNIST digits dataset
trainset = datasets.FashionMNIST('./.data',
    train=True,
    download=True,
    transform=transforms.Compose([
       transforms.ToTensor(),
       transforms.Normalize((0.5,), (0.5,))
    ]))
train_loader = torch.utils.data.DataLoader(
    dataset     = trainset,
    batch_size  = BATCH_SIZE,
    shuffle     = True)

0it [00:00, ?it/s]

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./.data/FashionMNIST/raw/train-images-idx3-ubyte.gz


26427392it [00:08, 3068511.07it/s]                              


Extracting ./.data/FashionMNIST/raw/train-images-idx3-ubyte.gz


0it [00:00, ?it/s]

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./.data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


32768it [00:00, 46015.88it/s]                           
0it [00:00, ?it/s]

Extracting ./.data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./.data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


4423680it [00:02, 1666677.48it/s]                             
0it [00:00, ?it/s]

Extracting ./.data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./.data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


8192it [00:00, 21224.52it/s]            

Extracting ./.data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!





In [5]:
# Discriminator
D = nn.Sequential(
        nn.Linear(784, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 1),
        nn.Sigmoid())

In [6]:
# Generator 
G = nn.Sequential(
        nn.Linear(64, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 784),
        nn.Tanh())

In [7]:

# Device setting
D = D.to(DEVICE)
G = G.to(DEVICE)

# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = optim.Adam(G.parameters(), lr=0.0002)

In [8]:
total_step = len(train_loader)
for epoch in range(EPOCHS):
    for i, (images, _) in enumerate(train_loader):
        images = images.reshape(BATCH_SIZE, -1).to(DEVICE)
        
        # Create the labels which are later used as input for the BCE loss
        real_labels = torch.ones(BATCH_SIZE, 1).to(DEVICE)
        fake_labels = torch.zeros(BATCH_SIZE, 1).to(DEVICE)

        # Train Discriminator

        # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
        # Second term of the loss is always zero since real_labels == 1
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        # Compute BCELoss using fake images
        # First term of the loss is always zero since fake_labels == 0
        z = torch.randn(BATCH_SIZE, 64).to(DEVICE)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        # Backprop and optimize
        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # Train Generator

        # Compute loss with fake images
        z = torch.randn(BATCH_SIZE, 64).to(DEVICE)
        fake_images = G(z)
        outputs = D(fake_images)
        
        # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
        # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
        g_loss = criterion(outputs, real_labels)
        
        # Backprop and optimize
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, EPOCHS, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))

Epoch [0/100], Step [200/600], d_loss: 0.0787, g_loss: 4.1506, D(x): 0.98, D(G(z)): 0.06
Epoch [0/100], Step [400/600], d_loss: 0.2156, g_loss: 4.7861, D(x): 0.93, D(G(z)): 0.10
Epoch [0/100], Step [600/600], d_loss: 0.0326, g_loss: 5.2619, D(x): 0.99, D(G(z)): 0.02
Epoch [1/100], Step [200/600], d_loss: 0.0656, g_loss: 5.0974, D(x): 0.99, D(G(z)): 0.03
Epoch [1/100], Step [400/600], d_loss: 0.1571, g_loss: 3.6610, D(x): 0.95, D(G(z)): 0.07
Epoch [1/100], Step [600/600], d_loss: 0.0500, g_loss: 4.5240, D(x): 0.99, D(G(z)): 0.03
Epoch [2/100], Step [200/600], d_loss: 0.0376, g_loss: 6.1814, D(x): 0.98, D(G(z)): 0.01
Epoch [2/100], Step [400/600], d_loss: 0.0241, g_loss: 6.5856, D(x): 0.99, D(G(z)): 0.01
Epoch [2/100], Step [600/600], d_loss: 0.1581, g_loss: 6.0980, D(x): 0.96, D(G(z)): 0.02
Epoch [3/100], Step [200/600], d_loss: 0.0641, g_loss: 6.9642, D(x): 0.97, D(G(z)): 0.00
Epoch [3/100], Step [400/600], d_loss: 0.1090, g_loss: 4.8299, D(x): 0.96, D(G(z)): 0.02
Epoch [3/100], Step [

Epoch [30/100], Step [600/600], d_loss: 0.5736, g_loss: 2.8101, D(x): 0.80, D(G(z)): 0.10
Epoch [31/100], Step [200/600], d_loss: 0.5028, g_loss: 3.7064, D(x): 0.81, D(G(z)): 0.07
Epoch [31/100], Step [400/600], d_loss: 0.6508, g_loss: 2.7957, D(x): 0.79, D(G(z)): 0.18
Epoch [31/100], Step [600/600], d_loss: 0.5380, g_loss: 2.9571, D(x): 0.84, D(G(z)): 0.16
Epoch [32/100], Step [200/600], d_loss: 0.4424, g_loss: 2.4877, D(x): 0.87, D(G(z)): 0.17
Epoch [32/100], Step [400/600], d_loss: 0.4118, g_loss: 2.5834, D(x): 0.90, D(G(z)): 0.17
Epoch [32/100], Step [600/600], d_loss: 0.7510, g_loss: 2.4698, D(x): 0.73, D(G(z)): 0.12
Epoch [33/100], Step [200/600], d_loss: 0.6722, g_loss: 2.7541, D(x): 0.80, D(G(z)): 0.15
Epoch [33/100], Step [400/600], d_loss: 0.4703, g_loss: 3.2515, D(x): 0.84, D(G(z)): 0.13
Epoch [33/100], Step [600/600], d_loss: 0.8260, g_loss: 2.3634, D(x): 0.74, D(G(z)): 0.19
Epoch [34/100], Step [200/600], d_loss: 0.5580, g_loss: 2.2795, D(x): 0.83, D(G(z)): 0.17
Epoch [34/

Epoch [61/100], Step [400/600], d_loss: 0.9973, g_loss: 2.8761, D(x): 0.73, D(G(z)): 0.21
Epoch [61/100], Step [600/600], d_loss: 0.5597, g_loss: 3.2259, D(x): 0.87, D(G(z)): 0.24
Epoch [62/100], Step [200/600], d_loss: 0.9815, g_loss: 1.9172, D(x): 0.71, D(G(z)): 0.25
Epoch [62/100], Step [400/600], d_loss: 0.7774, g_loss: 2.0705, D(x): 0.71, D(G(z)): 0.16
Epoch [62/100], Step [600/600], d_loss: 0.7258, g_loss: 2.2306, D(x): 0.83, D(G(z)): 0.28
Epoch [63/100], Step [200/600], d_loss: 0.7889, g_loss: 2.1573, D(x): 0.70, D(G(z)): 0.15
Epoch [63/100], Step [400/600], d_loss: 0.5854, g_loss: 2.8319, D(x): 0.82, D(G(z)): 0.18
Epoch [63/100], Step [600/600], d_loss: 0.7550, g_loss: 3.0140, D(x): 0.77, D(G(z)): 0.18
Epoch [64/100], Step [200/600], d_loss: 0.8088, g_loss: 3.1151, D(x): 0.76, D(G(z)): 0.22
Epoch [64/100], Step [400/600], d_loss: 0.7828, g_loss: 2.3200, D(x): 0.71, D(G(z)): 0.17
Epoch [64/100], Step [600/600], d_loss: 0.7833, g_loss: 2.2829, D(x): 0.79, D(G(z)): 0.27
Epoch [65/

Epoch [92/100], Step [200/600], d_loss: 0.7545, g_loss: 1.8934, D(x): 0.72, D(G(z)): 0.19
Epoch [92/100], Step [400/600], d_loss: 0.8663, g_loss: 1.8379, D(x): 0.72, D(G(z)): 0.26
Epoch [92/100], Step [600/600], d_loss: 0.6683, g_loss: 2.8028, D(x): 0.74, D(G(z)): 0.14
Epoch [93/100], Step [200/600], d_loss: 0.8399, g_loss: 1.3640, D(x): 0.80, D(G(z)): 0.34
Epoch [93/100], Step [400/600], d_loss: 0.8975, g_loss: 1.6973, D(x): 0.69, D(G(z)): 0.22
Epoch [93/100], Step [600/600], d_loss: 0.8928, g_loss: 1.9126, D(x): 0.81, D(G(z)): 0.32
Epoch [94/100], Step [200/600], d_loss: 0.6605, g_loss: 1.5847, D(x): 0.83, D(G(z)): 0.28
Epoch [94/100], Step [400/600], d_loss: 1.0173, g_loss: 1.5576, D(x): 0.71, D(G(z)): 0.37
Epoch [94/100], Step [600/600], d_loss: 0.7937, g_loss: 2.1991, D(x): 0.76, D(G(z)): 0.25
Epoch [95/100], Step [200/600], d_loss: 0.7031, g_loss: 1.6621, D(x): 0.75, D(G(z)): 0.24
Epoch [95/100], Step [400/600], d_loss: 0.7104, g_loss: 2.2526, D(x): 0.74, D(G(z)): 0.22
Epoch [95/

## 참고
본 튜토리얼은 다음 자료를 참고하여 만들어졌습니다.

* [yunjey/pytorch-tutorial](https://github.com/yunjey/pytorch-tutorial) - MIT License