In [12]:
from keras.models import Sequential , Model
from keras.layers import Dense ,  BatchNormalization , Reshape , Input , Flatten
from keras.layers import Conv2D , MaxPool2D , Conv2DTranspose , UpSampling2D , ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers import Activation
from keras.layers import Dropout

from keras.initializers import truncated_normal , random_normal , constant

from keras.optimizers import Adam , RMSprop

from keras.datasets import mnist

In [2]:
import os

import matplotlib as plt
import numpy as np
import keras.backend as K

import gc

%matplotlib inline

In [3]:
WIDTH = 96
HEIGHT = 96
CHANNEL = 3

LATENT_DIM = 100 #latent variable z sample from normal distribution

BATCH_SIZE = 64
EPOCHS = 10

PATH = 'faces/'

#生成多少个图像 长*宽
ROW = 5
COL = 5

#为WGAN增加的
N_CRITIC = 5 #训练G时使用
CLIP_VALUE = 0.01 #更新G的权重参数时进行截断使用


In [4]:

load_index = 0

images_name = os.listdir(PATH)

IMAGES_COUNT = len(images_name)


In [5]:
'''
(X_train , y_train),(X_test , y_test) = mnist.load_data()
X_train = X_train/127.5-1
X_train = np.expand_dims(X_train , 3)
'''

In [6]:
'''
def load_mnist():
    return X_train[np.random.randint(0, X_train.shape[0], BATCH_SIZE)]
    
def write_image_mnist(epoch):
    
    noise = np.random.normal(size = (ROW*COL , LATENT_DIM))
    generated_image = generator_i.predict(noise)
    generated_image = generated_image*0.5+0.5
    
    fig , axes = plt.pyplot.subplots(ROW , COL)
    
    count=0
    
    for i in range(ROW):
        for j in range(COL):
            axes[i][j].imshow(generated_image[count,:,:,0] , cmap = 'gray')
            axes[i][j].axis('off')
            count += 1
            
    fig.savefig('mnist_wgan/No.%d.png' % epoch)
    plt.pyplot.close()

'''

In [5]:

def load_image(batch_size = BATCH_SIZE):
    global load_index
    
    images = []
    
    for i in range(batch_size):
        images.append(plt.image.imread(PATH + images_name[(load_index + i) % IMAGES_COUNT]))
    
    load_index += batch_size
    
    return np.array(images)/127.5-1

def write_image(epoch):
    
    noise = np.random.normal(size = (ROW*COL , LATENT_DIM))
    generated_image = generator_i.predict(noise)
    generated_image = (generated_image+1)*127.5
    
    fig , axes = plt.pyplot.subplots(ROW , COL)
    
    count=0
    
    for i in range(ROW):
        for j in range(COL):
            axes[i][j].imshow(generated_image[count])
            axes[i][j].axis('off')
            count += 1
            
    fig.savefig('generated_faces_wgan/No.%d.png' % epoch)
    plt.pyplot.close()
    
    #plt.image.imsave('images/'+str(epoch)+'.jpg')


In [6]:
def conv2d(output_size):
    return Conv2D(output_size , kernel_size=(5,5) , strides=(2,2) , padding='same' , kernel_initializer=truncated_normal(stddev=0.02) , bias_initializer=constant(0.0))

def dense(output_size):
    return Dense(output_size , kernel_initializer=random_normal(stddev=0.02) , bias_initializer=constant(0.0))

def deconv2d(output_size):
    return Conv2DTranspose(output_size , kernel_size=(5,5) , strides=(2,2) , padding='same' , kernel_initializer=random_normal(stddev=0.02) , bias_initializer=constant(0.0))

def batch_norm():
    return BatchNormalization(momentum=0.9 , epsilon=1e-5)


