## 0. 데이터셋 다운로드

In [1]:
# !chmod +x download.sh
# !./download.sh
# !unzip -q CelebA_128crop_FD.zip?dl=0 -d ./data/

## 1. 모듈/라이브러리 임포트

In [2]:
import torch, os
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.utils import save_image
from torch.autograd import Variable
from torch.utils.data import DataLoader

## 하이퍼 파라메터

In [3]:
lr = 0.0002
max_epoch = 20
batch_size = 32
z_dim = 100
image_size = 32
g_conv_dim = 64
d_conv_dim = 64
log_step = 100
sample_step = 500
sample_num = 32
image_path = './data/CelebA/'
sample_path = './output2/'
if not os.path.exists(sample_path) : os.makedirs(sample_path)

## 2. 몇 가지 함수 정의

In [4]:
# [-1,1] -> [0,1]
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

## 3. 데이터 준비

In [5]:
transform = transforms.Compose([
                transforms.Scale(image_size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

dataset = ImageFolder(image_path, transform)
data_loader = DataLoader(dataset=dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=8,
                          drop_last=True)

## 4. 모델 정의

In [6]:
def deconv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
    """Custom deconvolutional layer for simplicity."""
    layers = []
    layers.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad))
    if bn:
        layers.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*layers)


class Generator(nn.Module):
    """Generator containing 7 deconvolutional layers."""
    def __init__(self, z_dim=256, image_size=128, conv_dim=64):
        super(Generator, self).__init__()
        self.fc = deconv(z_dim, conv_dim*8, int(image_size/16), 1, 0, bn=False)
        self.deconv1 = deconv(conv_dim*8, conv_dim*4, 4)
        self.deconv2 = deconv(conv_dim*4, conv_dim*2, 4)
        self.deconv3 = deconv(conv_dim*2, conv_dim, 4)
        self.deconv4 = deconv(conv_dim, 3, 4, bn=False)
        
    def forward(self, z):
        z = z.view(z.size(0), z.size(1), 1, 1)      # If image_size is 64, output shape is as below.
        out = self.fc(z)                            # (?, 512, 4, 4)
        out = F.leaky_relu(self.deconv1(out), 0.05)  # (?, 256, 8, 8)
        out = F.leaky_relu(self.deconv2(out), 0.05)  # (?, 128, 16, 16)
        out = F.leaky_relu(self.deconv3(out), 0.05)  # (?, 64, 32, 32)
        out = F.tanh(self.deconv4(out))             # (?, 3, 64, 64)
        return out
    
    
def conv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
    """Custom convolutional layer for simplicity."""
    layers = []
    layers.append(nn.Conv2d(c_in, c_out, k_size, stride, pad))
    if bn:
        layers.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*layers)


class Discriminator(nn.Module):
    """Discriminator containing 4 convolutional layers."""
    def __init__(self, image_size=128, conv_dim=64):
        super(Discriminator, self).__init__()
        self.conv1 = conv(3, conv_dim, 4, bn=False)
        self.conv2 = conv(conv_dim, conv_dim*2, 4)
        self.conv3 = conv(conv_dim*2, conv_dim*4, 4)
        self.conv4 = conv(conv_dim*4, conv_dim*8, 4)
        self.fc = conv(conv_dim*8, 1, int(image_size/16), 1, 0, False)
        
    def forward(self, x):                         # If image_size is 64, output shape is as below.
        out = F.leaky_relu(self.conv1(x), 0.05)    # (?, 64, 32, 32)
        out = F.leaky_relu(self.conv2(out), 0.05)  # (?, 128, 16, 16)
        out = F.leaky_relu(self.conv3(out), 0.05)  # (?, 256, 8, 8)
        out = F.leaky_relu(self.conv4(out), 0.05)  # (?, 512, 4, 4)
        out = F.sigmoid(self.fc(out)).squeeze()
#         out = self.fc(out).squeeze()
        return out
    
D = Discriminator(image_size)
G = Generator(z_dim,image_size,g_conv_dim)

D.cuda()
G.cuda()

Generator (
  (fc): Sequential (
    (0): ConvTranspose2d(100, 512, kernel_size=(2, 2), stride=(1, 1))
  )
  (deconv1): Sequential (
    (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
  )
  (deconv2): Sequential (
    (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
  )
  (deconv3): Sequential (
    (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
  )
  (deconv4): Sequential (
    (0): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
)

## 5. Loss/Optimizer

In [7]:
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss().cuda()
d_optimizer = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5,0.999))
g_optimizer = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5,0.999))

## 6. 트레이닝

In [8]:
# Start training
total_batch = len(data_loader.dataset)//batch_size
fixed_z = Variable(torch.randn(sample_num, z_dim)).cuda()
for epoch in range(max_epoch):
    for i, (images, _) in enumerate(data_loader):
        # Build mini-batch dataset
