In [1]:
from keras.layers import Input, Dense, Reshape, Flatten, Embedding, Dropout
from keras.layers import LeakyReLU
from keras.layers import multiply
from keras.layers import BatchNormalization
from keras.layers import Conv2D, UpSampling2D
from keras.layers import Reshape, Conv2DTranspose
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.layers.noise import GaussianNoise
import keras.backend as K
from keras.datasets import mnist

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline
# np.random.seed(2018)

Using TensorFlow backend.


### generator

In [2]:
def build_generator(latent_size):
    model = Sequential()
    # model.add(Dense(1024, input_dim=latent_size, activation='relu'))
    model.add(Dense(128 * 7 * 7, activation="relu", input_dim=latent_size))
    model.add(Reshape((7, 7, 128)))

    model.add(Conv2DTranspose(256, kernel_size=5, strides=1, padding='same',
                              activation='relu', kernel_initializer='glorot_normal'))
    model.add(BatchNormalization())
    
    model.add(Conv2DTranspose(128, kernel_size=5, strides=2, padding='same', 
                              activation='relu', kernel_initializer='glorot_normal'))
    model.add(BatchNormalization())
    
    model.add(Conv2DTranspose(64, kernel_size=5, strides=2, padding='same', 
                              activation='relu', kernel_initializer='glorot_normal'))
    model.add(BatchNormalization())  
    
    model.add(Conv2DTranspose(1, kernel_size=3, strides=1, padding='same', 
                              activation='tanh', kernel_initializer='glorot_normal'))

    # this is the z space commonly refered to in GAN papers
    latent = Input(shape=(latent_size, ))

    # label
    image_label = Input(shape=(1,), dtype='int32')

    # 10 classes in MNIST
    embed = Embedding(10, latent_size, embeddings_initializer='glorot_normal')(image_label)
    label_embed = Flatten()(embed)

    model_input = multiply([latent, label_embed])

    fake_image = model(model_input)
    gen_model =  Model(inputs=[latent, image_label], outputs=fake_image)   
    return gen_model

### discriminator

In [3]:
def build_discriminator():
    model = Sequential()
    # model.add(Conv2D(32, kernel_size=3, strides=2, padding='same', input_shape=(28, 28, 1),kernel_initializer='glorot_normal'))
    model.add(GaussianNoise(0.05, input_shape=(28, 28, 1))) 
    model.add(Conv2D(32, kernel_size=3, strides=2, padding='same', kernel_initializer='glorot_normal'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.2))

    model.add(Conv2D(64, kernel_size=3, strides=1, padding='same', kernel_initializer='glorot_normal'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dropout(0.2))

    model.add(Conv2D(128, kernel_size=3, strides=2, padding='same', kernel_initializer='glorot_normal'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dropout(0.2))

    model.add(Conv2D(256, kernel_size=3, strides=1, padding='same', kernel_initializer='glorot_normal'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.2))

    model.add(Flatten())

    image = Input(shape=(28, 28, 1))

    features = model(image)

    # first output (name=generation) is whether or not the discriminator
    # thinks the image that is being shown is fake, and the second output
    # (name=auxiliary) is the class that the discriminator thinks the image
    # belongs to.
    fake = Dense(1, activation='sigmoid', name='generation')(features)
    aux = Dense(11, activation='softmax', name='auxiliary')(features)

    return Model(inputs=image, outputs=[fake, aux])


### build model

In [4]:
latent_size = 100

# Adam parameters  https://arxiv.org/abs/1511.06434
adam_lr = 0.0002
adam_beta_1 = 0.5

# build the discriminator
discriminator = build_discriminator()
discriminator.compile(optimizer=Adam(lr=adam_lr, beta_1=adam_beta_1),
                      loss=['binary_crossentropy', 'sparse_categorical_crossentropy'])
# discriminator.summary()

# build the generator
generator = build_generator(latent_size)
generator.compile(optimizer=Adam(lr=adam_lr, beta_1=adam_beta_1), 
                  loss='binary_crossentropy')
# generator.summary()

latent = Input(shape=(latent_size, ))
image_class = Input(shape=(1,), dtype='int32')

# get a fake image
fake = generator([latent, image_class])

discriminator.trainable = False
fake, aux = discriminator(fake)

combined = Model(inputs=[latent, image_class], outputs=[fake, aux])
combined.compile(optimizer=Adam(lr=adam_lr, beta_1=adam_beta_1),
                 loss=['binary_crossentropy', 'sparse_categorical_crossentropy'])

combined.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            (None, 100)          0                                            
__________________________________________________________________________________________________
input_5 (InputLayer)            (None, 1)            0                                            
__________________________________________________________________________________________________
model_2 (Model)                 (None, 28, 28, 1)    2480489     input_4[0][0]                    
                                                                 input_5[0][0]                    
__________________________________________________________________________________________________
model_1 (Model)                 [(None, 1), (None, 1 539148      model_2[1][0]                    
Total para

In [5]:
(X_train, y_train), (_, _) = mnist.load_data()

X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
y_train = y_train.reshape(-1, 1)

In [6]:
def train(X_train, y_train, epochs=3, batch_size=64, sample_interval=50):

    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in range(epochs):
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs = X_train[idx]
        
        #noise = np.random.normal(0, 1, (batch_size, 100))
        noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
            
        sampled_labels = np.random.randint(0, 10, (batch_size, 1))

        gen_imgs = generator.predict([noise, sampled_labels])

        img_labels = y_train[idx]
        fake_labels = 10 * np.ones(img_labels.shape)

        d_loss_real = discriminator.train_on_batch(imgs, [valid, img_labels])
        d_loss_fake = discriminator.train_on_batch(gen_imgs, [fake, fake_labels])
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        g_loss = combined.train_on_batch([noise, sampled_labels], [valid, sampled_labels])
        
        if epoch % sample_interval == 0:
            print ("epoch: %d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1],  g_loss[0]))
            save_images(generator, epoch)

In [7]:
import os
def save_images(generator, epoch):
    r, c = 10, 10
    noise = np.random.normal(0, 1, (r * c, 100))
    sampled_labels = np.array([num for _ in range(r) for num in range(c)])
    gen_imgs = generator.predict([noise, sampled_labels])
        # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
            
    if not os.path.exists("images"):
        os.makedirs("images")            
    fig.savefig("images/%d.png" % epoch)
    plt.close()

In [8]:
train(X_train, y_train, epochs=2500, batch_size=64, sample_interval=50)