In [1]:
from keras.models import Sequential , Model
from keras.layers import Dense ,  BatchNormalization , Reshape , Input , Flatten
from keras.layers import Conv2D , MaxPool2D , Conv2DTranspose , UpSampling2D , ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU , PReLU
from keras.layers import Activation
from keras.layers import Dropout

from keras.layers import Concatenate

#addin cycleGAN 使用instance-norm
from keras_contrib.layers.normalization import InstanceNormalization

from keras.initializers import truncated_normal , constant , random_normal

from keras.optimizers import Adam , RMSprop

#残差块使用
from keras.layers import Add


  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
import os

import matplotlib as plt
import numpy as np

import gc

from glob import glob

import keras.backend as K

import scipy

%matplotlib inline

In [3]:
WIDTH = 256
HEIGHT = 256
CHANNEL = 3

SHAPE = (WIDTH , HEIGHT , CHANNEL)

BATCH_SIZE = 4 #crazy!!! slow turtle
EPOCHS = 10

PATH = '../dataset/edges2shoes/'

#生成多少个图像 长*宽
ROW = 2 #几行决定显示几个测试样例 显示2个
COL = 3 #3列是因为要显示 原图像 另一个特征空间的图像 还原后的图像

TRAIN_PATH = glob(PATH + 'train/*')

#卷积使用 基卷积核大小
G_filters = 64
D_filters = 64


In [4]:
patch = int(HEIGHT/(2**4)) #16
disc_patch = (patch , patch , 1) #16*16*1

In [5]:
def load_image(batch_size = BATCH_SIZE , training = True):
    #随机在图片库中挑选
        
    images = np.random.choice(TRAIN_PATH , size=batch_size)
    
    edges = []
    shoes = []
    
    for i in range(batch_size):
        image = scipy.misc.imread(images[i] , mode='RGB').astype(np.float)
        
        edge = image[: , :WIDTH , : ]
        shoe = image[: , WIDTH: , : ]
        
        #随机性地对训练样本进行 左右反转
        if training and np.random.random()<0.5:
            edge = np.fliplr(edge)
            shoe = np.fliplr(shoe)
        
        edges.append(edge)
        shoes.append(shoe)
        
    edges = np.array(edges)/127.5 - 1
    shoes = np.array(shoes)/127.5 - 1
    
    return edges , shoes

def write_image(epoch):
    #生成高分图像时 进行对比显示
    edges , shoes = load_image(batch_size=1 , training=False) #1个batch就是两幅图像 一个edge 一个shoe
    
    fake_shoes = generator_apple2orange.predict(edges)
    fake_edges = generator_orange2apple.predict(shoes)
    
    edges_hat = generator_orange2apple.predict(fake_edges)
    shoes_hat = generator_apple2orange.predict(fake_shoes)
    
    edges = edges*0.5+0.5
    shoes = shoes*0.5+0.5
    
    fake_edges = fake_edges*0.5+0.5
    fake_shoes = fake_shoes*0.5+0.5
    
    edges_hat = edges_hat*0.5+0.5
    shoes_hat = shoes_hat*0.5+0.5
    
    fig , axes = plt.pyplot.subplots(ROW , COL)
    count=0
    
    axes[0][0].imshow(edges[0])
    axes[0][0].set_title('edges')
    axes[0][0].axis('off')

    axes[0][1].imshow(fake_shoes[0])
    axes[0][1].set_title('fake_shoes')
    axes[0][1].axis('off')
    
    axes[0][2].imshow(edges_hat[0])
    axes[0][2].set_title('restruct_edges')
    axes[0][2].axis('off')

    axes[1][0].imshow(shoes[0])
    axes[1][0].set_title('shoes')
    axes[1][0].axis('off')

    axes[1][1].imshow(fake_edges[0])
    axes[1][1].set_title('fake_edges')
    axes[1][1].axis('off')
    
    axes[1][2].imshow(shoes_hat[0])
    axes[1][2].set_title('restruct_shoes')
    axes[1][2].axis('off')
      
    fig.savefig('edges2shoes_discogan/No.%d.png' % epoch)
    plt.pyplot.close()


In [6]:
#==============

