In [None]:
from datetime import datetime
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D, SeparableConv2D, Conv2DTranspose
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Reshape
from tensorflow.keras.layers import Permute
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Lambda
from tensorflow.keras.models import Model

In [None]:
#Pre-process images
import os
import cv2
import matplotlib.pyplot as plt

train_path = '../input/wnet-data-large/data_large/Data_Large/train/'
val_path = '../input/wnetdataset/data/Data/val/'

train_images = os.listdir(train_path)
val_images = os.listdir(val_path)

X_train = []
X_val = []

for i in range(len(train_images)):
    img = cv2.imread(train_path+train_images[i])
    img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
    resized = cv2.resize(img, (224,224), interpolation = cv2.INTER_AREA)
    X_train.append(resized)
X_train = np.asarray(X_train)
X_train = X_train.astype('float32')/255.
X_train = np.reshape(X_train, (len(X_train), 224, 224, 1))

for i in range(len(val_images)):
    img = cv2.imread(val_path+val_images[i])
    img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
    resized = cv2.resize(img, (224,224), interpolation = cv2.INTER_AREA)
    X_val.append(resized)
X_val = np.asarray(X_val)
X_val = X_val.astype('float32')/255.
X_val = np.reshape(X_val, (len(X_val), 224, 224, 1))

print("Number of train images:",len(X_train), "Shape of X_train:", X_train.shape)
print("Number of val images:",len(X_val), "Shape of X_val:", X_val.shape)


In [None]:
# Shuffle X_train and X_val
np.random.shuffle(X_train)
np.random.shuffle(X_val)

In [None]:
#Parameters
input_img = Input(shape=(224, 224, 1))
droprate=0.2
droprate_input = 0.8
num_classes = 3 #background, cell boundary, cell.
num_epochs = 10
ae_lr = 0.001
enc_lr = 0.0001
num_batches = 80
batch_size = len(X_train)//num_batches
batch_size

In [None]:
#Concatination fro skip connections in the network
def upconv_concat(bottom_a, bottom_b, n_filter, pool_size, stride, padding='VALID'):
    up_conv = Conv2DTranspose(filters=n_filter, kernel_size=[pool_size, pool_size],
                                         strides=stride, padding=padding)(bottom_a)
    return Concatenate(axis=-1)([up_conv, bottom_b])


In [None]:
#Encoder

#Module 1
conv_1_1 = Conv2D(filters = 64, kernel_size = 3, activation='relu', padding='same')(input_img)
conv_1_1_bn = BatchNormalization()(conv_1_1)
conv_1_1_do = Dropout(droprate)(conv_1_1_bn)

conv_1_2 = Conv2D(filters = 64, kernel_size = 3, activation='relu', padding='same')(conv_1_1_do)
conv_1_2_bn = BatchNormalization()(conv_1_2)
conv_1_2_do = Dropout(droprate)(conv_1_2_bn)

pool_1 = MaxPooling2D(pool_size= 2, strides = 2)(conv_1_2_do) #Module 1 to Module 2

#Module 2

conv_2_1 = SeparableConv2D(filters = 128, kernel_size = 3, activation='relu', padding='same')(pool_1)
conv_2_1_bn = BatchNormalization()(conv_2_1)
conv_2_1_do = Dropout(droprate)(conv_2_1_bn)

conv_2_2 = SeparableConv2D(filters = 128, kernel_size = 3, activation='relu', padding='same')(conv_2_1_do)
conv_2_2_bn = BatchNormalization()(conv_2_2)
conv_2_2_do = Dropout(droprate)(conv_2_2_bn)

pool_2 = MaxPooling2D(pool_size= 2, strides = 2)(conv_2_2_do) #Module 2 to Module 3

#Module 3

conv_3_1 = SeparableConv2D(filters = 256, kernel_size = 3, activation='relu', padding='same')(pool_2)
conv_3_1_bn = BatchNormalization()(conv_3_1)
conv_3_1_do = Dropout(droprate)(conv_3_1_bn)

conv_3_2 = SeparableConv2D(filters = 256, kernel_size = 3, activation='relu', padding='same')(conv_3_1_do)
conv_3_2_bn = BatchNormalization()(conv_3_2)
conv_3_2_do = Dropout(droprate)(conv_3_2_bn)

pool_3 = MaxPooling2D(pool_size= 2, strides = 2)(conv_3_2_do) #Module 3 to Module 4

#Module 4

