In [1]:
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers.core import Activation
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Flatten
from keras.optimizers import SGD
from keras.datasets import mnist
import numpy as np
from PIL import Image
import argparse
import math

Using TensorFlow backend.


In [2]:
def generator_model():
    model = Sequential()
    model.add(Dense(input_dim=100, output_dim=1024))
    model.add(Activation('tanh'))
    model.add(Dense(128*7*7))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(64, (5, 5), padding='same'))
    model.add(Activation('tanh'))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(1, (5, 5), padding='same'))
    model.add(Activation('tanh'))
    return model

In [3]:
def discriminator_model():
    model = Sequential()
    model.add(
            Conv2D(64, (5, 5),
            padding='same',
            input_shape=(28, 28, 1))
            )
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(128, (5, 5)))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(1024))
    model.add(Activation('tanh'))
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    return model

In [4]:
def generator_containing_discriminator(g, d):
    model = Sequential()
    model.add(g)
    d.trainable = False
    model.add(d)
    return model

In [5]:
def combine_images(generated_images):
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num)/width))
    shape = generated_images.shape[1:3]
    image = np.zeros((height*shape[0], width*shape[1]),
                     dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index/width)
        j = index % width
        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \
            img[:, :, 0]
    return image

In [6]:
def train(BATCH_SIZE):
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    X_train = (X_train.astype(np.float32) - 127.5)/127.5
    X_train = X_train[:, :, :, None]
    X_test = X_test[:, :, :, None]
    # X_train = X_train.reshape((X_train.shape, 1) + X_train.shape[1:])
    d = discriminator_model()
    g = generator_model()
    d_on_g = generator_containing_discriminator(g, d)
    d_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
    g_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
    g.compile(loss='binary_crossentropy', optimizer="SGD")
    d_on_g.compile(loss='binary_crossentropy', optimizer=g_optim)
    d.trainable = True
    d.compile(loss='binary_crossentropy', optimizer=d_optim)
    for epoch in range(1):
        print("Epoch is", epoch)
        print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))
        for index in range(int(X_train.shape[0]/BATCH_SIZE)):
            noise = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))
            image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
            generated_images = g.predict(noise, verbose=0)
            if index % 20 == 0:
                image = combine_images(generated_images)
                image = image*127.5+127.5
                Image.fromarray(image.astype(np.uint8)).save(
                    str(epoch)+"_"+str(index)+".png")
            X = np.concatenate((image_batch, generated_images))
            y = [1] * BATCH_SIZE + [0] * BATCH_SIZE
            d_loss = d.train_on_batch(X, y)
            print("batch %d d_loss : %f" % (index, d_loss))
            noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
            d.trainable = False
            g_loss = d_on_g.train_on_batch(noise, [1] * BATCH_SIZE)
            d.trainable = True
            print("batch %d g_loss : %f" % (index, g_loss))
            if index % 10 == 9:
                g.save_weights('generator', True)
                d.save_weights('discriminator', True)

In [7]:
def generate(BATCH_SIZE, nice=False):
    g = generator_model()
    g.compile(loss='binary_crossentropy', optimizer="SGD")
    g.load_weights('generator')
    if nice:
        d = discriminator_model()
        d.compile(loss='binary_crossentropy', optimizer="SGD")
        d.load_weights('discriminator')
        noise = np.random.uniform(-1, 1, (BATCH_SIZE*20, 100))
        generated_images = g.predict(noise, verbose=1)
        d_pret = d.predict(generated_images, verbose=1)
        index = np.arange(0, BATCH_SIZE*20)
        index.resize((BATCH_SIZE*20, 1))
        pre_with_index = list(np.append(d_pret, index, axis=1))
        pre_with_index.sort(key=lambda x: x[0], reverse=True)
        nice_images = np.zeros((BATCH_SIZE,) + generated_images.shape[1:3], dtype=np.float32)
        nice_images = nice_images[:, :, :, None]
        for i in range(BATCH_SIZE):
            idx = int(pre_with_index[i][1])
            nice_images[i, :, :, 0] = generated_images[idx, :, :, 0]
        image = combine_images(nice_images)
    else:
        noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
        generated_images = g.predict(noise, verbose=1)
        image = combine_images(generated_images)
    image = image*127.5+127.5
    Image.fromarray(image.astype(np.uint8)).save(
        "generated_image.png")

In [8]:
train(128)

  This is separate from the ipykernel package so we can avoid doing imports until


Epoch is 0
Number of batches 468
batch 0 d_loss : 0.712963
batch 0 g_loss : 0.690157
batch 1 d_loss : 0.693470
batch 1 g_loss : 0.687619
batch 2 d_loss : 0.671058
batch 2 g_loss : 0.675324
batch 3 d_loss : 0.666001
batch 3 g_loss : 0.667267
batch 4 d_loss : 0.649058
batch 4 g_loss : 0.659376
batch 5 d_loss : 0.624096
batch 5 g_loss : 0.642514
batch 6 d_loss : 0.605722
batch 6 g_loss : 0.631137
batch 7 d_loss : 0.589367
batch 7 g_loss : 0.617047
batch 8 d_loss : 0.568790
batch 8 g_loss : 0.608502
batch 9 d_loss : 0.555324
batch 9 g_loss : 0.600668
batch 10 d_loss : 0.543000
batch 10 g_loss : 0.586789
batch 11 d_loss : 0.527616
batch 11 g_loss : 0.580579
batch 12 d_loss : 0.509056
batch 12 g_loss : 0.567274
batch 13 d_loss : 0.498991
batch 13 g_loss : 0.555326
batch 14 d_loss : 0.490862
batch 14 g_loss : 0.552215
batch 15 d_loss : 0.488205
batch 15 g_loss : 0.542100
batch 16 d_loss : 0.478661
batch 16 g_loss : 0.538384
batch 17 d_loss : 0.466474
batch 17 g_loss : 0.527592
batch 18 d_loss

