In [1]:
from keras.models import Sequential , Model
from keras.layers import Dense ,  BatchNormalization , Reshape , Input , Flatten
from keras.layers import Conv2D , MaxPool2D , Conv2DTranspose , UpSampling2D , ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU , PReLU
from keras.layers import Activation
from keras.layers import Dropout

from keras.layers import Concatenate

#addin cycleGAN 使用instance-norm
from keras_contrib.layers.normalization import InstanceNormalization

from keras.initializers import truncated_normal , constant , random_normal

from keras.optimizers import Adam , RMSprop

from keras.datasets import mnist

from keras.utils import to_categorical

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
import os

import matplotlib as plt
import numpy as np

import gc

from glob import glob

import keras.backend as K

import scipy

%matplotlib inline

In [36]:
WIDTH = 28
HEIGHT = 28
CHANNEL = 1

SHAPE = (WIDTH , HEIGHT , CHANNEL)

BATCH_SIZE = 64
EPOCHS = 10

#生成多少个图像 长*宽
ROW = 2
COL = 2

LATENT_DIM = 72 #infogan中变为72维
CLASSES_NUM = 10 #mnist一共10中类别


In [37]:

(X_train , y_train),(_ , _) = mnist.load_data()
X_train = X_train/127.5-1
X_train = np.expand_dims(X_train , 3)
y_train = y_train.reshape(-1 , 1)
y_train = to_categorical(y_train , num_classes=CLASSES_NUM)


In [38]:

def load_mnist(batch_size = BATCH_SIZE):
    idx = np.random.randint(0 , X_train.shape[0] ,batch_size)
    return X_train[idx] , y_train[idx]
    
def write_image_mnist(epoch):
    
    z = np.random.normal(0 , 1 , size=(ROW * COL , LATENT_DIM-CLASSES_NUM))
    c = np.random.randint(0 , high=CLASSES_NUM , size=(ROW*COL , 1))
    c_one_hot = to_categorical(c , num_classes=CLASSES_NUM)
    
    G_input = np.concatenate((z , c_one_hot) , axis = 1)
    
    imgs = generator_i.predict(G_input)
    
    imgs = imgs*0.5+0.5
    
    fig , axes = plt.pyplot.subplots(ROW , COL)
    
    axes[0][0].imshow(imgs[0,:,:,0] , cmap='gray')
    axes[0][0].set_title('%d' % (c[0]))
    axes[0][0].axis('off')

    axes[0][1].imshow(imgs[1,:,:,0] , cmap='gray')
    axes[0][1].set_title('%d' % (c[1]))
    axes[0][1].axis('off')
    
    axes[1][0].imshow(imgs[2,:,:,0] , cmap='gray')
    axes[1][0].set_title('%d' % (c[2]))
    axes[1][0].axis('off')

    axes[1][1].imshow(imgs[3,:,:,0] , cmap='gray')
    axes[1][1].set_title('%d' % (c[3]))
    axes[1][1].axis('off')

    fig.savefig('mnist_infogan/No.%d.png' % epoch)
    plt.pyplot.close()



