In [1]:
from keras.preprocessing import image
from keras.applications.imagenet_utils import preprocess_input
import numpy as np
import os
import glob
import random
import math
from PIL import Image
import matplotlib.pyplot as plt

from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers.core import Activation
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Flatten
from keras.optimizers import SGD

%matplotlib inline
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['figure.figsize'] = (8, 8)

Using TensorFlow backend.


In [2]:
"""
load training data
"""
def list_image_pathes(trainDir, testDir):
    return glob.glob(trainDir+"/*.jpg"), glob.glob(testDir+"/*.jpg")

def random_images(pathes, target_size, batch_size=32):
    rRange = random.randint(0, len(pathes)-batch_size-1)
    randomPathes = pathes[rRange:rRange+batch_size]
    images = [image.img_to_array(img) for img in 
                  [image.load_img(path, target_size=target_size) for path in randomPathes]]
    return (np.array(images).astype(np.float32)-127.5)/127.5

In [3]:
def combine_generated_images(generated_images):
    num = generated_images.shape[0]
    gridWidth = int(math.sqrt(num))
    gridHeight = int(math.ceil(float(num)/gridWidth))
    shape = generated_images.shape[1:3]
    image = np.zeros((gridHeight*shape[0], gridWidth*shape[1], 3), dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index/gridWidth)
        j = index % gridWidth
        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1],:] = img[:, :, :]
    return image

In [4]:
"""
build model
"""
def generator_model(input_dim=128):
    model = Sequential()
    
    model.add(Dense(input_dim=input_dim, units=1024))
    model.add(Activation('tanh'))
    
    model.add(Dense(1024*4*4))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    
    model.add(Reshape((4, 4, 1024), input_shape=(1024*4*4,)))
    # 8x8
    model.add(UpSampling2D(size=(4, 4)))
    model.add(Conv2D(512, strides=(2, 2), kernel_size=5, padding='same'))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    # 16x16
    model.add(UpSampling2D(size=(4, 4)))
    model.add(Conv2D(256, strides=(2, 2), kernel_size=5, padding='same'))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    # 32x32
    model.add(UpSampling2D(size=(4, 4)))
    model.add(Conv2D(128, strides=(2, 2), kernel_size=5, padding='same'))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    
    # 64x64
    model.add(UpSampling2D(size=(4, 4)))
    model.add(Conv2D(3, strides=(2, 2), kernel_size=5, padding='same'))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    
    return model

generator_model().summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_1 (Dense)              (None, 1024)              132096    
_________________________________________________________________
activation_1 (Activation)    (None, 1024)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 16384)             16793600  
_________________________________________________________________
batch_normalization_1 (Batch (None, 16384)             65536     
_________________________________________________________________
activation_2 (Activation)    (None, 16384)             0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 4, 4, 1024)        0         
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 16, 16, 1024)      0         
__________

In [5]:
def discriminator_model(input_shape=(64, 64, 3)):
    model = Sequential()
    
    model.add(
            Conv2D(64, kernel_size=2, strides=1,
            padding='same',
            input_shape=input_shape)
            )
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    
    model.add(Conv2D(128, kernel_size=2, strides=1))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    
    model.add(Conv2D(256, kernel_size=2, strides=1))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    
    model.add(Conv2D(512, kernel_size=2, strides=1))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    
    model.add(Flatten())
    
    model.add(Dense(1024))
    model.add(Activation('tanh'))
    
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    return model

