In [1]:
from keras.models import Sequential , Model
from keras.layers import Dense ,  BatchNormalization , Reshape , Input , Flatten
from keras.layers import Conv2D , MaxPool2D , Conv2DTranspose , UpSampling2D , ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers import Activation
from keras.layers import Dropout

from keras.optimizers import Adam , RMSprop

from keras.initializers import truncated_normal , random_normal , constant

#_Merge
from keras.layers.merge import _Merge

from keras.datasets import mnist

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
import os

import matplotlib as plt
import numpy as np
import keras.backend as K

from functools import partial



import gc

%matplotlib inline

In [3]:
WIDTH = 96
HEIGHT = 96
CHANNEL = 3

LATENT_DIM = 100 #latent variable z sample from normal distribution

BATCH_SIZE = 64
EPOCHS = 10

PATH = 'faces/'

#生成多少个图像 长*宽
ROW = 5
COL = 5

#为WGAN增加的
N_CRITIC = 5 #训练G时使用
CLIP_VALUE = 0.01 #更新G的权重参数时进行截断使用


In [4]:

load_index = 0

images_name = os.listdir(PATH)

IMAGES_COUNT = len(images_name)


In [5]:

def load_image(batch_size = BATCH_SIZE):
    global load_index
    
    images = []
    
    for i in range(batch_size):
        images.append(plt.image.imread(PATH + images_name[(load_index + i) % IMAGES_COUNT]))
    
    load_index += batch_size
    
    return np.array(images)/127.5-1

def write_image(epoch):
    
    noise = np.random.normal(size = (ROW*COL , LATENT_DIM))
    generated_image = generator_i.predict(noise)
    generated_image = (generated_image+1)*127.5
    
    fig , axes = plt.pyplot.subplots(ROW , COL)
    
    count=0
    
    for i in range(ROW):
        for j in range(COL):
            axes[i][j].imshow(generated_image[count])
            axes[i][j].axis('off')
            count += 1
            
    fig.savefig('generated_faces_wgan-gp/No.%d.png' % epoch)
    plt.pyplot.close()
    
    #plt.image.imsave('images/'+str(epoch)+'.jpg')


In [6]:
def conv2d(output_size):
    return Conv2D(output_size , kernel_size=(5,5) , strides=(2,2) , padding='same' , kernel_initializer=truncated_normal(stddev=0.02) , bias_initializer=constant(0.0))

def dense(output_size):
    return Dense(output_size , kernel_initializer=random_normal(stddev=0.02) , bias_initializer=constant(0.0))

def deconv2d(output_size):
    return Conv2DTranspose(output_size , kernel_size=(5,5) , strides=(2,2) , padding='same' , kernel_initializer=random_normal(stddev=0.02) , bias_initializer=constant(0.0))

def batch_norm():
    return BatchNormalization(momentum=0.9 , epsilon=1e-5)


In [7]:
def generator():
    #sample from noise z
    model = Sequential(name='generator')
    
    #cartoon 图像使用 96*96*3
    model.add(Dense(6*6*8*64 , input_shape=(LATENT_DIM,) , kernel_initializer=random_normal(stddev=0.02) , bias_initializer=constant(0.0)))
    
    model.add(Reshape((6, 6, 64*8)))
    
    model.add(batch_norm())
    model.add(Activation('relu'))

    model.add(deconv2d(64*4))
    model.add(batch_norm())
    model.add(Activation('relu'))
    
    model.add(deconv2d(64*2))
    model.add(batch_norm())
    model.add(Activation('relu'))
    
    model.add(deconv2d(64*1))
    model.add(batch_norm())
    model.add(Activation('relu'))
    
    model.add(deconv2d(3))
    model.add(Activation('tanh'))
    
    #model.summary()
    
    noise = Input(shape=(LATENT_DIM , ) , name='input1')
    image = model(noise)
    
    return Model(noise , image , name='generator_Model')

