In [1]:
#对偶学习的原理

In [13]:
from keras.models import Sequential , Model
from keras.layers import Dense ,  BatchNormalization , Reshape , Input , Flatten
from keras.layers import Conv2D , MaxPool2D , Conv2DTranspose , UpSampling2D , ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU , PReLU
from keras.layers import Activation
from keras.layers import Dropout

from keras.layers import Concatenate

#addin cycleGAN 使用instance-norm
from keras_contrib.layers.normalization import InstanceNormalization

from keras.initializers import truncated_normal , constant , random_normal

from keras.optimizers import Adam , RMSprop

#残差块使用
from keras.layers import Add

from keras.datasets import mnist


In [14]:
import os

import matplotlib as plt
import numpy as np

import gc

from glob import glob

import keras.backend as K

import scipy

%matplotlib inline

In [31]:
WIDTH = 28
HEIGHT = 28
CHANNEL = 1

SHAPE = (WIDTH , HEIGHT , CHANNEL)

BATCH_SIZE = 64
EPOCHS = 10

#生成多少个图像 长*宽
ROW = 2
COL = 2

LATENT_DIM = 100 #噪音维度 100维  noise z


In [28]:

(X_train , _),(_ , _) = mnist.load_data()
X_train = X_train/127.5-1
X_train = np.expand_dims(X_train , 3)

#为了使用CoGAN 对样本简单处理一下 旋转90度
X_A = X_train[: int(X_train.shape[0]/2)]
X_B = scipy.ndimage.interpolation.rotate(X_train[int(X_train.shape[0]/2) : ] , angle=90 , axes=(1,2)) #沿着 1 2维上进行旋转 0维是样本数量 3维上是channel


In [32]:

def load_mnist(X , batch_size = BATCH_SIZE):
    return X[np.random.randint(0, X.shape[0], batch_size)]
    
def write_image_mnist(epoch):
    
    z = np.random.normal(0 , 1 , size=(2 , LATENT_DIM))
    
    style1 = generator_style1.predict(z)
    style2 = generator_style2.predict(z)
    
    style1 = style1*0.5+0.5
    style2 = style2*0.5+0.5
    
    fig , axes = plt.pyplot.subplots(ROW , COL)
    
    axes[0][0].imshow(style1[0,:,:,0] , cmap='gray')
    axes[0][0].set_title('style1')
    axes[0][0].axis('off')

    axes[0][1].imshow(style2[0,:,:,0] , cmap='gray')
    axes[0][1].set_title('style2')
    axes[0][1].axis('off')
    
    axes[1][0].imshow(style1[1,:,:,0] , cmap='gray')
    axes[1][0].set_title('style1')
    axes[1][0].axis('off')

    axes[1][1].imshow(style2[1,:,:,0] , cmap='gray')
    axes[1][1].set_title('style2')
    axes[1][1].axis('off')

    fig.savefig('mnist_cogan/No.%d.png' % epoch)
    plt.pyplot.close()



In [20]:
def generator():
    z = Input(shape=(LATENT_DIM , )) #输入一个风格的图像 生成另一个风格的图像
    
    model = Sequential()
    model.add(Dense(256, input_shape=(LATENT_DIM , )))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    
    feature = model(z)
    
    #G1 G2需要共享前面几层的权重参数
    #G1
    h1 = Dense(1024)(feature)
    h1 = LeakyReLU(alpha=0.2)(h1)
    h1 = BatchNormalization(momentum=0.8)(h1)
    h1 = Dense(units=WIDTH*HEIGHT*CHANNEL , activation='tanh')(h1)
    style1 = Reshape(target_shape=SHAPE)(h1)
    
    #G2
    h2 = Dense(1024)(feature)
    h2 = LeakyReLU(alpha=0.2)(h2)
    h2 = BatchNormalization(momentum=0.8)(h2)
    h2 = Dense(units=WIDTH*HEIGHT*CHANNEL , activation='tanh')(h2)
    style2 = Reshape(target_shape=SHAPE)(h2)
    
    return Model(z , style1) , Model(z , style2)

In [21]:
def discriminator():
    style1 = Input(shape=SHAPE)
    style2 = Input(shape=SHAPE)
    
    model = Sequential()
    model.add(Flatten(input_shape=SHAPE))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    
    hidden_style1 = model(style1)
    hidden_style2 = model(style2)
    
    validity_style1 = Dense(1 , activation='sigmoid')(hidden_style1)
    validity_style2 = Dense(1 , activation='sigmoid')(hidden_style2)
        
    return Model(style1 , validity_style1) , Model(style2 , validity_style2)

