<a href="https://colab.research.google.com/github/castlechoi/studyingDL/blob/main/Gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image

In [17]:
transforms_train = transforms.Compose([
    transforms.Resize(28),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_dataset = datasets.MNIST(root="./dataset", train=True, download=True, transform=transforms_train)
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)

In [24]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    self.input_dim = 128
    self.output_dim = 28 * 28 * 1

    self.model = nn.Sequential(
        nn.Linear(100,self.input_dim),
        nn.BatchNorm1d(128,0.8),
        nn.LeakyReLU(0.2, inplace = True),

        nn.Linear(self.input_dim, self.input_dim * 2),
        nn.BatchNorm1d(self.input_dim * 2, 0.8),
        nn.LeakyReLU(0.2, inplace = True),

        nn.Linear(self.input_dim * 2, self.input_dim * 4),
        nn.BatchNorm1d(self.input_dim * 4, 0.8),
        nn.LeakyReLU(0.2, inplace = True),

        nn.Linear(self.input_dim*4, self.input_dim * 8),
        nn.BatchNorm1d(self.input_dim * 8, 0.8),
        nn.LeakyReLU(0.2, inplace = True),

        nn.Linear(self.input_dim*8, self.output_dim),
        nn.Tanh()
    )

  def forward(self, x):
    fake = self.model(x)
    fake = fake.view(fake.size(0),1,28,28) # batch * channel * w * h
    return fake

In [20]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()

    self.model = nn.Sequential(
        nn.Linear(1* 28 * 28, 512),
        nn.LeakyReLU(0.2, inplace = True),
        nn.Linear(512,256),
        nn.LeakyReLU(0.2, inplace = True),
        nn.Linear(256,1),
        nn.Sigmoid()
    )

  def forward(self, img):
    flatten = img.view(img.size(0), -1)
    output = self.model(flatten)
    return output

In [35]:
generator = Generator()
discriminator = Discriminator()

generator.cuda()
discriminator.cuda()

adversarial_loss = nn.BCELoss()
adversarial_loss.cuda()

lr = 0.0002

optimizer_g = optim.Adam(generator.parameters(), lr = lr, betas=(0.5,0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr = lr, betas=(0.5,0.999))

In [None]:
import time

n_epochs = 200
sample_interval = 2000
start_time = time.time()

for epoch in range(n_epochs):
  for i, (img, _) in enumerate(dataloader):
    # batch size 만큼 label 생성
    real = torch.cuda.FloatTensor(img.size(0),1).fill_(1.0)
    fake = torch.cuda.FloatTensor(img.size(0),1).fill_(0.0)

    real_imgs = img.cuda()

    optimizer_g.zero_grad()
    z = torch.normal(mean = 0, std = 1, size = (img.size(0), 100)).cuda()
    generated_imgs = generator(z)

    g_loss = adversarial_loss(discriminator(generated_imgs), real)
    g_loss.backward()
    optimizer_g.step()

    optimizer_d.zero_grad()

    real_loss = adversarial_loss(discriminator(real_imgs), real)
    fake_loss = adversarial_loss(discriminator(generated_imgs.detach()), fake)
    d_loss = (real_loss + fake_loss) / 2

    d_loss.backward()
    optimizer_d.step()

    done = epoch * len(dataloader) + i
    if done % sample_interval == 0:
      save_image(generated_imgs.data[:25], f"{done}.png", nrow = 5, normalize = True )
 # 하나의 epoch이 끝날 때마다 로그(log) 출력
  print(f"[Epoch {epoch}/{n_epochs}] [D loss: {d_loss.item():.6f}] [G loss: {g_loss.item():.6f}] [Elapsed time: {time.time() - start_time:.2f}s]")     

[Epoch 0/200] [D loss: 0.375986] [G loss: 1.061186] [Elapsed time: 18.39s]
[Epoch 1/200] [D loss: 0.311376] [G loss: 1.573003] [Elapsed time: 37.77s]
[Epoch 2/200] [D loss: 0.441066] [G loss: 0.666339] [Elapsed time: 55.98s]
[Epoch 3/200] [D loss: 0.343097] [G loss: 1.621822] [Elapsed time: 75.30s]
[Epoch 4/200] [D loss: 0.477819] [G loss: 0.642458] [Elapsed time: 95.40s]
[Epoch 5/200] [D loss: 0.230828] [G loss: 1.960929] [Elapsed time: 115.20s]
[Epoch 6/200] [D loss: 0.220737] [G loss: 2.191481] [Elapsed time: 136.19s]
[Epoch 7/200] [D loss: 0.251085] [G loss: 1.633245] [Elapsed time: 155.37s]
[Epoch 8/200] [D loss: 0.619106] [G loss: 3.613500] [Elapsed time: 175.40s]
[Epoch 9/200] [D loss: 0.335777] [G loss: 2.202023] [Elapsed time: 194.59s]
[Epoch 10/200] [D loss: 0.178077] [G loss: 2.168049] [Elapsed time: 214.67s]
[Epoch 11/200] [D loss: 0.299057] [G loss: 2.200274] [Elapsed time: 234.49s]
[Epoch 12/200] [D loss: 0.253040] [G loss: 3.009355] [Elapsed time: 254.94s]
[Epoch 13/200]

In [None]:
from IPython.display import Image

Image('92000.png')