In [1]:
#reference https://github.com/mokemokechicken/keras_BEGAN
#reference https://github.com/siddharthalodha/BEGAN_KERAS/blob/master/BEGAN_v1.ipynb    

In [47]:
from keras.models import Sequential , Model
from keras.layers import Dense ,  BatchNormalization , Reshape , Input , Flatten
from keras.layers import Conv2D , MaxPool2D , Conv2DTranspose , UpSampling2D , ZeroPadding2D , Convolution2D
from keras.layers.advanced_activations import LeakyReLU , PReLU
from keras.layers import Activation
from keras.layers import Dropout

from keras.initializers import truncated_normal , constant , random_normal

from keras.optimizers import Adam , RMSprop

#残差块使用
from keras.layers import Add

from keras.datasets import mnist

#addin BEGAN
from keras.layers import Lambda #可以自己指定特定层 完成特定的功能 discriminator使用
from keras.layers import Concatenate #将discriminator的两个输出合并使用

In [29]:
import os

import matplotlib as plt
import numpy as np

import gc

from glob import glob

import keras.backend as K

from scipy.misc import imread , imsave , imresize

%matplotlib inline

In [63]:
#BEGAN使用64*64*3 图像
WIDTH = 64
HEIGHT = 64
CHANNEL = 3

SHAPE = (WIDTH , HEIGHT , CHANNEL)

LATENT_DIM = 64 #latent variable z sample from normal distribution

BATCH_SIZE = 32
EPOCHS = 10

    PATH = '../dataset/CelebA/img_align_celeba/'

#生成多少个图像 长*宽
ROW = 2
COL = 2

#addin BEGAN
N_FILTERS = 128

In [64]:
#==============
IMAGES_PATH = glob(PATH+'*')

In [78]:
len(IMAGES_PATH)

202599

In [107]:
def load_image(batch_size = BATCH_SIZE , training = True):
    #随机在图片库中挑选
    image_path = np.random.choice(IMAGES_PATH , size=batch_size)
    
    images = []
    
    for an_image_path in image_path:
        img = imread(an_image_path , mode='RGB').astype(np.float)
        
        #尽管原图像不是指定的大小 下面将强制将图像resize
        img = imresize(img , size=SHAPE) #resize到64*64*3尺寸
        
        #随机性地对训练样本进行 左右反转
        #BEGAN不进行左右反转
        #if training and np.random.random()<0.5:
        #    img = np.fliplr(img)
        
        images.append(img)
        
    images = np.array(images)/255.0
    
    return images

def write_image(epoch):
    #生成高分图像时 进行对比显示
    z = np.random.uniform(-1 , 1 , size=(ROW*COL , LATENT_DIM))
    images = Generator.predict(z)
    
    images = images*0.5+0.5
    
    fig , axes = plt.pyplot.subplots(ROW , COL)
    count=0
    
    axes[0][0].imshow(images[0])
    #axes[0][0].set_title('original high')
    axes[0][0].axis('off')

    axes[0][1].imshow(images[1])
    #axes[0][1].set_title('generated high')
    axes[0][1].axis('off')
    

    axes[1][0].imshow(images[2])
    #axes[1][0].set_title('original high')
    axes[1][0].axis('off')

    axes[1][1].imshow(images[3])
    #axes[1][1].set_title('generated high')
    axes[1][1].axis('off')
    
    fig.savefig('celeba_began/No.%d.png' % epoch)
    plt.pyplot.close()


In [108]:
def load_image_new(index_list):
    images = []
    
    for an_image_path in index_list:
        img = imread(IMAGES_PATH[an_image_path] , mode='RGB').astype(np.float)
        
        #尽管原图像不是指定的大小 下面将强制将图像resize
        img = imresize(img , size=SHAPE) #resize到64*64*3尺寸
        
        #随机性地对训练样本进行 左右反转
        #BEGAN不进行左右反转
        #if training and np.random.random()<0.5:
        #    img = np.fliplr(img)
        
        images.append(img)
        
    images = np.array(images)/255.0
    
    return images

In [66]:
#==============

In [67]:
#def multi_conv2d(x , output_size , strides , n_layer):
#    for _ in range(n_layer):
#        x = Conv2D(output_size , kernel_size=(3,3) , strides=(1,1) , activation='elu' , padding='same')(x) #1步卷积 保证图像尺寸不变
#    
#    return Conv2D(output_size , kernel_size=(3,3) , strides=strides , activation='elu' , padding='same')(x) #最后的这个卷积的stride为2 进行降2倍采样

