In [1]:
from keras.models import Sequential , Model
from keras.layers import Dense ,  BatchNormalization , Reshape , Input , Flatten
from keras.layers import Conv2D , MaxPool2D , Conv2DTranspose , UpSampling2D , ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU , PReLU
from keras.layers import Activation
from keras.layers import Dropout

from keras.initializers import truncated_normal , constant , random_normal

from keras.optimizers import Adam , RMSprop

#残差块使用
from keras.layers import Add

from keras.datasets import mnist

#导入存在的模型
from keras.applications import VGG16 , VGG19

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
import os

import matplotlib as plt
import numpy as np

import gc

from glob import glob

import keras.backend as K

import scipy

%matplotlib inline

In [3]:
WIDTH = 64 #先使用原有的尺寸参数
HEIGHT = 64
CHANNEL = 3

LOW_RESOLUTION_SHAPE = (WIDTH , HEIGHT , CHANNEL)

HIGH_WIDTH = WIDTH*4
HIGH_HEIGHT = HEIGHT*4
HIGH_RESOLUTION_SHAPE = (HIGH_WIDTH , HIGH_HEIGHT , CHANNEL)


LATENT_DIM = 100 #latent variable z sample from normal distribution

BATCH_SIZE = 4 #crazy!!! slow turtle
EPOCHS = 10

PATH = '../dataset/CelebA/img_align_celeba/'

#生成多少个图像 长*宽
ROW = 2
COL = 3

#addin SRGAN
RESIDUAL_BLOCK_NUM = 16 #使用残差块 数量

In [4]:
#==============
IMAGES_PATH = glob(PATH+'*')

In [5]:
def load_image(batch_size = BATCH_SIZE , training = True):
    #随机在图片库中挑选
    images = np.random.choice(IMAGES_PATH , size=batch_size)
    
    images_high_resolution = []
    images_low_resolution = []
    
    for image in images:
        img = scipy.misc.imread(image , mode='RGB').astype(np.float)
        
        #尽管原图像不是指定的大小 下面将强制将图像resize
        img_high_resolution = scipy.misc.imresize(img , size=HIGH_RESOLUTION_SHAPE)
        img_low_resolution = scipy.misc.imresize(img , size=LOW_RESOLUTION_SHAPE)
        
        #随机性地对训练样本进行 左右反转
        if training and np.random.random()<0.5:
            img_high_resolution = np.fliplr(img_high_resolution)
            img_low_resolution = np.fliplr(img_low_resolution)
        
        images_high_resolution.append(img_high_resolution)
        images_low_resolution.append(img_low_resolution)
        
    images_high_resolution = np.array(images_high_resolution)/127.5 - 1
    images_low_resolution = np.array(images_low_resolution)/127.5 - 1
    
    return images_high_resolution , images_low_resolution


def write_image(epoch):
    #生成高分图像时 进行对比显示
    high_resolution_image , low_resolution_image = load_image(batch_size=2 , training=False)
    fake_high_resolution_image = generator_i.predict(low_resolution_image) #使用G来生成高分图像 使用低分图像生成原始的高分图像 但是难免有偏差 细节表现
    
    low_resolution_image = low_resolution_image*0.5+0.5
    high_resolution_image = high_resolution_image*0.5+0.5
    fake_high_resolution_image = fake_high_resolution_image*0.5+0.5
    
    
    fig , axes = plt.pyplot.subplots(ROW , COL)
    count=0
    
    axes[0][0].imshow(high_resolution_image[0])
    axes[0][0].set_title('original high')
    axes[0][0].axis('off')

    axes[0][1].imshow(fake_high_resolution_image[0])
    axes[0][1].set_title('generated high')
    axes[0][1].axis('off')
    
    axes[0][2].imshow(low_resolution_image[0])
    axes[0][2].set_title('original low')
    axes[0][2].axis('off')

    axes[1][0].imshow(high_resolution_image[1])
    axes[1][0].set_title('original high')
    axes[1][0].axis('off')

    axes[1][1].imshow(fake_high_resolution_image[1])
    axes[1][1].set_title('generated high')
    axes[1][1].axis('off')
    
    axes[1][2].imshow(low_resolution_image[1])
    axes[1][2].set_title('original low')
    axes[1][2].axis('off')

            
    fig.savefig('celeba_began/No.%d.png' % epoch)
    plt.pyplot.close()


