In [1]:
from IPython.display import Image

합성곱 GAN 과 바서슈타인 GAN 으로 합성 이미지 품질 높이기

In [2]:
Image(url='https://raw.githubusercontent.com/rickiepark/ml-with-pytorch/main/ch17/figures/17_09.png', width=700)

배치정규화

In [3]:
Image(url='https://raw.githubusercontent.com/rickiepark/ml-with-pytorch/main/ch17/figures/17_11.png', width=700)

생성자와 판별자

In [4]:
Image(url='https://raw.githubusercontent.com/rickiepark/ml-with-pytorch/main/ch17/figures/17_12.png', width=700)

In [5]:
Image(url='https://raw.githubusercontent.com/rickiepark/ml-with-pytorch/main/ch17/figures/17_13.png', width=700)

In [6]:
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [7]:
import torch.nn as nn
import numpy as np

In [8]:
import matplotlib.pyplot as plt

In [9]:
import torchvision
from torchvision import transforms

image_path = './'
transfom = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))
])

mnist_dataset = torchvision.datasets.MNIST(root=image_path, train=True, transform=transfom,download=True)

batch_size = 64
torch.manual_seed(1)

from torch.utils.data import DataLoader
mnist_dl = DataLoader(mnist_dataset, batch_size=batch_size,shuffle=True,drop_last=True)

In [10]:
def make_generator_network(input_size, n_filters):
  model = nn.Sequential(
      nn.ConvTranspose2d(input_size, n_filters*4, 4,1,0,bias=False),
      nn.BatchNorm2d(n_filters*4),
      nn.LeakyReLU(0.2),

      nn.ConvTranspose2d(n_filters*4, n_filters*2,3,2,1,bias=False),
      nn.BatchNorm2d(n_filters*2),
      nn.LeakyReLU(0.2),

      nn.ConvTranspose2d(n_filters*2, n_filters,4,2,1,bias=False),
      nn.BatchNorm2d(n_filters),
      nn.LeakyReLU(0.2),

      nn.ConvTranspose2d(n_filters, 1,4,2,1, bias=False),
      nn.Tanh()
  )
  return model

class Discriminator(nn.Module):
  def __init__(self,n_filters):
    super().__init__()
    self.network = nn.Sequential(
        nn.Conv2d(1,n_filters, 4,2,1,bias=False),
        nn.LeakyReLU(0.2),

        nn.Conv2d(n_filters, n_filters*2,4,2,1, bias=False),
        nn.BatchNorm2d(n_filters*2),
        nn.LeakyReLU(0.2),

        nn.Conv2d(n_filters*2, n_filters*4,3,2,1, bias=False),
        nn.BatchNorm2d(n_filters*4),
        nn.LeakyReLU(0.2),

        nn.Conv2d(n_filters*4,1,4,1,0,bias=False),
        nn.Sigmoid()
    )
  def forward(self, input):
    output = self.network(input)
    return output.view(-1,1).squeeze(0)

In [11]:
z_size = 100
image_size = (28,28)
n_filters = 32
gen_model = make_generator_network(z_size, n_filters).to(device)
print(gen_model)
disc_model = Discriminator(n_filters).to(device)
print(disc_model)