In [7]:
def conv2d(input_data , output_size , filter_size=4 , instance_norm=True):
    h = Conv2D(output_size , filter_size , strides=(2,2) , padding='same')(input_data)
    h = LeakyReLU(alpha=0.2)(h)
    
    if instance_norm:
        h = InstanceNormalization()(h)
    
    return h


#实现U-Net使用 需要网络的跳连接
def deconv2d(input_data , skip_input , output_size , filter_size=4 , dropout_rate=0.0):
    h = UpSampling2D(size=2)(input_data)
    h = Conv2D(output_size , filter_size , strides=(1,1) , padding='same')(h)
    h = Activation('relu')(h)
    
    if dropout_rate:
        h = Dropout(rate=dropout_rate)(h)
    
    h = InstanceNormalization()(h)
    h =  Concatenate()([h , skip_input]) #跳连接具体实现

    return h
    

In [8]:
#G使用encoder-decoder结构 但是需要引入跳连接 即U-Net
def generator(G_filters , name):
    style = Input(shape=SHAPE) #输入一个风格的图像 生成另一个风格的图像
    
    #encoder
    d1 = conv2d(style , G_filters , instance_norm=False)
    d2 = conv2d(d1 , G_filters*2)
    d3 = conv2d(d2 , G_filters*4)
    d4 = conv2d(d3 , G_filters*8)
    d5 = conv2d(d4 , G_filters*8)
    d6 = conv2d(d5 , G_filters*8)
    d7 = conv2d(d6 , G_filters*8)
 
    #decoder
    
    u1 = deconv2d(d7 , d6 , G_filters*8)
    u2 = deconv2d(u1 , d5 , G_filters*8)
    u3 = deconv2d(u2 , d4 , G_filters*8)
    u4 = deconv2d(u3 , d3 , G_filters*4)
    u5 = deconv2d(u4 , d2 , G_filters*2)
    u6 = deconv2d(u5 , d1 , G_filters)
    
    u7 = UpSampling2D(size=(2,2))(u6)
    other_style = Conv2D(filters=CHANNEL , kernel_size=(4,4) , strides=(1,1) , padding='same' , activation='tanh')(u7) #还原后的图像
    
    return Model(style , other_style , name=name)

In [9]:
def discriminator(D_filters , name):
    style = Input(shape=SHAPE)
    
    h1 = conv2d(style , output_size=D_filters , instance_norm=False)
    h2 = conv2d(h1 , output_size=D_filters*2)
    h3 = conv2d(h2 , output_size=D_filters*4)
    h4 = conv2d(h3 , output_size=D_filters*8)
    
    validity =  Conv2D(1 , kernel_size=(4,4) , strides=(1,1) , padding='same')(h4)
    
    return Model(style , validity , name=name)

In [None]:
adam = Adam(lr = 0.0002 , beta_1=0.5)

discriminator_apple = discriminator(D_filters , name='discriminator_apple')
discriminator_apple.compile(optimizer = adam , loss='mse' , metrics=['accuracy'])
discriminator_orange = discriminator(D_filters , name='discriminator_orange')
discriminator_orange.compile(optimizer = adam , loss='mse' , metrics=['accuracy'])

generator_apple2orange = generator(G_filters , name='generator_apple2orange')
generator_orange2apple = generator(G_filters , name='generator_orange2apple')

apples = Input(shape=SHAPE)
oranges = Input(shape=SHAPE)

fake_oranges = generator_apple2orange(apples)
fake_apples = generator_orange2apple(oranges)

apples_hat = generator_orange2apple(fake_oranges)
oranges_hat = generator_apple2orange(fake_apples)

#freeze D
discriminator_apple.trainable = False
discriminator_orange.trainable = False

validity_apple = discriminator_apple(fake_apples)
validity_orange = discriminator_orange(fake_oranges)

#一共6个输出 最后4个输出希望和原图像一致 这样图像同时具有两个风格
combined = Model([apples , oranges] , [validity_apple , validity_orange , fake_oranges , fake_apples , apples_hat , oranges_hat])
combined.compile(optimizer=adam , loss=['mse' , 'mse' , 'mae' , 'mae' , 'mae' , 'mae'] , loss_weights=[1 ,1,10, 10 , 1, 1]) #cycleGAN的损失权重

