In [1]:
#reference https://github.com/mokemokechicken/keras_BEGAN
#reference https://github.com/siddharthalodha/BEGAN_KERAS/blob/master/BEGAN_v1.ipynb    

In [1]:
from keras.models import Sequential , Model
from keras.layers import Dense ,  BatchNormalization , Reshape , Input , Flatten
from keras.layers import Conv2D , MaxPool2D , Conv2DTranspose , UpSampling2D , ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU , PReLU
from keras.layers import Activation
from keras.layers import Dropout

from keras.initializers import truncated_normal , constant , random_normal

from keras.optimizers import Adam , RMSprop

#残差块使用
from keras.layers import Add

from keras.datasets import mnist

#addin BEGAN
from keras.layers import Lambda #可以自己指定特定层 完成特定的功能 discriminator使用
from keras.layers import Concatenate #将discriminator的两个输出合并使用

Using TensorFlow backend.


In [2]:
import os

import matplotlib as plt
import numpy as np

import gc

from glob import glob

import keras.backend as K

import scipy

%matplotlib inline

In [3]:
WIDTH = 64
HEIGHT = 64
CHANNEL = 3

SHAPE = (WIDTH , HEIGHT , CHANNEL)

LATENT_DIM = 100 #latent variable z sample from normal distribution

BATCH_SIZE = 16 #crazy!!! slow turtle
EPOCHS = 10

PATH = '../dataset/CelebA/img_align_celeba/'

#生成多少个图像 长*宽
ROW = 2
COL = 3

#addin BEGAN
N_FILTERS = 128

In [4]:
#==============
IMAGES_PATH = glob(PATH+'*')

In [4]:
def load_image(batch_size = BATCH_SIZE , training = True):
    #随机在图片库中挑选
    images = np.random.choice(IMAGES_PATH , size=batch_size)
    
    images_high_resolution = []
    images_low_resolution = []
    
    for image in images:
        img = scipy.misc.imread(image , mode='RGB').astype(np.float)
        
        #尽管原图像不是指定的大小 下面将强制将图像resize
        img_high_resolution = scipy.misc.imresize(img , size=HIGH_RESOLUTION_SHAPE)
        img_low_resolution = scipy.misc.imresize(img , size=LOW_RESOLUTION_SHAPE)
        
        #随机性地对训练样本进行 左右反转
        if training and np.random.random()<0.5:
            img_high_resolution = np.fliplr(img_high_resolution)
            img_low_resolution = np.fliplr(img_low_resolution)
        
        images_high_resolution.append(img_high_resolution)
        images_low_resolution.append(img_low_resolution)
        
    images_high_resolution = np.array(images_high_resolution)/127.5 - 1
    images_low_resolution = np.array(images_low_resolution)/127.5 - 1
    
    return images_high_resolution , images_low_resolution


def write_image(epoch):
    #生成高分图像时 进行对比显示
    high_resolution_image , low_resolution_image = load_image(batch_size=2 , training=False)
    fake_high_resolution_image = generator_i.predict(low_resolution_image) #使用G来生成高分图像 使用低分图像生成原始的高分图像 但是难免有偏差 细节表现
    
    low_resolution_image = low_resolution_image*0.5+0.5
    high_resolution_image = high_resolution_image*0.5+0.5
    fake_high_resolution_image = fake_high_resolution_image*0.5+0.5
    
    
    fig , axes = plt.pyplot.subplots(ROW , COL)
    count=0
    
    axes[0][0].imshow(high_resolution_image[0])
    axes[0][0].set_title('original high')
    axes[0][0].axis('off')

    axes[0][1].imshow(fake_high_resolution_image[0])
    axes[0][1].set_title('generated high')
    axes[0][1].axis('off')
    
    axes[0][2].imshow(low_resolution_image[0])
    axes[0][2].set_title('original low')
    axes[0][2].axis('off')

    axes[1][0].imshow(high_resolution_image[1])
    axes[1][0].set_title('original high')
    axes[1][0].axis('off')

    axes[1][1].imshow(fake_high_resolution_image[1])
    axes[1][1].set_title('generated high')
    axes[1][1].axis('off')
    
    axes[1][2].imshow(low_resolution_image[1])
    axes[1][2].set_title('original low')
    axes[1][2].axis('off')

            
    fig.savefig('celeba_began/No.%d.png' % epoch)
    plt.pyplot.close()


In [6]:
#==============

In [5]:
def multi_conv2d(x , output_size , strides , n_layer):
    for _ in range(n_layer):
        x = Conv2D(output_size , kernel_size=(3,3) , strides=(1,1) , activation='elu' , padding='same')(x) #1步卷积 保证图像尺寸不变
    
    return Conv2D(output_size , kernel_size=(3,3) , strides=strides , activation='elu' , padding='same')(x) #最后的这个卷积的stride为2 进行降2倍采样


def multi_deconv2d(x , output_size , n_layer , upsample):
    for _ in range(n_layer):
        x = Conv2D(output_size , kernel_size=(3,3) , strides=(1,1) , activation='elu' , padding='same')(x)
    
    if upsample:
        x = UpSampling2D()(x)

    return x


