# Generative Adversarial Networks in Keras

In [3]:
%matplotlib inline
import importlib
import utils2; importlib.reload(utils2)
from utils2 import *

from tqdm import tqdm

Using TensorFlow backend.


In [4]:
from keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train.shape

(60000, 28, 28)

In [5]:
n = len(X_train)

In [6]:
X_train = X_train.reshape(n, -1).astype(np.float32)
X_test = X_test.reshape(len(X_test), -1).astype(np.float32)

In [7]:
X_train = X_train - 127.5
X_test = X_test - 127.5

In [8]:
# X_train /= 255.; X_test /= 255.
X_train /= 127.5; X_test /= 127.5

In [14]:
def plot_multi(im, dim=(4,4), figsize=(6,6), **kwargs ):
    plt.figure(figsize=figsize)
    plt.subplots_adjust(wspace=.03, hspace=.03)
    for i,img in enumerate(im):
        plt.subplot(*((dim)+(i+1,)))
        plt.imshow(img, **kwargs)
        plt.axis('off')
    return plt
    
def plot_gen(G, n_ex=16):
    dim1 = math.sqrt(n_ex)
    return plot_multi(G.predict(noise(n_ex)).reshape(n_ex, 28,28), 
                      dim=(dim1,dim1), figsize=(dim1,dim1), cmap='gray')

In [11]:
def noise(bs): return np.random.randn(bs,100)

In [12]:
def data_D(sz, G):
    real_img = X_train[np.random.randint(0,n,size=sz)]
    X = np.concatenate((real_img, G.predict(noise(sz))))
    return X, [1]*sz + [0]*sz

In [13]:
def make_trainable(net, val):
    net.trainable = val
    for l in net.layers: l.trainable = val

In [16]:
def train(D, G, m, nb_epoch=5000, bs=128):
    dl,gl=[],[]
    for e in tqdm(range(nb_epoch)):
        X,y = data_D(bs//2, G)
        dl.append(D.train_on_batch(X,y))
        make_trainable(D, False)
        gl.append(m.train_on_batch(noise(bs), np.ones([bs])))
        make_trainable(D, True)
        if (e <= 200 and e % 20 == 0) or (e > 200 and e % 100 == 0):
            p = plot_multi(G.predict(noise(100)).reshape(100, 28,28), figsize=(10,10), dim=(10,10), cmap='gray')
            p.savefig('../data/results/dcgan/dcgan_mnist_%s.png' % e, bbox_inches='tight')
            p.close()
    return dl,gl

In [17]:
X_train = X_train.reshape(n, 28, 28, 1)
X_test = X_test.reshape(len(X_test), 28, 28, 1)

In [19]:
CNN_G = Sequential([
    Dense(128*7*7, input_dim=100, 
          kernel_initializer=initializers.random_normal(stddev=0.01)),
    LeakyReLU(.2),
    BatchNormalization(),
    Reshape((7, 7, 128)),
    UpSampling2D(),
    Convolution2D(64, (5, 5), padding='same'),
    LeakyReLU(.2),
    BatchNormalization(),
    UpSampling2D(),
    Convolution2D(1, (5, 5), padding='same', activation='tanh')
])

In [21]:
CNN_D = Sequential([
    Convolution2D(64, (5, 5), strides=(2,2), padding='same', input_shape=(28, 28, 1), 
                  kernel_initializer=initializers.random_normal(stddev=0.01)),
    LeakyReLU(.2),
    Dropout(.3),
    Convolution2D(128, (5, 5), strides=(2,2), padding='same'),
    LeakyReLU(.2),
    Dropout(.3),
    # GlobalAveragePooling2D(),
    Flatten(),
    Dense(1, activation='sigmoid')
])

CNN_D.compile(keras.optimizers.RMSprop(lr=0.0008, clipvalue=1.0, decay=6e-8), "binary_crossentropy")

In [24]:
CNN_m = Sequential([CNN_G, CNN_D])
CNN_m.compile(keras.optimizers.RMSprop(lr=0.0004, clipvalue=1.0, decay=3e-8), "binary_crossentropy")

In [25]:
dl,gl = train(CNN_D, CNN_G, CNN_m, 5000)

100%|██████████| 5000/5000 [13:44<00:00,  6.06it/s]


In [27]:
CNN_G.save_weights('../data/results/dcgan_mnist_generator.h5')
CNN_D.save_weights('../data/results/dcgan_mnist_discriminator.h5')