In [None]:
#tuple类型相加 相当于cat连接
real_labels = np.ones(shape=(BATCH_SIZE , )+disc_patch) 
fake_labels = np.zeros(shape=(BATCH_SIZE , )+disc_patch)

for i in range(1001):
    edges , shoes = load_image()
    
    fake_shoes = generator_apple2orange.predict(edges)
    fake_edges = generator_orange2apple.predict(shoes)
    #训练判别器
    edges_loss = discriminator_apple.train_on_batch(edges , real_labels)
    fake_edges_loss = discriminator_apple.train_on_batch(fake_edges , fake_labels)
    loss_edges = np.add(edges_loss , fake_edges_loss)/2

    shoes_loss = discriminator_orange.train_on_batch(shoes , real_labels)
    fake_shoes_loss = discriminator_orange.train_on_batch(fake_shoes , fake_labels)
    loss_shoes = np.add(shoes_loss , fake_shoes_loss)/2

    loss = np.add(loss_edges , loss_shoes)/2

    #训练生成器
    generator_loss = combined.train_on_batch([edges , shoes] , [real_labels , real_labels , shoes , edges , edges , shoes])
    
    print('epoch:%d loss:%f accu:%f | mse1:%f :mse2:%f mae1:%f mae2:%f mae3:%f mae4:%f' % (i , loss[0] , loss[1] , generator_loss[0] , generator_loss[1] , generator_loss[2] , generator_loss[3] , generator_loss[4] , generator_loss[5]))

    if i % 50 == 0:
        write_image(i)

write_image(999)


