In [None]:
import pickle
import numpy as np
import pickle
import numpy as np
import tensorflow as tf
import keras
from keras.models import Sequential
from keras.layers import Dense, Conv2D, BatchNormalization, Activation, Conv2DTranspose, UpSampling2D, MaxPooling2D 
from keras import optimizers
from keras.callbacks import EarlyStopping, ModelCheckpoint

from keras.layers import Input, Lambda, Dropout, concatenate
from keras.models import Model, load_model
from keras.optimizers import Adam

from sklearn.metrics import accuracy_score


def conv2d_block(input_tensor, n_filters, kernel_size=3, batchnorm=True):
    
    x = Conv2D(filters=n_filters, kernel_size=(kernel_size, kernel_size), kernel_initializer="he_normal",
               padding="same")(input_tensor)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation("relu")(x)
    
    x = Conv2D(filters=n_filters, kernel_size=(kernel_size, kernel_size), kernel_initializer="he_normal",
               padding="same")(x)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation("relu")(x)
    return x

def make_model(input_img, n_filters=16, dropout=0.5, batchnorm=True):
    
    c1 = conv2d_block(input_img, n_filters=n_filters*1, kernel_size=3, batchnorm=batchnorm)
    p1 = MaxPooling2D((2, 2)) (c1)
    p1 = Dropout(dropout*0.5)(p1)

    c2 = conv2d_block(p1, n_filters=n_filters*2, kernel_size=3, batchnorm=batchnorm)
    p2 = MaxPooling2D((2, 2)) (c2)
    p2 = Dropout(dropout)(p2)

    c3 = conv2d_block(p2, n_filters=n_filters*4, kernel_size=3, batchnorm=batchnorm)
    p3 = MaxPooling2D((2, 2)) (c3)
    p3 = Dropout(dropout)(p3)

    c4 = conv2d_block(p3, n_filters=n_filters*8, kernel_size=3, batchnorm=batchnorm)
    p4 = MaxPooling2D(pool_size=(2, 2)) (c4)
    p4 = Dropout(dropout)(p4)
    
    c5 = conv2d_block(p4, n_filters=n_filters*16, kernel_size=3, batchnorm=batchnorm)
    
    u6 = Conv2DTranspose(n_filters*8, (3, 3), strides=(2, 2), padding='same') (c5)
    u6 = concatenate([u6, c4])
    u6 = Dropout(dropout)(u6)
    c6 = conv2d_block(u6, n_filters=n_filters*8, kernel_size=3, batchnorm=batchnorm)

    u7 = Conv2DTranspose(n_filters*4, (3, 3), strides=(2, 2), padding='same') (c6)
    u7 = concatenate([u7, c3])
    u7 = Dropout(dropout)(u7)
    c7 = conv2d_block(u7, n_filters=n_filters*4, kernel_size=3, batchnorm=batchnorm)

    u8 = Conv2DTranspose(n_filters*2, (3, 3), strides=(2, 2), padding='same') (c7)
    u8 = concatenate([u8, c2])
    u8 = Dropout(dropout)(u8)
    c8 = conv2d_block(u8, n_filters=n_filters*2, kernel_size=3, batchnorm=batchnorm)

    u9 = Conv2DTranspose(n_filters*1, (3, 3), strides=(2, 2), padding='valid') (c8)
    u9 = concatenate([u9, c1], axis=3)
    u9 = Dropout(dropout)(u9)
    c9 = conv2d_block(u9, n_filters=n_filters*1, kernel_size=3, batchnorm=batchnorm)
    
    outputs = Conv2D(2, (1, 1), activation='sigmoid') (c9)
    model = Model(inputs=[input_img], outputs=[outputs])
    return model

def get_iou(y_true, y_pred):
    intersection = np.sum(np.logical_and(y_pred, y_true))
    union = np.sum(np.logical_or(y_pred, y_true))
    mask_sum = np.sum(np.abs(y_true)) + np.sum(np.abs(y_pred))
    smooth = .001
    iou = (intersection + smooth) / (union + smooth)
    iou = np.mean(iou)
    return iou

def get_dice(im1, im2):
    
    im1 = np.asarray(im1).astype(np.bool)
    im2 = np.asarray(im2).astype(np.bool)

    intersection = np.logical_and(im1, im2)
    return 2. * intersection.sum() / (im1.sum() + im2.sum())

