In [2]:
import PIL
import tensorflow as tf
import numpy as np
from glob import glob
import time

from matplotlib import pyplot as plt

IMG_H = 128
IMG_W = 128
IMG_C = 3  ## Change this to 1 for grayscale.


# Regularization Rate for each loss function
ADV_REG_RATE_LF = 1
REC_REG_RATE_LF = 50
SSIM_REG_RATE_LF = 50
FEAT_REG_RATE_LF = 1

# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

# Loss function for evaluating adversarial loss
adv_loss_fn = tf.losses.MeanSquaredError()

w_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

In [3]:
def load_image(image_path):
  img = tf.io.read_file(image_path)
  img = tf.io.decode_bmp(img)
  img = tf.image.resize_with_crop_or_pad(img, IMG_H, IMG_W)
  img = tf.cast(img, tf.float32)
  img = (img - 127.5) / 127.5
  return img

In [4]:
def tf_dataset(images_path, batch_size):
  dataset = tf.data.Dataset.from_tensor_slices(images_path)
  dataset = dataset.shuffle(buffer_size=10240)
  dataset = dataset.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
  return dataset

In [5]:
def conv_block(input, num_filters):
    x = tf.keras.layers.Conv2D(num_filters, kernel_size=(1,1), padding="same")(input)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation("relu")(x)

    x = tf.keras.layers.Conv2D(num_filters, kernel_size=(3,3), padding="same")(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation("relu")(x)

    return x

In [6]:
def decoder_block(input, skip_features, num_filters):
    x = tf.keras.layers.Conv2DTranspose(num_filters, (1, 1), strides=2, padding="same")(input)
    x = tf.keras.layers.Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

In [9]:
class ResUnetGAN(tf.keras.models.Model):
    def __init__(self, input_shape, batch_size):
        super(ResUnetGAN, self).__init__()
        self.discriminator = self.build_discriminator(input_shape)
        self.generator = self.build_generator_resnet50_unet(input_shape)
        self.batch_size = batch_size

        # print("test")


        # self.discriminator.summary()
        # self.generator.summary()

    # create generator model based on resnet50 and unet network
    def build_generator_resnet50_unet(self, input_shape):
        # print(inputs)
        # print("pretained start")
        """ Pre-trained ResNet50 Model """
        resnet50 = tf.keras.applications.ResNet50(include_top=False, weights="imagenet", input_tensor=input_shape)
        # print("testing")
        """ Encoder using resnet50"""
        # for layer in resnet50.layers:
        #   print(layer.name)
        s1 = resnet50.get_layer("input_1").output           ## (128 x 128)
        # print(s1)
        s2 = resnet50.get_layer("conv1_relu").output        ## (64 x 64)
        s3 = resnet50.get_layer("conv2_block3_out").output  ## (32 x 32)
        s4 = resnet50.get_layer("conv3_block4_out").output  ## (16 x 16)

        """ Bridge """
        b1 = resnet50.get_layer("conv4_block6_out").output  ## (32 x 32)

        # print("test")
        # print(b1.get_weights())
        """ Decoder unet"""
        d1 = decoder_block(b1, s4, 128)                     ## (16 x 16)
        d2 = decoder_block(d1, s3, 64)                     ## (32 x 32)
        d3 = decoder_block(d2, s2, 32)                     ## (64 x 64)
        d4 = decoder_block(d3, s1, 16)                      ## (128 x 128)

        """ Output """
        final_model = tf.keras.layers.Conv2D(3, 1, padding="same", activation="sigmoid")(d4)

        model = tf.keras.models.Model(inputs, outputs=[final_model, b1])

        return model
        # return outputs

    # create discriminator model

    def build_discriminator(self ,input_shape):
      # Load the pre-trained model and freeze it.


        x = tf.keras.layers.SeparableConvolution2D(32,kernel_size= (1, 1), strides=(2, 2), padding='same')(input_shape)
        x = tf.keras.layers.LeakyReLU()(x)
        x = tf.keras.layers.Dropout(0.3)(x)

        x = tf.keras.layers.SeparableConvolution2D(64,kernel_size=(1, 1), strides=(2, 2), padding='same')(x)
        x = tf.keras.layers.LeakyReLU()(x)
        x = tf.keras.layers.Dropout(0.3)(x)

        x = tf.keras.layers.Flatten()(x)
        x = tf.keras.layers.Dense(1)(x)

        model = tf.keras.models.Model(inputs, x)
        return model
        # return x

    def compile(self, d_optimizer, g_optimizer):
        super(ResUnetGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer

  
# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
    @tf.function
    def train_step(self, images):
        batch_size = tf.shape(images)[0]
        # print(batch_size, IMG_W, IMG_H, IMG_C)
        # print("test")
        # print(self.generator)
        # print(self.discriminator)
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            reconstructed_images, low_feature = self.generator(images, training=True)
            real_output = self.discriminator(images, training=True)
            # print(generated_images.shape)
            fake_output = self.discriminator(reconstructed_images, training=True)

            # Loss 1: ADVERSARIAL loss
            loss_adv = tf.math.reduce_sum(tf.math.square(real_output - fake_output))
            # Loss 2: RECONSTRUCTION loss
            loss_rec = tf.math.reduce_sum(tf.math.abs(images - reconstructed_images))
            # Loss 3: SSMI loss
            loss_ssim = 1 - tf.math.reduce_mean(tf.image.ssim(images, reconstructed_images, 1.0))
            # Loss 4: FEATURE loss
            loss_feat = tf.math.reduce_sum(tf.math.square(real_output - fake_output))

            gen_loss = (loss_adv * ADV_REG_RATE_LF) + (loss_rec * REC_REG_RATE_LF) + (loss_ssim * SSIM_REG_RATE_LF) + (loss_feat * FEAT_REG_RATE_LF)
            disc_loss = (loss_adv * ADV_REG_RATE_LF) + (loss_feat * FEAT_REG_RATE_LF)

        
        gradients_of_generator = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
        
        self.g_optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables))
        self.d_optimizer.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables))

        return {"gen_loss": gen_loss, "disc_loss": disc_loss}

    def saved_model(self, filepath, num_of_epoch):
        self.generator.save(filepath + "g_model" + str(num_of_epoch) + ".h5")
        self.discriminator.save(filepath + "d_model" + str(num_of_epoch) + ".h5")

    def loaded_model(self, filepath):
        self.generator.load_weights(filepath)
        self.discriminator.load_weights(filepath)