def multi_conv2d(x, filters, strides=(1, 1), name=None, n_layer=2):
    for i in range(1, n_layer):
        x = Convolution2D(filters, (3, 3), activation="elu", padding="same")(x)

    x = Convolution2D(filters, (3, 3), activation="elu", padding="same", strides=strides)(x)
    return x

#def multi_deconv2d(x , output_size , n_layer , upsample):
#    for _ in range(n_layer):
#        x = Conv2D(output_size , kernel_size=(3,3) , strides=(1,1) , activation='elu' , padding='same')(x)
#    
#    if upsample:
#        x = UpSampling2D()(x)
#
#    return x

def multi_deconv2d(x, filters, upsample=None, name=None, n_layer=2):
    for i in range(1, n_layer+1):
        x = Convolution2D(filters, (3, 3), activation="elu", padding="same")(x)
    if upsample:
        x = UpSampling2D()(x)
    return x

In [68]:
def encoder():
    image = Input(shape=SHAPE)
    
    h = multi_conv2d(image , N_FILTERS , strides=(2,2) , n_layer=2)
    h = multi_conv2d(h , N_FILTERS*2 , strides=(2,2) , n_layer=2)
    h = multi_conv2d(h , N_FILTERS*3 , strides=(2,2) , n_layer=2)
    h = multi_conv2d(h , N_FILTERS*4 , strides=(1,1) , n_layer=2)
    
    h = Flatten()(h)
    
    feature = Dense(units=LATENT_DIM , activation='linear')(h)
    
    return Model(image , feature)
    
def decoder():
    feature = Input(shape=(LATENT_DIM , ))
    
    h = Dense(units=(8*8*N_FILTERS) , activation='linear')(feature)
    h = Reshape(target_shape=(8,8,N_FILTERS))(h)
    
    h = multi_deconv2d(h , N_FILTERS , upsample=True , n_layer=2 )
    h = multi_deconv2d(h , N_FILTERS , upsample=True , n_layer=2 )
    h = multi_deconv2d(h , N_FILTERS , upsample=True , n_layer=2 )
    h = multi_deconv2d(h , N_FILTERS , upsample=False , n_layer=2)
    
    image = Convolution2D(3 , kernel_size=(3,3) , strides=(1,1) , activation='linear' , padding='same')(h)
    
    return Model(feature , image)


In [69]:
def autoencoder():
    image = Input(shape=(SHAPE))
    
    Encoder = encoder()
    Decoder = decoder()
    
    feature = Encoder(image)
    image_hat = Decoder(feature)
    
    return Model(image , image_hat)

In [70]:
def discriminator():
    input_data = Input(shape=(HEIGHT , WIDTH , CHANNEL*2)) #因为discriminator的输入是两幅图像 所以输入的通道翻一倍
    
    real_image = Lambda(lambda x: x[:,:,:,0:3] , output_shape=SHAPE)(input_data)
    generator_image = Lambda(lambda x: x[:,:,:,3:6] , output_shape=SHAPE)(input_data)
    
    Autoencoder = autoencoder()
    
    real_image_hat = Autoencoder(real_image)
    generator_image_hat = Autoencoder(generator_image)

    output_data = Concatenate(axis=-1)([real_image_hat , generator_image_hat])

    return Model(input_data , output_data)

In [71]:
def generator():
    return decoder() #feature to image

In [97]:
class D_loss(object):
    __name__ = 'discriminator_loss'
    
    def __init__(self , init_k_var=0.0 , init_lambda_k = 0.001 , init_gamma=0.5):
        self.lambda_k = init_lambda_k
        self.gamma = init_gamma
        self.k_var = K.variable(init_k_var , dtype=K.floatx())
        
        self.m_global_var = K.variable(0.0 , dtype=K.floatx())
        self.loss_real_x_var = K.variable(0)
        self.loss_gene_x_var = K.variable(0)
        
        self.updates = []

    def __call__(self , y_true , y_pred):
        real_image_hat_true , generator_image_hat_true = y_true[:,:,:,:3] , y_true[:,:,:,3:]
        real_image_hat_pred ,generator_image_hat_pred = y_pred[:,:,:,:3] , y_true[:,:,:,3:]

        #下面的平均在batch维上还没有平均 所以得到的是1维张量
        loss_real_image_hat = K.mean(K.abs(real_image_hat_true-real_image_hat_pred) , axis=[1,2,3])
        loss_generator_image_hat = K.mean(K.abs(generator_image_hat_true-generator_image_hat_pred) , axis=[1,2,3])

        #paper equation
        discriminator_loss = loss_real_image_hat - self.k_var*loss_generator_image_hat

        mean_loss_real_image_hat = K.mean(loss_real_image_hat)
        mean_loss_generator_image_hat = K.mean(loss_generator_image_hat)
        #paper equation
        new_k = self.k_var + self.lambda_k*(self.gamma*mean_loss_real_image_hat-mean_loss_generator_image_hat)
        new_k = K.clip(new_k , 0 , 1)
        self.updates.append(K.update(self.k_var , new_k))
        
        m_global = mean_loss_real_image_hat + K.abs(self.gamma*mean_loss_real_image_hat-mean_loss_generator_image_hat)
        self.updates.append(K.update(self.m_global_var , m_global))
        
        self.updates.append(K.update(self.loss_real_x_var , mean_loss_real_image_hat))
        self.updates.append(K.update(self.loss_gene_x_var , mean_loss_generator_image_hat))
        
        return discriminator_loss

    @property
    def k(self):
        return K.get_value(self.k_var)
    
    @property
    def m_global(self):
        return K.get_value(self.m_global_var)
    
    @property
    def loss_real_x(self):
        return K.get_value(self.loss_real_x_var)
    
    @property
    def loss_gene_x(self):
        return K.get_value(self.loss_gene_x_var)


