In [1]:
from keras.models import Sequential , Model
from keras.layers import Dense ,  BatchNormalization , Reshape , Input , Flatten
from keras.layers import Conv2D , MaxPool2D , Conv2DTranspose , UpSampling2D , ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU , PReLU
from keras.layers import Activation
from keras.layers import Dropout

from keras.layers import Concatenate

#addin cycleGAN 使用instance-norm
from keras_contrib.layers.normalization import InstanceNormalization

from keras.initializers import truncated_normal , constant , random_normal

from keras.optimizers import Adam , RMSprop

#残差块使用
from keras.layers import Add

from keras.datasets import mnist

#导入存在的模型
from keras.applications import VGG16 , VGG19

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
import os

import matplotlib as plt
import numpy as np

import gc

from glob import glob

import keras.backend as K

import scipy

%matplotlib inline

In [8]:
WIDTH = 256
HEIGHT = 256
CHANNEL = 3

SHAPE = (WIDTH , HEIGHT , CHANNEL)

BATCH_SIZE = 4 #crazy!!! slow turtle
EPOCHS = 10

PATH = '../dataset/vangogh2photo/'

#生成多少个图像 长*宽
ROW = 2 #几行决定显示几个测试样例 显示2个
COL = 3 #3列是因为要显示 原图像 另一个特征空间的图像 还原后的图像

TRAIN_APPLE_PATH = glob(PATH + 'trainA/*')
TRAIN_ORANGE_PATH = glob(PATH + 'trainB/*')
TEST_APPLE_PATH = glob(PATH + 'testA/*')
TEST_ORANGE_PATH = glob(PATH + 'testB/*')

#卷积使用 基卷积核大小
G_filters = 32
D_filters = 64


In [4]:
patch = int(HEIGHT/(2**4)) #16
disc_patch = (patch , patch , 1) #16*16*1



In [5]:
def load_image(batch_size = BATCH_SIZE , training = True):
    #随机在图片库中挑选
    if training:
        APPLE_PATH = TRAIN_APPLE_PATH
        ORANGE_PATH = TRAIN_ORANGE_PATH
    else:
        APPLE_PATH = TEST_APPLE_PATH
        ORANGE_PATH = TEST_ORANGE_PATH
        
    images_apple = np.random.choice(APPLE_PATH , size=batch_size)
    images_orange = np.random.choice(ORANGE_PATH , size=batch_size)
    
    apples = []
    oranges = []
    
    for i in range(batch_size):
        apple = scipy.misc.imread(images_apple[i] , mode='RGB').astype(np.float)
        orange = scipy.misc.imread(images_orange[i] , mode='RGB').astype(np.float)
        
        #随机性地对训练样本进行 左右反转
        if training and np.random.random()<0.5:
            apple = np.fliplr(apple)
            orange = np.fliplr(orange)
        
        apples.append(apple)
        oranges.append(orange)
        
    apples = np.array(apples)/127.5 - 1
    oranges = np.array(oranges)/127.5 - 1
    
    return apples , oranges


def write_image(epoch):
    #生成高分图像时 进行对比显示
    apples , oranges = load_image(batch_size=1 , training=False) #1个batch就是两幅图像 一个苹果的 一个橘子的
    
    fake_apples = generator_apple2orange.predict(apples) #橘子风格的苹果
    fake_oranges = generator_orange2apple.predict(oranges) #苹果风格的橘子
    
    apples_hat = generator_orange2apple.predict(fake_apples) #还原后的苹果
    oranges_hat = generator_apple2orange.predict(fake_oranges) #还原后的橘子
    
    
    apples = apples*0.5+0.5
    oranges = oranges*0.5+0.5
    
    fake_apples = fake_apples*0.5+0.5
    fake_oranges = fake_oranges*0.5+0.5
    
    apples_hat = apples_hat*0.5+0.5
    oranges_hat = oranges_hat*0.5+0.5
    
    fig , axes = plt.pyplot.subplots(ROW , COL)
    count=0
    
    axes[0][0].imshow(apples[0])
    axes[0][0].set_title('apple')
    axes[0][0].axis('off')

    axes[0][1].imshow(fake_apples[0])
    axes[0][1].set_title('apple-orange')
    axes[0][1].axis('off')
    
    axes[0][2].imshow(apples_hat[0])
    axes[0][2].set_title('restruct apple')
    axes[0][2].axis('off')

    axes[1][0].imshow(oranges[0])
    axes[1][0].set_title('orange')
    axes[1][0].axis('off')

    axes[1][1].imshow(fake_oranges[0])
    axes[1][1].set_title('orange-apple')
    axes[1][1].axis('off')
    
    axes[1][2].imshow(oranges_hat[0])
    axes[1][2].set_title('restruct orange')
    axes[1][2].axis('off')
      
    fig.savefig('apple2orange_cyclegan/No.%d.png' % epoch)
    plt.pyplot.close()