In [2]:
#prepare the miccai data for inference
from multiprocessing import Pool
import matplotlib.pyplot as plt
import os
from PIL import Image
from color_deconvolution_2 import color_deconvolution
import time

im_loc = '/home/bkf15/lumen_seg/MiccaiData_ImageTiles/'
original_images = os.listdir(im_loc)
original_images = [o for o in original_images if "original.jpg" in o and "boundary" not in o]
num_samples = len(original_images)
im_shape = (1025, 1025)

# X_miccai = np.zeros((num_samples, im_shape[0], im_shape[1], 1), dtype=np.uint8)
#orig_ims = [] #unblurred ims
# im_names = []
s = time.time()
# for i, im_f in enumerate(os.listdir(im_loc)):
def process_image(i):
    # im_names.append(original_images[i])
    im = Image.open(os.path.join(im_loc, original_images[i]))
    im_small = np.asarray(im.resize((256,256)))
    #im = np.asarray(Image.open(os.path.join(im_loc, im_f)).resize((256,256)))
    im_decon = color_deconvolution(im_small)[0]
    im_decon_re = np.asarray(Image.fromarray(im_decon).resize(im_shape)).copy()
    #convert to int8
    im_decon_re[im_decon_re > 255] = 255 #already in range (0,255) , just being safe
    im_decon_re[im_decon_re < 0] = 0
    im_int = np.expand_dims(im_decon_re.astype(np.uint8), axis=2)
    # return im_int, original_images[i]
    return {"image": im_int, "name": original_images[i]}
    
    # if i % 100 == 0:
        # print("on image {0} of {1} in {2:.2f} seconds".format(i, num_samples, time.time()-s))
        # print('Image {0} of {1} took {2:.2f} seconds.'.format(i, num_samples, time.time()-s))
        # s = time.time()
        
# multi thread
if __name__ == '__main__':
    pool = Pool(processes=12)
    X_miccai = pool.map(process_image, range(len(original_images)))
    # X_miccai = pool.map(process_image, range(3))
    # X_miccai = []
    # im_names = res[:,1]
    # X_miccai = np.array(X_miccai)
    # print(f"done, shape: {X_miccai.shape}")
    
# #save the data
with open('unet_data/miccai_prepared_ims.pickle', 'wb') as f:
    pickle.dump(X_miccai, f, protocol=4)


Process ForkPoolWorker-8:
Process ForkPoolWorker-4:
Process ForkPoolWorker-10:
Process ForkPoolWorker-9:
Process ForkPoolWorker-2:
Process ForkPoolWorker-12:
Process ForkPoolWorker-7:
Process ForkPoolWorker-3:
Process ForkPoolWorker-1:
Process ForkPoolWorker-11:
Process ForkPoolWorker-5:
Process ForkPoolWorker-6:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/bkf15/anaconda3/envs/unet/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/home/bkf15/anaconda3/envs/unet/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/bkf15/anaconda3/envs/unet/lib/python3.7/multiprocessing/process.py", line 297

KeyboardInterrupt
  File "/home/bkf15/anaconda3/envs/unet/lib/python3.7/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/bkf15/anaconda3/envs/unet/lib/python3.7/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
KeyboardInterrupt
  File "/home/bkf15/anaconda3/envs/unet/lib/python3.7/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
KeyboardInterrupt
  File "/home/bkf15/anaconda3/envs/unet/lib/python3.7/multiprocessing/queues.py", line 351, in get
    with self._rlock:
KeyboardInterrupt
  File "/home/bkf15/anaconda3/envs/unet/lib/python3.7/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
KeyboardInterrupt
KeyboardInterrupt


In [14]:
#Read in prepared MICCAI images, predict masks, and save results 
import matplotlib.pyplot as plt
import pickle
import time
import os
from skimage.morphology import binary_dilation, binary_erosion, area_opening
from PIL import Image
%matplotlib inline
plt.rcParams['figure.figsize'] = [20, 10]
with open('unet_data/miccai_prepared_ims.pickle', "rb") as input_file:
    X = pickle.load(input_file)
# X = X_miccai

model = make_model(Input((1025, 1025, 1), name='img'))
# Restore the weights
model.load_weights('/home/bkf15/akash_data/models/best_unet')