conv_4_1 = SeparableConv2D(filters = 512, kernel_size = 3, activation='relu', padding='same')(pool_3)
conv_4_1_bn = BatchNormalization()(conv_4_1)
conv_4_1_do = Dropout(droprate)(conv_4_1_bn)

conv_4_2 = SeparableConv2D(filters = 512, kernel_size = 3, activation='relu', padding='same')(conv_4_1_do)
conv_4_2_bn = BatchNormalization()(conv_4_2)
conv_4_2_do = Dropout(droprate)(conv_4_2_bn)

pool_4 = MaxPooling2D(pool_size= 2, strides = 2)(conv_4_2_do) #Module 4 to Module 5

#Module 5

conv_5_1 = SeparableConv2D(filters = 1024, kernel_size = 3, activation='relu', padding='same')(pool_4)
conv_5_1_bn = BatchNormalization()(conv_5_1)
conv_5_1_do = Dropout(droprate)(conv_5_1_bn)

conv_5_2 = SeparableConv2D(filters = 1024, kernel_size = 3, activation='relu', padding='same')(conv_5_1_do)
conv_5_2_bn = BatchNormalization()(conv_5_2)
conv_5_2_do = Dropout(droprate)(conv_5_2_bn)

upconv_1 = upconv_concat(conv_5_2_do, conv_4_2_do, n_filter=512, pool_size=2, stride=2) #Module 5 to 6

#Module 6

conv_6_1 = SeparableConv2D(filters = 512, kernel_size = 3, activation='relu', padding='same')(upconv_1)
conv_6_1_bn = BatchNormalization()(conv_6_1)
conv_6_1_do = Dropout(droprate)(conv_6_1_bn)

conv_6_2 = SeparableConv2D(filters = 512, kernel_size = 3, activation='relu', padding='same')(conv_6_1_do)
conv_6_2_bn = BatchNormalization()(conv_6_2)
conv_6_2_do = Dropout(droprate)(conv_6_2_bn)

upconv_2 = upconv_concat(conv_6_2_do, conv_3_2_do, n_filter=256, pool_size=2, stride=2) #Module 6 to 7

#Module 7

conv_7_1 = SeparableConv2D(filters = 256, kernel_size = 3, activation='relu', padding='same')(upconv_2)
conv_7_1_bn = BatchNormalization()(conv_7_1)
conv_7_1_do = Dropout(droprate)(conv_7_1_bn)

conv_7_2 = SeparableConv2D(filters = 256, kernel_size = 3, activation='relu', padding='same')(conv_7_1_do)
conv_7_2_bn = BatchNormalization()(conv_7_2)
conv_7_2_do = Dropout(droprate)(conv_7_2_bn)

upconv_3 = upconv_concat(conv_7_2_do, conv_2_2_do, n_filter=128, pool_size=2, stride=2) #Module 7 to 8

#Module 8

conv_8_1 = SeparableConv2D(filters = 128, kernel_size = 3, activation='relu', padding='same')(upconv_3)
conv_8_1_bn = BatchNormalization()(conv_8_1)
conv_8_1_do = Dropout(droprate)(conv_8_1_bn)

conv_8_2 = SeparableConv2D(filters = 128, kernel_size = 3, activation='relu', padding='same')(conv_8_1_do)
conv_8_2_bn = BatchNormalization()(conv_8_2)
conv_8_2_do = Dropout(droprate)(conv_8_2_bn)

upconv_4 = upconv_concat(conv_8_2_do, conv_1_2_do, n_filter=64, pool_size=2, stride=2) #Module 8 to 9

#Module 9

conv_9_1 = SeparableConv2D(filters = 64, kernel_size = 3, activation='relu', padding='same')(upconv_4)
conv_9_1_bn = BatchNormalization()(conv_9_1)
conv_9_1_do = Dropout(droprate)(conv_9_1_bn)

conv_9_2 = SeparableConv2D(filters = 64, kernel_size = 3, activation='relu', padding='same')(conv_9_1_do)
conv_9_2_bn = BatchNormalization()(conv_9_2)
conv_9_2_do = Dropout(droprate)(conv_9_2_bn)

# encoder_output = Custom_Conv()(conv_9_2_do, k_size=1, num_outputs=3, stride=1) 

final_conv = Conv2D(num_classes, 1, 1)(conv_9_2_do)

x = Reshape((num_classes, 224*224))(final_conv)
x = Permute((2,1))(x)
x = Activation("softmax")(x)
encoder_output = Reshape((224, 224, num_classes))(x) #Module 9 to 10

In [None]:
#Decoder

#Module 10