In [6]:
#==============

In [10]:
def conv2d(input_data , output_size , filter_size=4 , instance_norm=True):
    h = Conv2D(output_size , filter_size , strides=(2,2) , padding='same')(input_data)
    h = LeakyReLU(alpha=0.2)(h)
    
    if instance_norm:
        h = InstanceNormalization()(h)
    
    return h


#实现U-Net使用 需要网络的跳连接
def deconv2d(input_data , skip_input , output_size , filter_size=4 , dropout_rate=0.0):
    h = UpSampling2D(size=2)(input_data)
    h = Conv2D(output_size , filter_size , strides=(1,1) , padding='same')(h)
    h = Activation('relu')(h)
    
    if dropout_rate:
        h = Dropout(rate=dropout_rate)(h)
    
    h = InstanceNormalization()(h)
    h =  Concatenate()([h , skip_input]) #跳连接具体实现

    return h
    

In [22]:
#G使用encoder-decoder结构 但是需要引入跳连接 即U-Net
def generator(G_filters , name):
    style = Input(shape=SHAPE) #输入一个风格的图像 生成另一个风格的图像
    
    #encoder
    d1 = conv2d(style , G_filters)
    d2 = conv2d(d1 , G_filters*2)
    d3 = conv2d(d2 , G_filters*4)
    d4 = conv2d(d3 , G_filters*8)

    #decoder
    u1 = deconv2d(d4 , d3 , G_filters*4)
    u2 = deconv2d(u1 , d2 , G_filters*2)
    u3 = deconv2d(u2 , d1 , G_filters)
    
    u4 = UpSampling2D(size=(2,2))(u3)
    other_style = Conv2D(filters=CHANNEL , kernel_size=(4,4) , strides=(1,1) , padding='same' , activation='tanh')(u4) #还原后的图像
    
    return Model(style , other_style , name=name)

In [23]:
def discriminator(D_filters , name):
    style = Input(shape=SHAPE) #风格1 的图像
    #style2 = Input(shape=SHAPE) #风格2 的图像
    
    #style = Concatenate()([style1 , style2])
    
    h1 = conv2d(style , output_size=D_filters , instance_norm=False)
    h2 = conv2d(h1 , output_size=D_filters*2)
    h3 = conv2d(h2 , output_size=D_filters*4)
    h4 = conv2d(h3 , output_size=D_filters*8)
    
    validity =  Conv2D(1 , kernel_size=(4,4) , strides=(1,1) , padding='same')(h4)
    
    return Model(style , validity , name=name)

In [25]:
adam = Adam(lr = 0.0002 , beta_1=0.5)

discriminator_apple = discriminator(D_filters , name='discriminator_apple') #判别苹果风格
discriminator_apple.compile(optimizer = adam , loss='mse' , metrics=['accuracy'])
discriminator_orange = discriminator(D_filters , name='discriminator_orange') #判别橘子风格
discriminator_orange.compile(optimizer = adam , loss='mse' , metrics=['accuracy'])


generator_apple2orange = generator(G_filters , name='generator_apple2orange')
generator_orange2apple = generator(G_filters , name='generator_orange2apple')


