In [1]:
#对偶学习的原理

In [20]:
from keras.models import Sequential , Model
from keras.layers import Dense ,  BatchNormalization , Reshape , Input , Flatten
from keras.layers import Conv2D , MaxPool2D , Conv2DTranspose , UpSampling2D , ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU , PReLU
from keras.layers import Activation
from keras.layers import Dropout

from keras.layers import Concatenate

#addin cycleGAN 使用instance-norm
from keras_contrib.layers.normalization import InstanceNormalization

from keras.initializers import truncated_normal , constant , random_normal

from keras.optimizers import Adam , RMSprop

from keras.datasets import mnist

from keras.utils import to_categorical

In [21]:
import os

import matplotlib as plt
import numpy as np

import gc

from glob import glob

import keras.backend as K

import scipy

%matplotlib inline

In [22]:
WIDTH = 28
HEIGHT = 28
CHANNEL = 1

SHAPE = (WIDTH , HEIGHT , CHANNEL)

BATCH_SIZE = 64
EPOCHS = 10

#生成多少个图像 长*宽
ROW = 2
COL = 2

LATENT_DIM = 100 #噪音维度 100维  noise z
CLASSES_NUM = 10 #mnist一共10中类别


In [23]:

(X_train , y_train),(_ , _) = mnist.load_data()
X_train = X_train/127.5-1
X_train = np.expand_dims(X_train , 3)
y_train = y_train.reshape(-1 , 1)
y_train = to_categorical(y_train , num_classes=CLASSES_NUM + 1) #加1 是因为多出来一个表示不存在的label 代替fake的label


In [24]:

def load_mnist(batch_size = BATCH_SIZE):
    idx = np.random.randint(0 , X_train.shape[0] ,batch_size)
    return X_train[idx] , y_train[idx]
    
def write_image_mnist(epoch):
    
    z = np.random.normal(0 , 1 , size=(ROW * COL , LATENT_DIM))
    
    imgs = generator_i.predict(z)
    
    imgs = imgs*0.5+0.5
    
    fig , axes = plt.pyplot.subplots(ROW , COL)
    
    axes[0][0].imshow(imgs[0,:,:,0] , cmap='gray')
    axes[0][0].axis('off')

    axes[0][1].imshow(imgs[1,:,:,0] , cmap='gray')
    axes[0][1].axis('off')
    
    axes[1][0].imshow(imgs[2,:,:,0] , cmap='gray')
    axes[1][0].axis('off')

    axes[1][1].imshow(imgs[3,:,:,0] , cmap='gray')
    axes[1][1].axis('off')

    fig.savefig('mnist_infogan/No.%d.png' % epoch)
    plt.pyplot.close()