conv_10_1 = Conv2D(filters = 64, kernel_size = 3, activation='relu', padding='same')(encoder_output)
conv_10_1_bn = BatchNormalization()(conv_10_1)
conv_10_1_do = Dropout(droprate)(conv_10_1_bn)

conv_10_2 = Conv2D(filters = 64, kernel_size = 3, activation='relu', padding='same')(conv_10_1_do)
conv_10_2_bn = BatchNormalization()(conv_10_2)
conv_10_2_do = Dropout(droprate)(conv_10_2_bn)

pool_5 = MaxPooling2D(pool_size= 2, strides = 2)(conv_10_2_do) #Module 10 to 11

#Module 11

conv_11_1 = SeparableConv2D(filters = 128, kernel_size = 3, activation='relu', padding='same')(pool_5)
conv_11_1_bn = BatchNormalization()(conv_11_1)
conv_11_1_do = Dropout(droprate)(conv_11_1_bn)

conv_11_2 = SeparableConv2D(filters = 128, kernel_size = 3, activation='relu', padding='same')(conv_11_1_do)
conv_11_2_bn = BatchNormalization()(conv_11_2)
conv_11_2_do = Dropout(droprate)(conv_11_2_bn)

pool_6 = MaxPooling2D(pool_size= 2, strides = 2)(conv_11_2_do) #Module 11 to 12

#Module 12

conv_12_1 = SeparableConv2D(filters = 256, kernel_size = 3, activation='relu', padding='same')(pool_6)
conv_12_1_bn = BatchNormalization()(conv_12_1)
conv_12_1_do = Dropout(droprate)(conv_12_1_bn)

conv_12_2 = SeparableConv2D(filters = 256, kernel_size = 3, activation='relu', padding='same')(conv_12_1_do)
conv_12_2_bn = BatchNormalization()(conv_12_2)
conv_12_2_do = Dropout(droprate)(conv_12_2_bn)

pool_7 = MaxPooling2D(pool_size= 2, strides = 2)(conv_12_2_do) #Module 12 to 13

#Module 13

conv_13_1 = SeparableConv2D(filters = 512, kernel_size = 3, activation='relu', padding='same')(pool_7)
conv_13_1_bn = BatchNormalization()(conv_13_1)
conv_13_1_do = Dropout(droprate)(conv_13_1_bn)

conv_13_2 = SeparableConv2D(filters = 512, kernel_size = 3, activation='relu', padding='same')(conv_13_1_do)
conv_13_2_bn = BatchNormalization()(conv_13_2)
conv_13_2_do = Dropout(droprate)(conv_13_2_bn)

pool_8 = MaxPooling2D(pool_size= 2, strides = 2)(conv_13_2_do) #Module 13 to 14

#Module 14

conv_14_1 = SeparableConv2D(filters = 1024, kernel_size = 3, activation='relu', padding='same')(pool_8)
conv_14_1_bn = BatchNormalization()(conv_14_1)
conv_14_1_do = Dropout(droprate)(conv_14_1_bn)

conv_14_2 = SeparableConv2D(filters = 1024, kernel_size = 3, activation='relu', padding='same')(conv_14_1_do)
conv_14_2_bn = BatchNormalization()(conv_14_2)
conv_14_2_do = Dropout(droprate)(conv_14_2_bn)

upconv_5 = upconv_concat(conv_14_2_do, conv_13_2_do, n_filter=512, pool_size=2, stride=2)  #Module 14 to 15

#Module 15

conv_15_1 = SeparableConv2D(filters = 512, kernel_size = 3, activation='relu', padding='same')(upconv_5)
conv_15_1_bn = BatchNormalization()(conv_15_1)
conv_15_1_do = Dropout(droprate)(conv_15_1_bn)

conv_15_2 = SeparableConv2D(filters = 512, kernel_size = 3, activation='relu', padding='same')(conv_15_1_do)
conv_15_2_bn = BatchNormalization()(conv_15_2)
conv_15_2_do = Dropout(droprate)(conv_15_2_bn)

upconv_6 = upconv_concat(conv_15_2_do, conv_12_2_do, n_filter=256, pool_size=2, stride=2)  #Module 15 to 16

#Module 16

conv_16_1 = SeparableConv2D(filters = 256, kernel_size = 3, activation='relu', padding='same')(upconv_6)
conv_16_1_bn = BatchNormalization()(conv_16_1)
conv_16_1_do = Dropout(droprate)(conv_16_1_bn)

conv_16_2 = SeparableConv2D(filters = 256, kernel_size = 3, activation='relu', padding='same')(conv_16_1_do)
conv_16_2_bn = BatchNormalization()(conv_16_2)
conv_16_2_do = Dropout(droprate)(conv_16_2_bn)

