In [33]:
import os
os.environ['KERAS_BACKEND'] = 'tensorflow'

In [34]:
import numpy as np

from keras import backend
from keras.layers import Dense, LeakyReLU, Dropout, Input
from keras.models import Model
from keras.datasets import fashion_mnist
from keras.optimizers import Adam
from keras import initializers

import matplotlib.pyplot as plt
%matplotlib inline

In [35]:
np.random.seed(10)
random_dim = 100

In [36]:
backend.backend()

'tensorflow'

In [37]:
def load_fashion_mnist_data():
    (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
    
    # normalize our inputs to be in the range[-1, 1]
    x_train = (x_train.astype(np.float32) - 127.5) / 127.5
    
    x_train = x_train.reshape(x_train.shape[0], 784)
    return (x_train, y_train, x_test, y_test)

In [38]:
def get_optimizer():
    return Adam(lr=0.0002, beta_1=0.5)

def get_generator(optimizer):
    input_tensor = Input(shape=(random_dim,))
    out = Dense(256, kernel_initializer=initializers.RandomNormal(stddev=0.02))(input_tensor)
    out = LeakyReLU(0.2)(out)
    
    out = Dense(512)(out)
    out = LeakyReLU(0.2)(out)
    
    out = Dense(1024)(out)
    out = LeakyReLU(0.2)(out)
    
    out = Dense(784, activation='tanh')(out)
    generator = Model(input_tensor, out)
    generator.compile(optimizer=optimizer, loss='binary_crossentropy')
    return generator

def get_discriminator(optimizer):
    input_tensor = Input(shape=(784,))
    out = Dense(1024, kernel_initializer=initializers.RandomNormal(stddev=0.02))(input_tensor)
    out = LeakyReLU(0.2)(out)
    out = Dropout(0.3)(out)
    
    out = Dense(512)(out)
    out = LeakyReLU(0.2)(out)
    out = Dropout(0.3)(out)
    
    out = Dense(256)(out)
    out = LeakyReLU(0.2)(out)
    out = Dropout(0.3)(out)
    
    out = Dense(1, activation='sigmoid')(out)
    discriminator = Model(input_tensor, out)
    discriminator.compile(optimizer=optimizer, loss='binary_crossentropy')
    return discriminator

In [39]:
def get_gan_network(discriminator, random_dim, generator, optimizer):
    # We initially set trainable to False since we only want to train either the
    # generator or discriminator at a time
    discriminator.trainable = False
    
    gan_input = Input(shape=(random_dim,))
    out = generator(gan_input)
    out = discriminator(out)
    
    gan = Model(gan_input, out)
    gan.compile(optimizer=optimizer, loss='binary_crossentropy')
    return gan

In [40]:
def save_epoch_weights_and_images(epoch, generator, discriminator, examples=100, dim=(10, 10), figsize=(10, 10)):
    generator.save_weights('gan_weights/gan_generator_weights_epoch_{}.hdf5'.format(epoch))
    discriminator.save_weights('gan_weights/gan_discriminator_weights_epoch_{}.hdf5'.format(epoch))
    
    noise = np.random.normal(0, 1, size=[examples, random_dim])
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(examples, 28, 28)
    
    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('gan_images/mnist_fashion/gan_fashion_epoch_{}.png'.format(epoch))

In [41]:
def train(epochs=1, batch_size=128):
    x_train, y_train, x_test, y_test = load_fashion_mnist_data()
    batch_count = x_train.shape[0] // batch_size
    
    adam = get_optimizer()
    generator = get_generator(adam)
    discriminator = get_discriminator(adam)
    gan = get_gan_network(discriminator, random_dim, generator, adam)
    
    for e in range(1, epochs+1):
        print('-----------Epoch {}-----------'.format(e))
        for batch_num in range(batch_count):
            # Generated Images
            noise = np.random.normal(0, 1, size=[batch_size, random_dim])
            generated_images = generator.predict(noise)
            
            # Authentic/Real Images
            image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]
            
            # Total Input
            X = np.concatenate([image_batch, generated_images])
            # Output
            y_dis = np.zeros(2 * batch_size)
            y_dis[:batch_size] = 0.9
            
            # Train the Discriminator
            discriminator.trainable = True
            discriminator.train_on_batch(X, y_dis)
            
            # Train generator
            noise = np.random.normal(0, 1, size=[batch_size, random_dim])
            y_gen = np.ones(batch_size)
            discriminator.trainable = False
            gan.train_on_batch(noise, y_gen)
        
        if e==1 or e%5==0:
            save_epoch_weights_and_images(e, generator, discriminator)