#         batch_size = images.size(0)
        images = Variable(images).cuda()
        # Create the labels which are later used as input for the BCE loss
        real_labels = Variable(torch.ones(batch_size)).cuda()
        fake_labels = Variable(torch.zeros(batch_size)).cuda()

        #============= Train the discriminator =============#
        # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
        # Second term of the loss is always zero since real_labels == 1
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
#         d_loss_real = torch.mean((outputs - 1) ** 2)
        
        real_score = outputs

        # Compute BCELoss using fake images
        # First term of the loss is always zero since fake_labels == 0
        z = Variable(torch.randn(batch_size, z_dim)).cuda()
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
#         d_loss_fake = torch.mean(outputs ** 2)
        fake_score = outputs

        # Backprop + Optimize
        d_loss = d_loss_real + d_loss_fake
        D.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        #=============== Train the generator ===============#
        # Compute loss with fake images
        z = Variable(torch.randn(batch_size, z_dim)).cuda()
        fake_images = G(z)
        outputs = D(fake_images)

        # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
        # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
#         g_loss = criterion(outputs, real_labels)
        g_loss = torch.mean((outputs - 1) ** 2)

        # Backprop + Optimize
        D.zero_grad()
        G.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i+1) % log_step == 0:
            print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, '
                  'g_loss: %.4f, D(x): %.2f, D(G(z)): %.2f'
                  %(epoch, max_epoch, i+1, total_batch, d_loss.data[0], g_loss.data[0],
                    real_score.data.mean(), fake_score.data.mean()))
            
        if (i+1) % sample_step == 0:
            fake_images = G(fixed_z)
            torchvision.utils.save_image(denorm(fake_images.data), 
                                         os.path.join(sample_path,'fake_samples-%d-%d.png' %(epoch+1, i+1)),
                                         nrow=8)

# Save the trained parameters
torch.save(G.state_dict(), './generator.pkl')
torch.save(D.state_dict(), './discriminator.pkl')

Epoch [0/20], Step[100/4882], d_loss: 0.2562, g_loss: 1.7654, D(x): 0.97, D(G(z)): 0.26
Epoch [0/20], Step[200/4882], d_loss: 0.4887, g_loss: 0.9513, D(x): 0.40, D(G(z)): -0.32
Epoch [0/20], Step[300/4882], d_loss: 0.1654, g_loss: 0.7554, D(x): 0.66, D(G(z)): 0.05
Epoch [0/20], Step[400/4882], d_loss: 0.3200, g_loss: 1.0616, D(x): 0.71, D(G(z)): 0.37
Epoch [0/20], Step[500/4882], d_loss: 0.2313, g_loss: 1.0304, D(x): 0.68, D(G(z)): 0.25
Epoch [0/20], Step[600/4882], d_loss: 0.2370, g_loss: 0.9404, D(x): 0.73, D(G(z)): 0.32
Epoch [0/20], Step[700/4882], d_loss: 0.3025, g_loss: 0.7942, D(x): 0.67, D(G(z)): 0.35
Epoch [0/20], Step[800/4882], d_loss: 0.1530, g_loss: 0.7627, D(x): 0.76, D(G(z)): 0.23
Epoch [0/20], Step[900/4882], d_loss: 0.3491, g_loss: 0.9568, D(x): 0.64, D(G(z)): 0.40
Epoch [0/20], Step[1000/4882], d_loss: 0.3177, g_loss: 0.5798, D(x): 0.53, D(G(z)): 0.18
Epoch [0/20], Step[1100/4882], d_loss: 0.2076, g_loss: 0.6759, D(x): 0.72, D(G(z)): 0.26
Epoch [0/20], Step[1200/4882]

Epoch [1/20], Step[4600/4882], d_loss: 0.2225, g_loss: 0.5236, D(x): 0.67, D(G(z)): 0.23
Epoch [1/20], Step[4700/4882], d_loss: 0.2738, g_loss: 0.9566, D(x): 0.91, D(G(z)): 0.45
Epoch [1/20], Step[4800/4882], d_loss: 0.2100, g_loss: 0.9334, D(x): 1.05, D(G(z)): 0.36
Epoch [2/20], Step[100/4882], d_loss: 0.1126, g_loss: 0.8564, D(x): 0.90, D(G(z)): 0.19
Epoch [2/20], Step[200/4882], d_loss: 0.2409, g_loss: 1.1705, D(x): 0.88, D(G(z)): 0.40
Epoch [2/20], Step[300/4882], d_loss: 0.3572, g_loss: 0.7377, D(x): 0.67, D(G(z)): 0.42
Epoch [2/20], Step[400/4882], d_loss: 0.2330, g_loss: 0.8885, D(x): 0.59, D(G(z)): 0.01
Epoch [2/20], Step[500/4882], d_loss: 0.1257, g_loss: 0.5736, D(x): 0.74, D(G(z)): 0.06
Epoch [2/20], Step[600/4882], d_loss: 0.1306, g_loss: 0.3770, D(x): 0.75, D(G(z)): 0.01
Epoch [2/20], Step[700/4882], d_loss: 0.3707, g_loss: 0.7898, D(x): 0.49, D(G(z)): 0.13
Epoch [2/20], Step[800/4882], d_loss: 0.3379, g_loss: 0.8086, D(x): 0.91, D(G(z)): 0.51
Epoch [2/20], Step[900/4882],

