In [66]:
import numpy as np
from PIL import Image
from keras.models import Sequential
from keras.layers import Dense, Reshape, Activation, BatchNormalization, Dropout, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D, UpSampling2D
from keras.optimizers import SGD
from keras.datasets import mnist
import argparse

In [52]:
 def discriminator():
    model = Sequential()
    model.add(
            Conv2D(64, (5, 5),
            padding='same',
            input_shape=(28,28,1))
            )
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2,2)))
    model.add(Flatten())
    model.add(Dense(1024))
    model.add(Activation('relu'))
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    return model

In [81]:
def generator():
    model = Sequential()
    model.add(Dense(input_dim=100, output_dim=1024))
    model.add(Activation('relu'))
    model.add(Dense(128*7*7))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))
    model.add(UpSampling2D(size=(2,2)))
    model.add(Conv2D(64, (5,5), padding='same'))
    model.add(Activation('relu'))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(1, (5, 5), padding='same'))
    model.add(Activation('relu'))
    return model


In [82]:
def generator_with_discriminator(g, d):
    model = Sequential()
    model.add(g)
    d.trainable = False
    model.add(d)
    return model

In [83]:
def combine_images(generated_images):
    num = generated_images.shape[0]
    width = int(np.sqrt(num))
    height = int(np.ceil(float(num)/width))
    shape = generated_images.shape[1:3]
    image = np.zeros((height*shape[0], width*shape[1]),
                    dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index/width)
        j = index % width
        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \
            img[:, :, 0]
    return image

In [88]:
def train(batch_size):
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    X_train = (X_train.astype(np.float32) - 127.5)/127.5
    X_train = X_train[:, :, :, None]
    X_test = X_test[:, :, :, None]
    d = discriminator()
    g = generator()
    d_with_g = generator_with_discriminator(g, d)
    d_opt = SGD(lr=0.0005, momentum=0.9, nesterov=True)
    g_opt = SGD(lr=0.0005, momentum=0.9, nesterov=True)
    g.compile(loss='binary_crossentropy', optimizer=g_opt)
    d_with_g.compile(loss='binary_crossentropy', optimizer=d_opt)
    d.trainable = True
    d.compile(loss='binary_crossentropy', optimizer=d_opt)
    for epoch in range(100):
        print("Epoch #: ", epoch)
        print('Number of batches', int(X_train.shape[0]/batch_size))
        for index in range(int(X_train.shape[0]/batch_size)):
            noise = np.random.uniform(-1, 1, size=(batch_size, 100))
            image_batch = X_train[index*batch_size: (index+1)*batch_size]
            generated_images = g.predict(noise, verbose=0)
            if index % 20 == 0:
                image = combine_images(generated_images)
                image = image*127.5 + 127.5
                Image.fromarray(image.astype(np.uint8)).save(
                str(epoch)+"_"+str(index)+".png")
            X = np.concatenate((image_batch, generated_images))
            y = [1] * batch_size + [0] * batch_size
            d_loss = d.train_on_batch(X, y)
            print("batch: {}   d_loss:  {}".format(index, d_loss))
            noise = np.random.uniform(-1, 1, (batch_size, 100))
            d.trainable = False
            g_loss = d_with_g.train_on_batch(noise, [1] * batch_size)
            d.trainable = True
            print("batch: {}   g_loss:  {}".format(index, g_loss))
            if index % 10 == 9:
                g.save_weights('generator', True)
                d.save_weights('discriminator', True)