Sequential(
  (0): ConvTranspose2d(100, 128, kernel_size=(4, 4), stride=(1, 1), bias=False)
  (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): LeakyReLU(negative_slope=0.2)
  (3): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): LeakyReLU(negative_slope=0.2)
  (6): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (7): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (8): LeakyReLU(negative_slope=0.2)
  (9): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (10): Tanh()
)
Discriminator(
  (network): Sequential(
    (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 

In [12]:
# 손실함수와 옵티마이져
loss_fn = nn.BCELoss()
g_optimizer = torch.optim.Adam(gen_model.parameters(),0.0003)
d_optimizer = torch.optim.Adam(disc_model.parameters(), 0.0002)

In [13]:
def create_noise(batch_size, z_size,mode_z):
  if mode_z == 'uniform':
    input_z = torch.rand(batch_size, z_size,1,1)*2 - 1
  elif mode_z == 'normal':
    input_z = torch.randn(batch_size, z_size,1,1)
  return input_z

In [14]:
# 판별자 훈련
def d_train(x):
  disc_model.zero_grad()
    # 진짜 배치에서 판별자 훈련
  batch_size = x.size(0)
  x = x.to(device)
  d_labels_real = torch.ones(batch_size, 1, device=device)

  d_proba_real = disc_model(x)
  d_loss_real = loss_fn(d_proba_real, d_labels_real)

  # 가짜 배치에서 판별자 훈련
  input_z = create_noise(batch_size, z_size, mode_z).to(device)
  g_output = gen_model(input_z)

  d_proba_fake = disc_model(g_output)
  d_labels_fake = torch.zeros(batch_size, 1, device=device)
  d_loss_fake = loss_fn(d_proba_fake, d_labels_fake)

  # 그레이디언트 역전파와 판별자 파라미터 최적화
  d_loss = d_loss_real + d_loss_fake
  d_loss.backward()
  d_optimizer.step()

  return d_loss.data.item(), d_proba_real.detach(), d_proba_fake.detach()

In [15]:
# 생성자 훈련
def g_train(x):
  gen_model.zero_grad()
  batch_size = x.size(0)
  input_z = create_noise(batch_size, z_size, mode_z).to(device)
  g_labels_real = torch.ones((batch_size,1),device=device)

  g_output = gen_model(input_z)
  d_proba_fake = disc_model(g_output)
  g_loss = loss_fn(d_proba_fake, g_labels_real)

  # 그레이디언트 역전파와 생성자 파라미터 최적화
  g_loss.backward()
  g_optimizer.step()

  return g_loss.data.item()

In [None]:
mode_z = 'uniform'
fixed_z = create_noise(batch_size, z_size, mode_z).to(device)

def create_samples(g_model, input_z):
    g_output = g_model(input_z)
    images = torch.reshape(g_output, (batch_size, *image_size))
    return (images+1)/2.0

epoch_samples = []

num_epochs = 100
torch.manual_seed(1)

for epoch in range(1, num_epochs+1):
    gen_model.train()
    d_losses, g_losses = [], []
    for i, (x, _) in enumerate(mnist_dl):
        d_loss, d_proba_real, d_proba_fake = d_train(x)
        d_losses.append(d_loss)
        g_losses.append(g_train(x))

    print(f'에포크 {epoch:03d} | 평균 손실 >>'
          f' 생성자/판별자 {torch.FloatTensor(g_losses).mean():.4f}'
          f'/{torch.FloatTensor(d_losses).mean():.4f}')
    gen_model.eval()
    epoch_samples.append(
        create_samples(gen_model, fixed_z).detach().cpu().numpy())

에포크 001 | 평균 손실 >> 생성자/판별자 5.0750/0.0856
에포크 002 | 평균 손실 >> 생성자/판별자 4.9231/0.1243
에포크 003 | 평균 손실 >> 생성자/판별자 3.9919/0.2308
에포크 004 | 평균 손실 >> 생성자/판별자 3.1527/0.3175
에포크 005 | 평균 손실 >> 생성자/판별자 3.1185/0.2937
에포크 006 | 평균 손실 >> 생성자/판별자 2.9826/0.2963
에포크 007 | 평균 손실 >> 생성자/판별자 3.0435/0.3099
에포크 008 | 평균 손실 >> 생성자/판별자 3.0469/0.3026
에포크 009 | 평균 손실 >> 생성자/판별자 3.0169/0.2856
에포크 010 | 평균 손실 >> 생성자/판별자 3.1379/0.2836
에포크 011 | 평균 손실 >> 생성자/판별자 3.1667/0.2654
에포크 012 | 평균 손실 >> 생성자/판별자 3.1919/0.2516
에포크 013 | 평균 손실 >> 생성자/판별자 3.2643/0.2369
에포크 014 | 평균 손실 >> 생성자/판별자 3.3087/0.2507
에포크 015 | 평균 손실 >> 생성자/판별자 3.3580/0.2278
에포크 016 | 평균 손실 >> 생성자/판별자 3.3583/0.2484
에포크 017 | 평균 손실 >> 생성자/판별자 3.4085/0.2352
에포크 018 | 평균 손실 >> 생성자/판별자 3.3736/0.2457
에포크 019 | 평균 손실 >> 생성자/판별자 3.5007/0.2160
에포크 020 | 평균 손실 >> 생성자/판별자 3.4770/0.2164
에포크 021 | 평균 손실 >> 생성자/판별자 3.5242/0.2347
에포크 022 | 평균 손실 >> 생성자/판별자 3.5354/0.2170
에포크 023 | 평균 손실 >> 생성자/판별자 3.5554/0.2115
에포크 024 | 평균 손실 >> 생성자/판별자 3.6500/0.2027
에포크 025 | 평균 손실 

In [None]:
# 시각화를 위해서 수집한 변수를 저장하고 불러오기
saved_variables = {
    'epoch_samples': epoch_samples,
    'd_losses': d_losses,
    'g_losses': g_losses
}
torch.save(saved_variables,'/content/drive/MyDrive/saved_variables2,pth')


# 판별자 모델
torch.save(disc_model,'/content/drive/MyDrive/disc_model2.pth')
 # 생성자 모델
torch.save(gen_model,'/content/drive/MyDrive/gen_model2.pth')

In [None]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device : {device}")

# 불러오기
loaded_variables = torch.load('/content/drive/MyDrive/saved_variables2,pth')
# 각 변수를 불러오기
epoch_samples = loaded_variables['epoch_samples']
d_losses = loaded_variables['d_losses']
g_losses = loaded_variables['g_losses']
# 판별자 생성자 모델 불러오기
disc_model = torch.load('/content/drive/MyDrive/disc_model2.pth')
gen_model = torch.load('/content/drive/MyDrive/gen_model2.pth')