upconv_7 = upconv_concat(conv_16_2_do, conv_11_2_do, n_filter=128, pool_size=2, stride=2)  #Module 16 to 17

#Module 17

conv_17_1 = SeparableConv2D(filters = 128, kernel_size = 3, activation='relu', padding='same')(upconv_7)
conv_17_1_bn = BatchNormalization()(conv_17_1)
conv_17_1_do = Dropout(droprate)(conv_17_1_bn)

conv_17_2 = SeparableConv2D(filters = 128, kernel_size = 3, activation='relu', padding='same')(conv_17_1_do)
conv_17_2_bn = BatchNormalization()(conv_17_2)
conv_17_2_do = Dropout(droprate)(conv_17_2_bn)

upconv_8 = upconv_concat(conv_17_2_do, conv_10_2_do, n_filter=64, pool_size=2, stride=2)  #Module 17 to 18

#Module 18

conv_18_1 = Conv2D(filters = 64, kernel_size = 3, activation='relu', padding='same')(upconv_8)
conv_18_1_bn = BatchNormalization()(conv_18_1)
conv_18_1_do = Dropout(droprate)(conv_18_1_bn)

conv_18_2 = Conv2D(filters = 64, kernel_size = 3, activation='relu', padding='same')(conv_18_1_do)
conv_18_2_bn = BatchNormalization()(conv_18_2)
conv_18_2_do = Dropout(droprate)(conv_18_2_bn)

decoder_output =  Conv2D(filters = 1, kernel_size = 1, activation='relu', padding='same')(conv_18_2_do)

## Encoder Model

In [None]:
encoder_model = Model(input_img, encoder_output)

In [None]:
encoder_model.summary()

In [None]:
#parameters for normalized cut loss
sigma_pixel = tf.square(tf.constant(10.0))
sigma_dist = 4.0 ** 2
r = 5 #radius
k = tf.constant(num_classes, dtype=tf.float32)

In [None]:
#place holders for encoder model
original_image = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 1]) #same as input image

In [None]:
#Normalized cut loss and optimizer for encoder model


#create a spatial kernel

s = 2 * r + 1
spatial_kernel = np.zeros((s, s), dtype=np.float32)
for y in range(s):
    for x in range(s):
        # calculate squared euclidean distance
        dist = (x - r) * (x - r) + (y - r) * (y - r)
        if dist < (r * r):
            spatial_kernel[y][x] = np.exp((-dist) / sigma_dist)

spatial_kernel = tf.constant(spatial_kernel.reshape(-1), dtype=tf.float32)


#create one dimensional kernel

s = 2 * r + 1
one_dim_kernel = np.zeros((s, s, (s * s)))
for i in range(s * s):
    one_dim_kernel[int(i / s)][i % s][i] = 1.0
one_dim_kernel = one_dim_kernel.reshape(s, s, 1, (s * s))
one_dim_kernel = tf.constant(one_dim_kernel, dtype=tf.float32)

In [None]:
class calc_norm_loss():
    def __init__(self):
        self.num_sum = tf.constant(0.0, dtype=tf.float32)
    def cal_sum(self):
        for depth in range(num_classes):
            softmax_layer = encoder_model(original_image)[:, :, :, depth:depth + 1]
            extracted_pixels = tf.nn.conv2d(softmax_layer, one_dim_kernel, strides=[1, 1, 1, 1], padding='SAME')

            intensity_sq_dif = tf.squared_difference(extracted_pixels, softmax_layer)
            intensity_values = tf.exp(tf.divide(tf.negative(intensity_sq_dif), sigma_pixel))

            weights = tf.multiply(intensity_values, spatial_kernel)
            # Reshape Input Softmax Layer for correct dimensions
            u_pixels = tf.reshape(softmax_layer, [batch_size, 224, 224])
            # Calculate entire numerator
            numerator_inner_sum = tf.reduce_sum(tf.multiply(weights, extracted_pixels), axis=3)
            numerator_outer_sum = tf.multiply(u_pixels, numerator_inner_sum)
            numerator = tf.reduce_sum(numerator_outer_sum)
            # Calculate denominator
            denominator_inner_sum = tf.reduce_sum(weights, axis=3)
            denominator_outer_sum = tf.multiply(u_pixels, denominator_inner_sum)
            denominator = tf.reduce_sum(denominator_outer_sum)

            processed_value = numerator / denominator
            self.num_sum += processed_value
            print("came here")
            return self.num_sum

