## GAN 간단 설명
GAN은 기본적으로 두 개의 서로 다른 신경망의 적대적인 관계로 대립하며 서로의 성능을 점차 개선해 나가는 것 <br>
- 생성 모델 G : 데이터의 분포를 학습하는 모델 <br>
- 판별 모델 D : 이미지를 진짜(Train data) 또는 가짜(Generated data)인지 분류하는 모델 <br>
- 진짜 간단히 말하면, 경찰과 도둑 <br>
- MNIST로 구현 예정

## Library

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.nn.modules import loss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

## Discriminator

In [2]:
class Discriminator(nn.Module):
  def __init__(self, in_features):
    super().__init__()
    self.disc = nn.Sequential(
        nn.Linear(in_features, 128),
        nn.LeakyReLU(0.1),
        nn.Linear(128, 1),
        nn.Sigmoid(), # 0.5를 기준으로 classification
    )

  def forward(self, x):
    return self.disc(x)

## Generator

In [3]:
class Generator(nn.Module):
  def __init__(self, z_dim, img_dim):
    super().__init__()
    self.gen = nn.Sequential(
        nn.Linear(z_dim, 256),
        nn.LeakyReLU(0.1),
        nn.Linear(256, img_dim),
        nn.Tanh(),
    )

  def forward(self, x):
    return self.gen(x)

## Hyperparameters

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64 # 784가 아닌 64인 이유?
image_dim = 28 * 28 * 1
batch_size = 32
num_epochs = 50

입력 벡터를 784로 하는 것이 아니라 64로 하는 이유는 <br>
1. 10종류의 숫자를 표시하는 것으로 MNIST데이터의 분포는 이것보다 낮은 차원으로 집약되어 있을 것이므로 입력차원을 낮게 하는 것 <br>
2. MNIST가 해상도가 커져도 집약된 정보는 변하지 않기 때문에 더 많은 차원의 벡터가 필요하지 않음

In [5]:
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device) # noise를 만듦
transforms = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, ), (0.5, )), ]
)
dataset = datasets.MNIST(root = "dataset/", transform = transforms, download = True)
loader = DataLoader(dataset, batch_size = batch_size, shuffle = True)
opt_disc = optim.Adam(disc.parameters(), lr = lr)
opt_gen = optim.Adam(gen.parameters(), lr = lr)
criterion = nn.BCELoss() # real/fake 구분하기 위함

writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
# tensorboard 사용을 위해 SummaryWriter를 선언

step = 0

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 310611883.69it/s]

Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 110323947.02it/s]


Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 207186680.55it/s]

Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4378425.37it/s]

Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw






In [6]:
for epoch in range(num_epochs):
  for batch_idx, (real, _) in enumerate(loader):
    real = real.view(-1, 784).to(device)
    batch_size = real.shape[0]

    # Train Discriminator : max log(D(x)) + log(1 - D(G(z)))
    noise = torch.randn(batch_size, z_dim).to(device)
    fake = gen(noise)
    disc_real = disc(real).view(-1)
    lossD_real = criterion(disc_real, torch.ones_like(disc_real))
    disc_fake = disc(fake).view(-1)
    lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
    lossD = (lossD_real + lossD_fake) / 2
    disc.zero_grad()
    lossD.backward(retain_graph = True)
    opt_disc.step()

    # Train Generator : min log(1 - D(G(z))) <-> max log(D(G(z)))
    # 생성자를 훈련시킬 때 그레디언트가 세츄레이션되지 않는 최대화 방법 -> 학습이 더 안정적으로 이루어짐
    output = disc(fake).view(-1)
    lossG = criterion(output, torch.ones_like(output))
    gen.zero_grad()
    lossG.backward()
    opt_gen.step()

    if batch_idx == 0:
      print(
          f'Epoch [{epoch} / {num_epochs}], Loss D : {lossD:.4f}, Loss G : {lossG:.4f}'
      )

      with torch.no_grad():
        fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
        data = real.reshape(-1, 1, 28, 28)
        img_grid_fake = torchvision.utils.make_grid(fake, normalize = True)
        img_grid_real = torchvision.utils.make_grid(data, normalize = True)

        writer_fake.add_image('MNIST Fake Image', img_grid_fake, global_step = step)
        writer_real.add_image('MNIST Real Image', img_grid_real, global_step = step)
        step += 1


Epoch [0 / 50], Loss D : 0.6289, Loss G : 0.6685
Epoch [1 / 50], Loss D : 0.6519, Loss G : 0.9258
Epoch [2 / 50], Loss D : 0.3328, Loss G : 1.3609
Epoch [3 / 50], Loss D : 0.8028, Loss G : 0.5892
Epoch [4 / 50], Loss D : 0.5987, Loss G : 1.1845
Epoch [5 / 50], Loss D : 0.4644, Loss G : 1.1955
Epoch [6 / 50], Loss D : 1.1846, Loss G : 0.5425
Epoch [7 / 50], Loss D : 0.4152, Loss G : 1.3380
Epoch [8 / 50], Loss D : 0.7330, Loss G : 1.0478
Epoch [9 / 50], Loss D : 0.5563, Loss G : 1.0605
Epoch [10 / 50], Loss D : 0.5873, Loss G : 1.1433
Epoch [11 / 50], Loss D : 0.5487, Loss G : 1.3203
Epoch [12 / 50], Loss D : 0.7453, Loss G : 0.7413
Epoch [13 / 50], Loss D : 0.5126, Loss G : 1.0858
Epoch [14 / 50], Loss D : 0.5114, Loss G : 1.1865
Epoch [15 / 50], Loss D : 0.5785, Loss G : 0.8795
Epoch [16 / 50], Loss D : 0.6545, Loss G : 1.0566
Epoch [17 / 50], Loss D : 0.6320, Loss G : 1.2448
Epoch [18 / 50], Loss D : 0.5495, Loss G : 1.1825
Epoch [19 / 50], Loss D : 0.8595, Loss G : 0.8963
Epoch [20 

Discriminator의 목표는 <br>
1. epoch을 돌면서 noise를 만들고 이미지 생성(fake) <br>
2. lossD_real = criterion(disc_real, torch.ones_like(disc_real)) -> 진짜 이미지에 대한 loss <br>
lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake)) -> 가짜 이미지에 대한 로스를 구해 <br>
평균을 구하고 backward 해줌 <br>

Generator의 목표는 <br>
1. lossG를 구하고 backward