In [14]:
def generator():
    #sample from noise z
    model = Sequential(name='generator')
    
    #cartoon 图像使用 96*96*3
    model.add(Dense(6*6*8*64 , input_shape=(LATENT_DIM,) , kernel_initializer=random_normal(stddev=0.02) , bias_initializer=constant(0.0)))
    
    model.add(Reshape((6, 6, 64*8)))
    
    model.add(batch_norm())
    model.add(Activation('relu'))

    model.add(deconv2d(64*4))
    model.add(batch_norm())
    model.add(Activation('relu'))
    
    model.add(deconv2d(64*2))
    model.add(batch_norm())
    model.add(Activation('relu'))
    
    model.add(deconv2d(64*1))
    model.add(batch_norm())
    model.add(Activation('relu'))
    
    model.add(deconv2d(3))
    model.add(Activation('tanh'))
    
    model.summary()
    
    noise = Input(shape=(LATENT_DIM , ) , name='input1')
    image = model(noise)
    
    return Model(noise , image , name='generator_Model')

In [15]:
def critic():
    #input a image to discriminate real or fake
    model = Sequential(name='critic')
    
    model.add(Conv2D(filters=64 , kernel_size=(5,5) , strides=(2,2) , padding='same' , input_shape=(WIDTH , HEIGHT , CHANNEL) , kernel_initializer=truncated_normal(stddev=0.02) , bias_initializer=constant(0.0) , name='conv1'))
    model.add(LeakyReLU(0.2))
    
    model.add(conv2d(64*2))
    model.add(batch_norm())
    model.add(LeakyReLU(0.2))
    
    model.add(conv2d(64*4))
    model.add(batch_norm())  
    model.add(LeakyReLU(0.2))

    
    model.add(conv2d(64*8))
    model.add(batch_norm())  
    model.add(LeakyReLU(0.2))

    model.add(Flatten())
    
    #===
    #如果没有下面的两个FC层 训练时发生损失不下降 且生成不出图像
    #model.add(dense(1024))
    #model.add(LeakyReLU(0.2))
    
    model.add(dense(256))
    model.add(LeakyReLU(0.2))
    #===
    model.add(dense(1))
    
    model.summary()
    
    image = Input(shape=(WIDTH , HEIGHT , CHANNEL) , name='input1')
    validity = model(image)
    
    return Model(image , validity , name='critic')

In [16]:
rmsprop = RMSprop(lr=0.00005)

def wgan_loss(y_true , y_pred):
    return K.mean(y_true*y_pred)

In [17]:
critic_i = critic()
critic_i.compile(optimizer=rmsprop , loss=wgan_loss , metrics=['accuracy'])

generator_i = generator()

z = Input(shape=(LATENT_DIM , ) , name='z')
image = generator_i(z)
validity = critic_i(image)