In [22]:
adam = Adam(lr = 0.0002 , beta_1=0.5)
#构建计算图
discriminator_style1 , discriminator_style2 = discriminator()
discriminator_style1.compile(optimizer = adam , loss='binary_crossentropy' , metrics=['accuracy'])
discriminator_style2.compile(optimizer = adam , loss='binary_crossentropy' , metrics=['accuracy'])

generator_style1 , generator_style2 = generator()

z = Input(shape=(LATENT_DIM , ))

style1 = generator_style1(z)
style2 = generator_style2(z)

#freeze D
discriminator_style1.trainable = False
discriminator_style2.trainable = False

validity_style1 = discriminator_style1(style1)
validity_style2 = discriminator_style2(style2)

#一共6个输出 最后4个输出希望和原图像一致 这样图像同时具有两个风格
combined = Model(z , [validity_style1 , validity_style2])
combined.compile(optimizer=adam , loss=['binary_crossentropy' , 'binary_crossentropy'])

In [34]:
#tuple类型相加 相当于cat连接
real_labels = np.ones(shape=(BATCH_SIZE , 1)) #真实的
fake_labels = np.zeros(shape=(BATCH_SIZE , 1))

for i in range(1001):
    apples_ = load_mnist(X_A)
    oranges_ = load_mnist(X_B)
    
    z = np.random.normal(0 , 1 , (BATCH_SIZE , LATENT_DIM))

    style1 = generator_style1.predict(z)
    style2 = generator_style2.predict(z)
    #训练判别器
    style1_loss = discriminator_style1.train_on_batch(apples_ , real_labels)
    style1_hat_loss = discriminator_style1.train_on_batch(style1 , fake_labels)
    loss_style1 = np.add(style1_loss , style1_hat_loss)/2

    style2_loss = discriminator_style2.train_on_batch(oranges_ , real_labels)
    style2_hat_loss = discriminator_style2.train_on_batch(style2 , fake_labels)
    loss_style2 = np.add(style2_loss , style2_hat_loss)/2
    
    loss = np.add(loss_style1 , loss_style2)/2

    #训练生成器
    generator_loss = combined.train_on_batch(z , [real_labels , real_labels])
    
    print('epoch:%d loss:%f accu:%f |xentropy_loss:%f xentropy_loss:%f' % (i , loss[0] , loss[1] , generator_loss[0] , generator_loss[1]))

    if i % 50 == 0:
        write_image_mnist(i)

write_image_mnist(999)


  'Discrepancy between trainable weights and collected trainable'


epoch:0 loss:2.854656 accu:0.570312 |xentropy_loss:5.427509 xentropy_loss:0.854947
epoch:1 loss:0.394153 accu:0.773438 |xentropy_loss:11.464317 xentropy_loss:1.460063
epoch:2 loss:1.210803 accu:0.429688 |xentropy_loss:13.263485 xentropy_loss:0.657380
epoch:3 loss:0.359027 accu:0.769531 |xentropy_loss:12.772544 xentropy_loss:0.864516
epoch:4 loss:0.591151 accu:0.718750 |xentropy_loss:14.746601 xentropy_loss:0.837653
epoch:5 loss:0.396062 accu:0.738281 |xentropy_loss:11.800735 xentropy_loss:1.119213
epoch:6 loss:0.613550 accu:0.722656 |xentropy_loss:13.050305 xentropy_loss:1.021489
epoch:7 loss:0.389278 accu:0.722656 |xentropy_loss:13.867621 xentropy_loss:0.839195
epoch:8 loss:0.417359 accu:0.714844 |xentropy_loss:13.116418 xentropy_loss:0.865474
epoch:9 loss:0.399165 accu:0.738281 |xentropy_loss:14.128585 xentropy_loss:1.331273
epoch:10 loss:0.677013 accu:0.457031 |xentropy_loss:12.734560 xentropy_loss:0.731470
epoch:11 loss:0.346600 accu:0.769531 |xentropy_loss:11.441513 xentropy_loss:

