- Author: Umair Khan
- Date: 15/07/20

# 1. Imports

In [28]:
from numpy import expand_dims
from numpy import mean
from numpy import ones
from numpy.random import randn
from numpy.random import randint
from keras.datasets.mnist import load_data
from keras import backend
from keras.optimizers import RMSprop, Adam
from keras.models import Sequential
from keras.layers import Activation, Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import Dropout
from keras.layers import LeakyReLU
from keras.layers import BatchNormalization
from keras.initializers import RandomNormal
from keras.constraints import Constraint
from matplotlib import pyplot
from os import listdir
from os.path import isfile
from os.path import join
from PIL import Image
import numpy as np


# 2. WGAN supporting functions

In [29]:
# clip model weights to a given hypercube
class ClipConstraint(Constraint):
    # set clip value when initialized
    def __init__(self, clip_value):
        self.clip_value = clip_value

    # clip model weights to hypercube
    def __call__(self, weights):
        return backend.clip(weights, -self.clip_value, self.clip_value)

    # get the config
    def get_config(self):
        return {'clip_value': self.clip_value}

# calculate wasserstein loss
def wasserstein_loss(y_true, y_pred):
    return backend.mean(y_true * y_pred)

# 3. Generator, Critic & GAN

In [30]:
# define the standalone critic model
def define_critic(in_shape=(96, 96, 3)):
    # weight initialization
    init = RandomNormal(stddev=0.02)
    # weight constraint
    const = ClipConstraint(0.01)
    # define model
    model = Sequential()
    # downsample to 48x48
    model.add(Conv2D(64, (3,3), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const, input_shape=in_shape))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.5))
    # downsample to 24x24
    model.add(Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.5))
    # downsample to 12x12
    model.add(Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.5))
    # downsample to 6x6
    model.add(Conv2D(512, (3,3), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.5))
    # scoring, linear activation
    model.add(Flatten())
    model.add(Dense(1))
    # compile model
    #opt = RMSprop(lr=0.00005)
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss=wasserstein_loss, optimizer=opt)
    return model

# define the standalone generator model
def define_generator(latent_dim):
    # weight initialization
    init = RandomNormal(stddev=0.02)
    # define model
    model = Sequential()
    # foundation for 7x7 image
    n_nodes = 1024 * 6 * 6
    model.add(Dense(n_nodes, kernel_initializer=init, input_dim=latent_dim))
    model.add(Activation('relu'))
    model.add(Reshape((6, 6, 1024)))
    # upsample to 6x6
    model.add(Conv2DTranspose(512, (5,5), strides=(2,2), padding='same', kernel_initializer=init))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    # upsample to 12x12
    model.add(Conv2DTranspose(512, (5,5), strides=(2,2), padding='same', kernel_initializer=init))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Activation('relu'))
    # upsample to 24x24
    model.add(Conv2DTranspose(512, (5,5), strides=(2,2), padding='same', kernel_initializer=init))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    # upsample to 48x48
    #model.add(Conv2DTranspose(512, (5,5), strides=(2,2), padding='same', kernel_initializer=init))
    model.add(Conv2DTranspose(3, (5,5), strides=(2,2), padding='same', kernel_initializer=init))
    #model.add(BatchNormalization())
    model.add(Activation('tanh'))
    # output 96x96x1
    #model.add(Conv2D(3, (6,6), activation='tanh', padding='same', kernel_initializer=init))
    return model

# define the combined generator and critic model, for updating the generator
def define_gan(generator, critic):
    # make weights in the critic not trainable
    critic.trainable = False
    # connect them
    model = Sequential()
    # add generator
    model.add(generator)
    # add the critic
    model.add(critic)
    # compile model
    #opt = RMSprop(lr=0.00005)
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss=wasserstein_loss, optimizer=opt)
    return model

# 4. Data Loading

In [31]:
def load_sample_data(sample_type='both'):
    data_path = "E:\\project\\gan_pcam\\data\\train_sample\\"
    images = []
    labels = []
    sample_paths = [join(data_path, 'positive'), join(data_path, 'negative')] if \
        sample_type == 'both' else [join(data_path, sample_type)]
    
    for sample_path in sample_paths:
        image_paths = [f for f in listdir(sample_path) if isfile(join(sample_path, f))]
        for img_path in image_paths:
            img = Image.open(join(sample_path, img_path))
            images.append(np.array(img))
            labels.append(1 if 'positive' in sample_path else 0)
            
    return np.array(images), np.array(labels)