apples = Input(shape=SHAPE)
oranges = Input(shape=SHAPE)

fake_apples = generator_apple2orange(apples) #使用G来 将苹果变成橘子风格的苹果
fake_oranges = generator_orange2apple(oranges) #使用F来 将橘子变成苹果风格的橘子

apples_hat = generator_orange2apple(fake_apples) #使用F将橘子风格的苹果还原为原苹果
oranges_hat = generator_apple2orange(fake_oranges) #使用G将苹果风格的橘子还原为原橘子

apples_id = generator_orange2apple(apples)
oranges_id = generator_apple2orange(oranges)

#freeze D
discriminator_apple.trainable = False
discriminator_orange.trainable = False

validity_apple = discriminator_apple(fake_oranges) #真苹果 和 苹果风格的橘子 之间的潜在模式相似度
validity_orange = discriminator_orange(fake_apples) #真橘子 和 橘子风格的苹果 之间的潜在模式相似度

#一共6个输出 最后4个输出希望和原图像一致 这样图像同时具有两个风格
combined = Model([apples , oranges] , [validity_apple , validity_orange , apples_hat , oranges_hat , apples_id , oranges_id])
combined.compile(optimizer=adam , loss=['mse' , 'mse' , 'mae' , 'mae' , 'mae' , 'mae'] , loss_weights=[1 ,1,10, 10 , 1, 1])

In [26]:
#tuple类型相加 相当于cat连接
real_labels = np.ones(shape=(BATCH_SIZE , )+disc_patch) 
fake_labels = np.zeros(shape=(BATCH_SIZE , )+disc_patch)


for i in range(1000001):
    apples_ , oranges_ = load_image()
    
    fake_apples_ = generator_apple2orange.predict(apples_) #使用G将苹果变成 橘子风格的苹果
    fake_oranges_ = generator_orange2apple.predict(oranges_) #使用F将橘子变成 苹果风格的橘子
    #训练判别器
    apple_loss = discriminator_apple.train_on_batch(apples_ , real_labels)
    fake_apple_loss = discriminator_apple.train_on_batch(fake_apples_ , fake_labels)
    loss_apple = np.add(apple_loss , fake_apple_loss)/2

    orange_loss = discriminator_orange.train_on_batch(oranges_ , real_labels)
    fake_orange_loss = discriminator_orange.train_on_batch(fake_oranges_ , fake_labels)
    loss_orange = np.add(orange_loss , fake_orange_loss)/2

    loss = np.add(loss_apple , loss_orange)/2

    #训练生成器
    generator_loss = combined.train_on_batch([apples_ , oranges_] , [real_labels , real_labels , apples_ , oranges_ , apples_ , oranges_])
    
    print('epoch:%d loss:%f accu:%f |mse1:%f :mse2:%f mae1:%f mae2:%f mae3:%f mae4:%f' % (i , loss[0] , loss[1] , generator_loss[0] , generator_loss[1] , generator_loss[2] , generator_loss[3] , generator_loss[4] , generator_loss[5]))

    if i % 50 == 0:
        write_image(i)

write_image(999)