epoch:99 loss:0.539661 accu:0.695312 |xentropy_loss:8.112101 xentropy_loss:0.942261
epoch:100 loss:0.650499 accu:0.585938 |xentropy_loss:6.468346 xentropy_loss:0.871358
epoch:101 loss:0.504807 accu:0.703125 |xentropy_loss:7.054933 xentropy_loss:0.839633
epoch:102 loss:0.532960 accu:0.667969 |xentropy_loss:7.104817 xentropy_loss:0.876731
epoch:103 loss:0.673426 accu:0.582031 |xentropy_loss:5.738816 xentropy_loss:0.872911
epoch:104 loss:0.455581 accu:0.785156 |xentropy_loss:6.765134 xentropy_loss:0.918432
epoch:105 loss:0.540138 accu:0.683594 |xentropy_loss:5.928705 xentropy_loss:0.866554
epoch:106 loss:0.478396 accu:0.769531 |xentropy_loss:6.222577 xentropy_loss:0.896312
epoch:107 loss:0.586900 accu:0.609375 |xentropy_loss:7.208040 xentropy_loss:0.955591
epoch:108 loss:0.558025 accu:0.656250 |xentropy_loss:6.477937 xentropy_loss:0.902497
epoch:109 loss:0.502606 accu:0.710938 |xentropy_loss:6.231215 xentropy_loss:0.937293
epoch:110 loss:0.550488 accu:0.640625 |xentropy_loss:7.285352 xent

epoch:201 loss:0.602177 accu:0.656250 |xentropy_loss:3.799437 xentropy_loss:0.848340
epoch:202 loss:0.603841 accu:0.667969 |xentropy_loss:3.974969 xentropy_loss:0.886202
epoch:203 loss:0.625974 accu:0.617188 |xentropy_loss:3.463038 xentropy_loss:0.884564
epoch:204 loss:0.650230 accu:0.582031 |xentropy_loss:3.330942 xentropy_loss:0.866033
epoch:205 loss:0.605380 accu:0.679688 |xentropy_loss:3.989462 xentropy_loss:0.859484
epoch:206 loss:0.647096 accu:0.601562 |xentropy_loss:3.391325 xentropy_loss:0.842733
epoch:207 loss:0.652536 accu:0.617188 |xentropy_loss:3.945213 xentropy_loss:0.881620
epoch:208 loss:0.646757 accu:0.601562 |xentropy_loss:3.537750 xentropy_loss:0.832496
epoch:209 loss:0.629017 accu:0.566406 |xentropy_loss:3.669160 xentropy_loss:0.844782
epoch:210 loss:0.635497 accu:0.589844 |xentropy_loss:3.522827 xentropy_loss:0.846630
epoch:211 loss:0.623074 accu:0.648438 |xentropy_loss:3.799749 xentropy_loss:0.848713
epoch:212 loss:0.611311 accu:0.648438 |xentropy_loss:3.406446 xen

epoch:301 loss:0.586874 accu:0.714844 |xentropy_loss:3.279675 xentropy_loss:0.905170
epoch:302 loss:0.587527 accu:0.703125 |xentropy_loss:2.654926 xentropy_loss:0.874278
epoch:303 loss:0.578339 accu:0.718750 |xentropy_loss:3.327827 xentropy_loss:0.915313
epoch:304 loss:0.582202 accu:0.730469 |xentropy_loss:3.350886 xentropy_loss:0.909573
epoch:305 loss:0.555436 accu:0.726562 |xentropy_loss:2.961510 xentropy_loss:0.882228
epoch:306 loss:0.574924 accu:0.699219 |xentropy_loss:2.972694 xentropy_loss:0.904955
epoch:307 loss:0.564161 accu:0.738281 |xentropy_loss:2.814891 xentropy_loss:0.939816
epoch:308 loss:0.598059 accu:0.683594 |xentropy_loss:2.940311 xentropy_loss:0.922055
epoch:309 loss:0.604842 accu:0.671875 |xentropy_loss:2.814119 xentropy_loss:0.874131
epoch:310 loss:0.621228 accu:0.617188 |xentropy_loss:2.678770 xentropy_loss:0.900101
epoch:311 loss:0.562958 accu:0.781250 |xentropy_loss:3.350358 xentropy_loss:0.928411
epoch:312 loss:0.577173 accu:0.722656 |xentropy_loss:2.585773 xen

