In [14]:
from keras.models import Sequential , Model
from keras.layers import Dense ,  BatchNormalization , Reshape , Input , Flatten
from keras.layers import Conv2D , MaxPool2D , Conv2DTranspose , UpSampling2D , ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers import Activation
from keras.layers import Dropout

from keras.layers import Add , Multiply

from keras.initializers import random_normal
from keras.initializers import constant
from keras.initializers import truncated_normal

from keras.optimizers import Adam

from keras.datasets import mnist

In [24]:
import os

import matplotlib as plt
import numpy as np

import keras.backend as K

import gc

%matplotlib inline

In [37]:
WIDTH = 96
HEIGHT = 96
CHANNEL = 3

LATENT_DIM = 100 #latent variable z sample from normal distribution

BATCH_SIZE = 128
EPOCHS = 10

PATH = '../gans/faces/'

#生成多少个图像 长*宽
ROW = 5
COL = 5


In [38]:

load_index = 0

images_name = os.listdir(PATH)

IMAGES_COUNT = len(images_name)


In [39]:

def load_image(batch_size = BATCH_SIZE):
    global load_index
    
    images = []
    
    for i in range(batch_size):
        images.append(plt.image.imread(PATH + images_name[(load_index + i) % IMAGES_COUNT]))
    
    load_index += batch_size
    
    return np.array(images)/127.5-1

def write_image(epoch):
    
    noise = np.random.normal(size = (ROW*COL , LATENT_DIM))
    generated_image = decoder_i.predict(noise)
    generated_image = (generated_image+1)*127.5
    
    fig , axes = plt.pyplot.subplots(ROW , COL)
    
    count=0
    
    for i in range(ROW):
        for j in range(COL):
            axes[i][j].imshow(generated_image[count])
            axes[i][j].axis('off')
            count += 1
            
    fig.savefig('cartoon_aae/No.%d.png' % epoch)
    plt.pyplot.close()
    
    #plt.image.imsave('images/'+str(epoch)+'.jpg')


In [40]:
def discriminator():
    #input a image to discriminate real or fake
    model = Sequential(name='discriminator')
    
    model.add(Dense(512 , input_shape=(LATENT_DIM , )))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1 , activation='sigmoid'))
    
    encoder_feature = Input(shape=(LATENT_DIM , ))
    validity = model(encoder_feature)
    
    return Model(encoder_feature , validity , name='discriminator_Model')

In [41]:
def conv2d(output_size):
    return Conv2D(output_size , kernel_size=(5,5) , strides=(2,2) , padding='same' , kernel_initializer=truncated_normal(stddev=0.02),bias_initializer=constant(0.0))

def batch_norm():
    return BatchNormalization(momentum=0.9 , epsilon=1e-5)

def deconv2d(output_size):
    return Conv2DTranspose(output_size , kernel_size=(5,5) , strides=(2,2) , padding='same' , kernel_initializer=truncated_normal(stddev=0.02) , bias_initializer=constant(0.0))



In [42]:
def exp(x):
    return K.exp(x)

def encoder():
    image = Input(shape=(HEIGHT , WIDTH , CHANNEL))
    
    h = Conv2D(64 , input_shape=(HEIGHT , WIDTH , CHANNEL) , kernel_size=(5,5) , strides=(2,2) , padding='same' , kernel_initializer=truncated_normal(stddev=0.02),bias_initializer=constant(0.0))(image)
    h = LeakyReLU(alpha=0.2)(h)
    
    h = conv2d(64*2)(h)
    h = batch_norm()(h)
    h = LeakyReLU(alpha=0.2)(h)
    
    h = conv2d(64*4)(h)
    h = batch_norm()(h)
    h = LeakyReLU(alpha=0.2)(h)
    
    h = conv2d(64*8)(h)
    h = batch_norm()(h)
    h = LeakyReLU(alpha=0.2)(h)
    
    h = Flatten()(h)
    
    mean = Dense(LATENT_DIM)(h)
    log_var = Dense(LATENT_DIM)(h)
    
    a=Activation(exp)(log_var)
    b=K.random_normal(shape=K.shape(log_var))
    encoder_feature = Add()([mean , Multiply()([a , b])])
    #    
    return Model(image , a)

def decoder():
    input_feature = Input(shape=(LATENT_DIM , ))
    
    h = Dense(6*6*8*64 , input_shape=(LATENT_DIM , ))(input_feature)
    h = Reshape(target_shape=(6,6,64*8))(h)
    
    h = batch_norm()(h)
    h = LeakyReLU(alpha=0.2)(h)
    
    h = deconv2d(64*4)(h)
    h = batch_norm()(h)
    h = LeakyReLU(alpha=0.2)(h)
    
    h = deconv2d(64*2)(h)
    h = batch_norm()(h)
    h = LeakyReLU(alpha=0.2)(h)
    
    h = deconv2d(64*1)(h)
    h = batch_norm()(h)
    h = LeakyReLU(alpha=0.2)(h)
    
    h = deconv2d(3)(h)
    h = Activation('tanh')(h)
    
    image_hat = h
    
    return Model(input_feature , image_hat)