In [73]:
def build_generator_loss(Autoencoder):
    def generator_loss(y_true , y_pred):
        y_pred_dash = Autoencoder(y_pred)
        
        return K.mean(K.abs(y_pred - y_pred_dash) , axis=[1,2,3])
    
    return generator_loss

In [109]:
Autoencoder = autoencoder()
Generator = generator()
Discrimminator = discriminator()

In [110]:
loss_discriminator = D_loss()
Discrimminator.compile(optimizer=Adam() , loss=loss_discriminator)

In [111]:
Generator.compile(optimizer=Adam() , loss=build_generator_loss(Autoencoder))

In [112]:
lr_decay_step = 0

#last_m_global = np.Inf
#
#test_use = len(IMAGES_PATH)//1000
#
#for ep in range(1 , 1002):
#    np.random.seed(ep*100)
#    
#    zd = np.random.uniform(-1, 1, (test_use , LATENT_DIM) )
#    zg = np.random.uniform(-1, 1, (test_use , LATENT_DIM) )
#
#    index_order = np.arange(test_use)
#    np.random.shuffle(index_order)
#    
#    lr = max(0.0001*(0.9**lr_decay_step) , 0.00001)
#    
#    K.set_value(Generator.optimizer.lr , lr)
#    K.set_value(Discrimminator.optimizer.lr , lr)
#    
#    m_global_history = []
#    
#    batch_len = test_use//BATCH_SIZE
#    
#    for b_idx in range(batch_len):
#        index_list = index_order[b_idx*BATCH_SIZE : (b_idx+1)*BATCH_SIZE]
#
#        # training discriminator
#        in_x1 = load_image_new(index_list)  # (bs, row, col, ch)
#        in_x2 = Generator.predict_on_batch(zd[index_list])
#        
#        in_x = np.concatenate([in_x1, in_x2], axis=-1)  # (bs, row, col, ch*2)
#        loss_discriminator = Discrimminator.train_on_batch(in_x, in_x)
#
#        # training generator
#        in_x1 = zg[index_list]
#        loss_generator = Generator.train_on_batch(in_x1, np.zeros_like(in_x2))  # y_true is meaningless
#
#        print('epoch %d' % b_idx)
#
#    write_image(ep)
#    lr_decay_step += 1
    
    
for i in range(1001):
    
    lr = max(0.0001*(0.9**lr_decay_step) , 0.00001)
    
    K.set_value(Generator.optimizer.lr , lr)
    K.set_value(Discrimminator.optimizer.lr , lr)
    
    #训练discriminator
    input_x1 = load_image() #真实图像
    input_x2 = Generator.predict(np.random.uniform(-1,1, size=(BATCH_SIZE , LATENT_DIM)))
    input_x = np.concatenate((input_x1 , input_x2) , axis=-1)
    
    loss_d = Discrimminator.train_on_batch(input_x , input_x)

    #训练generator
    input_x3 = np.random.uniform(-1,1 , size=(BATCH_SIZE , LATENT_DIM))
    loss_g = Generator.train_on_batch(input_x3 , np.zeros_like(input_x2)) #G向全0映射
    
    print('epoch:%d loss_D:%f loss_G:%f' % (i , loss_d , loss_g))

    if i % 50 == 0:
        write_image(i)
    
    
    lr_decay_step += 1

