<a href="https://colab.research.google.com/github/hansong0219/Advanced-DeepLearning-Study/blob/master/improved_GAN/WGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# WGAN


In [1]:
from tensorflow.keras.layers import Activation, Dense, Input
from tensorflow.keras.layers import Conv2D, Flatten
from tensorflow.keras.layers import Reshape, Conv2DTranspose
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import RMSprop

from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import load_model
from tensorflow.keras import backend as K


import numpy as np
import math
import matplotlib.pyplot as plt

# 생성기와 판별기 등 함수 구성 
생성기와 판별기의 함수는 DCGAN의 구성을 그대로 사용한다.

In [2]:

def generator(inputs,
              image_size,
              activation='sigmoid'):

    image_resize = image_size // 4
    kernel_size = 5
    layer_filters = [128, 64, 32, 1]

    x = Dense(image_resize * image_resize * layer_filters[0])(x)
    x = Reshape((image_resize, image_resize, layer_filters[0]))(x)

    for filters in layer_filters:
        if filters > layer_filters[-2]:
            strides = 2
        else:
            strides = 1
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = Conv2DTranspose(filters=filters,
                            kernel_size=kernel_size,
                            strides=strides,
                            padding='same')(x)

    if activation is not None:
        x = Activation(activation)(x)

    return Model(inputs, x, name='generator')


def discriminator(inputs,
                  activation='sigmoid'):

    kernel_size = 5
    layer_filters = [32, 64, 128, 256]

    x = inputs
    for filters in layer_filters:
        # first 3 convolution layers use strides = 2
        # last one uses strides = 1
        if filters == layer_filters[-1]:
            strides = 1
        else:
            strides = 2
        x = LeakyReLU(alpha=0.2)(x)
        x = Conv2D(filters=filters,
                   kernel_size=kernel_size,
                   strides=strides,
                   padding='same')(x)

    x = Flatten()(x)
    # default output is probability that the image is real
    outputs = Dense(1)(x)
    if activation is not None:
        print(activation)
        outputs = Activation(activation)(outputs)

    return Model(inputs, outputs, name='discriminator')


def plot_images(generator,
                noise_input,
                noise_label=None,
                noise_codes=None,
                show=False,
                step=0,
                model_name="gan"):
  
    os.makedirs(model_name, exist_ok=True)
    filename = os.path.join(model_name, "%05d.png" % step)
    rows = int(math.sqrt(noise_input.shape[0]))
    if noise_label is not None:
        noise_input = [noise_input, noise_label]
        if noise_codes is not None:
            noise_input += noise_codes

    images = generator.predict(noise_input)
    plt.figure(figsize=(2.2, 2.2))
    num_images = images.shape[0]
    image_size = images.shape[1]
    for i in range(num_images):
        plt.subplot(rows, rows, i + 1)
        image = np.reshape(images[i], [image_size, image_size])
        plt.imshow(image, cmap='gray')
        plt.axis('off')
    plt.savefig(filename)
    if show:
        plt.show()
    else:
        plt.close('all')


def test_generator(generator):
    noise_input = np.random.uniform(-1.0, 1.0, size=[16, 100])
    plot_images(generator,
                noise_input=noise_input,
                show=True,
                model_name="test_outputs")

# WGAN 구현 

GAN 과 유사하게 생성기와 판별기를 교대로 훈련 시키지만 구현에 있어서 WGAN 의 가장 큰 특징은 Wasserstein Loss 를 사용함과 동시에 생성기를 1회 훈련시키기전에 판별기를 n 회 훈련 시킨다. 

이는 생성기와 판별기를 동일한 횟수로 훈련시키는 GAN 과는 다른 점이다. 판별기를 훈련시킨다는 것은 판별기의 매개변수를 학습한다는 것을 뜻한다. 

또, EMD의 제약 조건을 만족 시키기 위해 위해 립시츠 제약에 대한 변수 또한 필요하다. 립시츠 조건이란 두 점사이의 거리를 일정 비율 이상 증가시키지 않는다. 

In [None]:
def build_and_train_models():
  #MNIST 데이터 세트 로딩
  (x_train,_),(_,_) = mnist.load_data()

  # 데이터 형상 변환 및 정규화
  image_size = x_train.shape[1]
  x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
  x_train = x_train.astype('float32')/255
  
  model_name = "wgan_mnist"

  #네트워크 매개 변수 지정 
  latent_size = 100
  
  n_critic = 5
  clip_value = 0.01

