In [1]:
# https://github.com/eriklindernoren/Keras-GAN/blob/master/wgan_gp/wgan_gp.py
# https://github.com/keras-team/keras-contrib/blob/master/examples/improved_wgan.py

import glob
import os
import time
import sys
from functools import partial

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 
os.environ['CUDA_VISIBLE_DEVICES']='1'

# from __future__ import print_function, division

from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Add
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, LeakyReLU
from keras.layers import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import RMSprop
import keras.backend as K

import tensorflow as tf

import matplotlib.pyplot as plt

from tqdm import tqdm

import numpy as np

import h5py
from sklearn.utils import shuffle

# from tensorflow.compat.v1.keras.backend import set_session
# config = tf.compat.v1.ConfigProto()
# config.gpu_options.allow_growth = True  # dynamically grow the memory used on the GPU
# config.log_device_placement = True  # to log device placement (on which device the operation ran)
# sess = tf.compat.v1.Session(config=config)
# graph = tf.compat.v1.get_default_graph()
# set_session(sess)

# import keras.backend.tensorflow_backend as KTF

# KTF.set_session(sess)

Using TensorFlow backend.


In [2]:
# def _compute_gradients(tensor, var_list):
#     grads = tf.gradients(tensor, var_list)
#     return [grad if grad is not None else tf.zeros_like(var)
#           for var, grad in zip(var_list, grads)]

class RandomWeightedAverage(Add):
    """Provides a (random) weighted average between real and generated image samples"""
    def _merge_function(self, inputs):
        input1, input2 = inputs
        alpha = K.random_uniform((K.shape(input1)[0], 1, 1, 1))
        return (alpha * input1) + ((1 - alpha) * input2)

class WGANGP():
    def __init__(self, width, height, channels):
        self.img_rows = width
        self.img_cols = height
        self.channels = channels
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        # Following parameter and optimizer set as recommended in paper
        self.n_critic = 5
        optimizer = RMSprop(lr=0.00005)

        # Build the generator and critic
        self.generator = self.build_generator()
        self.critic = self.build_critic()

        #-------------------------------
        # Construct Computational Graph
        #       for the Critic
        #-------------------------------

        # Freeze generator's layers while training critic
        self.generator.trainable = False

        # Image input (real sample)
        real_img = Input(shape=self.img_shape)

        # Noise input
        z_disc = Input(shape=(self.latent_dim,))
        # Generate image based of noise (fake sample)
        fake_img = self.generator(z_disc)

        # Discriminator determines validity of the real and fake images
        fake = self.critic(fake_img)
        valid = self.critic(real_img)

        # Construct weighted average between real and fake images
        interpolated_img = RandomWeightedAverage()([real_img, fake_img])
        # Determine validity of weighted sample
        validity_interpolated = self.critic(interpolated_img)

        # Use Python partial to provide loss function with additional
        # 'averaged_samples' argument
        partial_gp_loss = partial(self.gradient_penalty_loss,
                          averaged_samples=interpolated_img)
        partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names

        self.critic_model = Model(inputs=[real_img, z_disc],
                            outputs=[valid, fake, validity_interpolated])
        self.critic_model.compile(loss=[self.wasserstein_loss,
                                              self.wasserstein_loss,
                                              partial_gp_loss],
                                        optimizer=optimizer,
                                        loss_weights=[1, 1, 10])
        #-------------------------------
        # Construct Computational Graph
        #         for Generator
        #-------------------------------

        # For the generator we freeze the critic's layers
        self.critic.trainable = False
        self.generator.trainable = True

        # Sampled noise for input to generator
        z_gen = Input(shape=(self.latent_dim,))
        # Generate images based of noise
        img = self.generator(z_gen)
        # Discriminator determines validity
        valid = self.critic(img)
        # Defines generator model
        self.generator_model = Model(z_gen, valid)
        self.generator_model.compile(loss=self.wasserstein_loss, optimizer=optimizer)


    def gradient_penalty_loss(self, y_true, y_pred, averaged_samples):
        """
        Computes gradient penalty based on prediction and weighted real / fake samples
        """