In [46]:
adam = Adam(lr = 0.0001 , beta_1=0.5)
#learning rate设置为0.0002容易出现nan情况
#增大batchsize or 降低learning rate

In [None]:
discriminator_i = discriminator()
discriminator_i.compile(optimizer=adam , loss='binary_crossentropy' , metrics=['accuracy'])


encoder_i = encoder()
decoder_i = decoder()

image = Input(shape=(HEIGHT , WIDTH , CHANNEL))

encoder_feature = encoder_i(image)
image_hat = decoder_i(encoder_feature)

discriminator_i.trainable = False

validity = discriminator_i(encoder_feature)

#aae为重新构造的组合model
aae = Model(image , [image_hat , validity])

aae.compile(optimizer=adam , loss=['mse' , 'binary_crossentropy'] , loss_weights = [0.999 , 0.001])

In [None]:
real_labels = np.ones(shape=(BATCH_SIZE , 1)) #真实样本label为1
fake_labels = np.zeros(shape=(BATCH_SIZE , 1)) #假样本label为0

for i in range(1000):
    for j in range(int(IMAGES_COUNT/BATCH_SIZE)):
        noise_encoder_feature = np.random.normal(size=(BATCH_SIZE , LATENT_DIM))

        real_image = load_image()
        #训练判别器
        real_encoder_feature = encoder_i.predict(real_image)

        #从高斯分布采样的噪声 与 True label 组成训练集
        #真实的图像经过encoder后得到的特征与 False label 组成训练集
        noise_loss = discriminator_i.train_on_batch(noise_encoder_feature , real_labels)
        real_loss = discriminator_i.train_on_batch(real_encoder_feature , fake_labels)

        loss = np.add(noise_loss , real_loss)/2

        #训练aae
        aae_loss = aae.train_on_batch(real_image , [real_image , real_labels])

        print('epoch:%d batch:%d loss:%f accu:%f aae_loss_1:%f aae_loss_2:%f' % (i , j , loss[0] , loss[1] , aae_loss[0] , aae_loss[1]))

        if i % 20 == 0:
            write_image(i)
    
write_image(999)


  'Discrepancy between trainable weights and collected trainable'


epoch:0 loss:0.760958 accu:0.253906 aae_loss_1:0.385436 aae_loss_2:0.384872
epoch:1 loss:1.474787 accu:0.472656 aae_loss_1:0.368721 aae_loss_2:0.366708
epoch:2 loss:0.774396 accu:0.636719 aae_loss_1:0.368552 aae_loss_2:0.364419
epoch:3 loss:1.369915 accu:0.597656 aae_loss_1:0.354331 aae_loss_2:0.351153
epoch:4 loss:2.690154 accu:0.464844 aae_loss_1:0.359615 aae_loss_2:0.356323
epoch:5 loss:2.424586 accu:0.449219 aae_loss_1:0.344994 aae_loss_2:0.339629
epoch:6 loss:3.071923 accu:0.378906 aae_loss_1:0.346112 aae_loss_2:0.339765
epoch:7 loss:2.815377 accu:0.382812 aae_loss_1:0.368055 aae_loss_2:0.358943
epoch:8 loss:4.520477 accu:0.304688 aae_loss_1:0.371252 aae_loss_2:0.358368
epoch:9 loss:0.495848 accu:0.550781 aae_loss_1:0.348392 aae_loss_2:0.333556
epoch:10 loss:0.517954 accu:0.535156 aae_loss_1:0.352856 aae_loss_2:0.338217
epoch:11 loss:0.514346 accu:0.542969 aae_loss_1:0.343539 aae_loss_2:0.329559
epoch:12 loss:0.510888 accu:0.531250 aae_loss_1:0.353562 aae_loss_2:0.340126
epoch:13 