discriminator_model().summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_5 (Conv2D)            (None, 64, 64, 64)        832       
_________________________________________________________________
batch_normalization_6 (Batch (None, 64, 64, 64)        256       
_________________________________________________________________
activation_7 (Activation)    (None, 64, 64, 64)        0         
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 31, 31, 128)       32896     
_________________________________________________________________
batch_normalization_7 (Batch (None, 31, 31, 128)       512       
_________________________________________________________________
activation_8 (Activation)    (None, 31, 31, 128)       0         
__________

In [6]:
def concatenate_g_d(g, d):
    model = Sequential()
    model.add(g)
    d.trainable = False
    model.add(d)
    return model

In [7]:
"""
optimizer
"""
g_optim = SGD(lr=0.0002, momentum=0.5, nesterov=True)
d_optim = SGD(lr=0.0002, momentum=0.5, nesterov=True)

In [None]:
# realImgs = random_images(pathes, target_size)
"""
create model
"""
input_dim = 100
g_model = generator_model(input_dim=input_dim)
d_model = discriminator_model()
concatenate_model = concatenate_g_d(g_model, d_model)

"""
configuration
"""
g_model.compile(loss='binary_crossentropy', optimizer='SGD')
concatenate_model.compile(loss='binary_crossentropy', optimizer=g_optim)
d_model.trainable = True
d_model.compile(loss='binary_crossentropy', optimizer=d_optim)

In [None]:
"""
train
"""
trainDir = os.path.join('data', 'cars_train')
testDir = os.path.join('data', 'cars_test')
target_size=(64, 64)
batch_size = 128
train_p, test_p = list_image_pathes(trainDir, testDir)
pathes = train_p + test_p

for epoch in range(10):
    # traverse all of pathes
    for times in range(len(pathes)/batch_size):
        # mimic image
        noises = np.random.uniform(-1, 1, size=(batch_size, input_dim))
        generated_images = g_model.predict(noises, verbose=0)
        
        # store currency
        if times % 10 == 0:
            combinedImg = combine_generated_images(generated_images)
            combinedImg = combinedImg*127.5+127.5
            Image.fromarray(combinedImg.astype(np.uint8)).save("{}_epoch_{}_times.png".format(epoch, times))
        
        # train discriminator
        realImgs = random_images(pathes, target_size=target_size, batch_size=batch_size)
        X = np.concatenate((realImgs, generated_images))
        y = [1]*batch_size + [0]*batch_size
        d_loss = d_model.train_on_batch(X, y)
        
        # train generated model
        noises = np.random.uniform(-1, 1, size=(batch_size, input_dim))
        d_model.trainable = False
        g_loss = concatenate_model.train_on_batch(noises, [1]*batch_size)
        d_model.trainable = True
        print "epoch={}, times={}, d_loss={}, g_loss={}".format(epoch, times, d_loss, g_loss)
        
        if times % 50 == 0:
            g_model.save_weights('generator_weights.h5', True)
            d_model.save_weights('discriminator_weights.h5', True)

epoch=0, times=0, d_loss=0.891181826591, g_loss=0.682783842087
epoch=0, times=1, d_loss=0.854298532009, g_loss=0.690512657166
epoch=0, times=2, d_loss=0.826978206635, g_loss=0.693809092045
epoch=0, times=3, d_loss=0.805445075035, g_loss=0.712772905827
epoch=0, times=4, d_loss=0.78109061718, g_loss=0.674741506577
epoch=0, times=5, d_loss=0.761171400547, g_loss=0.699962496758
epoch=0, times=6, d_loss=0.7159512043, g_loss=0.681566953659
epoch=0, times=7, d_loss=0.682907938957, g_loss=0.708471655846
epoch=0, times=8, d_loss=0.661209821701, g_loss=0.70990973711
epoch=0, times=9, d_loss=0.646705389023, g_loss=0.711235404015
epoch=0, times=10, d_loss=0.661039292812, g_loss=0.74479842186
epoch=0, times=11, d_loss=0.618103802204, g_loss=0.756978034973
epoch=0, times=12, d_loss=0.603818535805, g_loss=0.724099516869
epoch=0, times=13, d_loss=0.587458610535, g_loss=0.710181117058
epoch=0, times=14, d_loss=0.548788845539, g_loss=0.760829031467
epoch=0, times=15, d_loss=0.548465192318, g_loss=0.7225

epoch=1, times=3, d_loss=0.107972547412, g_loss=0.687520503998
epoch=1, times=4, d_loss=0.154996290803, g_loss=0.673300623894
epoch=1, times=5, d_loss=0.143704116344, g_loss=0.680193185806
epoch=1, times=6, d_loss=0.138908788562, g_loss=0.676130115986
epoch=1, times=7, d_loss=0.139152869582, g_loss=0.697505950928
epoch=1, times=8, d_loss=0.134952992201, g_loss=0.662557244301
epoch=1, times=9, d_loss=0.12763504684, g_loss=0.676229655743
epoch=1, times=10, d_loss=0.117300920188, g_loss=0.638946473598
epoch=1, times=11, d_loss=0.135232374072, g_loss=0.677282094955
epoch=1, times=12, d_loss=0.124821312726, g_loss=0.655606746674
epoch=1, times=13, d_loss=0.154268473387, g_loss=0.663149178028
epoch=1, times=14, d_loss=0.118320629001, g_loss=0.660589456558
epoch=1, times=15, d_loss=0.103697814047, g_loss=0.627007365227
epoch=1, times=16, d_loss=0.138829946518, g_loss=0.640044271946
epoch=1, times=17, d_loss=0.125721439719, g_loss=0.638021230698
epoch=1, times=18, d_loss=0.116897955537, g_loss

epoch=2, times=5, d_loss=0.0881198272109, g_loss=0.542742013931
epoch=2, times=6, d_loss=0.0631159320474, g_loss=0.569753408432
epoch=2, times=7, d_loss=0.0763810425997, g_loss=0.540539085865
epoch=2, times=8, d_loss=0.0680690780282, g_loss=0.524990797043
epoch=2, times=9, d_loss=0.0795699879527, g_loss=0.572964966297
epoch=2, times=10, d_loss=0.0703251808882, g_loss=0.571820259094
epoch=2, times=11, d_loss=0.0666812583804, g_loss=0.555266737938
epoch=2, times=12, d_loss=0.0814810246229, g_loss=0.557921826839
epoch=2, times=13, d_loss=0.0874053239822, g_loss=0.559351563454
epoch=2, times=14, d_loss=0.0694440081716, g_loss=0.560148477554
epoch=2, times=15, d_loss=0.0714719370008, g_loss=0.547395586967
epoch=2, times=16, d_loss=0.081167884171, g_loss=0.570879340172
epoch=2, times=17, d_loss=0.0720067694783, g_loss=0.549884974957
epoch=2, times=18, d_loss=0.0681639611721, g_loss=0.544079005718
epoch=2, times=19, d_loss=0.0725938677788, g_loss=0.545015096664
epoch=2, times=20, d_loss=0.077

epoch=3, times=6, d_loss=0.067554756999, g_loss=0.483840435743
epoch=3, times=7, d_loss=0.0455225370824, g_loss=0.537879943848
epoch=3, times=8, d_loss=0.0464908666909, g_loss=0.492729902267
epoch=3, times=9, d_loss=0.0517211258411, g_loss=0.49878436327
epoch=3, times=10, d_loss=0.0386270098388, g_loss=0.495340287685
epoch=3, times=11, d_loss=0.0381702706218, g_loss=0.476252198219
epoch=3, times=12, d_loss=0.0668121427298, g_loss=0.502767980099
epoch=3, times=13, d_loss=0.0499800369143, g_loss=0.517434418201
epoch=3, times=14, d_loss=0.0393762625754, g_loss=0.502191007137
epoch=3, times=15, d_loss=0.0416799411178, g_loss=0.482301771641
epoch=3, times=16, d_loss=0.0529023557901, g_loss=0.482148379087
epoch=3, times=17, d_loss=0.0486023984849, g_loss=0.516412496567
epoch=3, times=18, d_loss=0.0711122751236, g_loss=0.496411859989
epoch=3, times=19, d_loss=0.0549451857805, g_loss=0.479088097811
epoch=3, times=20, d_loss=0.071121327579, g_loss=0.462647289038
epoch=3, times=21, d_loss=0.0675

epoch=4, times=7, d_loss=0.0616159774363, g_loss=0.449789345264
epoch=4, times=8, d_loss=0.0458917915821, g_loss=0.448183655739
epoch=4, times=9, d_loss=0.0394621342421, g_loss=0.42987370491
epoch=4, times=10, d_loss=0.0451115183532, g_loss=0.47376292944
epoch=4, times=11, d_loss=0.0529289580882, g_loss=0.463842511177
epoch=4, times=12, d_loss=0.0318158380687, g_loss=0.470794767141
epoch=4, times=13, d_loss=0.0384828671813, g_loss=0.436557888985
epoch=4, times=14, d_loss=0.0260342601687, g_loss=0.478190600872
epoch=4, times=15, d_loss=0.0269006732851, g_loss=0.450914263725
epoch=4, times=16, d_loss=0.0289090964943, g_loss=0.467066526413
epoch=4, times=17, d_loss=0.0384876541793, g_loss=0.45425760746
epoch=4, times=18, d_loss=0.0421004556119, g_loss=0.436431378126
epoch=4, times=19, d_loss=0.0362788662314, g_loss=0.445310741663
epoch=4, times=20, d_loss=0.0407154746354, g_loss=0.460417091846
epoch=4, times=21, d_loss=0.0294247381389, g_loss=0.459323346615
epoch=4, times=22, d_loss=0.028

epoch=5, times=8, d_loss=0.0461366847157, g_loss=0.407288193703
epoch=5, times=9, d_loss=0.0313345566392, g_loss=0.421734899282
epoch=5, times=10, d_loss=0.0293298028409, g_loss=0.398859739304
epoch=5, times=11, d_loss=0.0316550917923, g_loss=0.425453305244
epoch=5, times=12, d_loss=0.027719527483, g_loss=0.416379392147
epoch=5, times=13, d_loss=0.0257288720459, g_loss=0.4278678298
epoch=5, times=14, d_loss=0.0368267484009, g_loss=0.407321989536
epoch=5, times=15, d_loss=0.0248149521649, g_loss=0.409546107054
epoch=5, times=16, d_loss=0.0310111343861, g_loss=0.407366901636
epoch=5, times=17, d_loss=0.0338807925582, g_loss=0.397598743439
epoch=5, times=18, d_loss=0.0347255170345, g_loss=0.403487533331
epoch=5, times=19, d_loss=0.0426669828594, g_loss=0.41078466177
epoch=5, times=20, d_loss=0.0255410578102, g_loss=0.381487637758
epoch=5, times=21, d_loss=0.0233717598021, g_loss=0.443369626999
epoch=5, times=22, d_loss=0.0256636887789, g_loss=0.398667603731
epoch=5, times=23, d_loss=0.021

epoch=6, times=9, d_loss=0.0334482677281, g_loss=0.375753939152
epoch=6, times=10, d_loss=0.0453641042113, g_loss=0.379286408424
epoch=6, times=11, d_loss=0.0216820538044, g_loss=0.355836361647
epoch=6, times=12, d_loss=0.0233584325761, g_loss=0.353764474392
epoch=6, times=13, d_loss=0.0277601983398, g_loss=0.400144636631
epoch=6, times=14, d_loss=0.031565438956, g_loss=0.353773057461
epoch=6, times=15, d_loss=0.0223543234169, g_loss=0.373497843742
epoch=6, times=16, d_loss=0.0198988150805, g_loss=0.366450935602
epoch=6, times=17, d_loss=0.0357082784176, g_loss=0.366789579391
epoch=6, times=18, d_loss=0.0249175541103, g_loss=0.354903519154
epoch=6, times=19, d_loss=0.0314569734037, g_loss=0.382441103458
epoch=6, times=20, d_loss=0.0331264138222, g_loss=0.36617743969
epoch=6, times=21, d_loss=0.0192571394145, g_loss=0.382663041353
epoch=6, times=22, d_loss=0.0236108154058, g_loss=0.382521003485
epoch=6, times=23, d_loss=0.0190779268742, g_loss=0.387805610895
epoch=6, times=24, d_loss=0.

epoch=7, times=10, d_loss=0.0162050798535, g_loss=0.369708776474
epoch=7, times=11, d_loss=0.0303962379694, g_loss=0.362471044064
epoch=7, times=12, d_loss=0.0311548616737, g_loss=0.330468058586
epoch=7, times=13, d_loss=0.0282625891268, g_loss=0.346171945333
epoch=7, times=14, d_loss=0.0154851907864, g_loss=0.36923968792
epoch=7, times=15, d_loss=0.0209649689496, g_loss=0.373245239258
epoch=7, times=16, d_loss=0.0161788593978, g_loss=0.364374011755
epoch=7, times=17, d_loss=0.0411042608321, g_loss=0.336753368378
epoch=7, times=18, d_loss=0.0276979412884, g_loss=0.347206711769
epoch=7, times=19, d_loss=0.0306510422379, g_loss=0.354404687881
epoch=7, times=20, d_loss=0.0222952160984, g_loss=0.361521244049
epoch=7, times=21, d_loss=0.0251934435219, g_loss=0.351037561893
epoch=7, times=22, d_loss=0.0225868411362, g_loss=0.360456317663
epoch=7, times=23, d_loss=0.0166827067733, g_loss=0.352429866791
epoch=7, times=24, d_loss=0.0171906147152, g_loss=0.321602106094
epoch=7, times=25, d_loss=

epoch=8, times=11, d_loss=0.0225734803826, g_loss=0.317920029163
epoch=8, times=12, d_loss=0.0200542230159, g_loss=0.331325232983
epoch=8, times=13, d_loss=0.0147419050336, g_loss=0.347433686256
epoch=8, times=14, d_loss=0.0217952746898, g_loss=0.329251080751
epoch=8, times=15, d_loss=0.0163193568587, g_loss=0.300486832857
epoch=8, times=16, d_loss=0.0256621148437, g_loss=0.319092720747
epoch=8, times=17, d_loss=0.051259867847, g_loss=0.325147628784
epoch=8, times=18, d_loss=0.0420331060886, g_loss=0.306105673313
epoch=8, times=19, d_loss=0.0206706412137, g_loss=0.310238361359
epoch=8, times=20, d_loss=0.0162025019526, g_loss=0.318551898003
epoch=8, times=21, d_loss=0.0219259057194, g_loss=0.332616180182
epoch=8, times=22, d_loss=0.0272389966995, g_loss=0.331139445305
epoch=8, times=23, d_loss=0.0505234226584, g_loss=0.331843316555
epoch=8, times=24, d_loss=0.0154986744747, g_loss=0.338759094477
epoch=8, times=25, d_loss=0.0403342396021, g_loss=0.349749565125
epoch=8, times=26, d_loss=

epoch=9, times=12, d_loss=0.0276824850589, g_loss=0.316394776106
epoch=9, times=13, d_loss=0.0169074498117, g_loss=0.294770896435
epoch=9, times=14, d_loss=0.0173599887639, g_loss=0.321575343609
epoch=9, times=15, d_loss=0.022206492722, g_loss=0.323951840401
epoch=9, times=16, d_loss=0.0261010788381, g_loss=0.301266402006
epoch=9, times=17, d_loss=0.0177898220718, g_loss=0.30170199275
epoch=9, times=18, d_loss=0.0164088960737, g_loss=0.324144184589
epoch=9, times=19, d_loss=0.015076380223, g_loss=0.326410293579
epoch=9, times=20, d_loss=0.0189616065472, g_loss=0.327118396759
epoch=9, times=21, d_loss=0.017714323476, g_loss=0.294868171215
epoch=9, times=22, d_loss=0.0358661822975, g_loss=0.314464837313
epoch=9, times=23, d_loss=0.017045063898, g_loss=0.319771111012
epoch=9, times=24, d_loss=0.0180835388601, g_loss=0.335602998734
epoch=9, times=25, d_loss=0.0192157328129, g_loss=0.294457286596