#         with tf.GradientTape() as tape:
        gradients = K.gradients(y_pred, averaged_samples)[0]
        # compute the euclidean norm by squaring ...
        gradients_sqr = K.square(gradients)
        #   ... summing over the rows ...
        gradients_sqr_sum = K.sum(gradients_sqr,
                                  axis=np.arange(1, len(gradients_sqr.shape)))
        #   ... and sqrt
        gradient_l2_norm = K.sqrt(gradients_sqr_sum)
        # compute lambda * (1 - ||grad||)^2 still for each single sample
        gradient_penalty = K.square(1 - gradient_l2_norm)
        # return the mean as loss over all the batch samples
        return K.mean(gradient_penalty)


    def wasserstein_loss(self, y_true, y_pred):
        return K.mean(y_true * y_pred)
    
    
    def build_critic(self):
        kernel_size = 4
        model = Sequential()

        model.add(Conv2D(32, kernel_size=kernel_size, strides=2, 
                         input_shape=(self.img_rows, self.img_cols, self.channels), padding="same"))
        model.add(LeakyReLU(alpha=0.1))
        model.add(Dropout(0.15))

        model.add(Conv2D(64, kernel_size=kernel_size, strides=2, padding="same"))
        model.add(LeakyReLU(alpha=0.1))
        model.add(Dropout(0.15))

        model.add(Conv2D(128, kernel_size=kernel_size, strides=2, padding="same"))
        model.add(LeakyReLU(alpha=0.1))
        model.add(Dropout(0.15))

        model.add(Flatten())
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.1))
        model.add(Dropout(0.15))

        model.add(Dense(1))    
    
        return model
    
    def build_generator(self):
        model = Sequential()

        model.add(Dense(32*int(self.img_rows/4)*int(self.img_cols/4), input_dim=self.latent_dim))
    #     model.add(BatchNormalization(momentum=0.9, epsilon=0.00002))
        model.add(Activation('relu'))
        model.add(Reshape((int(self.img_rows/4),int(self.img_cols/4), 32)))

        model.add(UpSampling2D(interpolation='nearest'))
        model.add(Conv2D(64, kernel_size=3, strides=1, padding="same"))
    #     model.add(BatchNormalization(momentum=0.9, epsilon=0.00002))
        model.add(Activation('relu'))

        model.add(UpSampling2D(interpolation='nearest'))
        model.add(Conv2D(48, kernel_size=4, strides=2, padding="same"))
    #     model.add(BatchNormalization(momentum=0.9, epsilon=0.00002))
        model.add(Activation('relu'))

        model.add(UpSampling2D(interpolation='nearest'))
        model.add(Conv2D(32, kernel_size=3, strides=1, padding="same"))
    #     model.add(BatchNormalization(momentum=0.9, epsilon=0.00002))
        model.add(Activation('relu'))

        model.add(UpSampling2D(interpolation='nearest'))
        model.add(Conv2D(16, kernel_size=4, strides=2, padding="same"))
    #     model.add(BatchNormalization(momentum=0.9, epsilon=0.00002))
        model.add(Activation('relu'))

        model.add(Conv2D(8, kernel_size=3, strides=1, padding="same"))
    #     model.add(BatchNormalization(momentum=0.9, epsilon=0.00002))
        model.add(Activation('relu'))

        model.add(UpSampling2D(interpolation='nearest'))
        model.add(Conv2D(16, kernel_size=4, strides=2, padding="same"))
    #     model.add(BatchNormalization(momentum=0.9, epsilon=0.00002))
        model.add(Activation('relu'))

        model.add(Conv2D(self.channels, kernel_size=3, strides=1, padding="same"))
        model.add(Activation("tanh"))

        print('Actor')
        model.summary()    

        return model        
        
    def build_generator1(self):

        model = Sequential()

        model.add(Dense(128 * 8 * 8, activation="relu", input_dim=self.latent_dim))
        model.add(Reshape((8, 8, 128)))
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=4, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))

        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=4, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))

        model.add(UpSampling2D())
        model.add(Conv2D(32, kernel_size=4, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))

#         model.add(UpSampling2D())
#         model.add(Conv2D(16, kernel_size=4, padding="same"))
#         model.add(BatchNormalization(momentum=0.8))
#         model.add(Activation("relu"))        
        
        model.add(Conv2D(self.channels, kernel_size=4, padding="same"))
        model.add(Activation("tanh"))