In [85]:
def generate(batch_size, nice=False):
    g = generator()
    g.compile(loss='categorical_crossentropy', optimizer="SGD")
    g.load_weights('generator')
    if nice:
        d = discriminator()
        d.compile(loss='categorical_crossentropy', optimizer="SGD")
        d.load_weights('discriminator')
        noise = np.random.uniform(-1, 1, (batch_size*20, 100))
        generated_images = g.predict(noise, verbose=1)
        d_pret = d.predict(generated_images, verbose=1)
        index = np.arange(0, batch_size*20)
        index.resize((batch_size*20, 1))
        pre_with_index = list(np.append(d_pret, index, axis=1))
        pre_with_index.sort(key=lambda x: x[0], reverse=True)
        nice_images = np.zeros((batch_size,) + generated_images.shape[1:3], dtype=np.float32)
        nice_images = nice_images[:, :, :, None]
        for i in range(batch_size):
            idx = int(pre_with_index[i][1])
            nice_images[i, :, :, 0] = generated_images[idx, :, :, 0]
        image = combine_images(nice_images)
    else:
        noise = np.random.uniform(-1, 1, (batch_size, 100))
        generated_images = g.predict(noise, verbose=1)
        image = combine_images(generated_images)
    image = image*127.5 + 127.5
    Image.fromarray(image.astype(np.uint8)).save('generated_image.png')    

In [86]:
# %tb
# def get_args():
#     parser = argparse.ArgumentParser()
#     parser.add_argument("--mode", type=str)
#     parser.add_argument("--batch-size", type=int, default=128)
#     parser.add_argument("--nice", dest="nice", action="store_true")
#     parser.set_defaults(nice=False)
#     args = parser.parse_args()
#     return args

# if __name__ == '__main__':
#     args = get_args()
#     if args.mode == "train":
#         train(batch_size=args.batch_size)
#     elif args.mode == "generate":
#         generate(batch_size=args.batch_size, nice=args.nice)


In [89]:
train(batch_size=128)

  app.launch_new_instance()


Epoch #:  0
Number of batches 468
batch: 0   d_loss:  0.6886478662490845
batch: 0   g_loss:  0.6501742005348206
batch: 1   d_loss:  0.6802733540534973
batch: 1   g_loss:  0.6471580266952515
batch: 2   d_loss:  0.6713612079620361
batch: 2   g_loss:  0.6437466740608215
batch: 3   d_loss:  0.6571463942527771
batch: 3   g_loss:  0.6368536949157715
batch: 4   d_loss:  0.6450299620628357
batch: 4   g_loss:  0.6298059225082397
batch: 5   d_loss:  0.6255583167076111
batch: 5   g_loss:  0.6248408555984497
batch: 6   d_loss:  0.6073341369628906
batch: 6   g_loss:  0.6195707321166992
batch: 7   d_loss:  0.5945340394973755
batch: 7   g_loss:  0.6154245138168335
batch: 8   d_loss:  0.575135350227356
batch: 8   g_loss:  0.6065338850021362
batch: 9   d_loss:  0.559829831123352
batch: 9   g_loss:  0.5981008410453796
batch: 10   d_loss:  0.5420458316802979
batch: 10   g_loss:  0.5896244049072266
batch: 11   d_loss:  0.5307350754737854
batch: 11   g_loss:  0.5790247321128845
batch: 12   d_loss:  0.51440

batch: 102   d_loss:  0.28903061151504517
batch: 102   g_loss:  0.8425058722496033
batch: 103   d_loss:  0.2878354489803314
batch: 103   g_loss:  0.8454259037971497
batch: 104   d_loss:  0.28622734546661377
batch: 104   g_loss:  0.8484641909599304
batch: 105   d_loss:  0.285605251789093
batch: 105   g_loss:  0.8515022397041321
batch: 106   d_loss:  0.2837252616882324
batch: 106   g_loss:  0.8545708060264587
batch: 107   d_loss:  0.2819242775440216
batch: 107   g_loss:  0.8577473163604736
batch: 108   d_loss:  0.28126776218414307
batch: 108   g_loss:  0.8609595894813538
batch: 109   d_loss:  0.28037068247795105
batch: 109   g_loss:  0.864219605922699
batch: 110   d_loss:  0.2799861431121826
batch: 110   g_loss:  0.867525041103363
batch: 111   d_loss:  0.27821582555770874
batch: 111   g_loss:  0.870856523513794
batch: 112   d_loss:  0.2762184739112854
batch: 112   g_loss:  0.87428879737854
batch: 113   d_loss:  0.274557888507843
batch: 113   g_loss:  0.8777478933334351
batch: 114   d_los