Epoch [3/20], Step[4300/4882], d_loss: 0.0803, g_loss: 0.7470, D(x): 0.86, D(G(z)): 0.09
Epoch [3/20], Step[4400/4882], d_loss: 0.1520, g_loss: 0.8249, D(x): 1.00, D(G(z)): 0.30
Epoch [3/20], Step[4500/4882], d_loss: 0.1583, g_loss: 0.3985, D(x): 0.69, D(G(z)): 0.11
Epoch [3/20], Step[4600/4882], d_loss: 0.1243, g_loss: 0.6717, D(x): 0.76, D(G(z)): 0.07
Epoch [3/20], Step[4700/4882], d_loss: 0.1016, g_loss: 1.2138, D(x): 0.95, D(G(z)): 0.20
Epoch [3/20], Step[4800/4882], d_loss: 0.1207, g_loss: 0.2666, D(x): 0.92, D(G(z)): 0.26
Epoch [4/20], Step[100/4882], d_loss: 0.1770, g_loss: 0.8605, D(x): 0.66, D(G(z)): 0.12
Epoch [4/20], Step[200/4882], d_loss: 0.1633, g_loss: 0.4966, D(x): 0.87, D(G(z)): 0.28
Epoch [4/20], Step[300/4882], d_loss: 0.1472, g_loss: 0.6886, D(x): 0.72, D(G(z)): -0.14
Epoch [4/20], Step[400/4882], d_loss: 0.2112, g_loss: 0.3570, D(x): 0.65, D(G(z)): -0.18
Epoch [4/20], Step[500/4882], d_loss: 0.0893, g_loss: 0.9722, D(x): 0.96, D(G(z)): 0.14
Epoch [4/20], Step[600/4

Epoch [5/20], Step[4000/4882], d_loss: 0.1261, g_loss: 0.7021, D(x): 1.13, D(G(z)): 0.25
Epoch [5/20], Step[4100/4882], d_loss: 0.0771, g_loss: 0.8830, D(x): 0.87, D(G(z)): 0.11
Epoch [5/20], Step[4200/4882], d_loss: 0.2034, g_loss: 0.9202, D(x): 0.64, D(G(z)): -0.10
Epoch [5/20], Step[4300/4882], d_loss: 0.2306, g_loss: 0.4402, D(x): 0.60, D(G(z)): 0.05
Epoch [5/20], Step[4400/4882], d_loss: 0.1250, g_loss: 0.9432, D(x): 0.87, D(G(z)): 0.21
Epoch [5/20], Step[4500/4882], d_loss: 0.1167, g_loss: 0.5929, D(x): 0.73, D(G(z)): 0.09
Epoch [5/20], Step[4600/4882], d_loss: 0.0855, g_loss: 0.8548, D(x): 0.85, D(G(z)): 0.14
Epoch [5/20], Step[4700/4882], d_loss: 0.1255, g_loss: 0.9348, D(x): 0.71, D(G(z)): 0.08
Epoch [5/20], Step[4800/4882], d_loss: 0.0735, g_loss: 1.1816, D(x): 0.98, D(G(z)): 0.16
Epoch [6/20], Step[100/4882], d_loss: 0.2229, g_loss: 1.6231, D(x): 1.23, D(G(z)): 0.34
Epoch [6/20], Step[200/4882], d_loss: 0.0406, g_loss: 0.9000, D(x): 1.00, D(G(z)): 0.07
Epoch [6/20], Step[300

Epoch [7/20], Step[3700/4882], d_loss: 0.1155, g_loss: 0.8699, D(x): 0.98, D(G(z)): 0.28
Epoch [7/20], Step[3800/4882], d_loss: 0.0945, g_loss: 0.8639, D(x): 0.86, D(G(z)): 0.21
Epoch [7/20], Step[3900/4882], d_loss: 0.1549, g_loss: 0.6649, D(x): 0.67, D(G(z)): 0.11
Epoch [7/20], Step[4000/4882], d_loss: 0.0487, g_loss: 0.7417, D(x): 0.90, D(G(z)): 0.07
Epoch [7/20], Step[4100/4882], d_loss: 0.1025, g_loss: 1.7089, D(x): 1.10, D(G(z)): 0.26
Epoch [7/20], Step[4200/4882], d_loss: 0.1551, g_loss: 0.5985, D(x): 0.65, D(G(z)): -0.04
Epoch [7/20], Step[4300/4882], d_loss: 0.0543, g_loss: 1.0365, D(x): 1.06, D(G(z)): 0.09
Epoch [7/20], Step[4400/4882], d_loss: 0.1414, g_loss: 1.2223, D(x): 0.70, D(G(z)): -0.02
Epoch [7/20], Step[4500/4882], d_loss: 0.1009, g_loss: 0.9822, D(x): 0.97, D(G(z)): 0.24


FileNotFoundError: [Errno 2] No such file or directory: './output2/fake_samples-8-4500.png'