<a href="https://colab.research.google.com/github/choiking10/ML-tutorial/blob/main/mnist_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Test Page

이 페이지는 google colab과 github 연동을 테스트하기 위한 페이지 입니다.

깃허브 정리는... 언젠가 하겠지 뭐... 

In [6]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from IPython.display import Image
import matplotlib.pyplot as plt

def grid_image(tensor_images, size=10, nrow=5):
  tensor_images = ((tensor_images + 1) / 2 ).clamp(0, 1)
  return torchvision.utils.make_grid(tensor_images[:size], nrow=nrow)

def show_image(tensor_images, size=10, nrow=5):
  to_pil = transforms.ToPILImage()
  grid_img = grid_image(tensor_images)
  plt.imshow(to_pil(grid_img), interpolation="bicubic")

class Discriminator(nn.Module):
  def __init__(self, image_size, hidden_size):
    super(Discriminator, self).__init__()
    self.main = nn.Sequential(
        nn.Linear(image_size, hidden_size),
        nn.LeakyReLU(0.2),
        nn.Linear(hidden_size, hidden_size),
        nn.LeakyReLU(0.2),
        nn.Linear(hidden_size, 1),
        nn.Sigmoid()
    )
  def forward(self, x):
    return self.main(x)

class Generator(nn.Module):
  def __init__(self, latent_size, hidden_size, image_size):
    super(Generator, self).__init__()
    self.main = nn.Sequential(
        nn.Linear(latent_size, hidden_size),
        nn.ReLU(0.2),
        nn.Linear(hidden_size, hidden_size),
        nn.ReLU(0.2),
        nn.Linear(hidden_size, image_size),
        nn.Tanh()
    )
  def forward(self, x):
    return self.main(x)

def main():
  
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  print(f"using device {device}")
  # Hyper-parameters

  latent_size = 64
  hidden_size = 256
  image_size = 784

  num_epochs = 200
  batch_size = 300
  sample_dir = 'samples'

  if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

  transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])                              
  ])

  mnist = torchvision.datasets.MNIST(root='../../data/',
                                    train=True,
                                    transform=transform,
                                    download=True)

  data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                            batch_size=batch_size,
                                            shuffle=True)

  D = Discriminator(image_size, hidden_size).to(device)
  G = Generator(latent_size, hidden_size, image_size).to(device)

  criterion = nn.BCELoss()
  d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
  g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

  total_step = len(data_loader)
  def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()
  
  def rand_z():
    return torch.randn(batch_size, latent_size).to(device)

  for epoch in range(num_epochs):
    fake_images = None
    for i, (images, _) in enumerate(data_loader):
      images = images.view(batch_size, -1).to(device)

      real_label = torch.ones(batch_size, 1).to(device)
      fake_label = torch.zeros(batch_size, 1).to(device)

      # ---- D ----
      z = rand_z()

      # real
      real_outputs = D(images)
      d_loss_real = criterion(real_outputs, real_label)
      real_score = real_outputs

      # fake
      fake_images = G(z)
      fake_outputs = D(fake_images)
      d_loss_fake = criterion(fake_outputs, fake_label)
      fake_score = fake_outputs
      
      # loss
      d_loss = d_loss_real + d_loss_fake
      
      # backprop
      reset_grad()
      d_loss.backward()
      d_optimizer.step()

      # ---- G ----
      z = rand_z()
      fake_images = G(z)
      outputs = D(fake_images)
      g_loss = criterion(outputs, real_label)
      
      reset_grad()
      g_loss.backward()
      g_optimizer.step()

      if (i + 1) % 200 == 0:
        print(f'Epoch [{epoch}/{num_epochs}], Step [{i+1}/{total_step}],' + 
        f'd_loss: {d_loss:.8f}, g_loss: {g_loss:.8f}, ' + 
        f'D(x): {real_score.mean():.2f}, D(G(z)): {fake_score.mean():.2f}')
    if (epoch + 1) % 10 == 0:
      save_image(grid_image(fake_images.view(batch_size, 1, 28, 28)),
                  os.path.join(sample_dir, f'real_images_{epoch:03d}.png'))
      # show_image(fake_images.view(batch_size, 1, 28, 28))
      
      

main()

Epoch [79/200], Step [400/600],d_loss: 0.61423039, g_loss: 1.96425319, D(x): 0.80, D(G(z)): 0.23
Epoch [79/200], Step [600/600],d_loss: 0.67728984, g_loss: 1.89632070, D(x): 0.74, D(G(z)): 0.23
Epoch [80/200], Step [200/600],d_loss: 0.78896964, g_loss: 1.53994799, D(x): 0.78, D(G(z)): 0.32
Epoch [80/200], Step [400/600],d_loss: 0.71771699, g_loss: 1.77206588, D(x): 0.72, D(G(z)): 0.23
Epoch [80/200], Step [600/600],d_loss: 0.89282376, g_loss: 1.89902067, D(x): 0.73, D(G(z)): 0.33
Epoch [81/200], Step [200/600],d_loss: 0.89679956, g_loss: 1.69003677, D(x): 0.75, D(G(z)): 0.32
Epoch [81/200], Step [400/600],d_loss: 0.86619776, g_loss: 1.51929557, D(x): 0.68, D(G(z)): 0.26
Epoch [81/200], Step [600/600],d_loss: 0.93006009, g_loss: 1.60771835, D(x): 0.62, D(G(z)): 0.22
Epoch [82/200], Step [200/600],d_loss: 0.92236561, g_loss: 1.82908964, D(x): 0.69, D(G(z)): 0.29
Epoch [82/200], Step [400/600],d_loss: 0.82173753, g_loss: 1.76325464, D(x): 0.69, D(G(z)): 0.23
Epoch [82/200], Step [600/600]

KeyboardInterrupt: ignored