In [None]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import BatchNormalization, Dropout
from tensorflow.keras.layers import Conv2D, Flatten, Dense, Input, Reshape
from tensorflow.keras.layers import LeakyReLU, Activation
from tensorflow.keras.layers import Cropping2D, ZeroPadding2D, UpSampling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import plot_model
from tensorflow.keras import backend as k 
from tensorflow.keras.optimizers import RMSprop
from keras.callbacks import Callback
from tqdm import tqdm 
import gc

from Build_model import Generator,Discriminator,Adversarial

In [None]:
dataset = tf.keras.datasets.mnist
# dataset = tf.keras.datasets.fashion_mnist
(train_images, _), (test_images, _) = dataset.load_data()

# IMAGE_SHAPE = (128, 128, 3)
IMAGE_SHAPE = (32, 32,1)

def preprocess(images):
    if images.shape[1:3] != IMAGE_SHAPE:
        import cv2
        def resize(image):
            return cv2.resize(image, IMAGE_SHAPE[:2]) #28,28 -> 32,32
        images = np.array([resize(image) for image in images])
    # if len(images.shape)!=4:
    #     images = np.stack((images,)*IMAGE_SHAPE[-1], axis=-1) #32,32 -> 32,32,3

    images = images.astype(np.float32)
    maxs = np.max(images)
    mins = np.min(images)
    images = (images - mins) / (maxs-mins)
    images = images.reshape(-1,32,32,1)
    return images

train_images = preprocess(train_images)[:10000]
test_images = preprocess(test_images)

HALF_IMAGE_SHAPE = (IMAGE_SHAPE[0], IMAGE_SHAPE[1]//2, IMAGE_SHAPE[2])
print(HALF_IMAGE_SHAPE)

BATCH_SIZE = 64
LATENT_DIM = 128 
IMAGE_SIZE = 32
CLIP_VALUE = 0.01 
N_CRITIC = 5 
LR = 5e-5
DECAY = 6e-8
TRAIN_STEPS = 40000
MODEL_NAME = 'CWGAN_GP'

# 모델 선언 

In [None]:
""" loss 선언"""
#sasserstein_loss
def wasserstein_loss(y_label,y_pred):
    return -k.mean(y_label*y_pred)

"""
모델 선언
"""
noise_input = Input(LATENT_DIM)
image_input = Input(shape=(32,32,1))
condition_input = Input(shape=(32,16,1))

#식별자 
d_model = Discriminator(image_input,condition_input).build_dicriminator()
d_optimizer = RMSprop(learning_rate=LR, decay = DECAY)
d_model.compile(loss = wasserstein_loss,
                optimizer = d_optimizer,
                metrics = ['accuracy'])
d_model.trainable=False

#생성자
g_model = Generator(noise_input,condition_input,32).build_generator()
g_optimizer = RMSprop(learning_rate=LR*0.5, decay = DECAY*0.5)

#대립 모델 
adversarial = Adversarial(g_model,d_model,noise_input,condition_input).build_adversarial()
adversarial.compile(loss = wasserstein_loss,
                    optimizer = g_optimizer,
                    metrics = ['accuracy'])
models = (g_model,d_model,adversarial)
params = (BATCH_SIZE,LATENT_DIM,N_CRITIC,CLIP_VALUE,TRAIN_STEPS,MODEL_NAME)
train(models,train_images,params)

In [None]:
def gradient_penalty(batch_size, conditions, real_images, fake_images):
        # Get the interpolated image
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            pred = discriminator([conditions, interpolated], training=True)

        grads = gp_tape.gradient(pred, [conditions, interpolated])[1]
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp


def train(models,x_train,params):
    generator, discriminator, adversarial = models

    (batch_size, latent_size, n_critic,clip_value, train_steps, model_name) = params

    save_interval = 250 #500단계 마다 생성기 이미지 저장 

    noise_input = np.random.uniform(-1,1,size=[16,latent_size]) #훈련 동안 생성기 변화 확인 
    train_size = x_train.shape[0]
    image_size = x_train.shape[1]

    real_labels = np.ones((batch_size, 1))

    for i in tqdm(range(train_steps)):
        loss = 0 
        acc = 0 
        for _ in range(n_critic): #판별기 5회 학습 
            """
            데이터 생성 
            """
            #진짜 데이터 
            rand_indexes = np.random.randint(0,train_size,size=batch_size)
            real_images = x_train[rand_indexes]
            #가짜 데이터 
            noise = np.random.uniform(-1,1,size=[batch_size,latent_size])
            condition = train_images[rand_indexes][:batch_size,:,:int(image_size/2),:]
            fake_images = generator.predict([noise,condition])
            
            """
            학습 
            """
            #학습 - 진짜 데이터와 가짜 데이터 나눠서 학습 함 
            real_loss, real_acc = \
                discriminator.train_on_batch([real_images,condition],real_labels)
            fake_loss, fake_acc = \
                discriminator.train_on_batch([fake_images,condition],-real_labels)
            #가짜 데이터 라벨을 -1로 넣음 

            """
            학습 후 평가 
            """
            loss += 0.5 * (real_loss + fake_loss) #loss 두개 평균 
            acc += 0.5 * (real_acc + fake_acc) #acc 평균 
            
            """
            Weights Clipping 
            """
            #weight clip -> 립시츠 상수 만족시키기 위해서
            #각 layer의 weight를 출력한 뒤 cilp 후 다시 세팅 
            for layer in discriminator.layers:
                weights = layer.get_weights()
                weights = [np.clip(weight,-clip_value,clip_value) for weight in weights]
                layer.set_weights(weights)
        loss /= n_critic #판별기 5번 학습 시킨 거 평균 냄 
        acc /= n_critic 
        log = "%d: [discriminator loss: %f, acc: %f]" % (i, loss, acc)

        """
        생성기 학습 
        """
        noise = np.random.uniform(-1,1,size=[batch_size,latent_size])
        loss, acc = adversarial.train_on_batch([noise,condition], real_labels)
        log = "%s [adversarial loss :%f, acc :%f" % (log,loss,acc)
        
        if (i+1) % save_interval == 0 : 
            print(log)
            plot_images(generator,noise_input,x_train[:16])
                    

In [None]:
def plot_images(generator,noise_input,image_input):
    condition = image_input[:,:,:16]
    fake_images = generator.predict([noise_input,condition])
    fake_images = fake_images[:,:,16:]
    concat_images = np.concatenate([condition,fake_images],axis=2)
    plt.figure(figsize=(10,10))
    for i in range(7):
        plt.subplot(1,7,i+1)
        plt.imshow(concat_images[i],cmap='gray')
    plt.show()