In [None]:
norm_loss = calc_norm_loss()
norm_cut_loss = k-norm_loss.cal_sum()
norm_cut_opt = tf.train.AdamOptimizer(learning_rate=enc_lr, beta1=0.9, beta2=0.999).minimize(norm_cut_loss)

In [None]:
init1 = tf.global_variables_initializer()
init2 = tf.local_variables_initializer()

In [None]:
# Train encoder model
with tf.Session() as sess:
    sess.run(init1)
    sess.run(init2)
    for epoch in range(num_epochs):
        print("epoch:", epoch)
        count = 0
        batch_start_index = 0
        while (count != num_batches):
            X_train_batch = X_train[batch_start_index : batch_start_index+batch_size]
            _, train_loss = sess.run([norm_cut_opt,norm_cut_loss], feed_dict={original_image: X_train_batch})
            batch_start_index+=batch_size
            count+=1
        print("Train loss after ", str(epoch), "is", str(train_loss))

## Full model

In [None]:
# wnet_autoencoder = Model(input_img, decoder_output)

In [None]:
# wnet_autoencoder.summary()

In [None]:
# #place holders for wnet_autoencoder
# input_image = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 1])

# # Reconstruction loss and loss for autoencoder model
# rec_loss = tf.reduce_mean(tf.squared_difference(wnet_autoencoder(input_image), input_image))
# rec_opt = tf.train.GradientDescentOptimizer(learning_rate=ae_lr).minimize(rec_loss)

# init1 = tf.global_variables_initializer()
# init2 = tf.local_variables_initializer()

In [None]:
# #train autoencoder
# with tf.Session() as sess:
#     sess.run(init1)
#     sess.run(init2)
#     for epoch in range(num_epochs):
#         print("epoch:", epoch)
#         count = 0
#         batch_start_index = 0
#         while (count != num_batches):
#             X_train_batch = X_train[batch_start_index : batch_start_index+batch_size]
#             _, train_loss = sess.run([rec_opt,rec_loss], feed_dict={input_image: X_train_batch})
#             batch_start_index+=batch_size
#             count+=1
#         print("Train loss after ", str(epoch), "is", str(train_loss))
        

In [None]:
# Train full Auto-Encoder with keras

# with tf.Session() as sess:
#     sess.run(init1)
#     sess.run(init2)
#     wnet_autoencoder.compile(optimizer='adagrad', loss='mean_squared_error')
#     history = wnet_autoencoder.fit(X_train, X_train,
#                     epochs=5,
#                     batch_size=8,
#                     shuffle=True,
#                     validation_data=(X_val, X_val)).history

In [None]:
# Train Encoder with keras. 

# with tf.Session() as sess:
#     sess.run(init1)
#     sess.run(init2)
#     encoder_model.compile(optimizer='adagrad', loss='sparse_categorical_crossentropy')
#     history = encoder_model.fit(X_train, X_train,
#                     epochs=5,
#                     batch_size=8,
#                     shuffle=True,
#                     validation_data=(X_val, X_val)).history

In [None]:
#Plot graph for train_loss vs val_loss for keras training

# plt.plot(history['loss'], linewidth=2, label='Train')
# plt.plot(history['val_loss'], linewidth=2, label='Test')
# plt.legend(loc='upper right')
# plt.title('Model loss')
# plt.ylabel('Loss')
# plt.xlabel('Epoch')
# # plt.show()

## Testing the Resutls

In [None]:
#Load a test image and preprocess it.

img = cv2.imread('../input/test-dataset/a184.jpg')
img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
resized = cv2.resize(img, (224,224), interpolation = cv2.INTER_AREA)
resized = resized/255
resized1 = resized[:, :, np.newaxis]
print(resized1.shape)
resized1 = resized1[np.newaxis, :, :] 
print(resized1.shape)

In [None]:
init1 = tf.global_variables_initializer()
init2 = tf.local_variables_initializer()

In [None]:
#Test Encoder model
with tf.Session() as sess:
    sess.run(init1)
    sess.run(init2)
    img3 = encoder_model.predict(resized1)

In [None]:
img3

In [None]:
print(img3.shape)

In [None]:
img3 = np.reshape(img3, (224,224,3))
plt.imshow(img3)

In [None]:
img3[0][0][0] + img3[0][0][1] + img3[0][0][2]

In [None]:
output_image_copy = img3.copy()

output_image_copy[:, :, 2] = 0
output_image_copy[:, :, 0] = 0

#output_image_copy = output_image_copy*255
plt.imshow(output_image_copy)