#list of all the mask files
mask_dir = '/home/bkf15/lumen_seg/MiccaiData_ImageTiles/'
masks = os.listdir(mask_dir)
masks = [m for m in masks if "mask" in m]
out_dir = 'miccai2021_lumen_seg_results_03'
original_images = os.listdir('/home/bkf15/lumen_seg/MiccaiData_ImageTiles/')
original_images = [i for i in original_images if "original.jpg" in i]

threshold = 0.3

# Y_hat = model.predict(X, verbose=1)

print("starting prediction")
s = time.time()
# for i, x_i in enumerate(X):   #or, do the predictions one by one
def predict_1(i):
    x_i = X[i]["image"]
    x_i_d = np.expand_dims(x_i, axis=0)
    y_i = model.predict(x_i_d)[0]
    lumen_pred = y_i[:,:,0].copy()
    
    #mask out non duct pixels
    im_name = X[i]["name"][:-13]
    mask = np.asarray(Image.open(os.path.join(mask_dir, [m for m in masks if im_name in m][0])).resize((1025,1025)))
    # mask = mask[:,:,0]
    lumen_pred[mask < 1] = 0
    
    #threshold + morphology
    lumen_mask = (lumen_pred > threshold).astype(np.uint8) * 255
    for _ in range(3):
        lumen_mask = binary_dilation(lumen_mask)
    for _ in range(5):
        lumen_mask = binary_erosion(lumen_mask)
    lumen_mask = area_opening(lumen_mask, 500)
    
    # resize lumen mask to original image shape
    disp_im = np.asarray(Image.open("/home/bkf15/lumen_seg/MiccaiData_ImageTiles/" 
                                    + [o_im for o_im in original_images if im_name in o_im and "boundary" in o_im][0]))
    resized_lumen = np.asarray(Image.fromarray(lumen_mask).resize((disp_im.shape[1], disp_im.shape[0])))
    
    #create display image and save
    # disp_im = np.asarray(Image.open("/home/bkf15/lumen_seg/MiccaiData_ImageTiles/" 
                                    # + [o_im for o_im in original_images if im_name in o_im and "boundary" in o_im][0]).resize((1025,1025))).copy()
    
    disp_im = np.asarray(disp_im).copy()
    disp_im[resized_lumen] = (255,0,0)
    Image.fromarray(disp_im).save(os.path.join(out_dir, 'disp_ims', im_name + '_lumen_drawing.jpg'))
    
    #save
    Image.fromarray(lumen_mask).save(os.path.join(out_dir, im_name + '_lumen_seg.png'))
    
    if i > 0 and i % 10 == 0:
        print(f'Done {i} of {len(X)} in {round(time.time() - s, 2)} seconds.')
    return
        
# multi thread
# if __name__ == '__main__':
    # pool = Pool(processes=12)
    # res = pool.map(predict_1, range(len(X)))
    # print("Done:)")
# predict_1(0)
for i in range(len(X)):
    predict_1(i)


starting prediction
Done 10 of 1880 in 17.1 seconds.
Done 20 of 1880 in 31.69 seconds.
Done 30 of 1880 in 46.65 seconds.
Done 40 of 1880 in 61.4 seconds.
Done 50 of 1880 in 76.05 seconds.
Done 60 of 1880 in 90.75 seconds.
Done 70 of 1880 in 105.48 seconds.
Done 80 of 1880 in 120.13 seconds.
Done 90 of 1880 in 135.08 seconds.
Done 100 of 1880 in 149.88 seconds.
Done 110 of 1880 in 164.65 seconds.
Done 120 of 1880 in 179.5 seconds.
Done 130 of 1880 in 193.95 seconds.
Done 140 of 1880 in 208.29 seconds.
Done 150 of 1880 in 223.15 seconds.
Done 160 of 1880 in 237.72 seconds.
Done 170 of 1880 in 252.45 seconds.
Done 180 of 1880 in 267.33 seconds.
Done 190 of 1880 in 282.01 seconds.
Done 200 of 1880 in 296.45 seconds.
Done 210 of 1880 in 311.28 seconds.
Done 220 of 1880 in 326.35 seconds.
Done 230 of 1880 in 340.96 seconds.
Done 240 of 1880 in 355.83 seconds.
Done 250 of 1880 in 370.77 seconds.
Done 260 of 1880 in 385.67 seconds.
Done 270 of 1880 in 400.54 seconds.
Done 280 of 1880 in 415.71

In [None]:
print(len(X))