In [6]:
#==============

In [7]:
def conv2d(output_size):
    return Conv2D(output_size , kernel_size=(3,3) , strides=(1,1) , padding='same')

def conv2d_with_stride_2(output_size):
    return Conv2D(output_size , kernel_size=(3,3) , strides=(2,2) , padding='same')

def dense(output_size):
    return Dense(output_size , kernel_initializer=random_normal(stddev=0.02) , bias_initializer=constant(0.0))

def deconv2d_(input_data):
    return Conv2DTranspose(256 , kernel_size=(3,3) , strides=(2,2) , padding='same' , activation='relu')(input_data)

def deconv2d(input_data):
    h = UpSampling2D(size=2)(input_data)
    h = conv2d(output_size = 256)(h)
    h = Activation('relu')(h)

    return h
    
def batch_norm():
    return BatchNormalization(momentum=0.8)

def res_block(output_size , input_data):
    h = conv2d(output_size)(input_data)
    h = Activation('relu')(h)
    h = batch_norm()(h)
    h = conv2d(output_size)(h)
    h = batch_norm()(h)
    h = Add()([h , input_data])
    
    return h
    

In [8]:
def generator(G_filters):
    image_low_resolution = Input(shape=LOW_RESOLUTION_SHAPE)
    #64位置参数应为G_filters 参数传入 64
    c1 = Conv2D(64 , kernel_size=(3,3) , strides=(1,1) , padding='same')(image_low_resolution)
    c1 = Activation('relu')(c1)
    
    r = res_block(output_size=G_filters , input_data=c1)
    r = res_block(output_size=G_filters , input_data=r)
    r = res_block(output_size=G_filters , input_data=r)
    r = res_block(output_size=G_filters , input_data=r)
    r = res_block(output_size=G_filters , input_data=r)
    
    c2 = conv2d(output_size=G_filters)(r)
    c2 = batch_norm()(c2)
    c2 = Add()([c1,c2])

    #使用传统的反卷积生成的高分图像当中 有一些彩色的格子 隐隐约约的出现
    #使用paper中的方法 上采样后进行卷积
    u1 = deconv2d(c2)
    u2 = deconv2d(u1)
    
    gen_image_high_resolution = Conv2D(filters=CHANNEL , kernel_size=(9,9) , strides=(1,1) , padding='same')(u2)
    gen_image_high_resolution = Activation('tanh')(gen_image_high_resolution)

    #model input shape 64*64*3 output shape 256*256*3
    #由低分图像生成高分图像
    return Model(image_low_resolution , gen_image_high_resolution , name='generator_Model')

In [9]:
def discriminator(G_filters):
    image_high_resolution = Input(shape=HIGH_RESOLUTION_SHAPE) #64*64*64低分图像时 此输入为256*256*3
    
    h1 = conv2d(output_size=D_filters)(image_high_resolution)
    h1 = LeakyReLU(alpha=0.2)(h1)
    
    h2 = conv2d_with_stride_2(output_size=D_filters)(h1)
    h2 = LeakyReLU(alpha=0.2)(h2)
    h2 = batch_norm()(h2)
    
    h3 = conv2d(output_size=D_filters*2)(h2)
    h3 = LeakyReLU(alpha=0.2)(h3)
    h3 = batch_norm()(h3)
    
    h4 = conv2d_with_stride_2(output_size=D_filters*2)(h3)
    h4 = LeakyReLU(alpha=0.2)(h4)
    h4 = batch_norm()(h4)
    
    h5 = conv2d(output_size=D_filters*4)(h4)
    h5 = LeakyReLU(alpha=0.2)(h5)
    h5 = batch_norm()(h5)
    
    h6 = conv2d_with_stride_2(output_size=D_filters*4)(h5)
    h6 = LeakyReLU(alpha=0.2)(h6)
    h6 = batch_norm()(h6)
    
    h7 = conv2d(output_size=D_filters*8)(h6)
    h7 = LeakyReLU(alpha=0.2)(h7)
    h7 = batch_norm()(h7)
    
    h8 = conv2d_with_stride_2(output_size=D_filters*8)(h7)
    h8 = LeakyReLU(alpha=0.2)(h8)
    h8 = batch_norm()(h8)
    
    #此处不需要Flatten层 
    
    h9 = dense(output_size=D_filters*16)(h8)
    h9 = LeakyReLU(alpha=0.2)(h9)
    
    validity = Dense(units=1 , activation='sigmoid')(h9)
    
    return Model(image_high_resolution , validity , name='discriminator_Model')