`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
  # Remove the CWD from sys.path while we load stuff.
  'Discrepancy between trainable weights and collected trainable'


epoch:0 loss:26.839453 accu:0.157959 |mse1:127.972374 :mse2:4.310163 mae1:104.165298 mae2:0.629249 mae3:1.151195 mae4:0.844402
epoch:1 loss:53.837662 accu:0.016846 |mse1:144.365829 :mse2:30.638023 mae1:98.452301 mae2:0.801013 mae3:0.552286 mae4:1.003998
epoch:2 loss:31.830841 accu:0.065430 |mse1:195.412537 :mse2:2.052113 mae1:179.529007 mae2:0.913364 mae3:0.346781 mae4:0.344338
epoch:3 loss:9.288846 accu:0.040527 |mse1:28.717392 :mse2:1.487355 mae1:12.682568 mae2:1.144615 mae3:0.183636 mae4:0.128674
epoch:4 loss:10.061327 accu:0.177979 |mse1:15.776257 :mse2:1.153067 mae1:0.133939 mae2:1.126424 mae3:0.203821 mae4:0.071111
epoch:5 loss:0.683095 accu:0.375977 |mse1:18.832788 :mse2:0.860998 mae1:3.972034 mae2:1.180704 mae3:0.095113 mae4:0.060779
epoch:6 loss:0.931264 accu:0.147461 |mse1:14.891116 :mse2:0.257235 mae1:1.413712 mae2:1.120754 mae3:0.082633 mae4:0.064726
epoch:7 loss:0.336615 accu:0.546387 |mse1:13.652836 :mse2:0.338684 mae1:0.596270 mae2:1.101540 mae3:0.055088 mae4:0.049380
ep

epoch:67 loss:0.316215 accu:0.440186 |mse1:7.552331 :mse2:0.786475 mae1:0.822550 mae2:0.485926 mae3:0.053715 mae4:0.053356
epoch:68 loss:0.371071 accu:0.369873 |mse1:7.224988 :mse2:0.708131 mae1:0.854384 mae2:0.469102 mae3:0.045407 mae4:0.045151
epoch:69 loss:0.243915 accu:0.519287 |mse1:9.097054 :mse2:0.772256 mae1:1.006557 mae2:0.613222 mae3:0.051431 mae4:0.050853
epoch:70 loss:0.406780 accu:0.237305 |mse1:6.797404 :mse2:0.751241 mae1:0.764282 mae2:0.437165 mae3:0.042507 mae4:0.042203
epoch:71 loss:0.314644 accu:0.381836 |mse1:6.657200 :mse2:0.761157 mae1:0.962625 mae2:0.396271 mae3:0.051784 mae4:0.051537
epoch:72 loss:0.445719 accu:0.181152 |mse1:6.817579 :mse2:0.767745 mae1:0.972036 mae2:0.420346 mae3:0.040961 mae4:0.040767
epoch:73 loss:0.406635 accu:0.230225 |mse1:7.711053 :mse2:0.791524 mae1:1.111583 mae2:0.487156 mae3:0.040325 mae4:0.039995
epoch:74 loss:0.330644 accu:0.418213 |mse1:8.513199 :mse2:0.747661 mae1:0.848883 mae2:0.575633 mae3:0.052576 mae4:0.052021
epoch:75 loss:0.

epoch:134 loss:0.160127 accu:0.801758 |mse1:7.209192 :mse2:0.959266 mae1:0.864592 mae2:0.427032 mae3:0.061927 mae4:0.061772
epoch:135 loss:0.456590 accu:0.203857 |mse1:6.473062 :mse2:0.805270 mae1:0.622126 mae2:0.422802 mae3:0.035889 mae4:0.035746
epoch:136 loss:0.427980 accu:0.289307 |mse1:5.774258 :mse2:0.820599 mae1:0.584970 mae2:0.366442 mae3:0.030838 mae4:0.030773
epoch:137 loss:0.357896 accu:0.262451 |mse1:5.586280 :mse2:0.832122 mae1:0.708997 mae2:0.337417 mae3:0.030411 mae4:0.030339
epoch:138 loss:0.423350 accu:0.368896 |mse1:5.847128 :mse2:0.906239 mae1:0.532478 mae2:0.364403 mae3:0.036381 mae4:0.036312
epoch:139 loss:0.341536 accu:0.461426 |mse1:6.175719 :mse2:0.953649 mae1:0.646293 mae2:0.358235 mae3:0.057851 mae4:0.057794
epoch:140 loss:0.544469 accu:0.159912 |mse1:5.563390 :mse2:0.623981 mae1:0.652575 mae2:0.358641 mae3:0.031079 mae4:0.030984
epoch:141 loss:0.197705 accu:0.760010 |mse1:6.549042 :mse2:1.011944 mae1:0.776501 mae2:0.390446 mae3:0.041965 mae4:0.041842
epoch:14

epoch:201 loss:0.252432 accu:0.604492 |mse1:5.909969 :mse2:0.914918 mae1:0.738307 mae2:0.355957 mae3:0.031103 mae4:0.031062
epoch:202 loss:0.152767 accu:0.806885 |mse1:6.410722 :mse2:1.025745 mae1:0.652340 mae2:0.395254 mae3:0.034925 mae4:0.034818
epoch:203 loss:0.250755 accu:0.725098 |mse1:6.889791 :mse2:0.955908 mae1:0.795326 mae2:0.414414 mae3:0.051963 mae4:0.051876
epoch:204 loss:0.611001 accu:0.114502 |mse1:6.868327 :mse2:0.740222 mae1:0.527899 mae2:0.475644 mae3:0.033122 mae4:0.033086
epoch:205 loss:0.192312 accu:0.632812 |mse1:6.464798 :mse2:1.034609 mae1:0.688270 mae2:0.380120 mae3:0.049684 mae4:0.049617
epoch:206 loss:0.505937 accu:0.184326 |mse1:5.886548 :mse2:0.831780 mae1:0.487260 mae2:0.382981 mae3:0.032114 mae4:0.032087
epoch:207 loss:0.165440 accu:0.699707 |mse1:6.294809 :mse2:1.007177 mae1:0.689899 mae2:0.373683 mae3:0.043610 mae4:0.043557
epoch:208 loss:0.259340 accu:0.576660 |mse1:6.482959 :mse2:0.977935 mae1:0.579205 mae2:0.396991 mae3:0.050698 mae4:0.050623
epoch:20

In [70]:
real_labels.shape

(64, 1)

In [None]:
gc.collect()

In [None]:
gc.collect()

In [13]:
scipy.misc.imread('../dataset/edges2shoes/train/10000_AB.jpg' , mode='RGB').shape

`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
  """Entry point for launching an IPython kernel.


(256, 512, 3)