In [1]:
from keras.models import Sequential , Model
from keras.layers import Dense ,  BatchNormalization , Reshape , Input , Flatten
from keras.layers import Conv2D , MaxPool2D , Conv2DTranspose , UpSampling2D , ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU , PReLU
from keras.layers import Activation
from keras.layers import Dropout

from keras.layers import Concatenate

#addin cycleGAN 使用instance-norm
from keras_contrib.layers.normalization import InstanceNormalization

from keras.initializers import truncated_normal , constant , random_normal

from keras.optimizers import Adam , RMSprop

from keras.datasets import mnist

from keras.utils import to_categorical

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
import os

import matplotlib as plt
import numpy as np

import gc

from glob import glob

import keras.backend as K

import scipy

%matplotlib inline

In [36]:
WIDTH = 28
HEIGHT = 28
CHANNEL = 1

SHAPE = (WIDTH , HEIGHT , CHANNEL)

BATCH_SIZE = 64
EPOCHS = 10

#生成多少个图像 长*宽
ROW = 2
COL = 2

LATENT_DIM = 72 #infogan中变为72维
CLASSES_NUM = 10 #mnist一共10中类别


In [37]:

(X_train , y_train),(_ , _) = mnist.load_data()
X_train = X_train/127.5-1
X_train = np.expand_dims(X_train , 3)
y_train = y_train.reshape(-1 , 1)
y_train = to_categorical(y_train , num_classes=CLASSES_NUM)


In [38]:

def load_mnist(batch_size = BATCH_SIZE):
    idx = np.random.randint(0 , X_train.shape[0] ,batch_size)
    return X_train[idx] , y_train[idx]
    
def write_image_mnist(epoch):
    
    z = np.random.normal(0 , 1 , size=(ROW * COL , LATENT_DIM-CLASSES_NUM))
    c = np.random.randint(0 , high=CLASSES_NUM , size=(ROW*COL , 1))
    c_one_hot = to_categorical(c , num_classes=CLASSES_NUM)
    
    G_input = np.concatenate((z , c_one_hot) , axis = 1)
    
    imgs = generator_i.predict(G_input)
    
    imgs = imgs*0.5+0.5
    
    fig , axes = plt.pyplot.subplots(ROW , COL)
    
    axes[0][0].imshow(imgs[0,:,:,0] , cmap='gray')
    axes[0][0].set_title('%d' % (c[0]))
    axes[0][0].axis('off')

    axes[0][1].imshow(imgs[1,:,:,0] , cmap='gray')
    axes[0][1].set_title('%d' % (c[1]))
    axes[0][1].axis('off')
    
    axes[1][0].imshow(imgs[2,:,:,0] , cmap='gray')
    axes[1][0].set_title('%d' % (c[2]))
    axes[1][0].axis('off')

    axes[1][1].imshow(imgs[3,:,:,0] , cmap='gray')
    axes[1][1].set_title('%d' % (c[3]))
    axes[1][1].axis('off')

    fig.savefig('mnist_infogan/No.%d.png' % epoch)
    plt.pyplot.close()



