In [1]:
from keras.models import Sequential , Model
from keras.layers import Dense ,  BatchNormalization , Reshape , Input , Flatten
from keras.layers import Conv2D , MaxPool2D , Conv2DTranspose , UpSampling2D , ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers import Activation
from keras.layers import Dropout

from keras.layers import Embedding

from keras.layers import Multiply

from keras.optimizers import Adam

from keras.initializers import truncated_normal , random_normal , constant

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
import os

import matplotlib as plt
import numpy as np

import pandas as pd

import gc

%matplotlib inline

In [3]:
#CIFAR10 dataset
WIDTH = 32
HEIGHT = 32
CHANNEL = 3

LATENT_DIM = 100 #latent variable z sample from normal distribution

BATCH_SIZE = 64
EPOCHS = 10

PATH = '../dataset/train/'

#生成多少个图像 长*宽
ROW = 5
COL = 5

#=========
#=========
#add new FLAG(s)
CLASS_NUM = 10 #mnist=10 CIFAR10=10 CIFAR100=100 CIFAR1000=1000

LABEL2INDEX = {'frog':0 , 'truck':1 , 'deer':2 , 'automobile':3 , 'bird':4 , 'horse':5 , 'ship':6 , 'cat':7 , 'dog':8 , 'airplane':9}
INDEX2LABEL = {value:key for key , value in LABEL2INDEX.items()}

In [4]:
INDEX2LABEL

{0: 'frog',
 1: 'truck',
 2: 'deer',
 3: 'automobile',
 4: 'bird',
 5: 'horse',
 6: 'ship',
 7: 'cat',
 8: 'dog',
 9: 'airplane'}

In [5]:

load_index = 0

images_name = os.listdir(PATH)

IMAGES_COUNT = len(images_name)


In [6]:
train_labels = pd.read_csv('../dataset/trainLabels.csv')
#每张图片对应的类别标号
train_labels = train_labels['label'].map(LABEL2INDEX).get_values()

In [31]:

def load_image(batch_size = BATCH_SIZE):
    global load_index
    
    images = []
    labels = []
    
    for i in range(batch_size):
        images.append(plt.image.imread(PATH + images_name[(load_index + i) % IMAGES_COUNT]))
        labels.append(train_labels[(load_index + i) % IMAGES_COUNT])
    
    load_index += batch_size
    
    return np.array(images)/127.5-1 , np.array(labels)

def write_image(epoch):
    
    noise = np.random.normal(size = (ROW*COL , LATENT_DIM))
    labels = np.random.randint(low=0 , high=CLASS_NUM , size=(ROW*COL , 1))
    
    generated_image = generator_i.predict([noise , labels])
    
    generated_image = (generated_image+1)*127.5
    
    fig , axes = plt.pyplot.subplots(ROW , COL)
    fig.subplots_adjust(hspace=0.9 , wspace=0.9)
    
    count = 0
    
    for i in range(ROW):
        for j in range(COL):
            axes[i][j].imshow(generated_image[count])
            axes[i][j].axis('off')
            axes[i][j].set_title(INDEX2LABEL[labels[i*ROW+j][0]])
            
            count += 1
            
    fig.savefig('CIFAR10_cgan/No.%d.png' % epoch)
    plt.pyplot.close()
    
    #plt.image.imsave('images/'+str(epoch)+'.jpg')


In [15]:
def conv2d(output_size):
    return Conv2D(output_size , kernel_size=(5,5) , strides=(2,2) , padding='same' , kernel_initializer=truncated_normal(stddev=0.02) , bias_initializer=constant(0.0))

def dense(output_size):
    return Dense(output_size , kernel_initializer=random_normal(stddev=0.02) , bias_initializer=constant(0.0))

def deconv2d(output_size):
    return Conv2DTranspose(output_size , kernel_size=(5,5) , strides=(2,2) , padding='same' , kernel_initializer=random_normal(stddev=0.02) , bias_initializer=constant(0.0))

def batch_norm():
    return BatchNormalization(momentum=0.9 , epsilon=1e-5)


