In [None]:
# importing Neccessary Library and constant variable

import itertools
import tensorflow as tf
import numpy as np
from glob import glob
import time
import os

from sklearn.metrics import roc_curve, auc

from matplotlib import pyplot as plt

IMG_H = 128
IMG_W = 128
IMG_C = 3  ## Change this to 1 for grayscale.

In [None]:
def load_image(image_path):
    img = tf.io.read_file(image_path)
    img = tf.io.decode_bmp(img)
    img = tf.image.resize_with_crop_or_pad(img, IMG_H, IMG_W)
    img = tf.cast(img, tf.float32)
    img = (img - 127.5) / 127.5
    return img

In [None]:
def tf_dataset(images_path, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices(images_path)
    dataset = dataset.shuffle(buffer_size=10240)
    dataset = dataset.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    return dataset

In [None]:
# load image dataaset for testing with labels
def load_image_test(filename, class_names, size=(128,128)):
	# load image with the preferred size
	pixels = tf.keras.preprocessing.image_dataset_from_directory(
        filename, labels='inferred',
        label_mode='int',
        image_size=size,
        batch_size=1,
        class_names=class_names
    )

	return pixels

In [None]:
''' calculate the auc value for lables and scores'''
def roc(labels, scores, saveto=None):
    """Compute ROC curve and ROC area for each class"""
    roc_auc = dict()
    # True/False Positive Rates.
    fpr, tpr, _ = roc_curve(labels, scores)
    roc_auc = auc(fpr, tpr)
    return roc_auc

In [None]:
def plot_confusion_matrix(cm, classes,
                        normalize=False,
                        title='Confusion matrix',
                        cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
            horizontalalignment="center",
            color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

In [None]:
def bn_act(x, act=True):
    x = tf.keras.layers.BatchNormalization()(x)
    if act == True:
        x = tf.keras.layers.Activation("relu")(x)
    return x

def conv_block(x, filters, kernel_size=(3, 3), padding="same", strides=1):
    conv = bn_act(x)
    conv = tf.keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides)(conv)
    return conv

def stem(x, filters, kernel_size=(3, 3), padding="same", strides=1):
    conv = tf.keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides)(x)
    conv = conv_block(conv, filters, kernel_size=kernel_size, padding=padding, strides=strides)
    
    shortcut = tf.keras.layers.Conv2D(filters, kernel_size=(1, 1), padding=padding, strides=strides)(x)
    shortcut = bn_act(shortcut, act=False)
    
    output = tf.keras.layers.Add()([conv, shortcut])
    return output

def residual_block(x, filters, kernel_size=(3, 3), padding="same", strides=1):
    res = conv_block(x, filters, kernel_size=kernel_size, padding=padding, strides=strides)
    res = conv_block(res, filters, kernel_size=kernel_size, padding=padding, strides=1)
    
    shortcut = tf.keras.layers.Conv2D(filters, kernel_size=(1, 1), padding=padding, strides=strides)(x)
    shortcut = bn_act(shortcut, act=False)
    
    output = tf.keras.layers.Add()([shortcut, res])
    return output

def upsample_concat_block(x, xskip):
    u = tf.keras.layers.UpSampling2D((2, 2))(x)
    c = tf.keras.layers.Concatenate()([u, xskip])
    return c

In [None]:
class ResUnetGAN(tf.keras.models.Model):
    def __init__(self, input_shape, batch_size):
        super(ResUnetGAN, self).__init__()
        self.discriminator = self.build_discriminator(input_shape)
        self.generator = self.build_generator_resnet50_unet(input_shape)
        self.batch_size = batch_size
        
        # Regularization Rate for each loss function
        self.ADV_REG_RATE_LF = 1
        self.REC_REG_RATE_LF = 50
        self.SSIM_REG_RATE_LF = 50
        self.FEAT_REG_RATE_LF = 1
        
        
        self.d_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5, beta_2=0.999)
        self.g_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5, beta_2=0.999)
        
        # setting for checkpoint
        checkpoint_dir = './training_checkpoints'
        self.checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
        self.checkpoint = tf.train.Checkpoint(generator_optimizer=self.g_optimizer,
                                 discriminator_optimizer=self.d_optimizer,
                                 generator=self.generator,
                                 discriminator=self.discriminator)
        # self.discriminator.summary()
        # self.generator.summary()

    # create generator model based on resnet50 and unet network
    def build_generator_resnet50_unet(self, inputs):

        f = [16, 32, 64, 128, 256]