In [39]:
def generator():
    z = Input(shape=(LATENT_DIM , )) #输入一个风格的图像 生成另一个风格的图像
    
    model = Sequential()

    model.add(Dense(128 * 7 * 7, activation="relu", input_shape=(LATENT_DIM , )))
    model.add(Reshape((7, 7, 128)))
    model.add(BatchNormalization(momentum=0.8))
    model.add(UpSampling2D())
    model.add(Conv2D(128, kernel_size=3, padding="same"))
    model.add(Activation("relu"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(UpSampling2D())
    model.add(Conv2D(64, kernel_size=3, padding="same"))
    model.add(Activation("relu"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(CHANNEL, kernel_size=3, padding="same"))
    model.add(Activation("tanh"))
    
    img = model(z)
    
    return Model(z , img)

In [40]:
def discriminator():
    img = Input(shape=SHAPE)
    
    model = Sequential()

    model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=SHAPE , padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
    model.add(ZeroPadding2D(padding=((0,1),(0,1))))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(256, kernel_size=3, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(512, kernel_size=3, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Flatten())
    
    feature = model(img)
   
    validity = Dense(1 , activation='sigmoid')(feature)
    
    q_net = Dense(units=128 , activation='relu')(feature)
    labels = Dense(CLASSES_NUM , activation='softmax')(q_net) #加1 因为多出来一个不存在的label 来代替fake的label
        
    return Model(img , validity) , Model(img , labels)

In [41]:
#给Q-NET使用的互信息损失
def mutual_info_loss(c , c_x):
    epsilon = 1e-8
    
    x_entropy = K.mean(-K.sum(c*K.log(c_x+epsilon)))
    entropy = K.mean(-K.sum(c*K.log(c+epsilon)))
    
    return x_entropy + entropy

In [42]:
adam = Adam(lr = 0.0002 , beta_1=0.5)
#构建计算图
D , Q = discriminator()

D.compile(optimizer = adam , loss='binary_crossentropy' , metrics=['accuracy'])
Q.compile(optimizer = adam , loss=mutual_info_loss , metrics=['accuracy'])

generator_i = generator()

z = Input(shape=(LATENT_DIM , ))
img = generator_i(z)

#freeze D
D.trainable = False

validity = D(img)
labels = Q(img)

combined = Model(z , [validity , labels])
combined.compile(optimizer=adam , loss=['binary_crossentropy' , mutual_info_loss])

In [43]:
def train():
    real_labels = np.ones(shape=(BATCH_SIZE , 1))
    fake_labels = np.zeros(shape=(BATCH_SIZE , 1))

    for i in range(1001):
        X_train , y_train = load_mnist()
        
        z = np.random.normal(0 , 1 , (BATCH_SIZE , LATENT_DIM-CLASSES_NUM)) #一定要减去 变为62 因为后面需要和隐藏编码c concatenate 一下
        c = np.random.randint(0 , CLASSES_NUM , size=(BATCH_SIZE , 1)) #隐藏编码c
        c = to_categorical(c , num_classes=CLASSES_NUM)
        
        G_input = np.concatenate((z , c) , axis = 1)

        img = generator_i.predict(G_input)
        #训练判别器
        real_loss = D.train_on_batch(X_train , real_labels  )
        fake_loss = D.train_on_batch(img , fake_labels )
        
        loss = np.add(real_loss , fake_loss)/2

        #训练生成器
        generator_loss = combined.train_on_batch(G_input , [real_labels , c])

        print('epoch:%d loss:%f accu:%f |xentropy_loss:%f mutual_info_loss:%f' % (i , loss[0] , loss[1] , generator_loss[0] , generator_loss[1]))

        if i % 50 == 0:
            write_image_mnist(i)

    write_image_mnist(999)


In [None]:
train()

  'Discrepancy between trainable weights and collected trainable'


epoch:0 loss:1.453244 accu:0.312500 |xentropy_loss:194.964401 mutual_info_loss:0.702166
epoch:1 loss:0.697662 accu:0.648438 |xentropy_loss:185.331161 mutual_info_loss:1.461353
epoch:2 loss:0.300386 accu:0.875000 |xentropy_loss:200.293365 mutual_info_loss:2.542433
epoch:3 loss:0.193046 accu:0.953125 |xentropy_loss:178.210068 mutual_info_loss:3.078330
epoch:4 loss:0.166678 accu:0.945312 |xentropy_loss:185.508011 mutual_info_loss:3.348593
epoch:5 loss:0.147102 accu:0.937500 |xentropy_loss:171.259491 mutual_info_loss:3.639608
epoch:6 loss:0.115704 accu:0.984375 |xentropy_loss:179.104614 mutual_info_loss:3.811452
epoch:7 loss:0.072813 accu:0.984375 |xentropy_loss:169.574921 mutual_info_loss:4.354710
epoch:8 loss:0.049793 accu:0.984375 |xentropy_loss:171.637314 mutual_info_loss:4.615820
epoch:9 loss:0.033696 accu:0.992188 |xentropy_loss:182.234589 mutual_info_loss:4.557958
epoch:10 loss:0.035176 accu:1.000000 |xentropy_loss:179.604004 mutual_info_loss:4.936716
epoch:11 loss:0.029306 accu:0.9

epoch:93 loss:0.001996 accu:1.000000 |xentropy_loss:9.389690 mutual_info_loss:6.039533
epoch:94 loss:0.005685 accu:1.000000 |xentropy_loss:12.092136 mutual_info_loss:6.032244
epoch:95 loss:0.004194 accu:1.000000 |xentropy_loss:14.108487 mutual_info_loss:6.042083
epoch:96 loss:0.003649 accu:1.000000 |xentropy_loss:21.581364 mutual_info_loss:6.539426
epoch:97 loss:0.003460 accu:1.000000 |xentropy_loss:13.902529 mutual_info_loss:6.148232
epoch:98 loss:0.003184 accu:1.000000 |xentropy_loss:13.099068 mutual_info_loss:5.838982
epoch:99 loss:0.003482 accu:1.000000 |xentropy_loss:11.685228 mutual_info_loss:6.256062
epoch:100 loss:0.002684 accu:1.000000 |xentropy_loss:10.304833 mutual_info_loss:6.126686
epoch:101 loss:0.005039 accu:1.000000 |xentropy_loss:7.844625 mutual_info_loss:5.861578
epoch:102 loss:0.008356 accu:1.000000 |xentropy_loss:10.240515 mutual_info_loss:5.744391
epoch:103 loss:0.002913 accu:1.000000 |xentropy_loss:9.498930 mutual_info_loss:6.386139
epoch:104 loss:0.004789 accu:1.

epoch:187 loss:0.003013 accu:1.000000 |xentropy_loss:6.514163 mutual_info_loss:6.236753
epoch:188 loss:0.004696 accu:1.000000 |xentropy_loss:6.520231 mutual_info_loss:5.926260
epoch:189 loss:0.003856 accu:1.000000 |xentropy_loss:7.250856 mutual_info_loss:6.320292
epoch:190 loss:0.002694 accu:1.000000 |xentropy_loss:6.920531 mutual_info_loss:6.349882
epoch:191 loss:0.003616 accu:1.000000 |xentropy_loss:6.716771 mutual_info_loss:6.418643
epoch:192 loss:0.003440 accu:1.000000 |xentropy_loss:7.201082 mutual_info_loss:6.300027
epoch:193 loss:0.002581 accu:1.000000 |xentropy_loss:9.876743 mutual_info_loss:6.718013
epoch:194 loss:0.002148 accu:1.000000 |xentropy_loss:7.976922 mutual_info_loss:6.462908
epoch:195 loss:0.002732 accu:1.000000 |xentropy_loss:6.582392 mutual_info_loss:6.319736
epoch:196 loss:0.006378 accu:1.000000 |xentropy_loss:7.772297 mutual_info_loss:6.551520
epoch:197 loss:0.001592 accu:1.000000 |xentropy_loss:8.252666 mutual_info_loss:6.434277
epoch:198 loss:0.005018 accu:1.0

epoch:281 loss:0.003063 accu:1.000000 |xentropy_loss:7.244631 mutual_info_loss:6.727700
epoch:282 loss:0.002263 accu:1.000000 |xentropy_loss:7.145779 mutual_info_loss:6.933306
epoch:283 loss:0.001798 accu:1.000000 |xentropy_loss:7.634421 mutual_info_loss:6.608213
epoch:284 loss:0.003264 accu:1.000000 |xentropy_loss:6.407469 mutual_info_loss:6.352909
epoch:285 loss:0.003887 accu:1.000000 |xentropy_loss:7.330215 mutual_info_loss:7.033831
epoch:286 loss:0.002155 accu:1.000000 |xentropy_loss:7.353854 mutual_info_loss:6.762799
epoch:287 loss:0.002600 accu:1.000000 |xentropy_loss:6.559950 mutual_info_loss:6.050184
epoch:288 loss:0.004561 accu:1.000000 |xentropy_loss:6.758438 mutual_info_loss:6.639702
epoch:289 loss:0.003042 accu:1.000000 |xentropy_loss:8.170339 mutual_info_loss:7.359140
epoch:290 loss:0.006024 accu:1.000000 |xentropy_loss:7.483604 mutual_info_loss:7.126861
epoch:291 loss:0.004298 accu:1.000000 |xentropy_loss:7.953258 mutual_info_loss:7.175683
epoch:292 loss:0.005374 accu:1.0

epoch:375 loss:0.014079 accu:1.000000 |xentropy_loss:6.997089 mutual_info_loss:5.358901
epoch:376 loss:0.007560 accu:1.000000 |xentropy_loss:6.961537 mutual_info_loss:6.012252
epoch:377 loss:0.040190 accu:0.992188 |xentropy_loss:7.550861 mutual_info_loss:5.542336
epoch:378 loss:0.022273 accu:0.992188 |xentropy_loss:7.037471 mutual_info_loss:5.414548
epoch:379 loss:0.012870 accu:1.000000 |xentropy_loss:6.231892 mutual_info_loss:5.731722
epoch:380 loss:0.021339 accu:1.000000 |xentropy_loss:6.068645 mutual_info_loss:5.247360
epoch:381 loss:0.011535 accu:1.000000 |xentropy_loss:7.859032 mutual_info_loss:5.799368
epoch:382 loss:0.018211 accu:1.000000 |xentropy_loss:6.164380 mutual_info_loss:5.067134
epoch:383 loss:0.003776 accu:1.000000 |xentropy_loss:7.155517 mutual_info_loss:5.829760
epoch:384 loss:0.009058 accu:1.000000 |xentropy_loss:6.890091 mutual_info_loss:5.850469
epoch:385 loss:0.011784 accu:1.000000 |xentropy_loss:6.863315 mutual_info_loss:6.091824
epoch:386 loss:0.008341 accu:1.0

epoch:469 loss:0.025783 accu:1.000000 |xentropy_loss:5.494129 mutual_info_loss:4.493676
epoch:470 loss:0.037673 accu:1.000000 |xentropy_loss:7.834383 mutual_info_loss:4.149475
epoch:471 loss:0.019660 accu:1.000000 |xentropy_loss:6.061206 mutual_info_loss:4.288327
epoch:472 loss:0.037362 accu:1.000000 |xentropy_loss:5.602153 mutual_info_loss:4.972959
epoch:473 loss:0.027143 accu:0.992188 |xentropy_loss:6.839282 mutual_info_loss:5.134422
epoch:474 loss:0.018718 accu:1.000000 |xentropy_loss:8.817874 mutual_info_loss:4.627512
epoch:475 loss:0.032421 accu:1.000000 |xentropy_loss:5.545232 mutual_info_loss:4.458145
epoch:476 loss:0.015918 accu:1.000000 |xentropy_loss:7.989838 mutual_info_loss:5.341719
epoch:477 loss:0.014538 accu:1.000000 |xentropy_loss:5.673120 mutual_info_loss:5.198914
epoch:478 loss:0.019503 accu:1.000000 |xentropy_loss:7.531320 mutual_info_loss:4.704385
epoch:479 loss:0.016256 accu:1.000000 |xentropy_loss:4.878246 mutual_info_loss:4.642928
epoch:480 loss:0.037783 accu:1.0

In [23]:
X_A.shape

(30000, 784)

In [70]:
real_labels.shape

(64, 1)

In [None]:
gc.collect()

In [None]:
gc.collect()