combined_model_i = Model(z , validity)
combined_model_i.compile(optimizer=rmsprop , loss=wgan_loss)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv1 (Conv2D)               (None, 48, 48, 64)        4864      
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 48, 48, 64)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 24, 24, 128)       204928    
_________________________________________________________________
batch_normalization_8 (Batch (None, 24, 24, 128)       512       
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 24, 24, 128)       0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 12, 12, 256)       819456    
_________________________________________________________________
batch_normalization_9 (Batch (None, 12, 12, 256)       1024      
__________

In [19]:
real_labels = -np.ones(shape=(BATCH_SIZE , 1)) #真实样本label为1
fake_labels = np.ones(shape=(BATCH_SIZE , 1)) #假样本label为0

for i in range(1000):
    #============================
    #训练一次G就要训练N_CRITIC次D（Discriminator）
    #===
    critic_i.trainable = True
    for layer in critic_i.layers:
        layer.trainable = True
    #===
    for _ in range(N_CRITIC):
        
        noise = np.random.normal(size=(BATCH_SIZE , LATENT_DIM))

        real_image = load_image()
        #训练判别器
        fake_image = generator_i.predict(noise)

        real_loss = critic_i.train_on_batch(real_image , real_labels)
        fake_loss = critic_i.train_on_batch(fake_image , fake_labels)

        loss = np.add(real_loss , fake_loss)/2
        
        #addin
        #WGAN的变化 对权重参数进行裁剪
        for layer in critic_i.layers:
            weights = layer.get_weights()
            weights = [np.clip(w , - CLIP_VALUE , CLIP_VALUE) for w in weights]
            layer.set_weights(weights)
    #============================
    #===
    critic_i.trainable = False
    for layer in critic_i.layers:
        layer.trainable = False
    #===
    
    #训练生成器
    noise2 = np.random.normal(size=(BATCH_SIZE , LATENT_DIM))
    generator_loss = combined_model_i.train_on_batch(noise2 , real_labels)

    print('epoch:%d loss:%f gene_loss:%f' % (i , 1-loss[0] , 1-generator_loss))

    if i % 10 == 0:
        write_image(i)
    
write_image(999)


  'Discrepancy between trainable weights and collected trainable'


epoch:0 loss:1.800870 gene_loss:0.420035
epoch:1 loss:1.860250 gene_loss:0.373381
epoch:2 loss:1.920574 gene_loss:0.317472
epoch:3 loss:1.981230 gene_loss:0.266104
epoch:4 loss:2.048466 gene_loss:0.214421
epoch:5 loss:2.102844 gene_loss:0.164625
epoch:6 loss:2.177643 gene_loss:0.116024
epoch:7 loss:2.234054 gene_loss:0.074761
epoch:8 loss:2.292432 gene_loss:0.035752
epoch:9 loss:2.346036 gene_loss:-0.000503
epoch:10 loss:2.383422 gene_loss:-0.029925
epoch:11 loss:2.435460 gene_loss:-0.066685
epoch:12 loss:2.483436 gene_loss:-0.100914
epoch:13 loss:2.527373 gene_loss:-0.136183
epoch:14 loss:2.573527 gene_loss:-0.175185
epoch:15 loss:2.600257 gene_loss:-0.206935
epoch:16 loss:2.648395 gene_loss:-0.244676
epoch:17 loss:2.698341 gene_loss:-0.280201
epoch:18 loss:2.731285 gene_loss:-0.308801
epoch:19 loss:2.777992 gene_loss:-0.343140
epoch:20 loss:2.809980 gene_loss:-0.375676
epoch:21 loss:2.846431 gene_loss:-0.406862
epoch:22 loss:2.896318 gene_loss:-0.442894
epoch:23 loss:2.931679 gene_lo

epoch:190 loss:3.803133 gene_loss:-1.471437
epoch:191 loss:3.804697 gene_loss:-1.466181
epoch:192 loss:3.808078 gene_loss:-1.472929
epoch:193 loss:3.803762 gene_loss:-1.466535
epoch:194 loss:3.808145 gene_loss:-1.472509
epoch:195 loss:3.806570 gene_loss:-1.473524
epoch:196 loss:3.806127 gene_loss:-1.468312
epoch:197 loss:3.808488 gene_loss:-1.473478
epoch:198 loss:3.806086 gene_loss:-1.470021
epoch:199 loss:3.807036 gene_loss:-1.473267
epoch:200 loss:3.808108 gene_loss:-1.471969
epoch:201 loss:3.810417 gene_loss:-1.474874
epoch:202 loss:3.804861 gene_loss:-1.469447
epoch:203 loss:3.810587 gene_loss:-1.474670
epoch:204 loss:3.811357 gene_loss:-1.474313
epoch:205 loss:3.808753 gene_loss:-1.471510
epoch:206 loss:3.809780 gene_loss:-1.475173
epoch:207 loss:3.805023 gene_loss:-1.466982
epoch:208 loss:3.809957 gene_loss:-1.476110
epoch:209 loss:3.809104 gene_loss:-1.473895
epoch:210 loss:3.809937 gene_loss:-1.468246
epoch:211 loss:3.810262 gene_loss:-1.476298
epoch:212 loss:3.809054 gene_los

epoch:377 loss:3.801017 gene_loss:-1.457913
epoch:378 loss:3.802518 gene_loss:-1.463113
epoch:379 loss:3.800969 gene_loss:-1.455535
epoch:380 loss:3.806221 gene_loss:-1.462500
epoch:381 loss:3.806056 gene_loss:-1.465640
epoch:382 loss:3.805408 gene_loss:-1.464550
epoch:383 loss:3.805786 gene_loss:-1.461814
epoch:384 loss:3.802516 gene_loss:-1.465362
epoch:385 loss:3.807281 gene_loss:-1.467212
epoch:386 loss:3.804342 gene_loss:-1.463690
epoch:387 loss:3.809145 gene_loss:-1.467850
epoch:388 loss:3.802522 gene_loss:-1.462652
epoch:389 loss:3.809324 gene_loss:-1.469354
epoch:390 loss:3.810121 gene_loss:-1.465284
epoch:391 loss:3.807993 gene_loss:-1.468410
epoch:392 loss:3.806672 gene_loss:-1.465441
epoch:393 loss:3.809218 gene_loss:-1.469035
epoch:394 loss:3.808664 gene_loss:-1.466835
epoch:395 loss:3.811724 gene_loss:-1.470194
epoch:396 loss:3.809163 gene_loss:-1.469512
epoch:397 loss:3.809860 gene_loss:-1.466820
epoch:398 loss:3.810922 gene_loss:-1.470439
epoch:399 loss:3.808951 gene_los

epoch:564 loss:3.806898 gene_loss:-1.453367
epoch:565 loss:3.805144 gene_loss:-1.456576
epoch:566 loss:3.803969 gene_loss:-1.462035
epoch:567 loss:3.800801 gene_loss:-1.444471
epoch:568 loss:3.807053 gene_loss:-1.463873
epoch:569 loss:3.807825 gene_loss:-1.461736
epoch:570 loss:3.799843 gene_loss:-1.461870
epoch:571 loss:3.802947 gene_loss:-1.461673
epoch:572 loss:3.807482 gene_loss:-1.463588
epoch:573 loss:3.793334 gene_loss:-1.448540
epoch:574 loss:3.808275 gene_loss:-1.467096
epoch:575 loss:3.809508 gene_loss:-1.466359
epoch:576 loss:3.805493 gene_loss:-1.459304
epoch:577 loss:3.797812 gene_loss:-1.449726
epoch:578 loss:3.806082 gene_loss:-1.467113
epoch:579 loss:3.806259 gene_loss:-1.466015
epoch:580 loss:3.788645 gene_loss:-1.426539
epoch:581 loss:3.798097 gene_loss:-1.461115
epoch:582 loss:3.802943 gene_loss:-1.463788
epoch:583 loss:3.801775 gene_loss:-1.463731
epoch:584 loss:3.801646 gene_loss:-1.462528
epoch:585 loss:3.797191 gene_loss:-1.457582
epoch:586 loss:3.803156 gene_los

epoch:751 loss:3.814669 gene_loss:-1.475324
epoch:752 loss:3.816029 gene_loss:-1.472886
epoch:753 loss:3.816287 gene_loss:-1.475584
epoch:754 loss:3.814265 gene_loss:-1.470305
epoch:755 loss:3.813953 gene_loss:-1.466035
epoch:756 loss:3.811086 gene_loss:-1.459377
epoch:757 loss:3.812134 gene_loss:-1.444570
epoch:758 loss:3.810899 gene_loss:-1.429932
epoch:759 loss:3.809475 gene_loss:-1.427191
epoch:760 loss:3.807673 gene_loss:-1.431852
epoch:761 loss:3.807743 gene_loss:-1.416010
epoch:762 loss:3.803235 gene_loss:-1.386970
epoch:763 loss:3.794336 gene_loss:-1.268582
epoch:764 loss:3.793496 gene_loss:1.222467
epoch:765 loss:3.699900 gene_loss:2.037062
epoch:766 loss:3.768307 gene_loss:2.956992
epoch:767 loss:3.773214 gene_loss:3.725720
epoch:768 loss:3.787179 gene_loss:3.778543
epoch:769 loss:3.780385 gene_loss:3.257481
epoch:770 loss:3.772239 gene_loss:3.242775
epoch:771 loss:3.774400 gene_loss:3.496365
epoch:772 loss:3.767679 gene_loss:3.846582
epoch:773 loss:3.774600 gene_loss:3.54016

epoch:939 loss:3.725879 gene_loss:3.000499
epoch:940 loss:3.740181 gene_loss:3.597590
epoch:941 loss:3.732969 gene_loss:3.522485
epoch:942 loss:3.743046 gene_loss:2.758092
epoch:943 loss:3.755360 gene_loss:3.606364
epoch:944 loss:3.765016 gene_loss:3.522363
epoch:945 loss:3.763556 gene_loss:1.974002
epoch:946 loss:3.760409 gene_loss:3.297387
epoch:947 loss:3.765330 gene_loss:3.395870
epoch:948 loss:3.780151 gene_loss:3.255010
epoch:949 loss:3.745528 gene_loss:-0.247847
epoch:950 loss:3.769017 gene_loss:-0.377298
epoch:951 loss:3.755682 gene_loss:1.465835
epoch:952 loss:3.774311 gene_loss:1.985044
epoch:953 loss:3.723406 gene_loss:-0.461999
epoch:954 loss:3.759938 gene_loss:3.057737
epoch:955 loss:3.689037 gene_loss:-1.068154
epoch:956 loss:3.767183 gene_loss:0.205434
epoch:957 loss:3.752760 gene_loss:-1.068387
epoch:958 loss:3.758076 gene_loss:0.037460
epoch:959 loss:3.752205 gene_loss:-0.967462
epoch:960 loss:3.770689 gene_loss:1.585284
epoch:961 loss:3.747769 gene_loss:-0.318208
epoc

In [70]:
real_labels.shape

(64, 1)

In [53]:
discriminator_i.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input1 (InputLayer)          (None, 28, 28, 1)         0         
_________________________________________________________________
discriminator (Sequential)   (None, 1)                 533505    
Total params: 533,505
Trainable params: 0
Non-trainable params: 533,505
_________________________________________________________________


In [54]:
generator_i.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input1 (InputLayer)          (None, 100)               0         
_________________________________________________________________
generator (Sequential)       (None, 28, 28, 1)         1097744   
Total params: 1,097,744
Trainable params: 1,095,184
Non-trainable params: 2,560
_________________________________________________________________


In [39]:
combined_model_i.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
z (InputLayer)               (None, 100)               0         
_________________________________________________________________
generator_Model (Model)      (None, 96, 96, 3)         29029120  
_________________________________________________________________
discriminator_Model (Model)  (None, 1)                 14320641  
Total params: 43,349,761
Trainable params: 29,025,536
Non-trainable params: 14,324,225
_________________________________________________________________


In [None]:
gc.collect()

In [14]:
modeli = Sequential()

modeli.add(Dense(128 * 7 * 7, activation="relu", input_shape=(LATENT_DIM,)))
modeli.add(Reshape((7, 7, 128)))
modeli.add(UpSampling2D())
modeli.add(Conv2D(128, kernel_size=3, padding="same"))
modeli.add(BatchNormalization(momentum=0.8))
modeli.add(Activation("relu"))
modeli.add(UpSampling2D())
modeli.add(Conv2D(64, kernel_size=3, padding="same"))
modeli.add(BatchNormalization(momentum=0.8))
modeli.add(Activation("relu"))
modeli.add(Conv2D(CHANNEL, kernel_size=3, padding="same"))
modeli.add(Activation("tanh"))

In [15]:
modeli.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_4 (Dense)              (None, 6272)              633472    
_________________________________________________________________
reshape_3 (Reshape)          (None, 7, 7, 128)         0         
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 14, 14, 128)       147584    
_________________________________________________________________
batch_normalization_2 (Batch (None, 14, 14, 128)       512       
_________________________________________________________________
activation_1 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
up_sampling2d_3 (UpSampling2 (None, 28, 28, 128)       0         
__________

In [41]:
32*400

12800

In [None]:
gc.collect()