In [10]:
def restruct_vgg(): #使用预训练的VGG16 的中间层输出（FC）作为一个新的model 来获取图像的低维feature
    vgg = VGG19(weights='imagenet') #第一次运行 下载VGG16 在ImageNet数据集上的预训练参数 耗时
    
    vgg.outputs = [vgg.layers[9].output] #修改vgg结构的输出 此输出为VGG16最后一层卷积层的输出
    
    image = Input(shape=HIGH_RESOLUTION_SHAPE)
    image_vgg_feature = vgg(image)
    
    return Model(image , image_vgg_feature) #新model重构完成

In [11]:
adam = Adam(lr = 0.0002 , beta_1=0.5)

vgg = restruct_vgg()
vgg.trainable = False
vgg.compile(loss='mse' , optimizer=adam , metrics=['accuracy'])

In [12]:
patch = int(HIGH_HEIGHT/(2**4)) #16
disc_patch = (patch , patch , 1) #16*16*1

G_filters = 64
D_filters = 64

In [13]:
discriminator_i = discriminator(D_filters)
discriminator_i.compile(optimizer=adam , loss='mse' , metrics=['accuracy'])

generator_i = generator(G_filters)

#image_high_resolution = Input(shape=HIGH_RESOLUTION_SHAPE) #不需要参与combined_model的整体构建 但是在训练的时候 是需要的
#在训练的时候
#来自真实样本的低分图像和高分图像
#高分样本经过VGG后的 低维特征和real_labels 作为训练generator时的labels 训练数据是低分图像
#具体过程为 将低分图像使用generator变为高分图像 然后经过VGG得到低维特征与训练样本中的低维样本(上面一句话中的低维特征)进行mse validity进行binary_crossentropy
image_low_resolution = Input(shape=LOW_RESOLUTION_SHAPE)

fake_image_high_resolution = generator_i(image_low_resolution) #低分 图像经过G后 生成高分图像
fake_image_high_resolution_feature = vgg(fake_image_high_resolution) #生成的高分图像经过VGG16得到的特征值

discriminator_i.trainable = False
validity = discriminator_i(fake_image_high_resolution) #判别器对生辰高分图像的validity值

combined_model_i = Model(image_low_resolution , [validity , fake_image_high_resolution_feature])

combined_model_i.compile(optimizer=adam , loss=['binary_crossentropy' , 'mse'] , loss_weights=[1e-3 , 1])

In [14]:
#tuple类型相加 相当于cat连接
real_labels = np.ones(shape=(BATCH_SIZE , )+disc_patch) #真实样本label为1
fake_labels = np.zeros(shape=(BATCH_SIZE , )+disc_patch) #假样本label为0

for i in range(1001):
    
    high_resolution_image , low_resolution_image = load_image() #真实的高分图像和低分图像都是来自真实样本
    
    fake_high_resolution_image = generator_i.predict(low_resolution_image) #使用G生成真低分样本的高分样本
    #训练判别器
    real_loss = discriminator_i.train_on_batch(high_resolution_image , real_labels) #使用真实的高分图像 训练 label全1
    fake_loss = discriminator_i.train_on_batch(fake_high_resolution_image , fake_labels) #使用G生成的假的高分图像 训练 label全0 

    loss = np.add(real_loss , fake_loss)/2

    #训练生成器
    high_resolution_image , low_resolution_image = load_image() #真实的高分图像和低分图像都是来自真实样本
    
    feature_high_resolution_image = vgg.predict(high_resolution_image)
    
    generator_loss = combined_model_i.train_on_batch(low_resolution_image , [real_labels , feature_high_resolution_image])

    print('epoch:%d loss:%f accu:%f gene_loss[x_entropy]:%f gene_loss[mse]:%f' % (i , loss[0] , loss[1] , generator_loss[0] , generator_loss[1]))

    if i % 50 == 0:
        write_image(i+1000)
    #write_image_mnist(i)
    
