In [2]:
from IPython.display import Image

합성곱 GAN 과 바서슈타인 GAN 으로 합성 이미지 품질 높이기

In [3]:
Image(url='https://raw.githubusercontent.com/rickiepark/ml-with-pytorch/main/ch17/figures/17_09.png', width=700)

배치정규화

In [4]:
Image(url='https://raw.githubusercontent.com/rickiepark/ml-with-pytorch/main/ch17/figures/17_11.png', width=700)

생성자와 판별자

In [7]:
Image(url='https://raw.githubusercontent.com/rickiepark/ml-with-pytorch/main/ch17/figures/17_12.png', width=700)

In [6]:
Image(url='https://raw.githubusercontent.com/rickiepark/ml-with-pytorch/main/ch17/figures/17_13.png', width=700)

In [10]:
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [11]:
import torch.nn as nn
import numpy as np

In [12]:
import matplotlib.pyplot as plt

In [13]:
import torchvision
from torchvision import transforms

image_path = './'
transfom = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))
])

mnist_dataset = torchvision.datasets.MNIST(root=image_path, train=True, transform=transfom,download=True)

batch_size = 64
torch.manual_seed(1)

from torch.utils.data import DataLoader
mnist_dl = DataLoader(mnist_dataset, batch_size=batch_size,shuffle=True,drop_last=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 158567297.92it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 30505085.32it/s]


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 59042057.43it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 22281320.20it/s]


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



In [15]:
def make_generator_network(input_size, n_filters):
  model = nn.Sequential(
      nn.ConvTranspose2d(input_size, n_filters*4, 4,1,0,bias=False),
      nn.BatchNorm2d(n_filters*4),
      nn.LeakyReLU(0.2),

      nn.ConvTranspose2d(n_filters*4, n_filters*2,3,2,1,bias=False),
      nn.BatchNorm2d(n_filters*2),
      nn.LeakyReLU(0.2),

      nn.ConvTranspose2d(n_filters*2, n_filters,4,2,1,bias=False),
      nn.BatchNorm2d(n_filters),
      nn.LeakyReLU(0.2),

      nn.ConvTranspose2d(n_filters, 1,4,2,1, bias=False),
      nn.Tanh()
  )
  return model

class Discriminator(nn.Module):
  def __init__(self,n_filters):
    super().__init__()
    self.network = nn.Sequential(
        nn.Conv2d(1,n_filters, 4,2,1,bias=False),
        nn.LeakyReLU(0.2),

        nn.Conv2d(n_filters, n_filters*2,4,2,1, bias=False),
        nn.BatchNorm2d(n_filters*2),
        nn.LeakyReLU(0.2),

        nn.Conv2d(n_filters*2, n_filters*4,3,2,1, bias=False),
        nn.BatchNorm2d(n_filters*4),
        nn.LeakyReLU(0.2),

        nn.Conv2d(n_filters*4,1,4,1,0,bias=False),
        nn.Sigmoid()
    )
  def forward(self, input):
    output = self.network(input)
    return output.view(-1,1).squeeze(0)

In [None]:
z_size = 100
image_size = (28,28)
n_filters = 32
gen_model = make_generator_network(z_size, n_filters).to(device)
print(gen_model)
disc_model = Discriminator(n_filters).to(device)
print(disc_model)

In [17]:
# 손실함수와 옵티마이져
loss_fn = nn.BCELoss()
g_optimizer = torch.optim.Adam(gen_model.parameters(),0.0003)
d_optimizer = torch.optim.Adam(disc_model.parameters(), 0.0002)

In [18]:
def create_noise(batch_size, z_size,mode_z):
  if mode_z == 'uniform':
    input_z = torch.rand(batch_size, z_size,1,1)*2 - 1
  elif mode_z == 'normal':
    input_z = torch.randn(batch_size, z_size,1,1)
  return input_z

In [21]:
# 판별자 훈련
def d_train(x):
  disc_model.zero_grad()
    # 진짜 배치에서 판별자 훈련
  batch_size = x.size(0)
  x = x.to(device)
  d_labels_real = torch.ones(batch_size, 1, device=device)

  d_proba_real = disc_model(x)
  d_loss_real = loss_fn(d_proba_real, d_labels_real)

  # 가짜 배치에서 판별자 훈련
  input_z = create_noise(batch_size, z_size, mode_z).to(device)
  g_output = gen_model(input_z)

  d_proba_fake = disc_model(g_output)
  d_labels_fake = torch.zeros(batch_size, 1, device=device)
  d_loss_fake = loss_fn(d_proba_fake, d_labels_fake)

  # 그레이디언트 역전파와 판별자 파라미터 최적화
  d_loss = d_loss_real + d_loss_fake
  d_loss.backward()
  d_optimizer.step()

  return d_loss.data.item(), d_proba_real.detach(), d_proba_fake.detach()

In [None]:
torch.ones((batch_size,1),device=device)

In [41]:
# 생성자 훈련
def g_train(x):
  gen_model.zero_grad()
  batch_size = x.size(0)
  input_z = create_noise(batch_size, z_size, mode_z).to(device)
  g_labels_real = torch.ones((batch_size,1),device=device)

  g_output = gen_model(input_z)
  d_proba_fake = disc_model(g_output)
  g_loss = loss_fn(d_proba_fake, g_labels_real)

  # 그레이디언트 역전파와 생성자 파라미터 최적화
  g_loss.backward()
  g_optimizer.step()

  return g_loss.data.item()

In [None]:
mode_z = 'uniform'
fixed_z = create_noise(batch_size,z_size,mode_z).to(device)
def create_samples(g_model, input_z):
  g_output = g_model(input_z)
  images = torch.reshape(g_output, (batch_size, *image_size))
  return (images+1) / 2.0

epoch_samples = []
num_epochs = 100
torch.manual_seed(1)

for epoch in range(1, num_epochs+1):
  gen_model.train()
  d_losses, g_losses = [],[]
  for i, (x, _) in enumerate(mnist_dl):
    d_loss, d_proba_real, d_proba_fake = d_train(x)
    d_losses.append(d_loss)
    g_losses.append(g_train(x))

  print(f'에포크 {epoch:03d} | 평균 손실 >>'
        f' 생성자/판별자 {torch.FloatTensor(g_losses).mean():.4f}'
        f'/{torch.FloatTensor(d_losses).mean():.4f}')
  gen_model.eval()
  epoch_samples.append(
      create_samples(gen_model, fixed_z).detach().cpu().numpy()
  )

In [None]:
# 시각화를 위해서 수집한 변수를 저장하고 불러오기
saved_variables = {
    'epoch_samples': epoch_samples,
    'd_losses': d_losses,
    'g_losses': g_losses
}
torch.save(saved_variables,'/content/drive/MyDrive/saved_variables2,pth')


# 판별자 모델
torch.save(disc_model,'/content/drive/MyDrive/disc_model2.pth')
 # 생성자 모델
torch.save(gen_model,'/content/drive/MyDrive/gen_model2.pth')

In [None]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device : {device}")

# 불러오기
loaded_variables = torch.load('/content/drive/MyDrive/saved_variables2,pth')
# 각 변수를 불러오기
epoch_samples = loaded_variables['epoch_samples']
d_losses = loaded_variables['d_losses']
g_losses = loaded_variables['g_losses']
# 판별자 생성자 모델 불러오기
disc_model = torch.load('/content/drive/MyDrive/disc_model2.pth')
gen_model = torch.load('/content/drive/MyDrive/gen_model2.pth')