In [None]:
#gan.py
import math
import random
import numpy as np

from tensorflow.keras.models import Model
from keras import backend as K
from tensorflow.keras.datasets import mnist, cifar10
from tensorflow.keras.regularizers import l1, l2
from tensorflow.keras.utils import plot_model
from tensorflow.keras.layers import Input, Dense, Dropout, BatchNormalization, Activation, Flatten, Reshape, SpatialDropout2D
from tensorflow.keras.layers import Activation, Input, Dense, Conv2D, Conv2DTranspose, BatchNormalization, LeakyReLU, GaussianNoise
from tensorflow.keras.losses import mse, binary_crossentropy, kullback_leibler_divergence
from tensorflow.keras.optimizers import Adam, RMSprop

%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
def load_mnist():
	img_rows, img_cols = 28, 28

	# the data, shuffled and split between train and test sets
	(x_train, y_train), (x_test, y_test) = mnist.load_data()
	x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
	x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
	x_train = x_train.astype('float32')
	x_test = x_test.astype('float32')
	x_train /= 128.0
	x_test /= 128.0
	x_train -= 1.0
	x_test -= 1.0
	print('bounds:', np.min(x_train), np.max(x_train))
	print('x_train shape:', x_train.shape)
	print(x_train.shape[0], 'train samples')
	print(x_test.shape[0], 'test samples')
	return (x_train, y_train), (x_test, y_test)

In [None]:
seeds = 100
first_channels = 36
dropout_rate = 0.0
opt = Adam(lr=2e-4, beta_1=0.5, beta_2=0.999, epsilon=1e-08)
dopt = Adam(lr=2e-4, beta_1=0.5, beta_2=0.999, epsilon=1e-08)
def build_generative_model():
    g_input = Input(shape=[seeds])
    H = Dense(7*7*first_channels, activation='relu')(g_input)
    H = Dropout(dropout_rate)(H)
    H = Reshape([7, 7, first_channels])(H)
    H = Conv2DTranspose(256, (3, 3), strides=2, padding='same')(H)
    H = BatchNormalization(momentum=0.9)(H)
    H = Activation('relu')(H)
    H = Dropout(dropout_rate)(H)
    H = Conv2DTranspose(128, (3, 3), strides=2, padding='same')(H)
    H = BatchNormalization(momentum=0.9)(H)
    H = Activation('relu')(H)
    H = Dropout(dropout_rate)(H)
    H = Conv2D(64, (3, 3), strides=1, padding='same')(H)
    H = BatchNormalization(momentum=0.9)(H)
    H = Activation('relu')(H)
    H = Dropout(dropout_rate)(H)
    pre_logit = Conv2D(1, (1, 1), padding='same', kernel_regularizer=l2(0.0001), activity_regularizer=l1(0.00002))(H)
    pixel = Activation('tanh')(pre_logit)

    generator = Model(g_input, pixel)
    generator.compile(loss='mse', optimizer=opt)
    generator.summary()
    return generator

def build_dense_generative_model():
    g_input = Input(shape=[seeds])
    H = Dense(28)(g_input)
    H = BatchNormalization(momentum=0.9)(H)
    H = Activation('relu')(H)
    H = Dropout(dropout_rate)(H)
    H = Dense(784)(g_input)
    H = BatchNormalization(momentum=0.9)(H)
    H = Activation('relu')(H)
    H = Dropout(dropout_rate)(H)
    H = Reshape([28, 28, 1])(H)
    pre_logit = Conv2D(1, (1, 1), padding='same')(H) #, kernel_regularizer=l2(0.0001), activity_regularizer=l1(0.00002))(H)
    pixel = Activation('tanh')(pre_logit)

    generator = Model(g_input, pixel)
    generator.compile(loss='mse', optimizer=opt)
    generator.summary()
    return generator
generator = build_dense_generative_model()