write_image(999)
#write_image_mnist(999)


  if issubdtype(ts, int):
  elif issubdtype(type(size), float):
  'Discrepancy between trainable weights and collected trainable'


epoch:0 loss:0.289170 accu:0.364258 gene_loss[x_entropy]:35.002327 gene_loss[mse]:0.726299
epoch:1 loss:0.240119 accu:0.560547 gene_loss[x_entropy]:37.199436 gene_loss[mse]:0.993910
epoch:2 loss:0.193222 accu:0.729980 gene_loss[x_entropy]:28.870689 gene_loss[mse]:1.238815
epoch:3 loss:0.188500 accu:0.707031 gene_loss[x_entropy]:21.937769 gene_loss[mse]:1.285119
epoch:4 loss:0.238619 accu:0.590820 gene_loss[x_entropy]:19.904377 gene_loss[mse]:1.846873
epoch:5 loss:0.178662 accu:0.746582 gene_loss[x_entropy]:16.685862 gene_loss[mse]:2.045166
epoch:6 loss:0.182008 accu:0.762695 gene_loss[x_entropy]:19.846888 gene_loss[mse]:1.548349
epoch:7 loss:0.082215 accu:0.955566 gene_loss[x_entropy]:15.399577 gene_loss[mse]:1.685131
epoch:8 loss:0.072267 accu:0.963379 gene_loss[x_entropy]:15.132298 gene_loss[mse]:2.216459
epoch:9 loss:0.044728 accu:0.997559 gene_loss[x_entropy]:16.553905 gene_loss[mse]:2.506667
epoch:10 loss:0.034918 accu:0.996094 gene_loss[x_entropy]:12.300821 gene_loss[mse]:2.55914

epoch:90 loss:0.000130 accu:1.000000 gene_loss[x_entropy]:9.448153 gene_loss[mse]:4.312959
epoch:91 loss:0.000223 accu:1.000000 gene_loss[x_entropy]:9.804703 gene_loss[mse]:4.449502
epoch:92 loss:0.000217 accu:1.000000 gene_loss[x_entropy]:6.716733 gene_loss[mse]:4.583884
epoch:93 loss:0.000327 accu:1.000000 gene_loss[x_entropy]:8.234782 gene_loss[mse]:4.382513
epoch:94 loss:0.000195 accu:1.000000 gene_loss[x_entropy]:10.530398 gene_loss[mse]:4.486005
epoch:95 loss:0.000257 accu:1.000000 gene_loss[x_entropy]:11.005095 gene_loss[mse]:4.198820
epoch:96 loss:0.000243 accu:1.000000 gene_loss[x_entropy]:10.559433 gene_loss[mse]:4.516956
epoch:97 loss:0.000147 accu:1.000000 gene_loss[x_entropy]:11.293322 gene_loss[mse]:4.495035
epoch:98 loss:0.000199 accu:1.000000 gene_loss[x_entropy]:10.266997 gene_loss[mse]:4.448599
epoch:99 loss:0.000180 accu:1.000000 gene_loss[x_entropy]:11.849442 gene_loss[mse]:4.331164
epoch:100 loss:0.000304 accu:1.000000 gene_loss[x_entropy]:8.402928 gene_loss[mse]:4

KeyboardInterrupt: 

In [70]:
real_labels.shape

(64, 1)

In [None]:
gc.collect()

In [None]:
gc.collect()

In [2]:
VGG19(weights='imagenet')

A local file was found, but it seems to be incomplete or outdated because the auto file hash does not match the original value of cbe5617147190e668d6c5d5026f83318 so we will re-download the data.
Downloading data from https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels.h5


<keras.engine.training.Model at 0x1ddbc5e7a58>