##  Import module

In [1]:
from keras.models import Sequential
from keras.layers.core import Dense , Activation , Flatten , Reshape 
# convolutional layers
from keras.layers.convolutional import Conv2D , UpSampling2D
from keras.layers.pooling import MaxPooling2D
# other module
from keras.layers.normalization import BatchNormalization
from keras.optimizers import SGD
from keras.datasets import mnist
from keras.utils import np_utils

from time import clock as now
import math
import numpy as np
from PIL import Image 

Using TensorFlow backend.


## Define Generator and Discriminator  

In [2]:
def generator_model():
    model = Sequential()
    model.add(Dense(1024,input_dim=100,name='gen_layer1'))
    model.add(Activation('tanh'))
    
    model.add(Dense(128*7*7,name='gen_layer2'))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    
    model.add(Reshape((128,7,7),input_shape=(128*7*7,)))
    model.add(UpSampling2D(size=(2,2)))
    
    model.add(Conv2D(filters=64,kernel_size=(5,5),padding='same',name='gen_layer3'))
    model.add(Activation('tanh'))
    model.add(UpSampling2D(size=(2,2)))
    
    model.add(Conv2D(1,(5,5),padding='same',name='gen_output'))
    model.add(Activation('tanh'))
    return model    

In [3]:
def discriminator_model():
    model = Sequential()
    model.add(Conv2D(64,(5,5),padding='same',input_shape=(1,28,28),name='dcm_layer1'))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2,2)))
    
    model.add(Conv2D(128,(5,5),padding='same',name='dcm_layer2'))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2,2)))
    
    model.add(Flatten())
    
    model.add(Dense(1024,name='dcm_layer3'))
    model.add(Activation('tanh'))
    
    model.add(Dense(1,name='dcm_output'))
    model.add(Activation('sigmoid'))
    return model

## Define Generator_D and Classifier  

In [4]:
def generator_containing_discriminator(generator,discriminator):
    model = Sequential()
    model.add(generator)
    discriminator.trainable = False
    model.add(discriminator)
    return model

In [5]:
def classifier_model():
    model = Sequential()
    model.add(Conv2D(64,(5,5),padding='same',input_shape=(1,28,28),name='dcm_layer1'))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2,2)))
    
    model.add(Conv2D(128,(5,5),padding='same',name='dcm_layer2'))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2,2)))
    
    model.add(Flatten())

    model.add(Dense(1024,name='dcm_layer3'))
    model.add(Activation('tanh'))
    
    model.add(Dense(10,name='clf_output'))
    model.add(Activation('softmax'))
    return model

##  Combine_image and  Generate_image

In [6]:
def combine_images(generate_images):
    num = generate_images.shape[0]
    
    shape  = generate_images.shape[2:]
    width  = int(math.sqrt(num))
    height = math.ceil(num/width)
    
    image = np.zeros((height*shape[0],width*shape[1]),# shape must be a tuple
                     dtype= generate_images.dtype)
    for index,img in enumerate(generate_images):
        i = int(index/width)
        j = index % width
        image[i*shape[0]:(i+1)*shape[0],j*shape[1]:(j+1)*shape[1]]=\
            img[0,:,:]
    return image

In [7]:
def generate(BATCH_SIZE,nice=False):
    generator = generator_model()
    generator.compile(loss = 'binary_crossentropy',optimizer='SGD')
    generator.load_weights('gan_generator.h5')
    
    if nice:
        discriminator = discriminator_model()
        discriminator.compile(loss='binary_crossentropy',optimizer='SGD')
        discriminator.load_weights('gan_discriminator.h5')
        
        noise = np.zeros((BATCH_SIZE*20,100))
        for i in range(BATCH_SIZE*20):
            noise[i,:]=np.random.uniform(-1,1,100)
        
        generated_images = generator.predict(noise,verbose=0)
        d_pret = discriminator.predict(generated_images,verbose=0)
        
        # order
        index =  np.arange(0,BATCH_SIZE*20)[:,np.newaxis]
        pre_with_index = list(np.append(d_pret,index,axis=1))
        pre_with_index.sort(key=lambda x:x[0],reverse=True)
        
        nice_images = np.zeros((BATCH_SIZE,1)+(generated_images.shape[2:]),dtype=np.float32)
        for i in range(BATCH_SIZE):
            idx = int(pre_with_index[i][1])
            nice_images[i,0,:,:]=generated_images[idx,0,:,:] # 从零开始计数[idx,1,:,:]会报错out of bounds for axis
        image =combine_images(nice_images)
    else:
        noise = np.zeros((BATCH_SIZE,100))
        for i in range(BATCH_SIZE):
            noise[i,:] = np.random.uniform(-1,1,100)
        generated_images = generator.predict(noise,verbose=1)
        image =combine_images(nice_images)
    
    image = image*127.5+127.5
    Image.fromarray(image.astype(np.uint8)).save('generated_image.png')                         