batch 150 d_loss : 0.249867
batch 150 g_loss : 0.837692
batch 151 d_loss : 0.245821
batch 151 g_loss : 0.845293
batch 152 d_loss : 0.249444
batch 152 g_loss : 0.906833
batch 153 d_loss : 0.242098
batch 153 g_loss : 0.893135
batch 154 d_loss : 0.258660
batch 154 g_loss : 0.883668
batch 155 d_loss : 0.289212
batch 155 g_loss : 0.922704
batch 156 d_loss : 0.288055
batch 156 g_loss : 0.901016
batch 157 d_loss : 0.259395
batch 157 g_loss : 0.894757
batch 158 d_loss : 0.275787
batch 158 g_loss : 0.939145
batch 159 d_loss : 0.245256
batch 159 g_loss : 0.939180
batch 160 d_loss : 0.245505
batch 160 g_loss : 0.980663
batch 161 d_loss : 0.241653
batch 161 g_loss : 0.999423
batch 162 d_loss : 0.253789
batch 162 g_loss : 1.013940
batch 163 d_loss : 0.249566
batch 163 g_loss : 1.071289
batch 164 d_loss : 0.275921
batch 164 g_loss : 1.043800
batch 165 d_loss : 0.255303
batch 165 g_loss : 1.045381
batch 166 d_loss : 0.271136
batch 166 g_loss : 1.148969
batch 167 d_loss : 0.241880
batch 167 g_loss : 1

batch 298 d_loss : 0.574658
batch 298 g_loss : 0.908186
batch 299 d_loss : 0.582629
batch 299 g_loss : 0.917890
batch 300 d_loss : 0.606080
batch 300 g_loss : 0.901772
batch 301 d_loss : 0.572360
batch 301 g_loss : 0.950799
batch 302 d_loss : 0.583085
batch 302 g_loss : 0.976962
batch 303 d_loss : 0.517421
batch 303 g_loss : 1.043594
batch 304 d_loss : 0.603699
batch 304 g_loss : 1.050511
batch 305 d_loss : 0.514088
batch 305 g_loss : 1.081458
batch 306 d_loss : 0.487782
batch 306 g_loss : 1.168267
batch 307 d_loss : 0.550568
batch 307 g_loss : 1.094818
batch 308 d_loss : 0.507594
batch 308 g_loss : 1.161890
batch 309 d_loss : 0.456510
batch 309 g_loss : 1.180079
batch 310 d_loss : 0.647902
batch 310 g_loss : 1.147270
batch 311 d_loss : 0.599441
batch 311 g_loss : 1.012503
batch 312 d_loss : 0.506220
batch 312 g_loss : 1.005777
batch 313 d_loss : 0.514145
batch 313 g_loss : 0.988546
batch 314 d_loss : 0.543809
batch 314 g_loss : 1.009192
batch 315 d_loss : 0.529881
batch 315 g_loss : 0

batch 446 d_loss : 0.500627
batch 446 g_loss : 1.052918
batch 447 d_loss : 0.516877
batch 447 g_loss : 1.033765
batch 448 d_loss : 0.540100
batch 448 g_loss : 1.025306
batch 449 d_loss : 0.515256
batch 449 g_loss : 1.032378
batch 450 d_loss : 0.518889
batch 450 g_loss : 0.969863
batch 451 d_loss : 0.494298
batch 451 g_loss : 0.979168
batch 452 d_loss : 0.523121
batch 452 g_loss : 1.032262
batch 453 d_loss : 0.515570
batch 453 g_loss : 1.041769
batch 454 d_loss : 0.513583
batch 454 g_loss : 1.017233
batch 455 d_loss : 0.553459
batch 455 g_loss : 0.996767
batch 456 d_loss : 0.527875
batch 456 g_loss : 1.029938
batch 457 d_loss : 0.503970
batch 457 g_loss : 1.034992
batch 458 d_loss : 0.458707
batch 458 g_loss : 1.031411
batch 459 d_loss : 0.501081
batch 459 g_loss : 1.080558
batch 460 d_loss : 0.492093
batch 460 g_loss : 1.030357
batch 461 d_loss : 0.483584
batch 461 g_loss : 1.068798
batch 462 d_loss : 0.504948
batch 462 g_loss : 1.074921
batch 463 d_loss : 0.529442
batch 463 g_loss : 1

In [15]:
generate(128, True)

  This is separate from the ipykernel package so we can avoid doing imports until




In [8]:
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--nice", dest="nice", action="store_true")
    parser.set_defaults(nice=False)
    args = parser.parse_args()
    return args

In [None]:
if __name__ == "__main__":
    args = get_args()
    if args.mode == "train":
        train(BATCH_SIZE=args.batch_size)
    elif args.mode == "generate":
        generate(BATCH_SIZE=args.batch_size, nice=args.nice)