#         inputs = tf.keras.layers.Input(input_shape)

        ## Encoder
        e0 = inputs
        e1 = stem(e0, f[0])
        e2 = residual_block(e1, f[1], strides=2)
        e3 = residual_block(e2, f[2], strides=2)
        e4 = residual_block(e3, f[3], strides=2)
        e5 = residual_block(e4, f[4], strides=2)

        ## Bridge
        b0 = conv_block(e5, f[4], strides=1)
        b1 = conv_block(b0, f[4], strides=1)

        ## Decoder
        u1 = upsample_concat_block(b1, e4)
        d1 = residual_block(u1, f[4])

        u2 = upsample_concat_block(d1, e3)
        d2 = residual_block(u2, f[3])

        u3 = upsample_concat_block(d2, e2)
        d3 = residual_block(u3, f[2])

        u4 = upsample_concat_block(d3, e1)
        d4 = residual_block(u4, f[1])

        outputs = tf.keras.layers.Conv2D(3, (1, 1), padding="same", activation="sigmoid")(d4)
       
        
        model = tf.keras.models.Model(inputs, outputs)
        return model

    # create discriminator model

    def build_discriminator(self ,input_shape):

        x = tf.keras.layers.SeparableConvolution2D(32,kernel_size= (1, 1), strides=(2, 2), padding='same')(input_shape)
        x = tf.keras.layers.LeakyReLU()(x)
        x = tf.keras.layers.Dropout(0.3)(x)

        x = tf.keras.layers.SeparableConvolution2D(64,kernel_size=(1, 1), strides=(2, 2), padding='same')(x)
        x = tf.keras.layers.LeakyReLU()(x)
        x = tf.keras.layers.Dropout(0.3)(x)

        x = tf.keras.layers.Flatten()(x)
        x = tf.keras.layers.Dense(1)(x)

        model = tf.keras.models.Model(inputs, x)
        return model
        # return x

    def compile(self, d_optimizer, g_optimizer):
        super(ResUnetGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer

  
# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
    @tf.function
    def train_step(self, images):


        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            # tf.print("Images: ", images)
            reconstructed_images = self.generator(images, training=True)
            real_output = self.discriminator(images, training=True)
            # print(generated_images.shape)
            fake_output = self.discriminator(reconstructed_images, training=True)
            
            #     tf.print(tf.math.is_nan(real_output)," real_output is NaN: ", real_output)
            # # else:
            #     # tf.print(real_output)
            # #
            # if tf.math.is_nan(fake_output) is not None:
            #     tf.print(tf.math.is_nan(fake_output)," fake_output is NaN: ", fake_output)
            # # else:
            # #     tf.print(fake_output)
            # #
            # if tf.math.is_nan(reconstructed_images) is not None:
            #     tf.print(tf.math.is_nan(reconstructed_images)," reconstructed_images is NaN: ",reconstructed_images)
            # else:
            #     tf.print(reconstructed_images)

            # Loss 1: ADVERSARIAL loss
            d_loss_real = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=real_output, labels=tf.ones_like(real_output)))
            d_loss_fake = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=fake_output, labels=tf.zeros_like(fake_output)))
            
            loss_adv = d_loss_real + d_loss_fake
            
            # loss_adv = tf.math.reduce_mean(tf.math.log(real_output) + tf.math.log(1 - fake_output))
            
            # Loss 2: RECONSTRUCTION loss
            loss_rec = tf.math.reduce_mean(tf.keras.losses.mae(reconstructed_images, images))