epoch:399 loss:0.582968 accu:0.667969 |xentropy_loss:2.016433 xentropy_loss:0.882219
epoch:400 loss:0.635659 accu:0.656250 |xentropy_loss:1.874436 xentropy_loss:0.851595
epoch:401 loss:0.600085 accu:0.660156 |xentropy_loss:2.020578 xentropy_loss:0.842535
epoch:402 loss:0.562514 accu:0.742188 |xentropy_loss:2.279024 xentropy_loss:0.920734
epoch:403 loss:0.642209 accu:0.617188 |xentropy_loss:2.227933 xentropy_loss:0.877688
epoch:404 loss:0.629942 accu:0.648438 |xentropy_loss:2.028617 xentropy_loss:0.836651
epoch:405 loss:0.595766 accu:0.699219 |xentropy_loss:2.042780 xentropy_loss:0.863406
epoch:406 loss:0.591885 accu:0.738281 |xentropy_loss:2.068000 xentropy_loss:0.846417
epoch:407 loss:0.641763 accu:0.597656 |xentropy_loss:1.865825 xentropy_loss:0.874552
epoch:408 loss:0.587526 accu:0.671875 |xentropy_loss:2.021907 xentropy_loss:0.881344
epoch:409 loss:0.596133 accu:0.691406 |xentropy_loss:2.172452 xentropy_loss:0.881213
epoch:410 loss:0.588977 accu:0.714844 |xentropy_loss:2.068536 xen

epoch:500 loss:0.650282 accu:0.601562 |xentropy_loss:1.869868 xentropy_loss:0.888724
epoch:501 loss:0.613286 accu:0.671875 |xentropy_loss:1.941945 xentropy_loss:0.897439
epoch:502 loss:0.639119 accu:0.617188 |xentropy_loss:1.936593 xentropy_loss:0.919075
epoch:503 loss:0.662177 accu:0.601562 |xentropy_loss:1.969491 xentropy_loss:0.923278
epoch:504 loss:0.641444 accu:0.632812 |xentropy_loss:1.958510 xentropy_loss:0.916890
epoch:505 loss:0.643115 accu:0.613281 |xentropy_loss:1.941753 xentropy_loss:0.912803
epoch:506 loss:0.633272 accu:0.628906 |xentropy_loss:1.909359 xentropy_loss:0.895321
epoch:507 loss:0.634574 accu:0.597656 |xentropy_loss:2.021181 xentropy_loss:0.943746
epoch:508 loss:0.654888 accu:0.582031 |xentropy_loss:1.945230 xentropy_loss:0.874258
epoch:509 loss:0.662618 accu:0.593750 |xentropy_loss:1.881788 xentropy_loss:0.862791
epoch:510 loss:0.656065 accu:0.593750 |xentropy_loss:1.732987 xentropy_loss:0.819321
epoch:511 loss:0.631516 accu:0.621094 |xentropy_loss:1.894045 xen

epoch:601 loss:0.617821 accu:0.699219 |xentropy_loss:1.939757 xentropy_loss:0.971938
epoch:602 loss:0.612791 accu:0.640625 |xentropy_loss:1.900482 xentropy_loss:0.917904
epoch:603 loss:0.589714 accu:0.726562 |xentropy_loss:1.935615 xentropy_loss:0.928699
epoch:604 loss:0.616818 accu:0.699219 |xentropy_loss:1.973945 xentropy_loss:0.916024
epoch:605 loss:0.605234 accu:0.687500 |xentropy_loss:1.898456 xentropy_loss:0.927042
epoch:606 loss:0.613474 accu:0.675781 |xentropy_loss:1.869891 xentropy_loss:0.913687
epoch:607 loss:0.624987 accu:0.675781 |xentropy_loss:1.829427 xentropy_loss:0.896885
epoch:608 loss:0.606567 accu:0.679688 |xentropy_loss:1.863596 xentropy_loss:0.904919
epoch:609 loss:0.625841 accu:0.628906 |xentropy_loss:1.795309 xentropy_loss:0.891462
epoch:610 loss:0.609187 accu:0.714844 |xentropy_loss:1.915186 xentropy_loss:0.938681
epoch:611 loss:0.601572 accu:0.675781 |xentropy_loss:1.962052 xentropy_loss:0.900620
epoch:612 loss:0.649596 accu:0.640625 |xentropy_loss:1.885301 xen