In [None]:
train(400)

-----------Epoch 1-----------
-----------Epoch 2-----------
-----------Epoch 3-----------
-----------Epoch 4-----------
-----------Epoch 5-----------
-----------Epoch 6-----------
-----------Epoch 7-----------
-----------Epoch 8-----------
-----------Epoch 9-----------
-----------Epoch 10-----------
-----------Epoch 11-----------
-----------Epoch 12-----------
-----------Epoch 13-----------
-----------Epoch 14-----------
-----------Epoch 15-----------
-----------Epoch 16-----------
-----------Epoch 17-----------
-----------Epoch 18-----------
-----------Epoch 19-----------
-----------Epoch 20-----------
-----------Epoch 21-----------
-----------Epoch 22-----------
-----------Epoch 23-----------
-----------Epoch 24-----------
-----------Epoch 25-----------
-----------Epoch 26-----------
-----------Epoch 27-----------
-----------Epoch 28-----------
-----------Epoch 29-----------
-----------Epoch 30-----------
-----------Epoch 31-----------
-----------Epoch 32-----------
-----------Epoch 

  if __name__ == '__main__':


-----------Epoch 101-----------
-----------Epoch 102-----------
-----------Epoch 103-----------
-----------Epoch 104-----------
-----------Epoch 105-----------
-----------Epoch 106-----------
-----------Epoch 107-----------
-----------Epoch 108-----------
-----------Epoch 109-----------
-----------Epoch 110-----------
-----------Epoch 111-----------
-----------Epoch 112-----------
-----------Epoch 113-----------
-----------Epoch 114-----------
-----------Epoch 115-----------
-----------Epoch 116-----------
-----------Epoch 117-----------
-----------Epoch 118-----------
-----------Epoch 119-----------
-----------Epoch 120-----------
-----------Epoch 121-----------


KeyboardInterrupt: 

#### Load weights

In [None]:
generator = get_generator()
generator.load_weights('')

discriminator = get_discriminator()
discriminator.load_weights('')

In [32]:
def train_on_gen_dis(generator, discriminator, epochs_offset=0, epochs=1, batch_size=128):
    x_train, y_train, x_test, y_test = load_fashion_mnist_data()
    batch_count = x_train.shape[0] // batch_size
    
    adam = get_optimizer()
    gan = get_gan_network(discriminator, random_dim, generator, adam)
    
    for e in range(1+epochs_offset, epochs+1+epochs_offset):
        print('-----------Epoch {}-----------'.format(e))
        for batch_num in range(batch_count):
            # Generated Images
            noise = np.random.normal(0, 1, size=[batch_size, random_dim])
            generated_images = generator.predict(noise)
            
            # Authentic/Real Images
            image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]
            
            # Total Input
            X = np.concatenate([image_batch, generated_images])
            # Output
            y_dis = np.zeros(2 * batch_size)
            y_dis[:batch_size] = 0.9
            
            # Train the Discriminator
            discriminator.trainable = True
            discriminator.train_on_batch(X, y_dis)
            
            # Train generator
            noise = np.random.normal(0, 1, size=[batch_size, random_dim])
            y_gen = np.ones(batch_size)
            discriminator.trainable = False
            gan.train_on_batch(noise, y_gen)
        
        if e==1 or e%5==0:
            save_epoch_weights_and_images(e+epochs_offset, generator, discriminator)