In [6]:
def encoder():
    image = Input(shape=SHAPE)
    
    h = multi_conv2d(image , output_size=N_FILTERS , strides=(2,2) , n_layer=2)
    h = multi_conv2d(h , output_size=N_FILTERS*2 , strides=(2,2) , n_layer=2)
    h = multi_conv2d(h , output_size=N_FILTERS*3 , strides=(2,2) , n_layer=2)
    h = multi_conv2d(h , output_size=N_FILTERS*4 , strides=(1,1) , n_layer=2)
    
    h = Flatten()(h)
    
    feature = Dense(units=64 , activation='linear')(h)
    
    return Model(image , feature)
    
def decoder():
    feature = Input(shape=(64 , ))
    
    h = Dense(units=8*8*N_FILTERS , activation='linear')(feature)
    h = Reshape(target_shape=(8,8,N_FILTERS))(h)
    
    h = multi_deconv2d(h , output_size=N_FILTERS , n_layer=2 , upsample=True)
    h = multi_deconv2d(h , output_size=N_FILTERS , n_layer=2 , upsample=True)
    h = multi_deconv2d(h , output_size=N_FILTERS , n_layer=2 , upsample=True)
    h = multi_deconv2d(h , output_size=N_FILTERS , n_layer=2 , upsample=False)
    
    image = Conv2D(3 , kernel_size=(3,3) , strides=(1,1) , activation='linear' , padding='same')(h)
    
    return Model(feature , image)


In [7]:
def autoencoder():
    image = Input(shape=(SHAPE))
    
    Encoder = encoder()
    Decoder = decoder()
    
    feature = Encoder(image)
    image_hat = Decoder(feature)
    
    return Model(image , image_hat)

In [8]:
def discriminator():
    input_data = Input(shape=(HEIGHT , WIDTH , CHANNEL*2)) #因为discriminator的输入是两幅图像 所以输入的通道翻一倍
    
    real_image = Lambda(lambda x: x[:,:,:,:3] , output_shape=SHAPE)(input_data)
    generator_image = Lambda(lambda x: x[:,:,:,3:] , output_shape=SHAPE)(input_data)
    
    Autoencoder = autoencoder()
    
    real_image_hat = Autoencoder(real_image)
    generator_image_hat = Autoencoder(generator_image)

    output_data = Concatenate()([real_image_hat , generator_image_hat])

    return Model(input_data , output_data)

In [9]:
def generator():
    return decoder() #feature to image

In [20]:
class D_loss(object):
    def __init__(self , init_k_var=0.0 , init_lambda_k = 0.001 , init_gamma=0.5):
        self.lambda_k = init_lambda_k
        self.gamma = init_gamma
        self.k_var = K.variable(init_k_var , dtype=K.floatx())
        
        self.m_global_var = K.variable(0.0 , dtype=K.floatx())
        self.loss_real_x_var = K.variable(0)
        self.loss_gene_x_var = K.variable(0)
        
        self.updates = []

    def D_loss(y_true , y_pred):
        real_image_hat_true , generator_image_hat_true = y_true[:,:,:,:3] , y_true[:,:,:,3:]
        real_image_hat_pred ,generator_image_hat_pred = y_pred[:,:,:,:3] , y_true[:,:,:,3:]

        #下面的平均在batch维上还没有平均 所以得到的是1维张量
        loss_real_image_hat = K.mean(K.abs(real_image_hat_true-real_image_hat_pred) , axis=[1,2,3])
        loss_generator_image_hat = K.mean(K.abs(generator_image_hat_true-generator_image_hat_pred) , axis=[1,2,3])

        #paper equation
        discriminator_loss = loss_real_image_hat - self.k_var*loss_generator_image_hat

        mean_loss_real_image_hat = K.mean(loss_real_image_hat)
        mean_loss_generator_image_hat = K.mean(loss_generator_image_hat)
        #paper equation
        new_k = self.k_var + self.lambda_k*(self.gamma*mean_loss_real_image_hat-mean_loss_generator_image_hat)
        new_k = K.clip(new_k , 0 , 1)
        self.updates.append(K.update(self.k_var , new_k))
        
        m_global = mean_loss_real_image_hat + K.abs(self.gamma*mean_loss_real_image_hat-mean_loss_generator_image_hat)
        self.updates.append(K.update(self.m_global_var , m_global))
        
        self.updates.append(K.update(self.loss_real_x_var , mean_loss_real_image_hat))
        self.updates.append(K.update(self.loss_gene_x_var , mean_loss_generator_image_hat))
        
        return discriminator_loss

In [11]:
adam = Adam(lr = 0.0002 , beta_1=0.5)


In [12]:
patch = int(HIGH_HEIGHT/(2**4)) #16
disc_patch = (patch , patch , 1) #16*16*1

G_filters = 64
D_filters = 64