`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
  
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.
  # This is added back by InteractiveShellApp.init_path()


epoch:0 loss_D:0.418878 loss_G:0.045258
epoch:1 loss_D:0.443488 loss_G:0.037838
epoch:2 loss_D:0.392980 loss_G:0.033248
epoch:3 loss_D:0.300696 loss_G:0.029937
epoch:4 loss_D:0.361292 loss_G:0.027289
epoch:5 loss_D:0.293150 loss_G:0.025801
epoch:6 loss_D:0.233563 loss_G:0.024543
epoch:7 loss_D:0.248045 loss_G:0.023060
epoch:8 loss_D:0.293889 loss_G:0.022275
epoch:9 loss_D:0.273142 loss_G:0.021573
epoch:10 loss_D:0.225354 loss_G:0.020987
epoch:11 loss_D:0.250578 loss_G:0.020224
epoch:12 loss_D:0.241314 loss_G:0.019485
epoch:13 loss_D:0.239344 loss_G:0.019324
epoch:14 loss_D:0.240600 loss_G:0.018877
epoch:15 loss_D:0.232112 loss_G:0.018133
epoch:16 loss_D:0.237394 loss_G:0.018214
epoch:17 loss_D:0.231723 loss_G:0.017812
epoch:18 loss_D:0.228199 loss_G:0.017803
epoch:19 loss_D:0.230269 loss_G:0.017756
epoch:20 loss_D:0.221120 loss_G:0.017484
epoch:21 loss_D:0.227154 loss_G:0.017124
epoch:22 loss_D:0.217813 loss_G:0.017081
epoch:23 loss_D:0.236953 loss_G:0.016824
epoch:24 loss_D:0.233071 l

epoch:198 loss_D:0.131197 loss_G:0.006844
epoch:199 loss_D:0.141513 loss_G:0.006905
epoch:200 loss_D:0.138327 loss_G:0.006848
epoch:201 loss_D:0.131111 loss_G:0.006886
epoch:202 loss_D:0.127584 loss_G:0.006837
epoch:203 loss_D:0.130179 loss_G:0.006900
epoch:204 loss_D:0.128798 loss_G:0.006763
epoch:205 loss_D:0.128042 loss_G:0.006661
epoch:206 loss_D:0.132492 loss_G:0.006888
epoch:207 loss_D:0.137565 loss_G:0.006807
epoch:208 loss_D:0.124384 loss_G:0.006570
epoch:209 loss_D:0.125838 loss_G:0.006660
epoch:210 loss_D:0.130285 loss_G:0.006533
epoch:211 loss_D:0.131025 loss_G:0.006708
epoch:212 loss_D:0.129272 loss_G:0.006647
epoch:213 loss_D:0.138353 loss_G:0.006510
epoch:214 loss_D:0.136552 loss_G:0.006621
epoch:215 loss_D:0.127353 loss_G:0.006479
epoch:216 loss_D:0.123801 loss_G:0.006459
epoch:217 loss_D:0.126151 loss_G:0.006557
epoch:218 loss_D:0.122483 loss_G:0.006371
epoch:219 loss_D:0.130405 loss_G:0.006480
epoch:220 loss_D:0.135506 loss_G:0.006332
epoch:221 loss_D:0.133041 loss_G:0

epoch:394 loss_D:0.112139 loss_G:0.003904
epoch:395 loss_D:0.113133 loss_G:0.003884
epoch:396 loss_D:0.116346 loss_G:0.003895
epoch:397 loss_D:0.126030 loss_G:0.003841
epoch:398 loss_D:0.110545 loss_G:0.003861
epoch:399 loss_D:0.113972 loss_G:0.003791
epoch:400 loss_D:0.111488 loss_G:0.003879
epoch:401 loss_D:0.112030 loss_G:0.003909
epoch:402 loss_D:0.114749 loss_G:0.003832
epoch:403 loss_D:0.114312 loss_G:0.003800
epoch:404 loss_D:0.105869 loss_G:0.003751
epoch:405 loss_D:0.104407 loss_G:0.003699
epoch:406 loss_D:0.109372 loss_G:0.003848
epoch:407 loss_D:0.121311 loss_G:0.003781
epoch:408 loss_D:0.108516 loss_G:0.003812
epoch:409 loss_D:0.110904 loss_G:0.003791
epoch:410 loss_D:0.107545 loss_G:0.003747
epoch:411 loss_D:0.111380 loss_G:0.003776
epoch:412 loss_D:0.105992 loss_G:0.003757
epoch:413 loss_D:0.118349 loss_G:0.003713
epoch:414 loss_D:0.112594 loss_G:0.003743
epoch:415 loss_D:0.108859 loss_G:0.003745
epoch:416 loss_D:0.110121 loss_G:0.003677
epoch:417 loss_D:0.112306 loss_G:0

epoch:590 loss_D:0.108587 loss_G:0.002590
epoch:591 loss_D:0.094692 loss_G:0.002560
epoch:592 loss_D:0.100704 loss_G:0.002582
epoch:593 loss_D:0.100682 loss_G:0.002591
epoch:594 loss_D:0.093396 loss_G:0.002546
epoch:595 loss_D:0.102624 loss_G:0.002555
epoch:596 loss_D:0.104266 loss_G:0.002547
epoch:597 loss_D:0.098732 loss_G:0.002588
epoch:598 loss_D:0.098669 loss_G:0.002536
epoch:599 loss_D:0.102227 loss_G:0.002552
epoch:600 loss_D:0.100574 loss_G:0.002500
epoch:601 loss_D:0.098564 loss_G:0.002526
epoch:602 loss_D:0.106293 loss_G:0.002517
epoch:603 loss_D:0.101007 loss_G:0.002461
epoch:604 loss_D:0.102762 loss_G:0.002500
epoch:605 loss_D:0.095993 loss_G:0.002547
epoch:606 loss_D:0.103725 loss_G:0.002459
epoch:607 loss_D:0.101004 loss_G:0.002518
epoch:608 loss_D:0.098070 loss_G:0.002499
epoch:609 loss_D:0.106334 loss_G:0.002480
epoch:610 loss_D:0.104363 loss_G:0.002477
epoch:611 loss_D:0.095376 loss_G:0.002478
epoch:612 loss_D:0.103375 loss_G:0.002473
epoch:613 loss_D:0.101281 loss_G:0

epoch:786 loss_D:0.102410 loss_G:0.001838
epoch:787 loss_D:0.093282 loss_G:0.001824
epoch:788 loss_D:0.091094 loss_G:0.001843
epoch:789 loss_D:0.090603 loss_G:0.001853
epoch:790 loss_D:0.095700 loss_G:0.001840
epoch:791 loss_D:0.090215 loss_G:0.001881
epoch:792 loss_D:0.094849 loss_G:0.001835
epoch:793 loss_D:0.093798 loss_G:0.001825
epoch:794 loss_D:0.091493 loss_G:0.001847
epoch:795 loss_D:0.092975 loss_G:0.001812
epoch:796 loss_D:0.090450 loss_G:0.001844
epoch:797 loss_D:0.095579 loss_G:0.001851
epoch:798 loss_D:0.097631 loss_G:0.001849
epoch:799 loss_D:0.097484 loss_G:0.001845
epoch:800 loss_D:0.089971 loss_G:0.001786
epoch:801 loss_D:0.095923 loss_G:0.001783
epoch:802 loss_D:0.091206 loss_G:0.001821
epoch:803 loss_D:0.098190 loss_G:0.001814
epoch:804 loss_D:0.086182 loss_G:0.001795
epoch:805 loss_D:0.092913 loss_G:0.001793
epoch:806 loss_D:0.101296 loss_G:0.001821
epoch:807 loss_D:0.094094 loss_G:0.001788
epoch:808 loss_D:0.095503 loss_G:0.001802
epoch:809 loss_D:0.091857 loss_G:0

epoch:982 loss_D:0.088893 loss_G:0.001417
epoch:983 loss_D:0.087907 loss_G:0.001406
epoch:984 loss_D:0.088518 loss_G:0.001402
epoch:985 loss_D:0.089701 loss_G:0.001415
epoch:986 loss_D:0.087453 loss_G:0.001369
epoch:987 loss_D:0.092119 loss_G:0.001409
epoch:988 loss_D:0.096502 loss_G:0.001420
epoch:989 loss_D:0.095783 loss_G:0.001396
epoch:990 loss_D:0.092756 loss_G:0.001401
epoch:991 loss_D:0.093843 loss_G:0.001399
epoch:992 loss_D:0.092321 loss_G:0.001365
epoch:993 loss_D:0.096396 loss_G:0.001414
epoch:994 loss_D:0.091100 loss_G:0.001401
epoch:995 loss_D:0.088427 loss_G:0.001394
epoch:996 loss_D:0.091272 loss_G:0.001361
epoch:997 loss_D:0.090497 loss_G:0.001399
epoch:998 loss_D:0.091405 loss_G:0.001383
epoch:999 loss_D:0.090404 loss_G:0.001396
epoch:1000 loss_D:0.087295 loss_G:0.001386


In [89]:
IMAGES_PATH[index_list[0]]

'../dataset/CelebA/img_align_celeba\\122405.jpg'

In [None]:
gc.collect()

In [None]:
gc.collect()

In [104]:
len(IMAGES_PATH)//1000

202

In [None]:
import google.protobuf