epoch:701 loss:0.592435 accu:0.691406 |xentropy_loss:1.950910 xentropy_loss:0.936163
epoch:702 loss:0.592158 accu:0.691406 |xentropy_loss:1.960640 xentropy_loss:0.947918
epoch:703 loss:0.593972 accu:0.707031 |xentropy_loss:1.974727 xentropy_loss:0.956969
epoch:704 loss:0.591915 accu:0.691406 |xentropy_loss:1.918667 xentropy_loss:0.941602
epoch:705 loss:0.621281 accu:0.625000 |xentropy_loss:1.894068 xentropy_loss:0.919626
epoch:706 loss:0.614043 accu:0.660156 |xentropy_loss:1.883939 xentropy_loss:0.919487
epoch:707 loss:0.590053 accu:0.718750 |xentropy_loss:1.892239 xentropy_loss:0.912783
epoch:708 loss:0.610558 accu:0.675781 |xentropy_loss:1.882521 xentropy_loss:0.886904
epoch:709 loss:0.597962 accu:0.718750 |xentropy_loss:1.880113 xentropy_loss:0.900807
epoch:710 loss:0.600597 accu:0.707031 |xentropy_loss:1.938342 xentropy_loss:0.950098
epoch:711 loss:0.631870 accu:0.648438 |xentropy_loss:1.897823 xentropy_loss:0.916560
epoch:712 loss:0.614860 accu:0.675781 |xentropy_loss:1.933518 xen

epoch:801 loss:0.594788 accu:0.675781 |xentropy_loss:1.967402 xentropy_loss:0.916503
epoch:802 loss:0.614471 accu:0.644531 |xentropy_loss:1.949373 xentropy_loss:0.937877
epoch:803 loss:0.601941 accu:0.707031 |xentropy_loss:1.951339 xentropy_loss:0.933161
epoch:804 loss:0.575584 accu:0.742188 |xentropy_loss:2.014297 xentropy_loss:0.939019
epoch:805 loss:0.614293 accu:0.664062 |xentropy_loss:1.990867 xentropy_loss:0.973812
epoch:806 loss:0.593287 accu:0.699219 |xentropy_loss:2.055638 xentropy_loss:0.988244
epoch:807 loss:0.578840 accu:0.753906 |xentropy_loss:2.021537 xentropy_loss:0.961661
epoch:808 loss:0.601241 accu:0.660156 |xentropy_loss:1.964193 xentropy_loss:0.936601
epoch:809 loss:0.615479 accu:0.660156 |xentropy_loss:1.970800 xentropy_loss:0.939056
epoch:810 loss:0.603349 accu:0.679688 |xentropy_loss:1.934354 xentropy_loss:0.932947
epoch:811 loss:0.603811 accu:0.707031 |xentropy_loss:1.985818 xentropy_loss:0.932334
epoch:812 loss:0.588854 accu:0.710938 |xentropy_loss:1.983568 xen

epoch:901 loss:0.581686 accu:0.734375 |xentropy_loss:1.976059 xentropy_loss:0.952673
epoch:902 loss:0.608582 accu:0.687500 |xentropy_loss:1.999272 xentropy_loss:0.966502
epoch:903 loss:0.610407 accu:0.683594 |xentropy_loss:2.024820 xentropy_loss:0.983756
epoch:904 loss:0.583788 accu:0.730469 |xentropy_loss:2.015879 xentropy_loss:0.995236
epoch:905 loss:0.601274 accu:0.707031 |xentropy_loss:1.977930 xentropy_loss:0.971003
epoch:906 loss:0.602616 accu:0.714844 |xentropy_loss:1.966646 xentropy_loss:0.948185
epoch:907 loss:0.579406 accu:0.707031 |xentropy_loss:2.052119 xentropy_loss:0.941313
epoch:908 loss:0.597949 accu:0.730469 |xentropy_loss:1.967826 xentropy_loss:0.917869
epoch:909 loss:0.603791 accu:0.699219 |xentropy_loss:1.895253 xentropy_loss:0.875898
epoch:910 loss:0.596907 accu:0.687500 |xentropy_loss:1.942404 xentropy_loss:0.925287
epoch:911 loss:0.596554 accu:0.671875 |xentropy_loss:2.028066 xentropy_loss:0.982608
epoch:912 loss:0.641980 accu:0.628906 |xentropy_loss:1.982925 xen

In [23]:
X_A.shape

(30000, 784)

In [70]:
real_labels.shape

(64, 1)

In [None]:
gc.collect()

In [None]:
gc.collect()