In [None]:
if __name__ == "__main__":
    # run the function here
    print("start")
    ## Hyperparameters
    batch_size = 24
    input_shape = (IMG_W, IMG_H, IMG_C)
    # print(input_shape)

    """ Input """
    inputs = tf.keras.layers.Input(input_shape, name="input_1")

    num_epochs = 600
    train_images_path = glob("/content/drive/MyDrive/mura_data/train_data/*")


    # d_model = build_discriminator(inputs)
    # g_model = build_generator_resnet50_unet(inputs)

    resunetgan = ResUnetGAN(inputs, batch_size)


    g_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5, beta_2=0.999)
    d_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5, beta_2=0.999)
    resunetgan.compile(d_optimizer, g_optimizer)

    # print(train_images_path)
    train_images_dataset = tf_dataset(train_images_path, batch_size)

    # resunetgan.fit(train_images_dataset)


    for epoch in range(num_epochs):
        print("Epoch: ", epoch)
        start = time.time()
        for image_batch in train_images_dataset:
        # print(image_batch.shape)
            resunetgan.fit(image_batch)
            resunetgan.saved_model("/content/drive/MyDrive/mura_data/saved_model/", num_epochs)

    # resunetgan.summary()
    # resunetgan.save_weights("saved_model/resunet_model")


start
Epoch:  0




Epoch:  1
Epoch:  2
Epoch:  3
Epoch:  4
Epoch:  5
Epoch:  6
Epoch:  7
Epoch:  8
Epoch:  9
Epoch:  10
Epoch:  11
Epoch:  12
Epoch:  13
Epoch:  14
Epoch:  15
Epoch:  16
Epoch:  17
Epoch:  18
Epoch:  19
Epoch:  20
Epoch:  21
Epoch:  22
Epoch:  23
Epoch:  24
Epoch:  25
Epoch:  26
Epoch:  27
Epoch:  28
Epoch:  29
Epoch:  30
Epoch:  31
Epoch:  32
Epoch:  33
Epoch:  34
Epoch:  35
Epoch:  36
Epoch:  37
Epoch:  38
Epoch:  39
Epoch:  40
Epoch:  41
Epoch:  42
Epoch:  43
Epoch:  44
Epoch:  45
Epoch:  46
Epoch:  47
Epoch:  48
Epoch:  49
Epoch:  50
Epoch:  51
Epoch:  52
Epoch:  53
Epoch:  54
Epoch:  55
Epoch:  56
Epoch:  57


