# Generative Adversarial Networks (NIPS 2014)
<br>
<br>
<b>References</b>
<br>
This notebook is created by learning from the following notebooks:

- https://github.com/ndb796/Deep-Learning-Paper-Review-and-Practice/blob/master/code_practices/GAN_for_MNIST_Tutorial.ipynb

In [22]:
import torch
import torch.nn as nn

from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image

## 생성자(Generator) 및 판별자(Discriminator) 모델 정의

In [23]:
latent_dim = 100

# 생성자(Generator) 클래스 정의
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # 하나의 블록(block) 정의
        def block(input_dim, output_dim, normalize=True):
            layers = [nn.Linear(input_dim, output_dim)]
            # 배치 정규화(batch normalization) 수행(차원 동일)
            if normalize:
                layers.append(nn.BatchNorm1d(output_dim, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        # 생성자 모델은 연속적인 여러 개의 블록을 가짐
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, 1*28*28),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img

In [24]:
# 판별자(Discriminator) 클래스 정의
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(1 * 28 * 28, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        flattened = img.view(img.size(0), -1)
        output = self.model(flattened)
        return output

## 학습 데이터셋 불러오기
- 학습을 위해 MNIST 데이터셋을 불러온다.

In [25]:
transforms_train = transforms.Compose([
    transforms.Resize(28),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_dataset = datasets.MNIST(root="./dataset", train=True, download=True, transform=transforms_train)
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)

## 모델 학습 및 샘플링
- 학습을 위해 생성자의 판별자 모델 초기화
- 적절한 하이퍼 파라미터를 설정

In [26]:
# 생성자(Generator)와 판별자(Discriminator) 초기화
generator = Generator()
discriminator = Discriminator()

generator.cuda()
discriminator.cuda()

# 손실 함수(loss function)
adversarial_loss = nn.BCELoss()
adversarial_loss.cuda()

# 학습률(learning rate) 설정
lr = 0.0002

# 생성자와 판별자를 위한 최적화 함수
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr,
                               betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr,
                               betas=(0.5, 0.999))

- 모델을 학습하면서 주기적으로 샘플링 하여 결과를 확인 할 수 있습니다.

In [27]:
import time

n_epochs = 200              # 학습의 횟수(에포크)
sample_interval = 2000      # 몇 번의 배치마다 결과 출력할 것인가?
start_time = time.time()

for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        # real, fake 이미지에 대한 정답 레이블 생성
        real = torch.cuda.FloatTensor(imgs.size(0), 1).fill_(1.0)
        fake = torch.cuda.FloatTensor(imgs.size(0), 1).fill_(0.0)

        real_imgs = imgs.cuda()

        # 생성자(generator)를 학습
        optimizer_G.zero_grad()

        # 랜덤 노이즈(noise) 샘플링
        z = torch.normal(mean=0, std=1, size=(imgs.shape[0], latent_dim)).cuda()

        # 이미지 생성
        generated_imgs = generator(z)

        # 생성자(generator)의 손실(loss) 값 계산
        g_loss = adversarial_loss(discriminator(generated_imgs), real)

        # 생성자(generator) 업데이트
        g_loss.backward()
        optimizer_G.step()

        # 판별자 학습
        optimizer_D.zero_grad()

        # 판별자의 loss 값 계산
        real_loss = adversarial_loss(discriminator(real_imgs), real)
        fake_loss = adversarial_loss(discriminator(generated_imgs.detach()), fake)

        d_loss = (real_loss + fake_loss) / 2

        # 판별자 (discriminator) 업데이트
        d_loss.backward()
        optimizer_D.step()

        done = epoch * len(dataloader) + i
        if done % sample_interval == 0:
            save_image(generated_imgs.data[:25], f"{done}.png", nrow=5, normalize=True)

    # epoch 단위로 log 출력
    print(f"[Epoch {epoch}/{n_epochs}] [D loss: {d_loss.item():.6f}] [G loss: {g_loss.item():.6f}] [Elapsed time: {time.time() - start_time:.2f}s]")




[Epoch 0/200] [D loss: 0.542029] [G loss: 0.782794] [Elapsed time: 1.76s]
[Epoch 1/200] [D loss: 0.483996] [G loss: 0.920153] [Elapsed time: 3.48s]
[Epoch 2/200] [D loss: 0.546040] [G loss: 1.222871] [Elapsed time: 5.25s]
[Epoch 3/200] [D loss: 0.366962] [G loss: 1.638766] [Elapsed time: 6.91s]
[Epoch 4/200] [D loss: 0.576263] [G loss: 0.487400] [Elapsed time: 8.72s]
[Epoch 5/200] [D loss: 0.490790] [G loss: 1.669144] [Elapsed time: 10.44s]
[Epoch 6/200] [D loss: 0.878131] [G loss: 3.614223] [Elapsed time: 12.03s]
[Epoch 7/200] [D loss: 0.434505] [G loss: 2.760345] [Elapsed time: 13.69s]
[Epoch 8/200] [D loss: 0.574028] [G loss: 3.446070] [Elapsed time: 15.45s]
[Epoch 9/200] [D loss: 0.269721] [G loss: 1.253381] [Elapsed time: 17.26s]
[Epoch 10/200] [D loss: 0.413095] [G loss: 0.735877] [Elapsed time: 18.90s]
[Epoch 11/200] [D loss: 0.301562] [G loss: 1.641016] [Elapsed time: 20.69s]
[Epoch 12/200] [D loss: 0.808327] [G loss: 4.553882] [Elapsed time: 22.35s]
[Epoch 13/200] [D loss: 0.2