In [1]:
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers.core import Activation
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Flatten
from keras.optimizers import SGD
from keras.datasets import mnist
import numpy as np
from PIL import Image
import math

Using TensorFlow backend.


In [0]:
def generator_model():
    model = Sequential()
    model.add(Dense(1024, input_shape=(100, ), activation="tanh"))
    model.add(Dense(128 * 7 * 7))
    model.add(BatchNormalization())
    model.add(Activation("tanh"))
    model.add(Reshape((7, 7, 128), input_shape=(7 * 7 * 128,)))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(64, (5, 5),
                     padding="same",
                     activation="tanh",
                     data_format="channels_last"))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(1, (5, 5),
                     padding="same",
                     activation="tanh",
                     data_format="channels_last"))
    return model

In [0]:
def discriminator_model():
    model = Sequential()
    model.add(Conv2D(64, (5, 5),
                     padding="same",
                     input_shape=(28, 28, 1),
                     activation="tanh",
                     data_format="channels_last"))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(128, (5, 5),
                     activation="tanh",
                     data_format="channels_last"))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(1024, activation="tanh"))
    model.add(Dense(1, activation="sigmoid"))
    return model

In [0]:
def generator_containing_discriminator(generator, discriminator):
    model = Sequential()
    model.add(generator)
    discriminator.trainable = False
    model.add(discriminator)
    return model

In [0]:
def combine_images(generated_images):
    generated_images = generated_images.reshape(generated_images.shape[0],
                                                generated_images.shape[3],
                                                generated_images.shape[1],
                                                generated_images.shape[2])
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num) / width))
    shape = generated_images.shape[2:]
    image = np.zeros((height*shape[0], width*shape[1]),
                     dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index/width)
        j = index % width
        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \
            img[0, :, :]
    return image

In [0]:
def train(BATCH_SIZE):
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    X_train = (X_train.astype(np.float32) - 127.5)/127.5
    X_train = X_train.reshape((X_train.shape[0], 1) + X_train.shape[1:])
    discriminator = discriminator_model()
    generator = generator_model()
    discriminator_on_generator = \
        generator_containing_discriminator(generator, discriminator)
    d_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
    g_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
    generator.compile(loss="binary_crossentropy", optimizer="SGD")
    discriminator_on_generator.compile(
        loss="binary_crossentropy", optimizer=g_optim)
    discriminator.trainable = True
    discriminator.compile(loss="binary_crossentropy", optimizer=d_optim)
    noise = np.zeros((BATCH_SIZE, 100))
    for epoch in range(100):
        print("Epoch is", epoch)
        print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))
        for index in range(int(X_train.shape[0]/BATCH_SIZE)):
            for i in range(BATCH_SIZE):
                noise[i, :] = np.random.uniform(-1, 1, 100)
            image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
            image_batch = image_batch.reshape(image_batch.shape[0],
                                              image_batch.shape[2],
                                              image_batch.shape[3],
                                              image_batch.shape[1])
            generated_images = generator.predict(noise, verbose=0)
            if index % 20 == 0:
                image = combine_images(generated_images)
                image = image*127.5+127.5
                Image.fromarray(image.astype(np.uint8)).save(
                    str(epoch)+"_"+str(index)+".png")
            X = np.concatenate((image_batch, generated_images))
            y = [1] * BATCH_SIZE + [0] * BATCH_SIZE
            d_loss = discriminator.train_on_batch(X, y)
            print("epoch %d batch %d d_loss : %f" % (epoch, index, d_loss), end=' ')
            for i in range(BATCH_SIZE):
                noise[i, :] = np.random.uniform(-1, 1, 100)
            discriminator.trainable = False
            g_loss = discriminator_on_generator.train_on_batch(
                noise, [1] * BATCH_SIZE)
            discriminator.trainable = True
            print("g_loss : %f" % g_loss)
            if index % 10 == 9:
                generator.save_weights("generator", True)
                discriminator.save_weights("discriminator", True)

In [0]:
def generate(BATCH_SIZE, nice=False):
    generator = generator_model()
    generator.compile(loss="binary_crossentropy", optimizer="SGD")
    generator.load_weights("generator")
    if nice:
        discriminator = discriminator_model()
        discriminator.compile(loss="binary_crossentropy", optimizer="SGD")
        discriminator.load_weights("discriminator")
        noise = np.zeros((BATCH_SIZE*20, 100))
        for i in range(BATCH_SIZE*20):
            noise[i, :] = np.random.uniform(-1, 1, 100)
        generated_images = generator.predict(noise, verbose=1)
        d_pret = discriminator.predict(generated_images, verbose=1)
        index = np.arange(0, BATCH_SIZE*20)
        index.resize((BATCH_SIZE*20, 1))
        pre_with_index = list(np.append(d_pret, index, axis=1))
        pre_with_index.sort(key=lambda x: x[0], reverse=True)
        nice_images = np.zeros((BATCH_SIZE, 1) +
                               (generated_images.shape[2:]), dtype=np.float32)
        for i in range(int(BATCH_SIZE)):
            idx = int(pre_with_index[i][1])
            nice_images[i, 0, :, :] = generated_images[idx, 0, :, :]
        image = combine_images(nice_images)
    else:
        noise = np.zeros((BATCH_SIZE, 100))
        for i in range(BATCH_SIZE):
            noise[i, :] = np.random.uniform(-1, 1, 100)
        generated_images = generator.predict(noise, verbose=1)
        image = combine_images(generated_images)
    image = image*127.5+127.5
    Image.fromarray(image.astype(np.uint8)).save(
        "generated_image.png")

In [8]:
batch_size = 2048

train(BATCH_SIZE=batch_size)

Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz
Epoch is 0
Number of batches 29
epoch 0 batch 0 d_loss : 0.654408 g_loss : 0.642399
epoch 0 batch 1 d_loss : 0.646542 g_loss : 0.639732
epoch 0 batch 2 d_loss : 0.639531 g_loss : 0.634365
epoch 0 batch 3 d_loss : 0.629243 g_loss : 0.631074
epoch 0 batch 4 d_loss : 0.617791 g_loss : 0.625404
epoch 0 batch 5 d_loss : 0.607520 g_loss : 0.619827
epoch 0 batch 6 d_loss : 0.598166 g_loss : 0.614048
epoch 0 batch 7 d_loss : 0.580690 g_loss : 0.607806
epoch 0 batch 8 d_loss : 0.568847 g_loss : 0.602620
epoch 0 batch 9 d_loss : 0.558618 g_loss : 0.596358
epoch 0 batch 10 d_loss : 0.545028 g_loss : 0.590469
epoch 0 batch 11 d_loss : 0.536879 g_loss : 0.585429
epoch 0 batch 12 d_loss : 0.522759 g_loss : 0.581719
epoch 0 batch 13 d_loss : 0.515672 g_loss : 0.574403
epoch 0 batch 14 d_loss : 0.506115 g_loss : 0.571806
epoch 0 batch 15 d_loss : 0.497043 g_loss : 0.567917
epoch 0 batch 16 d_loss : 0.491329 g_loss : 0.563444
epoch 0

KeyboardInterrupt: ignored

In [9]:
generate(BATCH_SIZE=2048)