#         model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_critic1(self):

        model = Sequential()

        model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
        model.add(ZeroPadding2D(padding=((0,1),(0,1))))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Flatten())
        model.add(Dense(1))

#         model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, X_train, epochs, batch_size, sample_interval=50):

        # Adversarial ground truths
        valid = -np.ones((batch_size, 1))
        fake =  np.ones((batch_size, 1))
        dummy = np.zeros((batch_size, 1)) # Dummy gt for gradient penalty
        for epoch in range(epochs):
            for idx in tqdm(np.array_split(shuffle(range(X_train.shape[0])), X_train.shape[0]/batch_size), desc="epoch "+str(epoch)):
#             for _ in range(self.n_critic):

                # ---------------------
                #  Train Discriminator
                # ---------------------

                # Select a random batch of images
#                 idx = np.random.randint(0, X_train.shape[0], batch_size)
                imgs = X_train[idx]
                # Sample generator input
                noise = np.random.normal(0, 1, (idx.shape[0], self.latent_dim))
                # Train the critic
                y = np.ones((batch_size, 1))
                y_dummy = np.zeros((idx.shape[0], 1))
                d_loss = self.critic_model.train_on_batch([imgs, noise], [-y, y, y_dummy])

            # ---------------------
            #  Train Generator
            # ---------------------

                g_loss = self.generator_model.train_on_batch(noise, valid)

            # Plot the progress
#             print ("%d [D loss: %f] [G loss: %f]" % (epoch, d_loss[0], g_loss))

            # If at save interval => save generated image samples
#             if epoch % sample_interval == 0:
            self.sample_images('wgan_gp', epoch)
            self.save_models('wgan_gp', epoch)
        self.sample_images('wgan_gp', epochs)
        self.save_models('wgan_gp', epochs)

    def sample_images(self, folder, epoch):
        r, c = 4, 4
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        generatedImage = 0.5 * self.generator.predict(noise) + 0.5
        
        fig = plt.figure(figsize=(10,10))
        
        axs = [fig.add_subplot(r,c,i+1) for i in range(r*c)]
        cnt = 0
        for ax in axs:
            ax.imshow(generatedImage[cnt],interpolation='nearest')
            ax.axis('off')
            ax.set_aspect('equal')
            cnt+=1
        fig.subplots_adjust(wspace=.008, hspace=.03)
    
        path = 'results/'+folder+'/samples'
        if not os.path.exists('results'):
            os.mkdir('results')
        if not os.path.exists('results/'+folder):
            os.mkdir('results/'+folder)
        if not os.path.exists(path):
            os.mkdir(path)
        fig.savefig(path+'/epoch_%d.png' % epoch)
        plt.close()

    def save_models(self, name, epoch):
        path = 'results/'+name+'/models'
        if not os.path.exists(path):
            os.mkdir(path)
        self.generator.save(path+'/G_%d.h5' % epoch)
        self.critic.save(path+'/D_%d.h5' % epoch)

In [3]:
# x_train = (h5py.File('camelyonpatch_level_2_split_train_x.h5', 'r')['x'][:, 16:80,16:80] - 127.5) / 127.5
# x_test = (h5py.File('camelyonpatch_level_2_split_test_x.h5', 'r')['x'][:, 16:80,16:80] - 127.5) / 127.5
# x_valid = (h5py.File('camelyonpatch_level_2_split_valid_x.h5', 'r')['x'][:, 16:80,16:80] - 127.5) / 127.5
# X = np.concatenate([x_train, x_test, x_valid])

X = h5py.File('X.hdf5', 'r')['X'][:]
X.shape

(327680, 64, 64, 3)

In [None]:
wgan = WGANGP(64, 64, 3)
wgan.train(X, epochs=10000, batch_size=512, sample_interval=10)

In [21]:
y_train = h5py.File('camelyonpatch_level_2_split_train_y.h5', 'r')['y'][:].reshape(-1,1)
y_test = h5py.File('camelyonpatch_level_2_split_test_y.h5', 'r')['y'][:].reshape(-1,1)
y_valid = h5py.File('camelyonpatch_level_2_split_valid_y.h5', 'r')['y'][:].reshape(-1,1)
          
Y = np.concatenate([y_train, y_test, y_valid])