epoch:107 loss:0.208827 accu:0.988281 aae_loss_1:0.220064 aae_loss_2:0.206522
epoch:108 loss:0.212189 accu:0.968750 aae_loss_1:0.228173 aae_loss_2:0.213940
epoch:109 loss:0.207834 accu:0.964844 aae_loss_1:0.238820 aae_loss_2:0.225270
epoch:110 loss:0.198953 accu:0.980469 aae_loss_1:0.231217 aae_loss_2:0.217596
epoch:111 loss:0.174124 accu:0.996094 aae_loss_1:0.224593 aae_loss_2:0.210338
epoch:112 loss:0.175385 accu:0.992188 aae_loss_1:0.220802 aae_loss_2:0.206034
epoch:113 loss:0.163089 accu:1.000000 aae_loss_1:0.215603 aae_loss_2:0.201254
epoch:114 loss:0.148157 accu:0.996094 aae_loss_1:0.301502 aae_loss_2:0.285811
epoch:115 loss:0.219028 accu:0.941406 aae_loss_1:0.202640 aae_loss_2:0.189280
epoch:116 loss:0.906095 accu:0.937500 aae_loss_1:0.235861 aae_loss_2:0.221703
epoch:117 loss:0.172958 accu:0.980469 aae_loss_1:0.225758 aae_loss_2:0.211546
epoch:118 loss:0.163587 accu:0.984375 aae_loss_1:0.237572 aae_loss_2:0.222625
epoch:119 loss:0.181927 accu:0.953125 aae_loss_1:0.228114 aae_lo

epoch:213 loss:0.061381 accu:0.992188 aae_loss_1:0.210854 aae_loss_2:0.195277
epoch:214 loss:0.060061 accu:0.996094 aae_loss_1:0.221319 aae_loss_2:0.205589
epoch:215 loss:0.054179 accu:1.000000 aae_loss_1:0.220636 aae_loss_2:0.204883
epoch:216 loss:0.051971 accu:1.000000 aae_loss_1:0.223423 aae_loss_2:0.207513
epoch:217 loss:0.053617 accu:1.000000 aae_loss_1:0.209276 aae_loss_2:0.193352
epoch:218 loss:0.051067 accu:1.000000 aae_loss_1:0.225706 aae_loss_2:0.209964
epoch:219 loss:0.050281 accu:1.000000 aae_loss_1:0.214499 aae_loss_2:0.198688
epoch:220 loss:0.046443 accu:1.000000 aae_loss_1:0.205519 aae_loss_2:0.189591
epoch:221 loss:0.050374 accu:1.000000 aae_loss_1:0.224609 aae_loss_2:0.208713
epoch:222 loss:0.047786 accu:1.000000 aae_loss_1:0.213368 aae_loss_2:0.197558
epoch:223 loss:0.046861 accu:1.000000 aae_loss_1:0.221177 aae_loss_2:0.205264
epoch:224 loss:0.046570 accu:1.000000 aae_loss_1:0.208443 aae_loss_2:0.192611
epoch:225 loss:0.054550 accu:1.000000 aae_loss_1:0.212118 aae_lo

In [70]:
real_labels.shape

(64, 1)

In [53]:
discriminator_i.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input1 (InputLayer)          (None, 28, 28, 1)         0         
_________________________________________________________________
discriminator (Sequential)   (None, 1)                 533505    
Total params: 533,505
Trainable params: 0
Non-trainable params: 533,505
_________________________________________________________________


In [54]:
generator_i.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input1 (InputLayer)          (None, 100)               0         
_________________________________________________________________
generator (Sequential)       (None, 28, 28, 1)         1097744   
Total params: 1,097,744
Trainable params: 1,095,184
Non-trainable params: 2,560
_________________________________________________________________


In [39]:
combined_model_i.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
z (InputLayer)               (None, 100)               0         
_________________________________________________________________
generator_Model (Model)      (None, 96, 96, 3)         29029120  
_________________________________________________________________
discriminator_Model (Model)  (None, 1)                 14320641  
Total params: 43,349,761
Trainable params: 29,025,536
Non-trainable params: 14,324,225
_________________________________________________________________


In [None]:
gc.collect()

In [14]:
modeli = Sequential()

modeli.add(Dense(128 * 7 * 7, activation="relu", input_shape=(LATENT_DIM,)))
modeli.add(Reshape((7, 7, 128)))
modeli.add(UpSampling2D())
modeli.add(Conv2D(128, kernel_size=3, padding="same"))
modeli.add(BatchNormalization(momentum=0.8))
modeli.add(Activation("relu"))
modeli.add(UpSampling2D())
modeli.add(Conv2D(64, kernel_size=3, padding="same"))
modeli.add(BatchNormalization(momentum=0.8))
modeli.add(Activation("relu"))
modeli.add(Conv2D(CHANNEL, kernel_size=3, padding="same"))
modeli.add(Activation("tanh"))

In [15]:
modeli.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_4 (Dense)              (None, 6272)              633472    
_________________________________________________________________
reshape_3 (Reshape)          (None, 7, 7, 128)         0         
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 14, 14, 128)       147584    
_________________________________________________________________
batch_normalization_2 (Batch (None, 14, 14, 128)       512       
_________________________________________________________________
activation_1 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
up_sampling2d_3 (UpSampling2 (None, 28, 28, 128)       0         
__________

In [41]:
32*400

12800

In [None]:
gc.collect()

In [30]:
aa=encoder()

AttributeError: 'NoneType' object has no attribute '_inbound_nodes'