# load images
def load_real_samples():
    # load dataset
    #(trainX, trainy), (_, _) = load_data()
    trainX, trainy = load_sample_data(sample_type='positive')
    # convert to float
    X = trainX.astype('float32')
    # scale from [0,255] to [-1,1]
    X = (X - 127.5) / 127.5
    return X

# select real samples
def generate_real_samples(dataset, n_samples):
    # choose random instances
    ix = randint(0, dataset.shape[0], n_samples)
    # select images
    X = dataset[ix]
    # generate class labels, -1 for 'real'
    y = -ones((n_samples, 1))
    return X, y

# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
    # generate points in the latent space
    x_input = randn(latent_dim * n_samples)
    # reshape into a batch of inputs for the network
    x_input = x_input.reshape(n_samples, latent_dim)
    return x_input

# use the generator to generate n fake examples, with class labels
def generate_fake_samples(generator, latent_dim, n_samples):
    # generate points in latent space
    x_input = generate_latent_points(latent_dim, n_samples)
    # predict outputs
    X = generator.predict(x_input)
    # create class labels with 1.0 for 'fake'
    y = ones((n_samples, 1))
    return X, y

# 5. Checkpoint and Summary Plotting

In [32]:
# generate samples and save as a plot and save the model
def summarize_performance(epoch, g_model, latent_dim, n_samples=49):
    # prepare fake examples
    X, _ = generate_fake_samples(g_model, latent_dim, n_samples)
    # scale from [-1,1] to [0,1]
    X = ((X * 127.5) + 127.5).astype('uint8')
    # plot images
    pyplot.figure(figsize=(20, 20))
    for i in range(7 * 7):
        # define subplot
        pyplot.subplot(7, 7, 1 + i)
        # turn off axis
        pyplot.axis('off')
        # plot raw pixel data
        pyplot.imshow(X[i, :, :, :])
    # save plot to file
    filename1 = f'.\\WGAN with sample PCam data\\output\\generated_output_{epoch}.png'
    pyplot.savefig(filename1)
    pyplot.close()
    # save the generator model
    filename2 = f'.\\WGAN with sample PCam data\\model\\g_model_{epoch}.h5'
    g_model.save(filename2)
    
# create a line plot of loss for the gan and save to file
def plot_history(d1_hist, d2_hist, g_hist):
    # plot history
    pyplot.figure(figsize=(12, 10))
    pyplot.plot(d1_hist, label='crit_real')
    pyplot.plot(d2_hist, label='crit_fake')
    pyplot.plot(g_hist, label='gen')
    pyplot.legend()
    pyplot.savefig(f'.\\WGAN with sample PCam data\\plot_line_plot_loss.png')
    pyplot.close()


# 6. Execute Training