In [None]:
def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4, 4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

# Display a single image using the epoch number
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))


In [None]:
# load an image
def load_image_test(filename, size=(128,128)):
	# load image with the preferred size
	pixels = tf.keras.preprocessing.image.load_img(filename, target_size=size)
	# convert to numpy array
	pixels = tf.keras.preprocessing.image.img_to_array(pixels)
	# scale from [0,255] to [-1,1]
	pixels = (pixels - 127.5) / 127.5
	# reshape to 1 sample
	pixels = np.expand_dims(pixels, 0)
	return pixels

In [None]:
# test_images_dataset = tf_dataset(test_images_path, batch_size)
# normal_images = glob('mura_data/mura_data/test_data/normal_*.bmp')
# defect_images = glob('mura_data/mura_data/test_data/defect_*.bmp')
# len_nor_data = len(normal_images)
# len_def_data = len(defect_images)
# print(len_nor_data)
# print(len_def_data)
# threshold = 0.6
# defect_preds = []
# for image in defect_images:
#   # print(image)
#   if "DS_Store" not in image:
#     src_image = load_image_test(image)
#
#     test = d_model.predict(src_image)
#     test = (test + 1) / 2.0
#     defect_preds = np.append(defect_preds,test)
#
#     # preds = (preds - preds.min())/(preds.max()-preds.min())
#     # print(test)
#
#
#
# normal_preds = []
# for image in normal_images:
#   # print(image)
#   if "DS_Store" not in image:
#     src_image = load_image_test(image)
#
#     test = d_model.predict(src_image)
#     test = (test + 1) / 2.0
#     normal_preds = np.append(normal_preds,test)
#
#     # preds = (preds - preds.min())/(preds.max()-preds.min())
#     # print(test)
#
#
# print(defect_preds)
# print(np.mean(defect_preds))
# true_def_pred = len(np.where(defect_preds > threshold)[0])
# print(true_def_pred)
#
#
# print(normal_preds)
# print(np.mean(normal_preds))
# true_nor_pred = len(np.where(normal_preds < threshold)[0])
# print(true_nor_pred)
#
# total_acc = (true_def_pred + true_nor_pred) / (len_nor_data + len_def_data) * 100
# print("total_accuracy: ", total_acc)



51
50
[0.64748341 0.64582115 0.64495695 0.63706183 0.64389598 0.64372158
 0.64029634 0.64423901 0.63870716 0.64056474 0.63585079 0.63704026
 0.6411283  0.64476216 0.63866848 0.63880849 0.63715732 0.63976002
 0.64145702 0.64556265 0.64365792 0.6378026  0.64362186 0.63863498
 0.64006096 0.64166403 0.64609551 0.63979173 0.64554846 0.63616109
 0.64451241 0.64017868 0.6469053  0.63884485 0.63761288 0.64289093
 0.64132708 0.64631164 0.64095742 0.63797343 0.64088923 0.63647813
 0.65117645 0.63967681 0.64133358 0.63783252 0.64289385 0.64198399
 0.63981831 0.63435316]
0.6412786686420441
50
[0.63939631 0.64166212 0.63711405 0.63674265 0.63656467 0.63669407
 0.63940752 0.63682628 0.63999265 0.63952291 0.63734055 0.64020377
 0.63860971 0.640733   0.64046967 0.63901877 0.63815737 0.63886672
 0.6363433  0.63823372 0.63949871 0.63427675 0.63712913 0.64116955
 0.63772613 0.63943613 0.63968712 0.6397562  0.63884401 0.64015615
 0.63704967 0.64029896 0.63756126 0.63938802 0.63810068 0.63947988
 0.6398171