## GAN_Training 

In [8]:
def train(BATCH_SIZE,EPOCH):
    (x_train,y_train),(x_test,y_test)=mnist.load_data()
    x_train = (x_train.astype(np.float32)-127.5)/127.5
    x_train = x_train.reshape((x_train.shape[0],1,28,28))
    
    #define models
    generator = generator_model()
    discriminator = discriminator_model()
    discriminator_on_generator = generator_containing_discriminator(generator,discriminator)
    
    # compile models
    d_optim = SGD(lr = 0.0005,momentum =0.9,nesterov=True)
    g_optim = SGD(lr = 0.0005,momentum = 0.9,nesterov =True)
    generator.compile(loss ='binary_crossentropy',optimizer='SGD')
    discriminator_on_generator.compile(loss='binary_crossentropy',optimizer=g_optim)
    discriminator.trainable = True
    discriminator.compile(loss='binary_crossentropy',optimizer=d_optim)
    
    # train epoch and bitch size
    noise = np.zeros((BATCH_SIZE,100))
    for epoch in range(EPOCH):
        BATCH_NUM = int(x_train.shape[0]/BATCH_SIZE)
        print('---------------------EPOCH IS %d TOTAL OF %d-------------------- '%(epoch,EPOCH))
        for index in range(BATCH_NUM):
            # create train_batch
            for i in range(BATCH_SIZE):
                noise[i,:] = np.random.uniform(-1,1,100)
            generated_images = generator.predict(noise,verbose=0)
            image_batch = x_train[i*BATCH_SIZE:(i+1)*BATCH_SIZE]
        
            X = np.concatenate((image_batch,generated_images))
            y = [1]*BATCH_SIZE+[0]*BATCH_SIZE
            # train the discriminator
            d_loss = discriminator.train_on_batch(X,y)
            # watch the process
            if index%100 == 0 :
                print('In batch %d of %d ,dcm loss is:%f'%(index,BATCH_NUM,d_loss))
                image = combine_images(generated_images)
                image = image*127.5+127.5
                Image.fromarray(image.astype(np.uint8)).save('Process_generator/'+str(epoch)+'_'+str(index)+'.png')
            
            for i in range(BATCH_SIZE):
                noise[i,:]=np.random.uniform(-1,1,100)
            discriminator.trainable =False
            g_loss = discriminator_on_generator.train_on_batch(noise,[1]*BATCH_SIZE)
            discriminator.trainable = True
            if index%100 ==0:
                print('In batch %d of %d ,gen loss is:%f'%(index,BATCH_NUM,g_loss))
        
        generator.save_weights('gan_generator.h5')
        discriminator.save_weights('gan_discriminator.h5')

## Classifier_training

In [None]:
def classifier_train():
    (x_train,y_train),(x_test,y_test) = mnist.load_data()
    x_train = (x_train-127.5)/127.5
    x_test  = (x_test -127.5)/127.5
    x_train = x_train.reshape([x_train.shape[0],1,28,28])
    x_test  = x_test.reshape([x_test.shape[0],1,28,28])
    y_train = np_utils.to_categorical(y_train,num_classes = 10)
    y_test  = np_utils.to_categorical(y_test, num_classes = 10)
    
    classifier = classifier_model()
    classifier.compile(loss='binary_crossentropy',optimizer='SGD',metrics=['accuracy'])
    classifier.load_weights('gan_discriminator.h5',by_name=True)
    classifier.fit(x_train,y_train,epochs=10,batch_size=128)
    
    loss,accuracy=classifier.evaluate(x_test,y_test,verbose = 1)
    classifier.save_weights('gan_classifier.h5')

## Main 

In [None]:
time1=now()
train(BATCH_SIZE=128,EPOCH=100)
time2=now()
generate(BATCH_SIZE=128,nice=True)
classifier_train()
print(time2-time1)

---------------------EPOCH IS 0 TOTAL OF 100-------------------- 


kwargs passed to function are ignored with Tensorflow backend


In batch 0 of 468 ,dcm loss is:0.663925
In batch 0 of 468 ,gen loss is:0.724368
In batch 100 of 468 ,dcm loss is:0.210255
In batch 100 of 468 ,gen loss is:1.159469