In [8]:
def critic():
    #input a image to discriminate real or fake
    model = Sequential(name='critic')
    
    model.add(Conv2D(filters=64 , kernel_size=(5,5) , strides=(2,2) , padding='same' , input_shape=(WIDTH , HEIGHT , CHANNEL) , kernel_initializer=truncated_normal(stddev=0.02) , bias_initializer=constant(0.0) , name='conv1'))
    model.add(LeakyReLU(0.2))
    #model.add(Dropout(0.25))
    
    #model.add(BatchNormalization(momentum=0.8))
    model.add(conv2d(64*2))
    model.add(batch_norm())
    model.add(LeakyReLU(0.2))
    
    model.add(conv2d(64*4))
    model.add(batch_norm())  
    model.add(LeakyReLU(0.2))

    
    model.add(conv2d(64*8))
    model.add(batch_norm())  
    model.add(LeakyReLU(0.2))

    model.add(Flatten())
    
    #===
    #如果没有下面的FC层 训练时发生损失不下降 且生成不出图像
    model.add(dense(128))
    model.add(LeakyReLU(0.2))
    #===
    model.add(dense(1)) #不使用sigmoid激活 y=x 激活
    
    #model.summary()
    
    image = Input(shape=(WIDTH , HEIGHT , CHANNEL) , name='input1')
    validity = model(image)
    
    return Model(image , validity , name='critic_Model')

In [9]:
rmsprop = RMSprop(lr=0.00005)


In [10]:
#需要继承_Merge类
class RandomWeightedAverage(_Merge):
    """Provides a (random) weighted average between real and generated image samples"""
    def _merge_function(self, inputs):
        alpha = K.random_uniform((BATCH_SIZE, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])


In [11]:
def wgan_loss(y_true , y_pred):
    return K.mean(y_true*y_pred)

def GP_penelty_loss(y_true , y_pred , averaged_samples): #WGAN-GP在原有GAN损失中增加的损失
    gradients = K.gradients(y_pred , averaged_samples)[0] #y_true在此位置代替的是averaged_samples 
    gradients_sqr = K.square(gradients)
    gradients_sqr_sum = K.sum(gradients_sqr , axis=np.arange(1 , len(gradients_sqr.shape)))
    
    gradients_l2_norm = K.sqrt(gradients_sqr_sum)
    
    gradients_penalty = K.square(1-gradients_l2_norm)
    
    return K.mean(gradients_penalty)

In [12]:
critic_i = critic()
generator_i = generator()

generator_i.trainable = False

#critic_i.compile(optimizer=rmsprop , loss=wgan_loss , metrics=['accuracy'])

real_image = Input(shape=(HEIGHT , WIDTH , CHANNEL) , name='real_image')


z = Input(shape=(LATENT_DIM , ) , name = 'z')
fake_image = generator_i(z)

validity_fake = critic_i(fake_image)
validity_real = critic_i(real_image)

#根据WGAN-GP中的loss公式中的插值的样本
interpolation_real_fake_image = RandomWeightedAverage()([real_image , fake_image])
validity_interpolation = critic_i(interpolation_real_fake_image)

#==========
#partial为python函数 类似装饰器
#下面就是在原有的GAN上的损失添加的惩罚项
partial_gp_loss = partial(GP_penelty_loss , averaged_samples=interpolation_real_fake_image)
partial_gp_loss.__name__ = 'gradient_penalty'
#==========

critic_model = Model(inputs=[real_image , z] , outputs=[validity_real , validity_fake , validity_interpolation] , name='critic_model')
critic_model.compile(optimizer=rmsprop , loss=[wgan_loss , wgan_loss , partial_gp_loss] , loss_weights=[1,1,10])

#==========
critic_i.trainable = False
generator_i.trainable = True


z_ = Input(shape=(LATENT_DIM , ))
image_ = generator_i(z_)
valid_ = critic_i(image_)

generator_model = Model(z_ , valid_)
generator_model.compile(optimizer=rmsprop , loss=wgan_loss)

In [None]:
real_labels = -np.ones(shape=(BATCH_SIZE , 1)) #真实样本label为1
fake_labels = np.ones(shape=(BATCH_SIZE , 1)) #假样本label为0
dummy = np.zeros(shape=(BATCH_SIZE , 1)) #为WGAN-GP的惩罚项准备的label

for i in range(1001):
    #============================
    #训练一次G就要训练N_CRITIC次D（Discriminator）
    for _ in range(N_CRITIC):
        
        noise = np.random.normal(size=(BATCH_SIZE , LATENT_DIM))

        real_image = load_image()
        
        #训练判别器
        loss = critic_model.train_on_batch([real_image , noise] , [real_labels , fake_labels , dummy])
        
        #取消对权重参数的裁剪 clip
        #到底需不需要进行参数截断
        #在mnist数据集上需要 在cartoon上不好说 耗长时间去验证
        for layer in critic_i.layers:
            weights = layer.get_weights()
            weights = [np.clip(w , -CLIP_VALUE , CLIP_VALUE) for w in weights]
            layer.set_weights(weights)
    #============================
    
    #训练生成器
    noise2 = np.random.normal(size=(BATCH_SIZE , LATENT_DIM))
    generator_loss = generator_model.train_on_batch(noise2 , real_labels)

    print('epoch:%d loss1:%f loss2:%f loss3:%f gene_loss:%f' % (i , loss[0] , loss[1] , loss[2] , generator_loss))

    if i % 10 == 0:
        write_image(i)
    
