In [1]:
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers.core import Activation
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Flatten, Dropout
from keras.optimizers import SGD, Adam
from keras.datasets import mnist
from keras.datasets import cifar10
import numpy as np
from PIL import Image
import argparse
import math

Using TensorFlow backend.


In [2]:
def generator_model():
    model = Sequential()
    model.add(Dense(input_dim=100, units=1024))
    model.add(Activation('tanh'))
    model.add(Dense(128*8*8))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(Reshape((8, 8, 128), input_shape=(128*8*8,)))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(64, (5, 5), padding='same'))
    model.add(Activation('tanh'))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(3, (5, 5), padding='same'))
    model.add(Activation('tanh'))
    return model

In [3]:
def discriminator_model():
    model = Sequential()
    model.add(
            Conv2D(64, (5, 5),
            padding='same',
            input_shape=(32, 32, 3))
            )
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(128, (5, 5)))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(1024))
    model.add(Activation('tanh'))
    model.add(Dropout(0.5))
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    return model

In [4]:
def combine_images(generated_images):
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num)/width))
    shape = generated_images.shape[1:3]
    image = np.zeros((height*shape[0], width*shape[1], 3), dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index/width)
        j = index % width
        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = img[:, :, 0:3]
    return image

In [5]:
 def train(BATCH_SIZE):
    (X_train, y_train), (X_test, y_test) = cifar10.load_data()
    #X_train = np.reshape(np.average(X_train, axis = 3), (50000, 32, 32, 1))
    X_train = (X_train.astype(np.float32) - 127.5)/127.5

    d = discriminator_model()
    d_opt = Adam(lr=1e-5, beta_1=0.1)
    d.compile(loss='binary_crossentropy', optimizer=d_opt)
    d.trainable = False
    g = generator_model()
    dcgan = Sequential([g, d])
    g_opt = Adam(lr=2e-4, beta_1=0.5)
    dcgan.compile(loss='binary_crossentropy', optimizer=g_opt)
    
    for epoch in range(100):
        print("Epoch is", epoch)
        print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))
        for index in range(int(X_train.shape[0]/BATCH_SIZE)):
            noise = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))
            image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
            generated_images = g.predict(noise, verbose=0)
            if index % 20 == 0:
                image = combine_images(generated_images)
                image = image*127.5+127.5
                Image.fromarray(image.astype(np.uint8)).save("./img/"+str(epoch)+"_"+str(index)+".png")
            X = np.concatenate((image_batch, generated_images))
            y = [1] * BATCH_SIZE + [0] * BATCH_SIZE
            d_loss = d.train_on_batch(X, y)
            print("batch %d d_loss : %f" % (index, d_loss))
            noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
            g_loss = dcgan.train_on_batch(noise, [1] * BATCH_SIZE)
            print("batch %d g_loss : %f" % (index, g_loss))
            if index % 10 == 9:
                g.save_weights('generator', True)
                d.save_weights('discriminator', True)

In [6]:
def generate(BATCH_SIZE, nice=False):
    g = generator_model()
    g.compile(loss='binary_crossentropy', optimizer="SGD")
    g.load_weights('generator')
    if nice:
        d = discriminator_model()
        d.compile(loss='binary_crossentropy', optimizer="SGD")
        d.load_weights('discriminator')
        noise = np.random.uniform(-1, 1, (BATCH_SIZE*20, 100))
        generated_images = g.predict(noise, verbose=1)
        d_pret = d.predict(generated_images, verbose=1)
        index = np.arange(0, BATCH_SIZE*20)
        index.resize((BATCH_SIZE*20, 1))
        pre_with_index = list(np.append(d_pret, index, axis=1))
        pre_with_index.sort(key=lambda x: x[0], reverse=True)
        nice_images = np.zeros((BATCH_SIZE,) + generated_images.shape[1:3], dtype=np.float32)
        nice_images = nice_images[:, :, :, None]
        for i in range(BATCH_SIZE):
            idx = int(pre_with_index[i][1])
            nice_images[i, :, :, 0] = generated_images[idx, :, :, 0]
        image = combine_images(nice_images)
    else:
        noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
        generated_images = g.predict(noise, verbose=1)
        image = combine_images(generated_images)
    image = image*127.5+127.5
    Image.fromarray(image.astype(np.uint8)).save("generated_image.png")

In [7]:
train(BATCH_SIZE=100)

Epoch is 0
Number of batches 500


  'Discrepancy between trainable weights and collected trainable'


batch 0 d_loss : 0.707008
batch 0 g_loss : 0.670932
batch 1 d_loss : 0.700514
batch 1 g_loss : 0.597870
batch 2 d_loss : 0.696429
batch 2 g_loss : 0.538662
batch 3 d_loss : 0.704431
batch 3 g_loss : 0.462353
batch 4 d_loss : 0.696870
batch 4 g_loss : 0.418958
batch 5 d_loss : 0.694693
batch 5 g_loss : 0.406823
batch 6 d_loss : 0.703857
batch 6 g_loss : 0.375859
batch 7 d_loss : 0.718289
batch 7 g_loss : 0.358038
batch 8 d_loss : 0.716161
batch 8 g_loss : 0.366927
batch 9 d_loss : 0.723528
batch 9 g_loss : 0.360275
batch 10 d_loss : 0.747817
batch 10 g_loss : 0.363908
batch 11 d_loss : 0.723017
batch 11 g_loss : 0.380034
batch 12 d_loss : 0.746603
batch 12 g_loss : 0.376350
batch 13 d_loss : 0.773362
batch 13 g_loss : 0.399260
batch 14 d_loss : 0.745472
batch 14 g_loss : 0.408827
batch 15 d_loss : 0.763203
batch 15 g_loss : 0.413658
batch 16 d_loss : 0.782959
batch 16 g_loss : 0.438747
batch 17 d_loss : 0.784774
batch 17 g_loss : 0.452947
batch 18 d_loss : 0.785310
batch 18 g_loss : 0.4

KeyboardInterrupt: 