https://www.tensorflow.org/tutorials/keras/basic_classification

https://github.com/Zackory/Keras-MNIST-GAN/blob/master/mnist_gan.py


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from keras.layers.advanced_activations import LeakyReLU

In [None]:
fashion_mnist = keras.datasets.fashion_mnist

In [None]:
(tr_im,tr_lab),(tt_im,tt_lab) = fashion_mnist.load_data()

In [None]:
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']


In [None]:
tr_im.shape

In [None]:
tr_lab

In [None]:
plt.imshow(tr_im[0,:,:])

In [None]:
npar = 100 #number of standard normal deviates to feed to generator input

In [None]:
tr_im.max()

In [None]:
tr_im = tr_im / tr_im.max()
tr_im = tr_im.reshape(tr_im.shape[0],784)

In [None]:
tt_im = tt_im / tt_im.max()
tt_im = tt_im.reshape(tt_im.shape[0],784)

In [None]:
opt = tf.train.AdamOptimizer(learning_rate=0.0005,beta1=0.5)

In [None]:
generator = keras.Sequential()
generator.add(keras.layers.Dense(256, input_dim=npar, activation=tf.nn.leaky_relu, kernel_initializer=keras.initializers.RandomNormal(stddev=0.02)))
generator.add(keras.layers.Dense(512, activation=tf.nn.leaky_relu))
generator.add(keras.layers.Dense(1024, activation=tf.nn.leaky_relu))
generator.add(keras.layers.Dense(784, activation=tf.nn.sigmoid))
generator.compile(loss='binary_crossentropy', optimizer=opt)
generator.summary()

In [None]:
discriminator = keras.Sequential()
discriminator.add(keras.layers.Dense(1024, input_dim=784, kernel_initializer=keras.initializers.RandomNormal(stddev=0.02),activation=tf.nn.leaky_relu))
discriminator.add(keras.layers.Dropout(0.3))
discriminator.add(keras.layers.Dense(512, activation=tf.nn.leaky_relu))
discriminator.add(keras.layers.Dropout(0.3))
discriminator.add(keras.layers.Dense(256, activation=tf.nn.leaky_relu))
discriminator.add(keras.layers.Dropout(0.3))
discriminator.add(keras.layers.Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=opt)
discriminator.summary()

In [None]:
discriminator.trainable = False
gi = keras.layers.Input(shape=(npar,))
x = generator(gi)
go = discriminator(x)
gan = keras.Model(inputs=gi,outputs=go)
gan.compile(loss="binary_crossentropy",optimizer=opt)

In [None]:
plot_noise = np.random.randn(10,npar) 
def plot_gen():
    fig, axes = plt.subplots(nrows=1,ncols=10,figsize = (10,10))
    vecs = generator.predict(plot_noise)#.reshape(10,28,28)
    isreal = discriminator.predict(vecs)
    print(isreal.shape)
    for ax,vec,real in zip(axes.flatten(),vecs.reshape(10,28,28),isreal.flatten()):
        ax.imshow(vec)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_title("{0:1.3G}".format(real))
plot_gen()

In [None]:
batch_size = 64
batch_count = int(tr_im.shape[0] / batch_size)
print(batch_count)
plot_gen()
plt.show()
for e in range(1,200):
    for _ in range(batch_count):
        noise = np.random.randn(batch_size,npar)
        tr_im_batch = tr_im[np.random.randint(0,tr_im.shape[0],size=batch_size)]
        gen_im = generator.predict(noise)
        #print(gen_im.shape,tr_im_batch.shape)
        tr_im_batch = np.concatenate([tr_im_batch,gen_im])
        tr_im_lab = np.zeros(tr_im_batch.shape[0])
        tr_im_lab[:batch_size] = 1.0
        discriminator.trainable = True
        dloss = discriminator.train_on_batch(tr_im_batch,tr_im_lab)
        noise = np.random.randn(batch_size,npar)
        tr_img_lab = np.ones(batch_size)
        discriminator.trainable = False
        gloss = gan.train_on_batch(noise,tr_img_lab)
    
    if e % 5 == 0:
        print(e,dloss,gloss)
        plot_gen()
        plt.show()
        
    