#             loss_rec = tf.math.reduce_sum(tf.math.abs(images - reconstructed_images))
            
            # Loss 3: SSIM loss
            loss_ssim = tf.math.reduce_mean(1 - tf.image.ssim(images, reconstructed_images, max_val=1.0)[0])
            
            # Loss 4: FEATURE loss
            loss_feat = tf.math.reduce_mean(tf.math.square(real_output - fake_output))

            gen_loss = tf.math.reduce_mean((loss_adv * self.ADV_REG_RATE_LF) + (loss_rec * self.REC_REG_RATE_LF) + (loss_ssim * self.SSIM_REG_RATE_LF) + (loss_feat * self.FEAT_REG_RATE_LF))
            disc_loss = tf.math.reduce_mean((loss_adv * self.ADV_REG_RATE_LF) + (loss_feat * self.FEAT_REG_RATE_LF))

        
        gradients_of_generator = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)

        
        
        # if tf.math.is_nan(real_output) is not None:
        self.g_optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables))
        self.d_optimizer.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables))


        return {
            "gen_loss": gen_loss,
            "disc_loss": disc_loss,
            "loss_adv": loss_adv,
            "loss_rec": loss_rec,
            "loss_ssim": loss_ssim,
            "loss_feat": loss_feat
        }

    def saved_model(self, filepath, num_of_epoch):
        self.generator.save(filepath + "g_model" + str(num_of_epoch) + ".h5")
        self.discriminator.save(filepath + "d_model" + str(num_of_epoch) + ".h5")

    def loaded_model(self, g_filepath, d_filepath):
        self.generator.load_weights(g_filepath)
        self.discriminator.load_weights(d_filepath)

    def test_and_eval(self, filepath, g_filepath, d_filepath):
        threshold = 0.6
        class_names = ["defect", "normal"]
        test_dateset = load_image_test(filepath, class_names)
        # print(test_dateset)
        
        # range between 0-1
        anomaly_wieght = 0.9
        bs = 24
        self.loaded_model(g_filepath, d_filepath)
        
#         predictions = np.array([])
#         labels =  np.array([])
        scores_ano = []
        real_label = []
        for images, labels in test_dateset:
            
            
#             print(labels)
#             print(images)
            train_step_dict = self.train_step(images)
#             print(train_step_dict)
                
            
            loss_rec, loss_feat = train_step_dict['loss_rec'], train_step_dict['loss_feat']
#             print(loss_rec,loss_feat)
            
            score = anomaly_wieght * loss_rec + (1-anomaly_wieght) * loss_feat
#             print(score, loss_rec, loss_feat)
#             print(score.numpy())
            print(labels.numpy())
            print(labels.numpy()[0])
            scores_ano = np.append(scores_ano, score.numpy())
            real_label = np.append(real_label, labels.numpy()[0])
            print(scores_ano, real_label)
        
        
        ''' Scale scores vector between [0, 1]'''
        scores_ano = (scores_ano - scores_ano.min())/(scores_ano.max()-scores_ano.min())
        
        auc_out = roc(real_label, scores_ano)
        print("auc: ", auc_out)
            # print(test)
            # change scale to 0 - 1
            # test = [(i + 1) / 2.0 for i in test]
            # test = 0 if test >= threshold else 1
#             predictions = np.concatenate([predictions, test])
#             labels = np.concatenate([labels, np.argmax(y.numpy(), axis=-1)])
        #
        cm = tf.math.confusion_matrix(labels=real_label, predictions=real_label).numpy()

        plot_confusion_matrix(cm, class_names)

In [None]:
if __name__ == "__main__":
    # run the function here
    print("start")
    ## Hyperparameters
    batch_size = 24
    input_shape = (IMG_W, IMG_H, IMG_C)
    # print(input_shape)

    """ Input """
    inputs = tf.keras.layers.Input(input_shape, name="input_1")

    num_epochs = 600
    train_images_path = glob("mura_data/mura_data/train_data/*.bmp")


    # d_model = build_discriminator(inputs)
    # g_model = build_generator_resnet50_unet(inputs)

    resunetgan = ResUnetGAN(inputs, batch_size)


    g_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5, beta_2=0.999)
    d_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5, beta_2=0.999)
    resunetgan.compile(d_optimizer, g_optimizer)

    # print(train_images_path)
    train_images_dataset = tf_dataset(train_images_path, batch_size)

    # resunetgan.fit(train_images_dataset)

    # resunetgan.train(train_images_dataset, num_epochs)
    saved_model_path = "mura_data/saved_model/"
    test_data_path = "mura_data/mura_data/test_data/"
    
    for epoch in range(num_epochs):
        print("Epoch: ", epoch)
        start = time.time()
        for image_batch in train_images_dataset:
        # print(image_batch.shape)
            # print("Images_batch: ", image_batch)
            resunetgan.fit(image_batch)
            resunetgan.saved_model(saved_model_path, num_epochs)


    resunetgan.test_and_eval(test_data_path,
                             saved_model_path + "g_model" + str(num_epochs) + ".h5",
                             saved_model_path + "d_model" + str(num_epochs) + ".h5")