In [13]:
discriminator_i = discriminator(D_filters)
discriminator_i.compile(optimizer=adam , loss='mse' , metrics=['accuracy'])

generator_i = generator(G_filters)

#image_high_resolution = Input(shape=HIGH_RESOLUTION_SHAPE) #不需要参与combined_model的整体构建 但是在训练的时候 是需要的
#在训练的时候
#来自真实样本的低分图像和高分图像
#高分样本经过VGG后的 低维特征和real_labels 作为训练generator时的labels 训练数据是低分图像
#具体过程为 将低分图像使用generator变为高分图像 然后经过VGG得到低维特征与训练样本中的低维样本(上面一句话中的低维特征)进行mse validity进行binary_crossentropy
image_low_resolution = Input(shape=LOW_RESOLUTION_SHAPE)

fake_image_high_resolution = generator_i(image_low_resolution) #低分 图像经过G后 生成高分图像
fake_image_high_resolution_feature = vgg(fake_image_high_resolution) #生成的高分图像经过VGG16得到的特征值

discriminator_i.trainable = False
validity = discriminator_i(fake_image_high_resolution) #判别器对生辰高分图像的validity值

combined_model_i = Model(image_low_resolution , [validity , fake_image_high_resolution_feature])

combined_model_i.compile(optimizer=adam , loss=['binary_crossentropy' , 'mse'] , loss_weights=[1e-3 , 1])

In [14]:
#tuple类型相加 相当于cat连接
real_labels = np.ones(shape=(BATCH_SIZE , )+disc_patch) #真实样本label为1
fake_labels = np.zeros(shape=(BATCH_SIZE , )+disc_patch) #假样本label为0

for i in range(1001):
    
    high_resolution_image , low_resolution_image = load_image() #真实的高分图像和低分图像都是来自真实样本
    
    fake_high_resolution_image = generator_i.predict(low_resolution_image) #使用G生成真低分样本的高分样本
    #训练判别器
    real_loss = discriminator_i.train_on_batch(high_resolution_image , real_labels) #使用真实的高分图像 训练 label全1
    fake_loss = discriminator_i.train_on_batch(fake_high_resolution_image , fake_labels) #使用G生成的假的高分图像 训练 label全0 

    loss = np.add(real_loss , fake_loss)/2

    #训练生成器
    high_resolution_image , low_resolution_image = load_image() #真实的高分图像和低分图像都是来自真实样本
    
    feature_high_resolution_image = vgg.predict(high_resolution_image)
    
    generator_loss = combined_model_i.train_on_batch(low_resolution_image , [real_labels , feature_high_resolution_image])

    print('epoch:%d loss:%f accu:%f gene_loss[x_entropy]:%f gene_loss[mse]:%f' % (i , loss[0] , loss[1] , generator_loss[0] , generator_loss[1]))

    if i % 50 == 0:
        write_image(i+1000)
    #write_image_mnist(i)
    
write_image(999)
#write_image_mnist(999)


  if issubdtype(ts, int):
  elif issubdtype(type(size), float):
  'Discrepancy between trainable weights and collected trainable'


epoch:0 loss:0.289170 accu:0.364258 gene_loss[x_entropy]:35.002327 gene_loss[mse]:0.726299
epoch:1 loss:0.240119 accu:0.560547 gene_loss[x_entropy]:37.199436 gene_loss[mse]:0.993910
epoch:2 loss:0.193222 accu:0.729980 gene_loss[x_entropy]:28.870689 gene_loss[mse]:1.238815
epoch:3 loss:0.188500 accu:0.707031 gene_loss[x_entropy]:21.937769 gene_loss[mse]:1.285119
epoch:4 loss:0.238619 accu:0.590820 gene_loss[x_entropy]:19.904377 gene_loss[mse]:1.846873
epoch:5 loss:0.178662 accu:0.746582 gene_loss[x_entropy]:16.685862 gene_loss[mse]:2.045166
epoch:6 loss:0.182008 accu:0.762695 gene_loss[x_entropy]:19.846888 gene_loss[mse]:1.548349
epoch:7 loss:0.082215 accu:0.955566 gene_loss[x_entropy]:15.399577 gene_loss[mse]:1.685131
epoch:8 loss:0.072267 accu:0.963379 gene_loss[x_entropy]:15.132298 gene_loss[mse]:2.216459
epoch:9 loss:0.044728 accu:0.997559 gene_loss[x_entropy]:16.553905 gene_loss[mse]:2.506667
epoch:10 loss:0.034918 accu:0.996094 gene_loss[x_entropy]:12.300821 gene_loss[mse]:2.55914

KeyboardInterrupt: 

In [70]:
real_labels.shape

(64, 1)

In [None]:
gc.collect()

In [None]:
gc.collect()

In [2]:
VGG19(weights='imagenet')

A local file was found, but it seems to be incomplete or outdated because the auto file hash does not match the original value of cbe5617147190e668d6c5d5026f83318 so we will re-download the data.
Downloading data from https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels.h5


<keras.engine.training.Model at 0x1ddbc5e7a58>