In [16]:
def generator():
    #sample from noise z
    model = Sequential(name='generator')
    
    #CIFAR10 图像使用 32*32*3
    model.add(Dense(2 * 2 * 64*8, activation="relu", input_shape=(LATENT_DIM,)))
    model.add(Reshape((2, 2, 64*8)))
    
    model.add(batch_norm())
    model.add(LeakyReLU(alpha=0.2))
    #4
    model.add(deconv2d(64*4))
    model.add(batch_norm())
    model.add(LeakyReLU(alpha=0.2))
    #8
    model.add(deconv2d(64*2))
    model.add(batch_norm())
    model.add(LeakyReLU(alpha=0.2))
    #16
    model.add(deconv2d(64*1))
    model.add(batch_norm())
    model.add(LeakyReLU(alpha=0.2))
    #32
    model.add(deconv2d(3))
    model.add(Activation('tanh'))
    
    #model.summary()
    
    noise = Input(shape=(LATENT_DIM , ) , name='input1')
    label = Input(shape=(1,) , dtype='int32')
    
    _ = Embedding(input_dim=CLASS_NUM , output_dim=LATENT_DIM)(label)
    embedding_label = Flatten()(_)
    
    noise_embedding_label = Multiply()([noise , embedding_label]) #(None , LATENT_DIM)
    
    image = model(noise_embedding_label)
    
    return Model([noise , label] , image , name='generator_Model')

In [17]:
def discriminator():
    #input a image to discriminate real or fake
    model = Sequential(name='discriminator')
    
    model.add(Conv2D(filters=64 , kernel_size=(5,5) , strides=(2,2) , padding='same' , input_shape=(WIDTH , HEIGHT , CHANNEL) , kernel_initializer=truncated_normal(stddev=0.02) , bias_initializer=constant(0.0) , name='conv1'))
    model.add(LeakyReLU(alpha=0.2))
    
    model.add(conv2d(64*2))
    model.add(batch_norm())
    model.add(LeakyReLU(alpha=0.2))
    
    model.add(conv2d(64*4))
    model.add(batch_norm())
    model.add(LeakyReLU(alpha=0.2))

    model.add(conv2d(64*8))
    model.add(batch_norm())
    model.add(LeakyReLU(alpha=0.2))
    
    model.add(Flatten())
    #=====
    model.add(Dense(128))
    model.add(LeakyReLU(alpha=0.2))
    #=====
    model.add(Dense(1 , activation='sigmoid'))
    
    #model.summary()
    
    image = Input(shape=(WIDTH , HEIGHT , CHANNEL) , name='input1')
    flatten_feature = model(image)
    
    label = Input(shape=(1,))
    embedding_label = Embedding(input_dim=CLASS_NUM , output_dim=WIDTH*HEIGHT*CHANNEL)(label)
    flatten_embedding_label = Flatten()(embedding_label)
    
    input_ = Multiply()([flatten_feature , flatten_embedding_label])
    input_reshape = Reshape(target_shape=(WIDTH , HEIGHT , CHANNEL))(input_)
    
    #FC层 多加了一层
    #_ = Dense(128)(flatten_feature)
    #_ = LeakyReLU(alpha=0.2)(_)
    
    validity = model(input_reshape)
        
    return Model([image , label] , validity , name='discriminator_Model')

In [21]:
def combined_model(generator_i , discriminator_i):
    #生成器和判别器组合成整体
    z = Input(shape=(LATENT_DIM , ) , name='z')
    label = Input(shape=(1,) , dtype='int32')
    
    image = generator_i([z , label])
    
    #print(image.shape)
    
    discriminator_i.trainable = False
    validity = discriminator_i([image , label])
    
    return Model([z , label] , validity , name='combined_model')

In [22]:
adam = Adam(lr = 0.0002 , beta_1=0.5)

In [23]:
discriminator_i = discriminator()
discriminator_i.compile(optimizer=adam , loss='binary_crossentropy' , metrics=['accuracy'])

generator_i = generator()

combined_model_i = combined_model(generator_i , discriminator_i)


combined_model_i.compile(optimizer=adam , loss='binary_crossentropy')

In [None]:
real_labels = np.ones(shape=(BATCH_SIZE , 1)) #真实样本label为1
fake_labels = np.zeros(shape=(BATCH_SIZE , 1)) #假样本label为0