write_image(999)


epoch:0 loss1:9.728308 loss2:-0.073971 loss3:-0.087728 gene_loss:0.087464
epoch:1 loss1:9.707046 loss2:-0.073995 loss3:-0.088454 gene_loss:0.088914
epoch:2 loss1:9.715948 loss2:-0.075428 loss3:-0.092194 gene_loss:0.090894
epoch:3 loss1:9.698011 loss2:-0.072518 loss3:-0.094405 gene_loss:0.094871
epoch:4 loss1:9.673544 loss2:-0.073216 loss3:-0.096763 gene_loss:0.098487
epoch:5 loss1:9.678801 loss2:-0.073581 loss3:-0.100370 gene_loss:0.101943
epoch:6 loss1:9.642222 loss2:-0.075849 loss3:-0.104160 gene_loss:0.104589
epoch:7 loss1:9.616793 loss2:-0.078041 loss3:-0.105591 gene_loss:0.105823
epoch:8 loss1:9.566839 loss2:-0.078663 loss3:-0.109154 gene_loss:0.108544
epoch:9 loss1:9.479568 loss2:-0.076215 loss3:-0.113224 gene_loss:0.111901
epoch:10 loss1:9.337676 loss2:-0.115847 loss3:-0.113366 gene_loss:0.116112
epoch:11 loss1:9.210569 loss2:-0.045073 loss3:-0.121073 gene_loss:0.117517
epoch:12 loss1:9.141478 loss2:-0.132518 loss3:-0.118188 gene_loss:0.130391
epoch:13 loss1:8.290174 loss2:-0.09

In [53]:
discriminator_i.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input1 (InputLayer)          (None, 28, 28, 1)         0         
_________________________________________________________________
discriminator (Sequential)   (None, 1)                 533505    
Total params: 533,505
Trainable params: 0
Non-trainable params: 533,505
_________________________________________________________________


In [54]:
generator_i.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input1 (InputLayer)          (None, 100)               0         
_________________________________________________________________
generator (Sequential)       (None, 28, 28, 1)         1097744   
Total params: 1,097,744
Trainable params: 1,095,184
Non-trainable params: 2,560
_________________________________________________________________


In [39]:
combined_model_i.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
z (InputLayer)               (None, 100)               0         
_________________________________________________________________
generator_Model (Model)      (None, 96, 96, 3)         29029120  
_________________________________________________________________
discriminator_Model (Model)  (None, 1)                 14320641  
Total params: 43,349,761
Trainable params: 29,025,536
Non-trainable params: 14,324,225
_________________________________________________________________


In [None]:
gc.collect()

In [14]:
modeli = Sequential()

modeli.add(Dense(128 * 7 * 7, activation="relu", input_shape=(LATENT_DIM,)))
modeli.add(Reshape((7, 7, 128)))
modeli.add(UpSampling2D())
modeli.add(Conv2D(128, kernel_size=3, padding="same"))
modeli.add(BatchNormalization(momentum=0.8))
modeli.add(Activation("relu"))
modeli.add(UpSampling2D())
modeli.add(Conv2D(64, kernel_size=3, padding="same"))
modeli.add(BatchNormalization(momentum=0.8))
modeli.add(Activation("relu"))
modeli.add(Conv2D(CHANNEL, kernel_size=3, padding="same"))
modeli.add(Activation("tanh"))

In [15]:
modeli.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_4 (Dense)              (None, 6272)              633472    
_________________________________________________________________
reshape_3 (Reshape)          (None, 7, 7, 128)         0         
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 14, 14, 128)       147584    
_________________________________________________________________
batch_normalization_2 (Batch (None, 14, 14, 128)       512       
_________________________________________________________________
activation_1 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
up_sampling2d_3 (UpSampling2 (None, 28, 28, 128)       0         
__________

In [41]:
32*400

12800

In [None]:
gc.collect()