In [None]:
# hacks to avoid using the gpu if you have one 
# import os 
# os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
%matplotlib inline

In [None]:
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers.core import Activation
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Flatten
from keras.optimizers import Adam
from keras.datasets import mnist
from PIL import Image
import numpy as np
import argparse
import math
import matplotlib.pyplot as plt
import sys
import os 

# Generative Adversarial Networks
Accompanying the [ML@B](https://ml.berkeley.edu) Fall Workshop on Generative Adversarial Networks. Original code from https://github.com/jacobgil/keras-dcgan

In [None]:
# params
BATCH_SIZE = 128
LEARNING_RATE=0.0005
BETA=0.5

In [None]:
def generator_model(): 
    generator = Sequential()
    generator.add(Dense(input_dim=100, units=1024))
    generator.add(Activation('tanh'))
    generator.add(Dense(128*7*7))
    generator.add(BatchNormalization())
    generator.add(Activation('tanh'))
    generator.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))
    generator.add(UpSampling2D(size=(2, 2)))
    generator.add(Conv2D(64, (5, 5), padding='same'))
    generator.add(Activation('tanh'))
    generator.add(UpSampling2D(size=(2, 2)))
    generator.add(Conv2D(1, (5, 5), padding='same'))
    generator.add(Activation('tanh'))
    return generator

In [None]:
def discriminator_model():
    discriminator = Sequential()
    discriminator.add(
            Conv2D(64, (5, 5),
            padding='same',
            input_shape=(28, 28, 1))
            )
    discriminator.add(Activation('tanh'))
    discriminator.add(MaxPooling2D(pool_size=(2, 2)))
    discriminator.add(Conv2D(128, (5, 5)))
    discriminator.add(Activation('tanh'))
    discriminator.add(MaxPooling2D(pool_size=(2, 2)))
    discriminator.add(Flatten())
    discriminator.add(Dense(1024))
    discriminator.add(Activation('tanh'))
    discriminator.add(Dense(1))
    discriminator.add(Activation('sigmoid'))
    return discriminator

In [None]:
def generator_containing_discriminator(g, d):
    model = Sequential()
    model.add(g)
    d.trainable = False
    model.add(d)
    return model

In [None]:
def combine_images(generated_images):
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num)/width))
    shape = generated_images.shape[1:3]
    image = np.zeros((height*shape[0], width*shape[1]),
                     dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index/width)
        j = index % width
        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \
            img[:, :, 0]
    return image

In [None]:
import os 

In [None]:
if not os.path.exists('samples'):
    os.makedirs('samples')

In [None]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_train = X_train[:, :, :, None]
X_test = X_test[:, :, :, None]
# X_train = X_train.reshape((X_train.shape, 1) + X_train.shape[1:])
d = discriminator_model()
g = generator_model()
d_on_g = generator_containing_discriminator(g, d)
d_optim = Adam(lr=LEARNING_RATE, beta_1=BETA)
g_optim = Adam(lr=LEARNING_RATE, beta_1=BETA)
g.compile(loss='binary_crossentropy', optimizer="SGD")
d_on_g.compile(loss='binary_crossentropy', optimizer=g_optim)
d.trainable = True
d.compile(loss='binary_crossentropy', optimizer=d_optim)
for epoch in range(100):
    print("Epoch is", epoch)
    print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))
    for index in range(int(X_train.shape[0]/BATCH_SIZE)):
        noise = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))
        image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
        generated_images = g.predict(noise, verbose=0)
        if index % 20 == 0:
            image = combine_images(generated_images)
            image = image*127.5+127.5
            Image.fromarray(image.astype(np.uint8)).save(
                 'samples/'+str(epoch)+"_"+str(index)+".png")
        X = np.concatenate((image_batch, generated_images))
        y = [1] * BATCH_SIZE + [0] * BATCH_SIZE
        d_loss = d.train_on_batch(X, y)
#         print("batch %d d_loss : %f" % (index, d_loss))
        noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
        d.trainable = False
        g_loss = d_on_g.train_on_batch(noise, [1] * BATCH_SIZE)
        d.trainable = True
        sys.stdout.write('\r'+"batch%d  d_loss: %f; g_loss : %f" % (index, d_loss, g_loss))
        if index % 10 == 9:
            g.save_weights('generator', True)
            d.save_weights('discriminator', True)


In [None]:
def display_image(img):
    plt.axis('off')
#     plt.figure(dpi = 100)
    plt.imshow(img, cmap='gray')
    plt.show()

In [None]:
def generate(BATCH_SIZE, nice=False, gen_model='generator', disc_model='discriminator'):
    g = generator_model()
    g.compile(loss='binary_crossentropy', optimizer="SGD")
    g.load_weights(gen_model)
    if nice:
        d = discriminator_model()
        d.compile(loss='binary_crossentropy', optimizer="SGD")
        d.load_weights(disc_model)
        noise = np.random.uniform(-1, 1, (BATCH_SIZE*20, 100))
        generated_images = g.predict(noise, verbose=1)
        d_pret = d.predict(generated_images, verbose=1)
        index = np.arange(0, BATCH_SIZE*20)
        index.resize((BATCH_SIZE*20, 1))
        pre_with_index = list(np.append(d_pret, index, axis=1))
        pre_with_index.sort(key=lambda x: x[0], reverse=True)
        nice_images = np.zeros((BATCH_SIZE,) + generated_images.shape[1:3], dtype=np.float32)
        nice_images = nice_images[:, :, :, None]
        for i in range(BATCH_SIZE):
            idx = int(pre_with_index[i][1])
            nice_images[i, :, :, 0] = generated_images[idx, :, :, 0]
        image = combine_images(nice_images)
    else:
        noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
        generated_images = g.predict(noise, verbose=1)
        image = combine_images(generated_images)
    image = image*127.5+127.5
    return image


In [None]:
display_image(generate(20, nice=False))

In [None]:
# code for the turkey that's already in the oven
# display_image(generate(20, nice=False, gen_model='good_gen', disc_model='good_disc' ))