In [39]:
def generator():
    z = Input(shape=(LATENT_DIM , )) #输入一个风格的图像 生成另一个风格的图像
    
    model = Sequential()

    model.add(Dense(128 * 7 * 7, activation="relu", input_shape=(LATENT_DIM , )))
    model.add(Reshape((7, 7, 128)))
    model.add(BatchNormalization(momentum=0.8))
    model.add(UpSampling2D())
    model.add(Conv2D(128, kernel_size=3, padding="same"))
    model.add(Activation("relu"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(UpSampling2D())
    model.add(Conv2D(64, kernel_size=3, padding="same"))
    model.add(Activation("relu"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(CHANNEL, kernel_size=3, padding="same"))
    model.add(Activation("tanh"))
    
    img = model(z)
    
    return Model(z , img)

In [40]:
def discriminator():
    img = Input(shape=SHAPE)
    
    model = Sequential()

    model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=SHAPE , padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
    model.add(ZeroPadding2D(padding=((0,1),(0,1))))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(256, kernel_size=3, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(512, kernel_size=3, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Flatten())
    
    feature = model(img)
   
    validity = Dense(1 , activation='sigmoid')(feature)
    
    q_net = Dense(units=128 , activation='relu')(feature)
    labels = Dense(CLASSES_NUM , activation='softmax')(q_net) #加1 因为多出来一个不存在的label 来代替fake的label
        
    return Model(img , validity) , Model(img , labels)

In [45]:
#给Q-NET使用的互信息损失
def mutual_info_loss(c , c_x):
    epsilon = 1e-8
    
    x_entropy = K.mean(-K.sum(c*K.log(c_x+epsilon) , axis=1 ))
    entropy = K.mean(-K.sum(c*K.log(c+epsilon) , axis=1))
    
    return x_entropy + entropy

In [46]:
adam = Adam(lr = 0.0002 , beta_1=0.5)
#构建计算图
D , Q = discriminator()

D.compile(optimizer = adam , loss='binary_crossentropy' , metrics=['accuracy'])
Q.compile(optimizer = adam , loss=mutual_info_loss , metrics=['accuracy'])

generator_i = generator()

z = Input(shape=(LATENT_DIM , ))
img = generator_i(z)

#freeze D
D.trainable = False

validity = D(img)
labels = Q(img)

combined = Model(z , [validity , labels])
combined.compile(optimizer=adam , loss=['binary_crossentropy' , mutual_info_loss])

In [47]:
def train():
    real_labels = np.ones(shape=(BATCH_SIZE , 1))
    fake_labels = np.zeros(shape=(BATCH_SIZE , 1))

    for i in range(1001):
        X_train , y_train = load_mnist()
        
        z = np.random.normal(0 , 1 , (BATCH_SIZE , LATENT_DIM-CLASSES_NUM)) #一定要减去 变为62 因为后面需要和隐藏编码c concatenate 一下
        c = np.random.randint(0 , CLASSES_NUM , size=(BATCH_SIZE , 1)) #隐藏编码c
        c = to_categorical(c , num_classes=CLASSES_NUM)
        
        G_input = np.concatenate((z , c) , axis = 1)

        img = generator_i.predict(G_input)
        #训练判别器
        real_loss = D.train_on_batch(X_train , real_labels  )
        fake_loss = D.train_on_batch(img , fake_labels )
        
        loss = np.add(real_loss , fake_loss)/2

        #训练生成器
        generator_loss = combined.train_on_batch(G_input , [real_labels , c])

        print('epoch:%d loss:%f accu:%f |xentropy_loss:%f mutual_info_loss:%f' % (i , loss[0] , loss[1] , generator_loss[0] , generator_loss[1]))

        if i % 50 == 0:
            write_image_mnist(i)

    write_image_mnist(999)


In [48]:
train()

  'Discrepancy between trainable weights and collected trainable'


epoch:0 loss:1.660878 accu:0.226562 |xentropy_loss:3.648218 mutual_info_loss:0.594777
epoch:1 loss:0.911609 accu:0.562500 |xentropy_loss:4.090487 mutual_info_loss:0.999379
epoch:2 loss:0.618219 accu:0.710938 |xentropy_loss:4.333584 mutual_info_loss:1.589423
epoch:3 loss:0.779388 accu:0.609375 |xentropy_loss:4.422917 mutual_info_loss:1.266415
epoch:4 loss:0.752575 accu:0.554688 |xentropy_loss:4.040728 mutual_info_loss:1.315180
epoch:5 loss:0.683502 accu:0.648438 |xentropy_loss:4.074518 mutual_info_loss:1.417691
epoch:6 loss:0.559876 accu:0.679688 |xentropy_loss:4.184185 mutual_info_loss:1.552115
epoch:7 loss:0.635173 accu:0.664062 |xentropy_loss:4.393694 mutual_info_loss:1.529202
epoch:8 loss:0.700243 accu:0.632812 |xentropy_loss:4.335485 mutual_info_loss:1.660262
epoch:9 loss:0.669257 accu:0.648438 |xentropy_loss:4.362889 mutual_info_loss:1.771099
epoch:10 loss:0.652122 accu:0.687500 |xentropy_loss:4.666224 mutual_info_loss:1.774040
epoch:11 loss:0.629556 accu:0.671875 |xentropy_loss:4

epoch:96 loss:0.402413 accu:0.742188 |xentropy_loss:3.068396 mutual_info_loss:2.302006
epoch:97 loss:0.334315 accu:0.804688 |xentropy_loss:3.343876 mutual_info_loss:2.567426
epoch:98 loss:0.389465 accu:0.812500 |xentropy_loss:2.887975 mutual_info_loss:2.350244
epoch:99 loss:0.316252 accu:0.851562 |xentropy_loss:2.755519 mutual_info_loss:2.194165
epoch:100 loss:0.385331 accu:0.773438 |xentropy_loss:2.763360 mutual_info_loss:2.291650
epoch:101 loss:0.345982 accu:0.804688 |xentropy_loss:2.856683 mutual_info_loss:2.341326
epoch:102 loss:0.452049 accu:0.742188 |xentropy_loss:3.040648 mutual_info_loss:2.456188
epoch:103 loss:0.413295 accu:0.781250 |xentropy_loss:3.192911 mutual_info_loss:2.518758
epoch:104 loss:0.389390 accu:0.789062 |xentropy_loss:2.665974 mutual_info_loss:2.282107
epoch:105 loss:0.421258 accu:0.773438 |xentropy_loss:2.981126 mutual_info_loss:2.289420
epoch:106 loss:0.356698 accu:0.812500 |xentropy_loss:2.577420 mutual_info_loss:2.066567
epoch:107 loss:0.391509 accu:0.77343

epoch:190 loss:0.615601 accu:0.632812 |xentropy_loss:1.643982 mutual_info_loss:1.486883
epoch:191 loss:0.638192 accu:0.671875 |xentropy_loss:1.571407 mutual_info_loss:1.353924
epoch:192 loss:0.578277 accu:0.687500 |xentropy_loss:1.678254 mutual_info_loss:1.477727
epoch:193 loss:0.615297 accu:0.671875 |xentropy_loss:1.693237 mutual_info_loss:1.532899
epoch:194 loss:0.563646 accu:0.656250 |xentropy_loss:1.599572 mutual_info_loss:1.484494
epoch:195 loss:0.638589 accu:0.593750 |xentropy_loss:1.865933 mutual_info_loss:1.475002
epoch:196 loss:0.629383 accu:0.640625 |xentropy_loss:1.738513 mutual_info_loss:1.554832
epoch:197 loss:0.611387 accu:0.695312 |xentropy_loss:1.525856 mutual_info_loss:1.458127
epoch:198 loss:0.713695 accu:0.562500 |xentropy_loss:1.609466 mutual_info_loss:1.452669
epoch:199 loss:0.574359 accu:0.664062 |xentropy_loss:1.475876 mutual_info_loss:1.339059
epoch:200 loss:0.678301 accu:0.593750 |xentropy_loss:1.466098 mutual_info_loss:1.341454
epoch:201 loss:0.644947 accu:0.6

epoch:284 loss:0.701110 accu:0.578125 |xentropy_loss:1.235927 mutual_info_loss:1.154473
epoch:285 loss:0.723403 accu:0.562500 |xentropy_loss:1.527660 mutual_info_loss:1.357096
epoch:286 loss:0.680079 accu:0.617188 |xentropy_loss:1.271445 mutual_info_loss:1.235235
epoch:287 loss:0.620892 accu:0.656250 |xentropy_loss:1.453178 mutual_info_loss:1.350470
epoch:288 loss:0.731081 accu:0.546875 |xentropy_loss:1.184799 mutual_info_loss:1.134002
epoch:289 loss:0.690815 accu:0.593750 |xentropy_loss:1.271586 mutual_info_loss:1.175390
epoch:290 loss:0.667259 accu:0.593750 |xentropy_loss:1.322805 mutual_info_loss:1.263490
epoch:291 loss:0.679416 accu:0.585938 |xentropy_loss:1.216363 mutual_info_loss:1.169250
epoch:292 loss:0.714773 accu:0.601562 |xentropy_loss:1.266116 mutual_info_loss:1.145254
epoch:293 loss:0.732694 accu:0.539062 |xentropy_loss:1.299702 mutual_info_loss:1.172667
epoch:294 loss:0.774685 accu:0.523438 |xentropy_loss:1.262094 mutual_info_loss:1.134813
epoch:295 loss:0.665072 accu:0.6

epoch:378 loss:0.670625 accu:0.617188 |xentropy_loss:1.352646 mutual_info_loss:1.195592
epoch:379 loss:0.667200 accu:0.593750 |xentropy_loss:1.257302 mutual_info_loss:1.107818
epoch:380 loss:0.692977 accu:0.617188 |xentropy_loss:1.112444 mutual_info_loss:1.011591
epoch:381 loss:0.729180 accu:0.570312 |xentropy_loss:1.167603 mutual_info_loss:1.100200
epoch:382 loss:0.713743 accu:0.593750 |xentropy_loss:1.105091 mutual_info_loss:0.997902
epoch:383 loss:0.699748 accu:0.625000 |xentropy_loss:1.154304 mutual_info_loss:1.091867
epoch:384 loss:0.677144 accu:0.640625 |xentropy_loss:1.181804 mutual_info_loss:1.041301
epoch:385 loss:0.754570 accu:0.523438 |xentropy_loss:1.050599 mutual_info_loss:0.944003
epoch:386 loss:0.729730 accu:0.554688 |xentropy_loss:1.036150 mutual_info_loss:0.992433
epoch:387 loss:0.658422 accu:0.625000 |xentropy_loss:1.269982 mutual_info_loss:1.227030
epoch:388 loss:0.695041 accu:0.593750 |xentropy_loss:1.034812 mutual_info_loss:1.010012
epoch:389 loss:0.648960 accu:0.6

epoch:472 loss:0.753387 accu:0.570312 |xentropy_loss:1.104338 mutual_info_loss:1.064758
epoch:473 loss:0.761040 accu:0.593750 |xentropy_loss:0.973422 mutual_info_loss:0.940033
epoch:474 loss:0.743722 accu:0.531250 |xentropy_loss:1.098574 mutual_info_loss:1.008796
epoch:475 loss:0.680211 accu:0.625000 |xentropy_loss:1.123246 mutual_info_loss:1.047757
epoch:476 loss:0.665243 accu:0.609375 |xentropy_loss:1.002207 mutual_info_loss:0.917816
epoch:477 loss:0.708318 accu:0.593750 |xentropy_loss:1.084743 mutual_info_loss:0.952981
epoch:478 loss:0.740846 accu:0.593750 |xentropy_loss:0.896158 mutual_info_loss:0.859506
epoch:479 loss:0.701354 accu:0.593750 |xentropy_loss:1.089541 mutual_info_loss:1.006278
epoch:480 loss:0.729935 accu:0.554688 |xentropy_loss:1.037200 mutual_info_loss:0.973509
epoch:481 loss:0.696560 accu:0.601562 |xentropy_loss:1.122172 mutual_info_loss:1.027231
epoch:482 loss:0.708138 accu:0.562500 |xentropy_loss:0.899121 mutual_info_loss:0.870273
epoch:483 loss:0.636964 accu:0.6

epoch:566 loss:0.729038 accu:0.578125 |xentropy_loss:1.070019 mutual_info_loss:0.999548
epoch:567 loss:0.693342 accu:0.601562 |xentropy_loss:0.998266 mutual_info_loss:0.974237
epoch:568 loss:0.739518 accu:0.554688 |xentropy_loss:0.995203 mutual_info_loss:0.924882
epoch:569 loss:0.743491 accu:0.546875 |xentropy_loss:0.991957 mutual_info_loss:0.906363
epoch:570 loss:0.694587 accu:0.625000 |xentropy_loss:0.921693 mutual_info_loss:0.905785
epoch:571 loss:0.683609 accu:0.609375 |xentropy_loss:0.985155 mutual_info_loss:0.952217
epoch:572 loss:0.746696 accu:0.570312 |xentropy_loss:1.034717 mutual_info_loss:0.983430
epoch:573 loss:0.739974 accu:0.554688 |xentropy_loss:1.002639 mutual_info_loss:0.991410
epoch:574 loss:0.661685 accu:0.617188 |xentropy_loss:1.081796 mutual_info_loss:1.029919
epoch:575 loss:0.656781 accu:0.609375 |xentropy_loss:0.981436 mutual_info_loss:0.909033
epoch:576 loss:0.693345 accu:0.585938 |xentropy_loss:0.883540 mutual_info_loss:0.838072
epoch:577 loss:0.719555 accu:0.5

epoch:660 loss:0.658388 accu:0.601562 |xentropy_loss:0.995916 mutual_info_loss:0.930810
epoch:661 loss:0.697874 accu:0.554688 |xentropy_loss:0.980651 mutual_info_loss:0.884861
epoch:662 loss:0.757969 accu:0.531250 |xentropy_loss:0.885671 mutual_info_loss:0.830092
epoch:663 loss:0.751542 accu:0.554688 |xentropy_loss:0.924314 mutual_info_loss:0.905362
epoch:664 loss:0.736087 accu:0.593750 |xentropy_loss:0.893298 mutual_info_loss:0.854393
epoch:665 loss:0.711016 accu:0.578125 |xentropy_loss:0.913571 mutual_info_loss:0.894847
epoch:666 loss:0.689739 accu:0.625000 |xentropy_loss:0.936537 mutual_info_loss:0.905133
epoch:667 loss:0.780244 accu:0.609375 |xentropy_loss:0.871075 mutual_info_loss:0.844841
epoch:668 loss:0.623875 accu:0.625000 |xentropy_loss:1.061926 mutual_info_loss:0.922297
epoch:669 loss:0.686455 accu:0.601562 |xentropy_loss:0.966220 mutual_info_loss:0.933082
epoch:670 loss:0.722029 accu:0.578125 |xentropy_loss:1.004109 mutual_info_loss:0.965030
epoch:671 loss:0.680398 accu:0.6

epoch:754 loss:0.714056 accu:0.609375 |xentropy_loss:0.998646 mutual_info_loss:0.986858
epoch:755 loss:0.722992 accu:0.625000 |xentropy_loss:1.008946 mutual_info_loss:0.922149
epoch:756 loss:0.702223 accu:0.593750 |xentropy_loss:0.924299 mutual_info_loss:0.889720
epoch:757 loss:0.655980 accu:0.632812 |xentropy_loss:0.929846 mutual_info_loss:0.876173
epoch:758 loss:0.698164 accu:0.578125 |xentropy_loss:0.974219 mutual_info_loss:0.961550
epoch:759 loss:0.700464 accu:0.578125 |xentropy_loss:1.019006 mutual_info_loss:0.854815
epoch:760 loss:0.745188 accu:0.523438 |xentropy_loss:0.929848 mutual_info_loss:0.906682
epoch:761 loss:0.695440 accu:0.617188 |xentropy_loss:0.946333 mutual_info_loss:0.889145
epoch:762 loss:0.684203 accu:0.625000 |xentropy_loss:0.975804 mutual_info_loss:0.927179
epoch:763 loss:0.687724 accu:0.570312 |xentropy_loss:0.942310 mutual_info_loss:0.885810
epoch:764 loss:0.707054 accu:0.609375 |xentropy_loss:1.027888 mutual_info_loss:0.980326
epoch:765 loss:0.646728 accu:0.6

epoch:848 loss:0.678827 accu:0.570312 |xentropy_loss:0.950238 mutual_info_loss:0.919962
epoch:849 loss:0.695940 accu:0.578125 |xentropy_loss:0.909034 mutual_info_loss:0.896389
epoch:850 loss:0.723314 accu:0.554688 |xentropy_loss:0.927911 mutual_info_loss:0.868380
epoch:851 loss:0.746197 accu:0.554688 |xentropy_loss:0.893912 mutual_info_loss:0.837788
epoch:852 loss:0.658165 accu:0.632812 |xentropy_loss:0.916020 mutual_info_loss:0.906103
epoch:853 loss:0.693792 accu:0.578125 |xentropy_loss:0.967958 mutual_info_loss:0.937479
epoch:854 loss:0.711465 accu:0.546875 |xentropy_loss:1.014864 mutual_info_loss:0.964873
epoch:855 loss:0.649125 accu:0.609375 |xentropy_loss:0.962643 mutual_info_loss:0.911488
epoch:856 loss:0.729599 accu:0.546875 |xentropy_loss:0.890463 mutual_info_loss:0.878843
epoch:857 loss:0.667247 accu:0.609375 |xentropy_loss:0.969900 mutual_info_loss:0.940785
epoch:858 loss:0.780485 accu:0.515625 |xentropy_loss:0.935756 mutual_info_loss:0.865656
epoch:859 loss:0.691624 accu:0.5

epoch:943 loss:0.721342 accu:0.562500 |xentropy_loss:0.965723 mutual_info_loss:0.909457
epoch:944 loss:0.703303 accu:0.593750 |xentropy_loss:0.965702 mutual_info_loss:0.933349
epoch:945 loss:0.701156 accu:0.609375 |xentropy_loss:1.089261 mutual_info_loss:0.986134
epoch:946 loss:0.700064 accu:0.585938 |xentropy_loss:0.970364 mutual_info_loss:0.877660
epoch:947 loss:0.686625 accu:0.609375 |xentropy_loss:0.979049 mutual_info_loss:0.940773
epoch:948 loss:0.714077 accu:0.554688 |xentropy_loss:0.923266 mutual_info_loss:0.905026
epoch:949 loss:0.694335 accu:0.570312 |xentropy_loss:0.999085 mutual_info_loss:0.932918
epoch:950 loss:0.731926 accu:0.570312 |xentropy_loss:1.035846 mutual_info_loss:0.965930
epoch:951 loss:0.688354 accu:0.593750 |xentropy_loss:0.957732 mutual_info_loss:0.905726
epoch:952 loss:0.694808 accu:0.601562 |xentropy_loss:0.972502 mutual_info_loss:0.942115
epoch:953 loss:0.706954 accu:0.578125 |xentropy_loss:0.858912 mutual_info_loss:0.849679
epoch:954 loss:0.660592 accu:0.6

In [23]:
X_A.shape

(30000, 784)

In [70]:
real_labels.shape

(64, 1)

In [None]:
gc.collect()

In [None]:
gc.collect()