`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
  'Discrepancy between trainable weights and collected trainable'


epoch:0 loss:11.553784 accu:0.214111 |mse1:77.071190 :mse2:33.420551 mae1:25.639145 mae2:0.877545 mae3:0.782217 mae4:0.701103
epoch:1 loss:49.221931 accu:0.045166 |mse1:113.575821 :mse2:43.440929 mae1:55.090828 mae2:0.741155 mae3:0.641522 mae4:0.676303
epoch:2 loss:45.915680 accu:0.069580 |mse1:21.033131 :mse2:4.208228 mae1:3.649053 mae2:0.669284 mae3:0.501797 mae4:0.760986
epoch:3 loss:11.965063 accu:0.122070 |mse1:16.649717 :mse2:1.430929 mae1:1.866860 mae2:0.673152 mae3:0.504948 mae4:0.846047
epoch:4 loss:1.906404 accu:0.242676 |mse1:14.612094 :mse2:1.134128 mae1:1.180288 mae2:0.638189 mae3:0.434565 mae4:0.825814
epoch:5 loss:0.774842 accu:0.398682 |mse1:15.909773 :mse2:1.323307 mae1:1.301226 mae2:0.605599 mae3:0.564084 mae4:0.782066
epoch:6 loss:0.342780 accu:0.623291 |mse1:15.112564 :mse2:0.802843 mae1:1.201771 mae2:0.559132 mae3:0.595427 mae4:0.787570
epoch:7 loss:0.362108 accu:0.578369 |mse1:13.687601 :mse2:0.350701 mae1:0.931773 mae2:0.507696 mae3:0.571418 mae4:0.734756
epoch:8

epoch:67 loss:0.237849 accu:0.691895 |mse1:6.539742 :mse2:0.189083 mae1:0.176901 mae2:0.278064 mae3:0.202497 mae4:0.669647
epoch:68 loss:0.335674 accu:0.618896 |mse1:7.359292 :mse2:0.148239 mae1:0.370633 mae2:0.278194 mae3:0.266630 mae4:0.621288
epoch:69 loss:0.251676 accu:0.675049 |mse1:7.531012 :mse2:0.352548 mae1:0.421875 mae2:0.304857 mae3:0.244166 mae4:0.615229
epoch:70 loss:0.266829 accu:0.721436 |mse1:6.844807 :mse2:0.274891 mae1:0.167250 mae2:0.274582 mae3:0.241722 mae4:0.654408
epoch:71 loss:0.243897 accu:0.731934 |mse1:7.361660 :mse2:0.276413 mae1:0.170888 mae2:0.285685 mae3:0.266940 mae4:0.604747
epoch:72 loss:0.276924 accu:0.662354 |mse1:7.509018 :mse2:0.233627 mae1:0.664754 mae2:0.258492 mae3:0.269582 mae4:0.603068
epoch:73 loss:0.292874 accu:0.660889 |mse1:6.931196 :mse2:0.283773 mae1:0.211343 mae2:0.312321 mae3:0.201735 mae4:0.638508
epoch:74 loss:0.412264 accu:0.495117 |mse1:7.421015 :mse2:0.345390 mae1:0.124306 mae2:0.323230 mae3:0.232698 mae4:0.678226
epoch:75 loss:0.

epoch:134 loss:0.196368 accu:0.734131 |mse1:6.208761 :mse2:0.213504 mae1:0.252441 mae2:0.267361 mae3:0.220499 mae4:0.425637
epoch:135 loss:0.424306 accu:0.564941 |mse1:7.118859 :mse2:0.933553 mae1:0.715220 mae2:0.263731 mae3:0.207876 mae4:0.389450
epoch:136 loss:0.456725 accu:0.499023 |mse1:6.137365 :mse2:0.172899 mae1:0.234004 mae2:0.262190 mae3:0.224902 mae4:0.449729
epoch:137 loss:0.277249 accu:0.713135 |mse1:5.995755 :mse2:0.550426 mae1:0.309121 mae2:0.240020 mae3:0.193883 mae4:0.373585
epoch:138 loss:0.243767 accu:0.664551 |mse1:6.071097 :mse2:0.213473 mae1:0.552719 mae2:0.224906 mae3:0.227073 mae4:0.375839
epoch:139 loss:0.326402 accu:0.530762 |mse1:5.314434 :mse2:0.156407 mae1:0.121485 mae2:0.259213 mae3:0.175160 mae4:0.402177
epoch:140 loss:0.258037 accu:0.671875 |mse1:6.319360 :mse2:0.373255 mae1:0.798568 mae2:0.238576 mae3:0.195275 mae4:0.382807
epoch:141 loss:0.300581 accu:0.646240 |mse1:5.986039 :mse2:0.265205 mae1:0.139851 mae2:0.252764 mae3:0.226756 mae4:0.406645
epoch:14

epoch:201 loss:0.270349 accu:0.601562 |mse1:5.947822 :mse2:0.375101 mae1:0.169704 mae2:0.218040 mae3:0.238469 mae4:0.387277
epoch:202 loss:0.340926 accu:0.546143 |mse1:5.927408 :mse2:0.747594 mae1:0.367255 mae2:0.223575 mae3:0.195694 mae4:0.332640
epoch:203 loss:0.167337 accu:0.778564 |mse1:5.674900 :mse2:0.256491 mae1:0.204516 mae2:0.235665 mae3:0.202073 mae4:0.433436
epoch:204 loss:0.479622 accu:0.528320 |mse1:5.680304 :mse2:0.587809 mae1:0.547195 mae2:0.207305 mae3:0.181382 mae4:0.317761
epoch:205 loss:0.241996 accu:0.626465 |mse1:5.871343 :mse2:0.505170 mae1:0.125919 mae2:0.244828 mae3:0.190576 mae4:0.436231
epoch:206 loss:0.206743 accu:0.726807 |mse1:5.468586 :mse2:0.187272 mae1:0.339731 mae2:0.214347 mae3:0.213598 mae4:0.365368
epoch:207 loss:0.285093 accu:0.604492 |mse1:5.789184 :mse2:0.334768 mae1:0.124828 mae2:0.280831 mae3:0.189316 mae4:0.384867
epoch:208 loss:0.301025 accu:0.663818 |mse1:6.196558 :mse2:0.723089 mae1:0.226337 mae2:0.251693 mae3:0.193620 mae4:0.388296
epoch:20

epoch:268 loss:0.164533 accu:0.799805 |mse1:4.249792 :mse2:0.197857 mae1:0.083080 mae2:0.190145 mae3:0.144406 mae4:0.332699
epoch:269 loss:0.126231 accu:0.831787 |mse1:5.014046 :mse2:0.142114 mae1:0.125072 mae2:0.236800 mae3:0.160255 mae4:0.367309
epoch:270 loss:0.090514 accu:0.908203 |mse1:5.247957 :mse2:0.306515 mae1:0.091391 mae2:0.244072 mae3:0.174754 mae4:0.353092
epoch:271 loss:0.081030 accu:0.913330 |mse1:4.152543 :mse2:0.147868 mae1:0.045064 mae2:0.186204 mae3:0.142732 mae4:0.366485
epoch:272 loss:0.175056 accu:0.759277 |mse1:5.005109 :mse2:0.146009 mae1:0.103868 mae2:0.197908 mae3:0.207919 mae4:0.326856
epoch:273 loss:0.231875 accu:0.648193 |mse1:4.427078 :mse2:0.231457 mae1:0.107713 mae2:0.185806 mae3:0.155548 mae4:0.337473
epoch:274 loss:0.338398 accu:0.543213 |mse1:4.930279 :mse2:0.310727 mae1:0.055031 mae2:0.234344 mae3:0.160367 mae4:0.342835
epoch:275 loss:0.223128 accu:0.758301 |mse1:4.005049 :mse2:0.104603 mae1:0.057066 mae2:0.155821 mae3:0.170862 mae4:0.312947
epoch:27

epoch:335 loss:0.078039 accu:0.919434 |mse1:5.206775 :mse2:0.340966 mae1:0.288020 mae2:0.185693 mae3:0.211122 mae4:0.305002
epoch:336 loss:0.180029 accu:0.785889 |mse1:4.398787 :mse2:0.075556 mae1:0.297053 mae2:0.206176 mae3:0.136877 mae4:0.392843
epoch:337 loss:0.106831 accu:0.873779 |mse1:4.808628 :mse2:0.145799 mae1:0.461561 mae2:0.219302 mae3:0.132967 mae4:0.363135
epoch:338 loss:0.258622 accu:0.619873 |mse1:3.926843 :mse2:0.171676 mae1:0.116157 mae2:0.174991 mae3:0.122789 mae4:0.369955
epoch:339 loss:0.139710 accu:0.807617 |mse1:5.213583 :mse2:0.226121 mae1:0.218685 mae2:0.232199 mae3:0.179294 mae4:0.365461
epoch:340 loss:0.159385 accu:0.790283 |mse1:4.980995 :mse2:0.105759 mae1:0.258961 mae2:0.207224 mae3:0.179759 mae4:0.364147
epoch:341 loss:0.177075 accu:0.805908 |mse1:4.681004 :mse2:0.094167 mae1:0.237210 mae2:0.183047 mae3:0.179821 mae4:0.349920
epoch:342 loss:0.127980 accu:0.856934 |mse1:4.265828 :mse2:0.072668 mae1:0.069489 mae2:0.207978 mae3:0.134697 mae4:0.323950
epoch:34

epoch:402 loss:0.240223 accu:0.712891 |mse1:4.516990 :mse2:0.086623 mae1:0.076694 mae2:0.196409 mae3:0.168327 mae4:0.356813
epoch:403 loss:0.088381 accu:0.905273 |mse1:4.340602 :mse2:0.030328 mae1:0.056343 mae2:0.191264 mae3:0.162268 mae4:0.334810
epoch:404 loss:0.134387 accu:0.827637 |mse1:4.240595 :mse2:0.075878 mae1:0.223017 mae2:0.169021 mae3:0.150024 mae4:0.381186
epoch:405 loss:0.204746 accu:0.703613 |mse1:4.087601 :mse2:0.038726 mae1:0.052267 mae2:0.170851 mae3:0.169647 mae4:0.307627
epoch:406 loss:0.094201 accu:0.903320 |mse1:3.860641 :mse2:0.333349 mae1:0.044540 mae2:0.155983 mae3:0.128935 mae4:0.358571
epoch:407 loss:0.137231 accu:0.821777 |mse1:4.928768 :mse2:0.117877 mae1:0.075174 mae2:0.214002 mae3:0.191581 mae4:0.401265
epoch:408 loss:0.077379 accu:0.923096 |mse1:4.219263 :mse2:0.164768 mae1:0.038613 mae2:0.179752 mae3:0.146210 mae4:0.402045
epoch:409 loss:0.112248 accu:0.895752 |mse1:4.614598 :mse2:0.274076 mae1:0.060593 mae2:0.208025 mae3:0.154575 mae4:0.387652
epoch:41

epoch:469 loss:0.101679 accu:0.886230 |mse1:3.444670 :mse2:0.035601 mae1:0.023807 mae2:0.158107 mae3:0.108852 mae4:0.402276
epoch:470 loss:0.069918 accu:0.928955 |mse1:4.155539 :mse2:0.109302 mae1:0.115334 mae2:0.187869 mae3:0.132200 mae4:0.379715
epoch:471 loss:0.063461 accu:0.951660 |mse1:3.445394 :mse2:0.036643 mae1:0.081210 mae2:0.163773 mae3:0.113272 mae4:0.338312
epoch:472 loss:0.067201 accu:0.973633 |mse1:3.533886 :mse2:0.087246 mae1:0.090772 mae2:0.144541 mae3:0.124651 mae4:0.355955
epoch:473 loss:0.083144 accu:0.927490 |mse1:4.140047 :mse2:0.180140 mae1:0.303930 mae2:0.158329 mae3:0.131335 mae4:0.374767
epoch:474 loss:0.182284 accu:0.715820 |mse1:4.128108 :mse2:0.018370 mae1:0.123468 mae2:0.180616 mae3:0.155024 mae4:0.318852
epoch:475 loss:0.173770 accu:0.764893 |mse1:4.904886 :mse2:0.589378 mae1:0.481618 mae2:0.187578 mae3:0.129387 mae4:0.316367
epoch:476 loss:0.181815 accu:0.751709 |mse1:5.236001 :mse2:0.437786 mae1:0.494226 mae2:0.169268 mae3:0.198438 mae4:0.350332


KeyboardInterrupt: 

In [70]:
real_labels.shape

(64, 1)

In [None]:
gc.collect()

In [None]:
gc.collect()