In [33]:
# train the generator and critic
def train(g_model, c_model, gan_model, dataset, latent_dim, n_epochs=300, n_batch=64, n_critic=5):
    # calculate the number of batches per training epoch
    bat_per_epo = int(dataset.shape[0] / n_batch)
    # calculate the number of training iterations
    n_steps = bat_per_epo * n_epochs
    # calculate the size of half a batch of samples
    half_batch = int(n_batch / 2)
    # lists for keeping track of loss
    c1_hist, c2_hist, g_hist = [], [], []
    c1_epoch, c2_epoch, g_epoch = [], [], []
    # manually enumerate epochs
    for i in range(n_steps):
        # update the critic more than the generator
        c1_tmp, c2_tmp = [], []
        for _ in range(n_critic):
            # get randomly selected 'real' samples
            X_real, y_real = generate_real_samples(dataset, half_batch)
            # update critic model weights
            c_loss1 = c_model.train_on_batch(X_real, y_real)
            c1_tmp.append(c_loss1)
            # generate 'fake' examples
            X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
            # update critic model weights
            c_loss2 = c_model.train_on_batch(X_fake, y_fake)
            c2_tmp.append(c_loss2)
        # store critic loss
        c1_hist.append(mean(c1_tmp))
        c2_hist.append(mean(c2_tmp))
        # prepare points in latent space as input for the generator
        X_gan = generate_latent_points(latent_dim, n_batch)
        # create inverted labels for the fake samples
        y_gan = -ones((n_batch, 1))
        # update the generator via the critic's error
        g_loss = gan_model.train_on_batch(X_gan, y_gan)
        g_hist.append(g_loss)
        # summarize loss on this batch
        print(f'>Step:{i+1} C1[{c1_hist[-1]:.3f}], C2[{c2_hist[-1]:.3f}] GL[{g_loss:.3f}]\r', end="")
        if (i+1) % bat_per_epo == 0:
            print(f'>Epoch:{(i+1)//bat_per_epo} C1[{mean(c1_hist):.3f}], C2[{mean(c2_hist):.3f}] GL[{mean(g_hist):.3f}]')
            c1_epoch.append(mean(c1_hist))
            c2_epoch.append(mean(c2_hist))
            g_epoch.append(mean(g_hist))
            c1_hist, c2_hist, g_hist = [], [], []
        # evaluate the model performance every 'epoch'
        if (i+1) % (bat_per_epo * 1) == 0:
            summarize_performance((i+1)//bat_per_epo, g_model, latent_dim)
    # line plots of loss
    plot_history(c1_epoch, c2_epoch, g_epoch)

# size of the latent space
latent_dim = 100
# create the critic
critic = define_critic()
# create the generator
generator = define_generator(latent_dim)
# create the gan
gan_model = define_gan(generator, critic)
# load image data
dataset = load_real_samples()
print(dataset.shape)
# train model
train(generator, critic, gan_model, dataset, latent_dim)

(5000, 96, 96, 3)
>Epoch:1 C1[-256.942], C2[-249.015] GL[254.666]
>Epoch:2 C1[-904.879], C2[-854.475] GL[819.779]]15]
>Epoch:3 C1[-1592.109], C2[-1619.909] GL[-969.912]1]
>Epoch:4 C1[-2638.971], C2[-2656.806] GL[-1319.683]]
>Epoch:5 C1[-3791.861], C2[-3792.909] GL[-2936.969]]
>Epoch:6 C1[-5267.056], C2[-5198.860] GL[-5241.060]]
>Epoch:7 C1[-6405.817], C2[-6383.061] GL[-1755.867]]
>Epoch:8 C1[-8354.547], C2[-8055.554] GL[7841.894]]
>Epoch:9 C1[-10133.038], C2[-9334.506] GL[9337.586]83]
>Epoch:10 C1[-11999.361], C2[-10973.760] GL[10985.469]
>Epoch:11 C1[-13972.513], C2[-12860.514] GL[12877.207]
>Epoch:12 C1[-16047.083], C2[-14850.232] GL[14868.760]
>Epoch:13 C1[-18215.445], C2[-16945.980] GL[16964.639]]
>Epoch:14 C1[-20421.018], C2[-19119.559] GL[19146.549]]
>Epoch:15 C1[-22866.027], C2[-21437.871] GL[21456.686]]
>Epoch:16 C1[-25360.209], C2[-23853.127] GL[23885.736]]
>Epoch:17 C1[-27929.955], C2[-26375.875] GL[26402.447]]
>Epoch:18 C1[-30613.369], C2[-29015.025] GL[29037.170]]
>Epoch:19

>Epoch:270 C1[-3263274.250], C2[-3285083.750] GL[3283076.000]]
>Epoch:271 C1[-3292423.750], C2[-3311980.250] GL[3312646.250]]
>Epoch:272 C1[-3317852.750], C2[-3338231.500] GL[3338023.250]]
>Epoch:273 C1[-3340739.000], C2[-3361717.500] GL[3362646.500]]
>Epoch:274 C1[-3365827.500], C2[-3386159.500] GL[3385862.250]]
>Epoch:275 C1[-3389163.500], C2[-3409550.750] GL[3411141.000]]
>Epoch:276 C1[-3413424.000], C2[-3434363.250] GL[3434424.250]]
>Epoch:277 C1[-3438932.500], C2[-3456892.250] GL[3458516.500]]
>Epoch:278 C1[-3461696.000], C2[-3481393.250] GL[3483438.250]]
>Epoch:279 C1[-3486485.750], C2[-3505093.000] GL[3506781.250]]
>Epoch:280 C1[-3510147.750], C2[-3529271.000] GL[3531442.750]]
>Epoch:281 C1[-3533050.250], C2[-3553948.000] GL[3554892.250]]
>Epoch:282 C1[-3558153.500], C2[-3577483.000] GL[3578402.000]]
>Epoch:283 C1[-3581313.250], C2[-3601453.250] GL[3603485.500]]
>Epoch:284 C1[-3606013.250], C2[-3625874.000] GL[3627284.500]]
>Epoch:285 C1[-3629757.500], C2[-3650853.000] GL[365217

1232345