batch: 202   d_loss:  0.12120196223258972
batch: 202   g_loss:  1.5647674798965454
batch: 203   d_loss:  0.12038439512252808
batch: 203   g_loss:  1.5773017406463623
batch: 204   d_loss:  0.11833310127258301
batch: 204   g_loss:  1.5897983312606812
batch: 205   d_loss:  0.11654660850763321
batch: 205   g_loss:  1.6024954319000244
batch: 206   d_loss:  0.11498434841632843
batch: 206   g_loss:  1.6151649951934814
batch: 207   d_loss:  0.11319732666015625
batch: 207   g_loss:  1.62790846824646
batch: 208   d_loss:  0.11198340356349945
batch: 208   g_loss:  1.6407408714294434
batch: 209   d_loss:  0.1106451228260994
batch: 209   g_loss:  1.6537134647369385
batch: 210   d_loss:  0.10860374569892883
batch: 210   g_loss:  1.6666924953460693
batch: 211   d_loss:  0.1077309399843216
batch: 211   g_loss:  1.67973792552948
batch: 212   d_loss:  0.1065448671579361
batch: 212   g_loss:  1.692825198173523
batch: 213   d_loss:  0.1048436239361763
batch: 213   g_loss:  1.7059223651885986
batch: 214   

batch: 301   g_loss:  2.8145852088928223
batch: 302   d_loss:  0.03275280445814133
batch: 302   g_loss:  2.825216770172119
batch: 303   d_loss:  0.03256497532129288
batch: 303   g_loss:  2.835803270339966
batch: 304   d_loss:  0.03197375684976578
batch: 304   g_loss:  2.84621524810791
batch: 305   d_loss:  0.032150838524103165
batch: 305   g_loss:  2.8566689491271973
batch: 306   d_loss:  0.03176505118608475
batch: 306   g_loss:  2.866934061050415
batch: 307   d_loss:  0.03132915496826172
batch: 307   g_loss:  2.877211093902588
batch: 308   d_loss:  0.03120749443769455
batch: 308   g_loss:  2.8874897956848145
batch: 309   d_loss:  0.030762968584895134
batch: 309   g_loss:  2.897660732269287
batch: 310   d_loss:  0.029761793091893196
batch: 310   g_loss:  2.9078168869018555
batch: 311   d_loss:  0.02972453460097313
batch: 311   g_loss:  2.9178919792175293
batch: 312   d_loss:  0.02960982918739319
batch: 312   g_loss:  2.928021192550659
batch: 313   d_loss:  0.029414284974336624
batch: 3

batch: 400   g_loss:  3.6386942863464355
batch: 401   d_loss:  0.014725906774401665
batch: 401   g_loss:  3.6452388763427734
batch: 402   d_loss:  0.014811962842941284
batch: 402   g_loss:  3.6516714096069336
batch: 403   d_loss:  0.014451375231146812
batch: 403   g_loss:  3.658155918121338
batch: 404   d_loss:  0.014315484091639519
batch: 404   g_loss:  3.6645870208740234
batch: 405   d_loss:  0.014069047756493092
batch: 405   g_loss:  3.670893430709839
batch: 406   d_loss:  0.013961468823254108
batch: 406   g_loss:  3.677386999130249
batch: 407   d_loss:  0.013908563181757927
batch: 407   g_loss:  3.6838302612304688
batch: 408   d_loss:  0.013805081136524677
batch: 408   g_loss:  3.6901538372039795
batch: 409   d_loss:  0.014260204508900642
batch: 409   g_loss:  3.6964333057403564
batch: 410   d_loss:  0.013815373182296753
batch: 410   g_loss:  3.7027347087860107
batch: 411   d_loss:  0.013575704768300056
batch: 411   g_loss:  3.708984851837158
batch: 412   d_loss:  0.013633981347084

KeyboardInterrupt: 