In [None]:
import os
import keras
from tqdm import tqdm
from glob import glob
import tensorflow as tf
from numpy import zeros
from numpy.random import randint

# Data
from tensorflow.image import resize
from keras.preprocessing.image import load_img, img_to_array

# Data Viz
import matplotlib.pyplot as plt

# Model
from keras.layers import add
from keras.layers import Input
from keras.layers import Layer
from keras.layers import Conv2D
from keras.layers import multiply
from keras.layers import concatenate
from keras.layers import Conv2DTranspose
from keras.layers import MaxPool2D
from keras.layers import UpSampling2D
from keras.layers import BatchNormalization
from keras.layers import Dropout
from keras.models import load_model
# Model Functions
from keras.models import Model
from tensorflow.keras.utils import plot_model
from keras.callbacks import Callback, ModelCheckpoint

In [None]:
def load_image(path):
    img = resize(img_to_array(load_img(path))/255., (256,256))
    return img

In [None]:
image_path = "drive/MyDrive/better_dataset/image"
total_images = len(os.listdir(image_path))
print(f"Total Number of Images : {total_images}")

Total Number of Images : 300


In [None]:
all_image_paths = sorted(glob(image_path + "/*.jpg"))

In [None]:
def load_data(paths):
    images = zeros(shape=(len(paths), 256,256,3))
    masks = zeros(shape=(len(paths), 256,256,3))
    for i, path in tqdm(enumerate(paths), desc="Loading"):
        image = load_image(path)
        images[i] = image

        mask_path = path.replace("image", "mask")
        mask = load_image(mask_path)
        masks[i] = mask
    return images, masks

In [None]:
train_paths = all_image_paths[:240]
X_train, y_train = load_data(train_paths)

Loading: 240it [01:22,  2.89it/s]


In [None]:
val_paths = all_image_paths[240:]
X_val, y_val = load_data(val_paths)

Loading: 60it [00:17,  3.40it/s]


In [None]:
def show_image(image, title=None):
    plt.imshow(image)
    plt.title(title)
    plt.axis('off')

In [None]:
class Encoder(Layer):

    def __init__(self, filters, rate, pooling=True, **kwargs):
        super(Encoder, self).__init__(**kwargs)

        self.filters = filters
        self.rate = rate
        self.pooling = pooling

        self.bn = BatchNormalization()
        self.c1 = Conv2D(filters, kernel_size=3, padding='same', activation='relu', kernel_initializer="he_normal")
        self.drop = Dropout(rate)
        self.c2 = Conv2D(filters, kernel_size=3, padding='same', activation='relu', kernel_initializer="he_normal")
        self.pool = MaxPool2D()

    def call(self, X):
        x = self.bn(X)
        x = self.c1(x)
        x = self.drop(x)
        x = self.c2(x)
        if self.pooling:
            y = self.pool(x)
            return y, x
        else:
            return x

    def get_config(self):
        base_config = super().get_config()
        return {
            **base_config,
            "filters":self.filters,
            "rate":self.rate,
            "pooling":self.pooling
        }

In [None]:
class Decoder(Layer):

    def __init__(self, filters, rate, **kwargs):
        super(Decoder, self).__init__(**kwargs)

        self.filters = filters
        self.rate = rate

        self.bn = BatchNormalization()
        self.cT = Conv2DTranspose(filters, kernel_size=3, strides=2, padding='same', activation='relu', kernel_initializer="he_normal")
        self.net = Encoder(filters, rate, pooling=False)

    def call(self, X):
        x, skip_x = X
        x = self.bn(x)
        x = self.cT(x)
        x = concatenate([x, skip_x])
        x = self.net(x)
        return x

    def get_config(self):
        base_config = super().get_config()
        return {
            **base_config,
            "filters":self.filters,
            "rate":self.rate,
        }

In [None]:
unet_inputs = Input(shape=(256,256,3), name="UNetInput")

# Encoder Network : Downsampling phase
p1, c1 = Encoder(64, 0.1, name="Encoder1")(unet_inputs)
p2, c2 = Encoder(128, 0.1, name="Encoder2")(p1)
p3, c3 = Encoder(256, 0.2, name="Encoder3")(p2)
p4, c4 = Encoder(512, 0.2, name="Encoder4")(p3)


# Encoding Layer : Latent Representation
e = Encoder(512, 0.3, pooling=False)(p4)

# Attention + Decoder Network : Upsampling phase.
d1 = Decoder(512, 0.2, name="Decoder1")([e, c4])
d2 = Decoder(256, 0.2, name="Decoder2")([d1, c3])
d3 = Decoder(128, 0.1, name="Decoder3")([d2, c2])
d4 = Decoder(64, 0.1, name="Decoder4")([d3, c1])

# Output
unet_out = Conv2D(3, kernel_size=3, padding='same', activation='sigmoid')(d4)

# Model
UNet = Model(
    inputs=unet_inputs,
    outputs=unet_out,
    name="AttentionUNet"
)

# Compiling
UNet.compile(
    loss='binary_crossentropy',
    optimizer='adam'
)

In [None]:
BATCH_SIZE = 16
SPE = len(X_train)//BATCH_SIZE

In [None]:
class ShowProgress(Callback):
    def on_epoch_end(self, epoch, logs=None):
        id = randint(len(X_val))
        image = X_val[id]
        mask = y_val[id]
        pred_mask = self.model(tf.expand_dims(image,axis=0))[0]

        plt.figure(figsize=(10,8))
        plt.subplot(1,3,1)
        show_image(image, title="Original Image")

        plt.subplot(1,3,2)
        show_image(mask, title="Original Mask")

        plt.subplot(1,3,3)
        show_image(pred_mask, title="Predicted Mask")

        plt.tight_layout()
        plt.show()

In [None]:
cbs = [
    ModelCheckpoint('drive/MyDrive/training_dataset/segmodel.h5', save_best_only=True),
    ShowProgress()
]

In [None]:
UNet.fit(
    X_train, y_train,
    epochs=30,
    batch_size=BATCH_SIZE,
    steps_per_epoch=SPE,
    validation_data=(X_val, y_val),
    callbacks=cbs
)

In [None]:
UNet.load_weights('drive/MyDrive/training_dataset/segmodel.h5')

In [None]:
for id in range(300):
    #id = randint(len(X_train))
    image = X_train[id]
    mask = y_train[id]
    pred_mask = UNet.predict(tf.expand_dims(image,axis=0))[0]
    post_process = (pred_mask[:,:,0] > 0.7).astype('uint')

    plt.figure(figsize=(10,8))
    plt.subplot(1,4,1)
    show_image(image, title="Original Image")

    plt.subplot(1,4,2)
    show_image(mask, title="Original Mask")

    plt.subplot(1,4,3)
    show_image(pred_mask, title="Predicted Mask")

    plt.subplot(1,4,4)
    show_image(post_process, title="Post=Processed Mask")
    plt.imshow(post_process, cmap='gray')

    plt.tight_layout()
    plt.show()