In [None]:
def build_discriminative_model(in_shape):
    d_input = Input(in_shape)
    H = Conv2D(128, (3, 3), strides=(2,2), padding='same', kernel_regularizer=l2(0.001))(d_input)
    H = BatchNormalization(momentum=0.9)(H)
    H = LeakyReLU(0.2)(H)
    H = Conv2D(64, (3, 3), strides=(2,2), padding='same', kernel_regularizer=l2(0.001))(H)
    H = BatchNormalization(momentum=0.9)(H)
    H = LeakyReLU(0.2)(H)
    H = Flatten()(H)
    H = Dense(32, kernel_regularizer=l2(0.001))(H)
    H = BatchNormalization(momentum=0.9)(H)
    d_V = Dense(2, activation='softmax', kernel_regularizer=l2(0.001))(H)
    discriminator = Model(d_input, d_V)
    #discriminator.compile(loss='categorical_crossentropy', optimizer=dopt)
    discriminator.summary()
    return discriminator

def build_dense_discriminative_model(in_shape):
    d_input = Input(in_shape)
    H = Flatten()(d_input)
    H = Dense(64)(H)
    H = BatchNormalization(momentum=0.9)(H)
    H = LeakyReLU(0.2)(H)
    H = Dense(32)(H)
    H = BatchNormalization(momentum=0.9)(H)
    H = LeakyReLU(0.2)(H)
    d_V = Dense(2, activation='softmax', kernel_regularizer=l2(0.001))(H)
    discriminator = Model(d_input, d_V)
    #discriminator.compile(loss='categorical_crossentropy', optimizer=dopt)
    discriminator.summary()
    return discriminator
discriminator = build_dense_discriminative_model((28, 28, 1))

In [None]:
def build_stacked_gan(generator, discriminator):
    discriminator.trainable = False
    gan_input = Input(shape=[seeds])
    H = generator(gan_input)
    gan_V = discriminator(H)
    GAN = Model(gan_input, gan_V)
    GAN.compile(loss='categorical_crossentropy', optimizer=opt)
    GAN.summary()
    return GAN
gan = build_stacked_gan(generator, discriminator)
discriminator.trainable = True
discriminator.compile(loss='categorical_crossentropy', optimizer=dopt)

In [None]:
def pretrain_discriminator(x_train, generator, discriminator, iterations):
    trainidx = random.sample(range(0, x_train.shape[0]), iterations)
    xt = x_train[trainidx,:,:,:]
    noise_gen = np.random.uniform(0, 1, size=[xt.shape[0], seeds])
    generated_images = generator.predict(noise_gen)
    x = np.concatenate((xt, generated_images))
    n = xt.shape[0]
    y = np.zeros([2*n, 2])
    y[:n, 0] = 1
    y[n:, 1] = 1
    print('np sum 1s:', np.sum(y[:,0]), 'x shape:', x.shape)
    discriminator.fit(x,y, epochs=1, batch_size=32, validation_split=0.1, shuffle=True)
    y_hat = discriminator.predict(x)

    # Measure accuracy of pre-trained discriminator network
    y_hat_idx = np.argmax(y_hat,axis=1)
    y_idx = np.argmax(y,axis=1)
    diff = y_idx-y_hat_idx
    n_tot = y.shape[0]
    n_rig = (diff==0).sum()
    acc = n_rig*100.0/n_tot
    print("Accuracy: %0.02f pct (%d of %d) right"%(acc, n_rig, n_tot))