for i in range(1001):
    #for j in range(int(IMAGES_COUNT/BATCH_SIZE)):
        
    noise = np.random.normal(size=(BATCH_SIZE , LATENT_DIM))
    corresponding_fake_label = np.random.randint(low=0 , high=CLASS_NUM , size=(BATCH_SIZE , 1)) #label的取值范围 可能会发生变化

    real_image , corresponding_real_label = load_image()

    #训练判别器
    fake_image = generator_i.predict([noise , corresponding_fake_label])

    real_loss = discriminator_i.train_on_batch([real_image , corresponding_real_label] , real_labels)
    fake_loss = discriminator_i.train_on_batch([fake_image , corresponding_fake_label] , fake_labels) #应该是real还是fake

    loss = np.add(real_loss , fake_loss)/2

    #训练生成器
    noise2 = np.random.normal(size=(BATCH_SIZE , LATENT_DIM))
    corresponding_fake_label2 = np.random.randint(low=0 , high=CLASS_NUM , size=(BATCH_SIZE , 1))

        #下面的损失是一个list 有两个损失 一个是validity一个是与label的softmax
    generator_loss = combined_model_i.train_on_batch([noise2 , corresponding_fake_label2] , real_labels)

    print('epoch:%d batch:%d loss:%f accu:%f gene_loss:[validity:%f]' % (i , 52 , loss[0] , loss[1] , generator_loss))

    if i % 100 == 0:
        write_image(i)
    
write_image(999)


  'Discrepancy between trainable weights and collected trainable'


epoch:0 batch:52 loss:0.718152 accu:0.000000 gene_loss:[validity:0.718192]
epoch:1 batch:52 loss:0.718136 accu:0.000000 gene_loss:[validity:0.718185]
epoch:2 batch:52 loss:0.718129 accu:0.000000 gene_loss:[validity:0.718162]
epoch:3 batch:52 loss:0.718112 accu:0.000000 gene_loss:[validity:0.718156]
epoch:4 batch:52 loss:0.718102 accu:0.000000 gene_loss:[validity:0.718136]
epoch:5 batch:52 loss:0.718086 accu:0.000000 gene_loss:[validity:0.718126]
epoch:6 batch:52 loss:0.718077 accu:0.000000 gene_loss:[validity:0.718110]
epoch:7 batch:52 loss:0.718060 accu:0.000000 gene_loss:[validity:0.718102]
epoch:8 batch:52 loss:0.718050 accu:0.000000 gene_loss:[validity:0.718082]
epoch:9 batch:52 loss:0.718037 accu:0.000000 gene_loss:[validity:0.718061]
epoch:10 batch:52 loss:0.718022 accu:0.000000 gene_loss:[validity:0.718044]
epoch:11 batch:52 loss:0.718008 accu:0.000000 gene_loss:[validity:0.718033]
epoch:12 batch:52 loss:0.717996 accu:0.000000 gene_loss:[validity:0.718020]
epoch:13 batch:52 loss

epoch:109 batch:52 loss:0.717532 accu:0.000000 gene_loss:[validity:0.717386]
epoch:110 batch:52 loss:0.717515 accu:0.000000 gene_loss:[validity:0.717390]
epoch:111 batch:52 loss:0.717502 accu:0.000000 gene_loss:[validity:0.717392]
epoch:112 batch:52 loss:0.717489 accu:0.000000 gene_loss:[validity:0.717391]
epoch:113 batch:52 loss:0.717476 accu:0.000000 gene_loss:[validity:0.717393]
epoch:114 batch:52 loss:0.717467 accu:0.000000 gene_loss:[validity:0.717394]
epoch:115 batch:52 loss:0.717458 accu:0.000000 gene_loss:[validity:0.717389]
epoch:116 batch:52 loss:0.717479 accu:0.000000 gene_loss:[validity:0.717355]
epoch:117 batch:52 loss:0.717465 accu:0.000000 gene_loss:[validity:0.717363]
epoch:118 batch:52 loss:0.717452 accu:0.000000 gene_loss:[validity:0.717384]
epoch:119 batch:52 loss:0.717442 accu:0.000000 gene_loss:[validity:0.717400]
epoch:120 batch:52 loss:0.717435 accu:0.000000 gene_loss:[validity:0.717409]
epoch:121 batch:52 loss:0.717428 accu:0.000000 gene_loss:[validity:0.717415]

