# Conditional GAN으로 생성 컨트롤하기

In [1]:
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np

In [2]:
# 하이퍼파라미터
EPOCHS = 300
BATCH_SIZE = 100
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
print("Using Device:", DEVICE)

Using Device: cpu


In [3]:
# Fashion MNIST 데이터셋
trainset = datasets.FashionMNIST(
    './.data',
    train=True,
    download=True,
    transform=transforms.Compose([
       transforms.ToTensor(),
       transforms.Normalize((0.5,), (0.5,))
    ]),
)
train_loader = torch.utils.data.DataLoader(
    dataset = trainset,
    batch_size = BATCH_SIZE,
    shuffle = True,
)

In [5]:
# 생성자 (Generator)
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.embed = nn.Embedding(10, 10)
        
        self.model = nn.Sequential(
            nn.Linear(110, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
    
    def forward(self, z, labels):
        c = self.embed(labels)
        x = torch.cat([z, c], 1)
        return self.model(x)

In [6]:
# 판별자 (Discriminator)
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.embed = nn.Embedding(10, 10)
        
        self.model = nn.Sequential(
            nn.Linear(794, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x, labels):
        c = self.embed(labels)
        x = torch.cat([x, c], 1)
        out = self.model(x)
        return out.squeeze()

In [7]:
# 모델 인스턴스를 만들고 모델의 가중치를 지정한 장치로 보내기
D = Discriminator().to(DEVICE)
G = Generator().to(DEVICE)

# 이진 크로스 엔트로피 (Binary cross entropy) 오차 함수와
# 생성자와 판별자를 최적화할 Adam 모듈
criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = optim.Adam(G.parameters(), lr=0.0002)

In [None]:
total_step = len(train_loader)
for epoch in range(EPOCHS):
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(BATCH_SIZE, -1).to(DEVICE)
        
        # '진짜'와 '가짜' 레이블 생성
        real_labels = torch.ones(BATCH_SIZE, 1).to(DEVICE)
        fake_labels = torch.zeros(BATCH_SIZE, 1).to(DEVICE)

        # 판별자가 진짜 이미지를 진짜로 인식하는 오차를 계산 (데이터셋 레이블 입력)
        outputs = D(images, labels)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        # 무작위 텐서와 무작위 레이블을 생성자에 입력해 가짜 이미지 생성
        z = torch.randn(BATCH_SIZE, 100).to(DEVICE)
        g_label = torch.randint(0, 10, (BATCH_SIZE,)).to(DEVICE)
        fake_images= G(z, g_label)
        
        # 판별자가 가짜 이미지를 가짜로 인식하는 오차를 계산
        outputs = D(fake_images, g_label)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        # 진짜와 가짜 이미지를 갖고 낸 오차를 더해서 판별자의 오차를 계산
        d_loss = d_loss_real + d_loss_fake
        
        # 역전파 알고리즘으로 판별자 모델의 학습을 진행
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # 생성자가 판별자를 속였는지에 대한 오차를 계산 (무작위 레이블 입력)
        fake_images = G(z, g_label)
        outputs = D(fake_images, g_label)
        g_loss = criterion(outputs, real_labels)
        
        # 역전파 알고리즘으로 생성자 모델의 학습을 진행
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
    print('이폭 [{}/{}] d_loss:{:.4f} g_loss: {:.4f} D(x):{:.2f} D(G(z)):{:.2f}' 
          .format(epoch, EPOCHS, d_loss.item(), g_loss.item(), 
                  real_score.mean().item(), fake_score.mean().item()))

  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)


이폭 [0/300] d_loss:0.3738 g_loss: 4.4556 D(x):0.89 D(G(z)):0.16
이폭 [1/300] d_loss:0.4419 g_loss: 3.6021 D(x):0.87 D(G(z)):0.15
이폭 [2/300] d_loss:0.4561 g_loss: 3.0607 D(x):0.91 D(G(z)):0.18
이폭 [3/300] d_loss:0.3792 g_loss: 4.0255 D(x):0.86 D(G(z)):0.05
이폭 [4/300] d_loss:0.5510 g_loss: 2.6251 D(x):0.83 D(G(z)):0.18
이폭 [5/300] d_loss:0.3710 g_loss: 2.6054 D(x):0.89 D(G(z)):0.16
이폭 [6/300] d_loss:0.8116 g_loss: 1.6371 D(x):0.78 D(G(z)):0.30
이폭 [7/300] d_loss:0.7591 g_loss: 2.3216 D(x):0.80 D(G(z)):0.26
이폭 [8/300] d_loss:0.9191 g_loss: 1.3288 D(x):0.75 D(G(z)):0.33
이폭 [9/300] d_loss:1.3485 g_loss: 1.2583 D(x):0.53 D(G(z)):0.34
이폭 [10/300] d_loss:0.5799 g_loss: 2.2198 D(x):0.83 D(G(z)):0.22
이폭 [11/300] d_loss:0.8559 g_loss: 1.7682 D(x):0.74 D(G(z)):0.26
이폭 [12/300] d_loss:0.8210 g_loss: 1.6988 D(x):0.68 D(G(z)):0.25
이폭 [13/300] d_loss:0.8274 g_loss: 1.6738 D(x):0.70 D(G(z)):0.26
이폭 [14/300] d_loss:1.1184 g_loss: 1.3756 D(x):0.65 D(G(z)):0.34
이폭 [15/300] d_loss:0.7997 g_loss: 1.6533 D(x):0.75

이폭 [128/300] d_loss:1.3620 g_loss: 0.7812 D(x):0.56 D(G(z)):0.48
이폭 [129/300] d_loss:1.2398 g_loss: 0.9357 D(x):0.60 D(G(z)):0.43
이폭 [130/300] d_loss:1.1394 g_loss: 1.0412 D(x):0.63 D(G(z)):0.41
이폭 [131/300] d_loss:1.2373 g_loss: 1.2306 D(x):0.54 D(G(z)):0.35
이폭 [132/300] d_loss:1.1669 g_loss: 1.0370 D(x):0.58 D(G(z)):0.39
이폭 [133/300] d_loss:1.1910 g_loss: 0.9166 D(x):0.57 D(G(z)):0.42
이폭 [134/300] d_loss:1.1920 g_loss: 1.0117 D(x):0.61 D(G(z)):0.42
이폭 [135/300] d_loss:1.3474 g_loss: 1.0011 D(x):0.50 D(G(z)):0.41
이폭 [136/300] d_loss:1.1307 g_loss: 1.0116 D(x):0.59 D(G(z)):0.39
이폭 [137/300] d_loss:1.2815 g_loss: 1.2175 D(x):0.55 D(G(z)):0.36
이폭 [138/300] d_loss:1.2454 g_loss: 0.9527 D(x):0.54 D(G(z)):0.41
이폭 [139/300] d_loss:1.3799 g_loss: 0.8091 D(x):0.57 D(G(z)):0.48
이폭 [140/300] d_loss:1.2009 g_loss: 0.9644 D(x):0.62 D(G(z)):0.44
이폭 [141/300] d_loss:1.3053 g_loss: 0.8883 D(x):0.57 D(G(z)):0.45
이폭 [142/300] d_loss:1.2064 g_loss: 1.0585 D(x):0.56 D(G(z)):0.40
이폭 [143/300] d_loss:1.102

이폭 [255/300] d_loss:1.2203 g_loss: 0.9086 D(x):0.59 D(G(z)):0.43
이폭 [256/300] d_loss:1.2320 g_loss: 1.0966 D(x):0.56 D(G(z)):0.40
이폭 [257/300] d_loss:1.3268 g_loss: 0.9774 D(x):0.53 D(G(z)):0.42
이폭 [258/300] d_loss:1.2899 g_loss: 1.0394 D(x):0.54 D(G(z)):0.42
이폭 [259/300] d_loss:1.3372 g_loss: 0.7894 D(x):0.52 D(G(z)):0.46
이폭 [260/300] d_loss:1.4328 g_loss: 0.8044 D(x):0.53 D(G(z)):0.47
이폭 [261/300] d_loss:1.2954 g_loss: 0.9023 D(x):0.54 D(G(z)):0.42
이폭 [262/300] d_loss:1.2584 g_loss: 0.9658 D(x):0.54 D(G(z)):0.40
이폭 [263/300] d_loss:1.2811 g_loss: 0.8787 D(x):0.58 D(G(z)):0.47
이폭 [264/300] d_loss:1.3819 g_loss: 0.9128 D(x):0.49 D(G(z)):0.43
이폭 [265/300] d_loss:1.3775 g_loss: 0.7944 D(x):0.55 D(G(z)):0.49
이폭 [266/300] d_loss:1.3488 g_loss: 0.8240 D(x):0.56 D(G(z)):0.48
이폭 [267/300] d_loss:1.3836 g_loss: 0.9073 D(x):0.51 D(G(z)):0.44
이폭 [268/300] d_loss:1.2127 g_loss: 1.0495 D(x):0.61 D(G(z)):0.40
이폭 [269/300] d_loss:1.2328 g_loss: 0.9838 D(x):0.55 D(G(z)):0.41
이폭 [270/300] d_loss:1.069

In [None]:
for i in range(100):
    label = torch.tensor([4])
    class_label = one_hot_embedding(label, 10).to(DEVICE)
    z = torch.randn(1, 64).to(DEVICE)
    generator_input = torch.cat([z, class_label], 1)
    fake_images= G(generator_input)
    fake_images = np.reshape(fake_images.cpu().data.numpy()[0],(28, 28))
    plt.imshow(fake_images, cmap = 'gray')
    plt.show()

In [None]:
# 생성자가 만든 이미지 시각화하기
z = torch.randn(BATCH_SIZE, 64).to(DEVICE)
g_label = torch.randint(0, 10, (BATCH_SIZE, 10), dtype=torch.float).to(DEVICE)
g_input = torch.cat([z, g_label], 1)
fake_images = G(g_input)
for i in range(64):
    fake_images_img = np.reshape(fake_images.data.cpu().numpy()[i],(28, 28))
    plt.imshow(fake_images_img, cmap = 'gray')
    plt.show()