def pretrain_generator(x_train, gan, generator, discriminator, iterations):
    batch_size = 64  
    for i in range(iterations):
        noise_tr = np.random.uniform(0, 1, size=[batch_size, seeds])
        y2 = np.zeros([batch_size, 2])
        # Tell the model that random is correct
        y2[:, 0] = 1
        g_loss = gan.train_on_batch(noise_tr, y2)
        if i%50 == 0:
            print(f"Pretrain gen it: {i}, Generator loss {g_loss}")
    
    trainidx = random.sample(range(0, x_train.shape[0]), iterations)
    xt = x_train[trainidx,:,:,:]
    noise_gen = np.random.uniform(0, 1, size=[xt.shape[0], seeds])
    generated_images = generator.predict(noise_gen)
    x = np.concatenate((xt, generated_images))
    n = xt.shape[0]
    y = np.zeros([2*n, 2])
    y[:n, 0] = 1
    y[n:, 1] = 1
    print('np sum 1s:', np.sum(y[:,0]), 'x shape:', x.shape)
    y_hat = discriminator.predict(x)

    # Measure accuracy of pre-trained discriminator network
    y_hat_idx = np.argmax(y_hat,axis=1)
    y_idx = np.argmax(y,axis=1)
    diff = y_idx-y_hat_idx
    n_tot = y.shape[0]
    n_rig = (diff==0).sum()
    acc = n_rig*100.0/n_tot
    print("Accuracy: %0.02f pct (%d of %d) right"%(acc, n_rig, n_tot))    

In [None]:
def make_trainable(net, val):
	net.trainable = val
	for l in net.layers:
		l.trainable = val

def plot_gen_color(generator, n_ex=16, dim=(4,4), figsize=(24,24), random_seeds=None, save_path=None):
    if random_seeds is None:
        random_seeds = np.random.uniform(0,1,size=[n_ex, seeds])
    generated_images = generator.predict(random_seeds)

    fig = plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0],dim[1],i+1)
        img = generated_images[i, :, :, 0]
        img += 1.0
        img /= 2.0
        plt.imshow(img)
        plt.axis('off')
    plt.tight_layout()
    if save_path:
        if not os.path.exists(os.path.dirname(save_path)):
            os.makedirs(os.path.dirname(save_path))
        plt.savefig(save_path)
        plt.close(fig)
    else:
        plt.show()
        
        
def train_for_n(epochs, data, generator, discriminator, gan):
    # set up loss storage vector
    losses = {"d":[], "g":[]}
    (x_train, y_train), (x_test, y_test) = data
    print('bounds:', np.min(x_train), np.max(x_train))
    print('x_train shape:', x_train.shape)
    print(x_train.shape[0], 'train samples')
    print(x_test.shape[0], 'test samples')

    samples_seeds = np.random.uniform(0, 1, size=[16, seeds])
    batch_size = 256  
    
    for e in range(epochs):
        # Make generative images
        noise_gen = np.random.uniform(0, 1, size=[batch_size, seeds])
        generated_images = generator.predict(noise_gen)
        image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size),:,:,:]  
        # Train discriminator on generated images
        X = np.concatenate((image_batch, generated_images))
        y = np.zeros([2*batch_size, 2])
        y[:] = 0.0
        y[:batch_size, 0] = 0.9
        y[batch_size:, 1] = 0.9
        discriminator.trainable = True
        d_loss = discriminator.train_on_batch(X,y)
        losses["d"].append(d_loss)
        discriminator.trainable = False

        # train Generator-Discriminator stack on input noise to non-generated output class
        noise_tr = np.random.uniform(0, 1, size=[batch_size, seeds])
        y2 = np.zeros([batch_size, 2])
        # Tell the model that random is correct
        y2[:, 0] = 1
        g_loss = gan.train_on_batch(noise_tr, y2)
        losses["g"].append(g_loss)
        #print(f"Epoch {e} gen it: {i}, Generator loss {g_loss}")
        if e% 100 == 0:
            print(f"Epoch {e}, Generator loss {g_loss}, discriminator loss {d_loss}")
            plot_gen_color(generator, random_seeds=samples_seeds)

In [None]:
data = load_mnist()
#pretrain_discriminator(data[0][0], generator, discriminator, 1200)
#pretrain_generator(data[0][0], gan, generator, discriminator, 100)
train_for_n(150000, data, generator, discriminator, gan)