epoch:217 batch:52 loss:0.717130 accu:0.000000 gene_loss:[validity:0.717143]
epoch:218 batch:52 loss:0.717131 accu:0.000000 gene_loss:[validity:0.717144]
epoch:219 batch:52 loss:0.717130 accu:0.000000 gene_loss:[validity:0.717145]
epoch:220 batch:52 loss:0.717131 accu:0.000000 gene_loss:[validity:0.717147]
epoch:221 batch:52 loss:0.717130 accu:0.000000 gene_loss:[validity:0.717145]
epoch:222 batch:52 loss:0.717130 accu:0.000000 gene_loss:[validity:0.717145]
epoch:223 batch:52 loss:0.717130 accu:0.000000 gene_loss:[validity:0.717150]
epoch:224 batch:52 loss:0.717130 accu:0.000000 gene_loss:[validity:0.717184]
epoch:225 batch:52 loss:0.717135 accu:0.000000 gene_loss:[validity:0.717195]
epoch:226 batch:52 loss:0.717140 accu:0.000000 gene_loss:[validity:0.717183]
epoch:227 batch:52 loss:0.717137 accu:0.000000 gene_loss:[validity:0.717165]
epoch:228 batch:52 loss:0.717151 accu:0.000000 gene_loss:[validity:0.717100]
epoch:229 batch:52 loss:0.717132 accu:0.000000 gene_loss:[validity:0.717082]

epoch:325 batch:52 loss:0.716821 accu:0.000000 gene_loss:[validity:0.716802]
epoch:326 batch:52 loss:0.716817 accu:0.000000 gene_loss:[validity:0.716798]
epoch:327 batch:52 loss:0.716812 accu:0.000000 gene_loss:[validity:0.716795]
epoch:328 batch:52 loss:0.716809 accu:0.000000 gene_loss:[validity:0.716792]
epoch:329 batch:52 loss:0.716804 accu:0.000000 gene_loss:[validity:0.716789]
epoch:330 batch:52 loss:0.716800 accu:0.000000 gene_loss:[validity:0.716788]
epoch:331 batch:52 loss:0.716796 accu:0.000000 gene_loss:[validity:0.716791]
epoch:332 batch:52 loss:0.716793 accu:0.000000 gene_loss:[validity:0.716789]
epoch:333 batch:52 loss:0.716792 accu:0.000000 gene_loss:[validity:0.716789]
epoch:334 batch:52 loss:0.716789 accu:0.000000 gene_loss:[validity:0.716790]
epoch:335 batch:52 loss:0.716787 accu:0.000000 gene_loss:[validity:0.716795]
epoch:336 batch:52 loss:0.716790 accu:0.000000 gene_loss:[validity:0.716801]
epoch:337 batch:52 loss:0.716791 accu:0.000000 gene_loss:[validity:0.716811]

epoch:432 batch:52 loss:0.716676 accu:0.000000 gene_loss:[validity:0.716661]
epoch:433 batch:52 loss:0.716667 accu:0.000000 gene_loss:[validity:0.716649]
epoch:434 batch:52 loss:0.716663 accu:0.000000 gene_loss:[validity:0.716637]
epoch:435 batch:52 loss:0.716657 accu:0.000000 gene_loss:[validity:0.716630]
epoch:436 batch:52 loss:0.718338 accu:0.000000 gene_loss:[validity:0.716913]
epoch:437 batch:52 loss:0.718298 accu:0.000000 gene_loss:[validity:0.716349]
epoch:438 batch:52 loss:0.718045 accu:0.000000 gene_loss:[validity:0.715770]
epoch:439 batch:52 loss:0.717639 accu:0.000000 gene_loss:[validity:0.715825]
epoch:440 batch:52 loss:0.717484 accu:0.000000 gene_loss:[validity:0.715994]
epoch:441 batch:52 loss:0.717405 accu:0.000000 gene_loss:[validity:0.716131]
epoch:442 batch:52 loss:0.717348 accu:0.000000 gene_loss:[validity:0.716576]
epoch:443 batch:52 loss:0.717254 accu:0.000000 gene_loss:[validity:0.716851]
epoch:444 batch:52 loss:0.717299 accu:0.000000 gene_loss:[validity:0.716853]