In [25]:
def generator():
    z = Input(shape=(LATENT_DIM , )) #输入一个风格的图像 生成另一个风格的图像
    
    model = Sequential()

    model.add(Dense(128 * 7 * 7, activation="relu", input_shape=(LATENT_DIM , )))
    model.add(Reshape((7, 7, 128)))
    model.add(BatchNormalization(momentum=0.8))
    model.add(UpSampling2D())
    model.add(Conv2D(128, kernel_size=3, padding="same"))
    model.add(Activation("relu"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(UpSampling2D())
    model.add(Conv2D(64, kernel_size=3, padding="same"))
    model.add(Activation("relu"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(1, kernel_size=3, padding="same"))
    model.add(Activation("tanh"))
    
    img = model(z)
    
    return Model(z , img)

In [26]:
def discriminator():
    img = Input(shape=SHAPE)
    
    model = Sequential()

    model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=SHAPE , padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
    model.add(ZeroPadding2D(padding=((0,1),(0,1))))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Flatten())
    
    feature = model(img)
   
    validity = Dense(1 , activation='sigmoid')(feature)
    labels = Dense(CLASSES_NUM+1 , activation='softmax')(feature) #加1 因为多出来一个不存在的label 来代替fake的label
        
    return Model(img , [validity , labels])

In [28]:
adam = Adam(lr = 0.0002 , beta_1=0.5)
#构建计算图
discriminator_i = discriminator()
discriminator_i.compile(optimizer = adam , loss=['binary_crossentropy' , 'categorical_crossentropy'] , metrics=['accuracy'] , loss_weights=[0.5,0.5])

generator_i = generator()

z = Input(shape=(LATENT_DIM , ))

img = generator_i(z)

#freeze D
discriminator_i.trainable = False

validity , labels = discriminator_i(img)

combined = Model(z , validity)
combined.compile(optimizer=adam , loss='binary_crossentropy')

In [29]:
#SGAN add
#防止数据不平衡 引入类别权重
class_weight1 = {0:1 , 1:1}
class_weight2 = {i: CLASSES_NUM/(int(BATCH_SIZE/2)) for i in range(CLASSES_NUM)}
class_weight2[CLASSES_NUM] = 1/int(BATCH_SIZE/2)

In [34]:
def train():
    real_labels = np.ones(shape=(BATCH_SIZE , 1))
    fake_labels = np.zeros(shape=(BATCH_SIZE , 1))

    for i in range(1001):
        X_train , y_train = load_mnist()
        
        z = np.random.normal(0 , 1 , (BATCH_SIZE , LATENT_DIM))
        fake_y_train = to_categorical(np.full(shape=(BATCH_SIZE , 1) , fill_value=CLASSES_NUM) , num_classes=CLASSES_NUM+1) #加1 表示多一个不存在的label fake的label

        img = generator_i.predict(z)
        #训练判别器
        real_loss = discriminator_i.train_on_batch(X_train , [real_labels , y_train] , class_weight=[class_weight1 , class_weight2])
        fake_loss = discriminator_i.train_on_batch(img , [fake_labels , fake_y_train] , class_weight=[class_weight1 , class_weight2])
        loss = np.add(real_loss , fake_loss)/2

        #训练生成器
        generator_loss = combined.train_on_batch(z , real_labels , class_weight=[class_weight1 , class_weight2])

        print('epoch:%d loss:%f loss:%f accu:%f |xentropy_loss:%f' % (i , loss[0] , loss[1] , loss[2] , generator_loss))

        if i % 50 == 0:
            write_image_mnist(i)

    write_image_mnist(999)


In [35]:
train()

  'Discrepancy between trainable weights and collected trainable'


epoch:0 loss:0.514145 loss:0.562888 accu:0.465401 |xentropy_loss:0.951901
epoch:1 loss:0.436431 loss:0.420730 accu:0.452131 |xentropy_loss:1.196074
epoch:2 loss:0.397500 loss:0.365278 accu:0.429723 |xentropy_loss:1.366534
epoch:3 loss:0.373885 loss:0.337564 accu:0.410206 |xentropy_loss:1.456960
epoch:4 loss:0.364652 loss:0.340888 accu:0.388417 |xentropy_loss:1.623733
epoch:5 loss:0.320986 loss:0.272020 accu:0.369953 |xentropy_loss:1.828723
epoch:6 loss:0.329750 loss:0.284810 accu:0.374690 |xentropy_loss:1.709741
epoch:7 loss:0.335962 loss:0.300877 accu:0.371047 |xentropy_loss:1.953167
epoch:8 loss:0.305127 loss:0.241958 accu:0.368296 |xentropy_loss:2.176242
epoch:9 loss:0.282134 loss:0.222279 accu:0.341989 |xentropy_loss:2.319942
epoch:10 loss:0.267649 loss:0.190256 accu:0.345043 |xentropy_loss:2.315417
epoch:11 loss:0.255751 loss:0.162791 accu:0.348711 |xentropy_loss:2.401908
epoch:12 loss:0.240089 loss:0.158151 accu:0.322027 |xentropy_loss:2.453415
epoch:13 loss:0.214539 loss:0.11689

epoch:113 loss:0.536006 loss:0.944445 accu:0.127567 |xentropy_loss:1.622544
epoch:114 loss:0.578331 loss:0.968153 accu:0.188508 |xentropy_loss:1.720972
epoch:115 loss:0.484507 loss:0.838538 accu:0.130476 |xentropy_loss:1.754562
epoch:116 loss:0.557829 loss:0.945978 accu:0.169681 |xentropy_loss:1.506038
epoch:117 loss:0.551632 loss:0.954948 accu:0.148316 |xentropy_loss:1.684857
epoch:118 loss:0.609607 loss:1.079880 accu:0.139334 |xentropy_loss:1.542286
epoch:119 loss:0.452072 loss:0.750799 accu:0.153346 |xentropy_loss:1.613542
epoch:120 loss:0.480333 loss:0.789990 accu:0.170677 |xentropy_loss:1.507253
epoch:121 loss:0.516080 loss:0.859737 accu:0.172422 |xentropy_loss:1.320386
epoch:122 loss:0.505460 loss:0.872749 accu:0.138172 |xentropy_loss:1.629429
epoch:123 loss:0.531001 loss:0.902387 accu:0.159616 |xentropy_loss:1.389093
epoch:124 loss:0.492758 loss:0.845755 accu:0.139762 |xentropy_loss:1.259111
epoch:125 loss:0.512911 loss:0.875440 accu:0.150382 |xentropy_loss:1.476598
epoch:126 lo

epoch:222 loss:0.421907 loss:0.704040 accu:0.139773 |xentropy_loss:1.201444
epoch:223 loss:0.391989 loss:0.667501 accu:0.116477 |xentropy_loss:1.226295
epoch:224 loss:0.370535 loss:0.663542 accu:0.077528 |xentropy_loss:1.218315
epoch:225 loss:0.449385 loss:0.775780 accu:0.122989 |xentropy_loss:1.182901
epoch:226 loss:0.390724 loss:0.669849 accu:0.111600 |xentropy_loss:1.152740
epoch:227 loss:0.433638 loss:0.720702 accu:0.146575 |xentropy_loss:1.112723
epoch:228 loss:0.477405 loss:0.795726 accu:0.159084 |xentropy_loss:1.066716
epoch:229 loss:0.398354 loss:0.669098 accu:0.127610 |xentropy_loss:1.306246
epoch:230 loss:0.472533 loss:0.775429 accu:0.169637 |xentropy_loss:1.077637
epoch:231 loss:0.392681 loss:0.659838 accu:0.125525 |xentropy_loss:1.303366
epoch:232 loss:0.406190 loss:0.706470 accu:0.105909 |xentropy_loss:1.097627
epoch:233 loss:0.414056 loss:0.737101 accu:0.091011 |xentropy_loss:1.053106
epoch:234 loss:0.444064 loss:0.749903 accu:0.138225 |xentropy_loss:1.091857
epoch:235 lo

epoch:330 loss:0.407488 loss:0.713310 accu:0.101665 |xentropy_loss:1.133888
epoch:331 loss:0.384658 loss:0.660709 accu:0.108608 |xentropy_loss:1.159201
epoch:332 loss:0.441507 loss:0.757254 accu:0.125760 |xentropy_loss:1.191174
epoch:333 loss:0.425108 loss:0.699195 accu:0.151021 |xentropy_loss:1.177933
epoch:334 loss:0.410067 loss:0.666556 accu:0.153578 |xentropy_loss:1.038932
epoch:335 loss:0.386036 loss:0.684583 accu:0.087488 |xentropy_loss:1.127638
epoch:336 loss:0.367143 loss:0.633978 accu:0.100307 |xentropy_loss:1.285226
epoch:337 loss:0.398588 loss:0.668666 accu:0.128511 |xentropy_loss:1.176688
epoch:338 loss:0.403927 loss:0.706360 accu:0.101494 |xentropy_loss:1.121494
epoch:339 loss:0.407345 loss:0.660024 accu:0.154666 |xentropy_loss:1.092636
epoch:340 loss:0.404232 loss:0.721442 accu:0.087023 |xentropy_loss:1.243212
epoch:341 loss:0.414047 loss:0.697198 accu:0.130895 |xentropy_loss:1.259846
epoch:342 loss:0.404966 loss:0.678812 accu:0.131120 |xentropy_loss:1.314674
epoch:343 lo

epoch:548 loss:0.363167 loss:0.635729 accu:0.090605 |xentropy_loss:1.318518
epoch:549 loss:0.341799 loss:0.603359 accu:0.080238 |xentropy_loss:1.241890
epoch:550 loss:0.396724 loss:0.687571 accu:0.105878 |xentropy_loss:1.140708
epoch:551 loss:0.350483 loss:0.635429 accu:0.065538 |xentropy_loss:1.189190
epoch:552 loss:0.312779 loss:0.552997 accu:0.072561 |xentropy_loss:1.279585
epoch:553 loss:0.372688 loss:0.620525 accu:0.124852 |xentropy_loss:1.300593
epoch:554 loss:0.330024 loss:0.588843 accu:0.071206 |xentropy_loss:1.326877
epoch:555 loss:0.410619 loss:0.730374 accu:0.090865 |xentropy_loss:1.246990
epoch:556 loss:0.442405 loss:0.742364 accu:0.142446 |xentropy_loss:1.304966
epoch:557 loss:0.341417 loss:0.591409 accu:0.091426 |xentropy_loss:1.205105
epoch:558 loss:0.375248 loss:0.652234 accu:0.098262 |xentropy_loss:1.250738
epoch:559 loss:0.361936 loss:0.634253 accu:0.089620 |xentropy_loss:1.354230
epoch:560 loss:0.343594 loss:0.598300 accu:0.088888 |xentropy_loss:1.182667
epoch:561 lo

epoch:659 loss:0.332720 loss:0.580558 accu:0.084881 |xentropy_loss:1.200113
epoch:660 loss:0.405488 loss:0.701790 accu:0.109186 |xentropy_loss:1.215848
epoch:661 loss:0.338372 loss:0.603979 accu:0.072764 |xentropy_loss:1.203190
epoch:662 loss:0.334386 loss:0.601380 accu:0.067393 |xentropy_loss:1.380892
epoch:663 loss:0.371921 loss:0.631587 accu:0.112254 |xentropy_loss:1.057923
epoch:664 loss:0.392229 loss:0.686354 accu:0.098103 |xentropy_loss:1.143908
epoch:665 loss:0.352034 loss:0.639485 accu:0.064583 |xentropy_loss:1.315231
epoch:666 loss:0.380187 loss:0.681657 accu:0.078717 |xentropy_loss:1.147699
epoch:667 loss:0.348303 loss:0.619458 accu:0.077148 |xentropy_loss:1.397666
epoch:668 loss:0.358101 loss:0.633545 accu:0.082657 |xentropy_loss:1.139467
epoch:669 loss:0.352420 loss:0.626221 accu:0.078620 |xentropy_loss:1.404264
epoch:670 loss:0.354648 loss:0.625974 accu:0.083323 |xentropy_loss:1.142704
epoch:671 loss:0.360157 loss:0.638926 accu:0.081387 |xentropy_loss:1.281743
epoch:672 lo

epoch:768 loss:0.372824 loss:0.663972 accu:0.081676 |xentropy_loss:1.123599
epoch:769 loss:0.335593 loss:0.606800 accu:0.064387 |xentropy_loss:1.098200
epoch:770 loss:0.356203 loss:0.637712 accu:0.074695 |xentropy_loss:1.247082
epoch:771 loss:0.362038 loss:0.651797 accu:0.072280 |xentropy_loss:1.194224
epoch:772 loss:0.407371 loss:0.711098 accu:0.103643 |xentropy_loss:1.195441
epoch:773 loss:0.391104 loss:0.681824 accu:0.100384 |xentropy_loss:1.150303
epoch:774 loss:0.453410 loss:0.771912 accu:0.134908 |xentropy_loss:1.215285
epoch:775 loss:0.363595 loss:0.627175 accu:0.100015 |xentropy_loss:1.292907
epoch:776 loss:0.388305 loss:0.651282 accu:0.125329 |xentropy_loss:1.099112
epoch:777 loss:0.403589 loss:0.715586 accu:0.091591 |xentropy_loss:1.116314
epoch:778 loss:0.419540 loss:0.732031 accu:0.107050 |xentropy_loss:1.008160
epoch:779 loss:0.387546 loss:0.707998 accu:0.067094 |xentropy_loss:1.113276
epoch:780 loss:0.344088 loss:0.609880 accu:0.078297 |xentropy_loss:1.307819
epoch:781 lo

epoch:877 loss:0.378769 loss:0.638469 accu:0.119069 |xentropy_loss:1.072580
epoch:878 loss:0.379896 loss:0.687358 accu:0.072434 |xentropy_loss:1.180006
epoch:879 loss:0.368488 loss:0.669642 accu:0.067335 |xentropy_loss:1.243827
epoch:880 loss:0.380847 loss:0.668494 accu:0.093200 |xentropy_loss:1.097429
epoch:881 loss:0.370599 loss:0.643718 accu:0.097480 |xentropy_loss:1.124681
epoch:882 loss:0.359236 loss:0.630076 accu:0.088397 |xentropy_loss:1.227212
epoch:883 loss:0.390497 loss:0.687993 accu:0.093001 |xentropy_loss:1.080902
epoch:884 loss:0.312599 loss:0.562530 accu:0.062667 |xentropy_loss:1.145080
epoch:885 loss:0.423150 loss:0.713975 accu:0.132324 |xentropy_loss:1.146893
epoch:886 loss:0.374140 loss:0.661954 accu:0.086325 |xentropy_loss:1.018915
epoch:887 loss:0.420322 loss:0.730912 accu:0.109732 |xentropy_loss:0.999476
epoch:888 loss:0.338994 loss:0.622165 accu:0.055823 |xentropy_loss:1.095599
epoch:889 loss:0.370565 loss:0.664795 accu:0.076334 |xentropy_loss:1.022239
epoch:890 lo

epoch:988 loss:0.427491 loss:0.769926 accu:0.085056 |xentropy_loss:0.987312
epoch:989 loss:0.395098 loss:0.708172 accu:0.082024 |xentropy_loss:1.168592
epoch:990 loss:0.383110 loss:0.698438 accu:0.067782 |xentropy_loss:1.050951
epoch:991 loss:0.385164 loss:0.708899 accu:0.061430 |xentropy_loss:1.138181
epoch:992 loss:0.375005 loss:0.681108 accu:0.068902 |xentropy_loss:1.076334
epoch:993 loss:0.346258 loss:0.620849 accu:0.071668 |xentropy_loss:1.194771
epoch:994 loss:0.360784 loss:0.626177 accu:0.095391 |xentropy_loss:1.104650
epoch:995 loss:0.341931 loss:0.614425 accu:0.069437 |xentropy_loss:1.302683
epoch:996 loss:0.378438 loss:0.673793 accu:0.083084 |xentropy_loss:1.216136
epoch:997 loss:0.408247 loss:0.732946 accu:0.083548 |xentropy_loss:0.966499
epoch:998 loss:0.347584 loss:0.628296 accu:0.066872 |xentropy_loss:1.105321
epoch:999 loss:0.368848 loss:0.647434 accu:0.090262 |xentropy_loss:1.182199
epoch:1000 loss:0.370808 loss:0.674310 accu:0.067307 |xentropy_loss:1.163575


In [23]:
X_A.shape

(30000, 784)

In [70]:
real_labels.shape

(64, 1)

In [None]:
gc.collect()

In [None]:
gc.collect()

In [4]:
5//2

2

In [5]:
6//5

1

In [6]:
5//6


0