In [11]:
from __future__ import division

import os
import operator
import numpy as np
import matplotlib.pyplot as plt
import cPickle as pickle

from keras.models import Sequential
from keras.layers.core import Dense, Dropout
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD

%matplotlib inline

In [23]:
data_files = {'cifar': '../data/cifar.npy',
              'mnist': '../data/mnist.npy',
              'lfw': '../data/lfw.npy'}

data = np.load(data_files['mnist'])
data = data.astype('float32')
data_shape = data.shape[1:3]
data_dim = reduce(operator.mul, data.shape[1:])

print 'Loaded data {}'.format(data.shape)

Loaded data (70000, 28, 28, 1)


In [32]:
# setup optimizer
opt = SGD(lr=0.01, momentum=0.1)

# setup generator network
generator = Sequential()
generator.add(Dense(2048*2, input_dim=2048, activation='relu'))
generator.add(Dense(1024*2, activation='relu'))
generator.add(Dense(data_dim, activation='linear'))
generator.compile(loss='binary_crossentropy', optimizer=opt)

# setup discriminator network
discriminator = Sequential()
discriminator.add(Dense(2048, input_dim=data_dim, activation='relu'))
discriminator.add(Dense(1024, activation='relu'))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=opt)

# setup combined network
gen_dis = Sequential()
gen_dis.add(generator)
discriminator.trainable = False
gen_dis.add(discriminator)
gen_dis.compile(loss='binary_crossentropy', optimizer=opt)

In [None]:
batch_size = 512

gen_labels = np.ones(2*batch_size)
gen_labels[:batch_size] = 0
gen_disc_labels = np.ones(batch_size)

for i in range(1000):
    zmb = np.random.uniform(-1, 1, size=(batch_size, 2048)).astype('float32')
    n = np.random.randint(0, data.shape[0]-data_dim)
    xmb = data[n:n+batch_size].reshape(batch_size, -1)
    if i % 10 == 0:
        r = gen_dis.fit(zmb, gen_disc_labels, epochs=1, verbose=0)
        print 'G:', r.history['loss']
    else:
        r = discriminator.fit(np.vstack([generator.predict(zmb),xmb]), gen_labels, epochs=1,verbose=0)
        print 'D:', r.history['loss']
    
    if i % 100 == 0:
        fakes = generator.predict(zmb[:16,:])
        plt.imshow(fakes[0].reshape(28, 28), cmap='gray_r')
        plt.show()