epoch:539 batch:52 loss:0.716631 accu:0.000000 gene_loss:[validity:0.716400]
epoch:540 batch:52 loss:0.716632 accu:0.000000 gene_loss:[validity:0.716467]
epoch:541 batch:52 loss:0.716625 accu:0.000000 gene_loss:[validity:0.716519]
epoch:542 batch:52 loss:0.716619 accu:0.000000 gene_loss:[validity:0.716556]
epoch:543 batch:52 loss:0.716611 accu:0.000000 gene_loss:[validity:0.716580]
epoch:544 batch:52 loss:0.716601 accu:0.000000 gene_loss:[validity:0.716590]
epoch:545 batch:52 loss:0.716592 accu:0.000000 gene_loss:[validity:0.716792]
epoch:546 batch:52 loss:0.716739 accu:0.000000 gene_loss:[validity:0.716419]
epoch:547 batch:52 loss:0.716746 accu:0.000000 gene_loss:[validity:0.716269]
epoch:548 batch:52 loss:0.716671 accu:0.000000 gene_loss:[validity:0.716321]
epoch:549 batch:52 loss:0.716642 accu:0.000000 gene_loss:[validity:0.716416]
epoch:550 batch:52 loss:0.716629 accu:0.000000 gene_loss:[validity:0.716493]
epoch:551 batch:52 loss:0.716618 accu:0.000000 gene_loss:[validity:0.716554]

epoch:647 batch:52 loss:0.716344 accu:0.000000 gene_loss:[validity:0.716330]
epoch:648 batch:52 loss:0.716327 accu:0.000000 gene_loss:[validity:0.716300]
epoch:649 batch:52 loss:0.720768 accu:0.000000 gene_loss:[validity:0.711636]
epoch:650 batch:52 loss:0.717925 accu:0.000000 gene_loss:[validity:0.712065]
epoch:651 batch:52 loss:0.717164 accu:0.000000 gene_loss:[validity:0.713250]
epoch:652 batch:52 loss:0.716939 accu:0.000000 gene_loss:[validity:0.714229]
epoch:653 batch:52 loss:0.716851 accu:0.000000 gene_loss:[validity:0.714960]
epoch:654 batch:52 loss:0.716788 accu:0.000000 gene_loss:[validity:0.715509]
epoch:655 batch:52 loss:0.716757 accu:0.000000 gene_loss:[validity:0.715879]
epoch:656 batch:52 loss:0.716730 accu:0.000000 gene_loss:[validity:0.716121]
epoch:657 batch:52 loss:0.716707 accu:0.000000 gene_loss:[validity:0.716273]
epoch:658 batch:52 loss:0.716686 accu:0.000000 gene_loss:[validity:0.716369]
epoch:659 batch:52 loss:0.716659 accu:0.000000 gene_loss:[validity:0.716434]

In [70]:
real_labels.shape

(64, 1)

In [53]:
discriminator_i.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input1 (InputLayer)          (None, 28, 28, 1)         0         
_________________________________________________________________
discriminator (Sequential)   (None, 1)                 533505    
Total params: 533,505
Trainable params: 0
Non-trainable params: 533,505
_________________________________________________________________


In [54]:
generator_i.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input1 (InputLayer)          (None, 100)               0         
_________________________________________________________________
generator (Sequential)       (None, 28, 28, 1)         1097744   
Total params: 1,097,744
Trainable params: 1,095,184
Non-trainable params: 2,560
_________________________________________________________________


In [39]:
combined_model_i.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
z (InputLayer)               (None, 100)               0         
_________________________________________________________________
generator_Model (Model)      (None, 96, 96, 3)         29029120  
_________________________________________________________________
discriminator_Model (Model)  (None, 1)                 14320641  
Total params: 43,349,761
Trainable params: 29,025,536
Non-trainable params: 14,324,225
_________________________________________________________________


In [None]:
gc.collect()

In [41]:
32*400

12800

In [None]:
gc.collect()