### Predict lung masks for Covid database images using model trained on JMS lung segmentation database   



In [1]:
# In[1]:

import os, sys, shutil
from os import listdir
from os.path import isfile, join 
import random
import numpy as np
import cv2


In [2]:
# In[2]:

from MODULES.Generators import train_generator_1, val_generator_1, test_generator_1
from MODULES.Generators import train_generator_2, val_generator_2, test_generator_2
from MODULES.Networks import ResNet_Atrous, Dense_ResNet_Atrous
from MODULES.Networks import ResUNet, ResUNet_Big, ResUNet_CR, ResUNet_CR_Big
from MODULES.Losses import dice_coeff
from MODULES.Losses import tani_loss, tani_coeff, weighted_tani_coeff
from MODULES.Losses import weighted_tani_loss, other_metrics
from MODULES.Constants import _Params, _Paths
from MODULES.Utils import get_class_threshold, get_model_memory_usage
import tensorflow as tf 
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import model_from_json, load_model 
from tensorflow.python.client import device_lib
import matplotlib.pyplot as plt
import datetime

# automatic reload of external definitions if changed during testing
%load_ext autoreload
%autoreload 2


In [3]:
# In[3]:

# ### CONSTANTS

HEIGHT, WIDTH, CHANNELS, IMG_COLOR_MODE, MSK_COLOR_MODE, NUM_CLASS, \
    KS1, KS2, KS3, DL1, DL2, DL3, NF, NFL, NR1, NR2, DIL_MODE, W_MODE, LS, \
    TRAIN_SIZE, VAL_SIZE, TEST_SIZE, DR1, DR2, CLASSES, IMG_CLASS = _Params()

TRAIN_IMG_PATH, TRAIN_MSK_PATH, TRAIN_MSK_CLASS, VAL_IMG_PATH, \
    VAL_MSK_PATH, VAL_MSK_CLASS, TEST_IMG_PATH, TEST_MSK_PATH, TEST_MSK_CLASS = _Paths()


In [5]:
# In[4]: 

# ### LOAD ENTIRE MODEL FROM PREVIOUS RUN END AND COMPILE

model_selection = 'model_' + str(NF) + 'F_' + str(NR1) + 'R1_' + str(NR2) + 'R2'
model_number = '2020-10-02_13_50' # model number from an earlier run
filepath = 'models/' + model_selection + '_' + model_number + '_all' + '.h5'

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    model = load_model(filepath, compile=False)     
    model.compile(optimizer=Adam(), loss=weighted_tani_loss, metrics=[tani_coeff]) 
    

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensor

In [6]:
print(model_selection,model_number)

model_16F_5R1_0R2 2020-10-02_13_50


In [10]:
# In[5]

# ### PREDICT MASKS FOR WORKING SET - PLOTS PATH

print(CLASSES)

# Source directory containing COVID patients lung CXR's
source_resized_img_path = 'dataset/selected_COVID_pos4_neg5_image_resized_equalized/'

# Target directories containing masks predicted for the CXR images in the source directory:
# For JMS database
target_resized_msk_path_binary = 'dataset/selected_COVID_pos4_neg5_masks_binary_952_3/'
target_resized_msk_path_float = 'dataset/selected_COVID_pos4_neg5_masks_float_952_3/'
target_img_mask_path = 'dataset/selected_COVID_pos4_neg5_images_masks_952_3/'

# Remove existing target directories and all their content if already present
pwd = os.getcwd()
root_dir = '/wsu/home/aa/aa14/aa1426/Documents/JENA/MYOTUBES_SEGMENTATION/CXR-Net/Module_1'
if root_dir == pwd:
    for root, dirs, files in os.walk(target_resized_msk_path_binary):
        for f in files:
            os.unlink(os.path.join(root, f))
        for d in dirs:
            shutil.rmtree(os.path.join(root, d))
    for root, dirs, files in os.walk(target_resized_msk_path_float):
        for f in files:
            os.unlink(os.path.join(root, f))
        for d in dirs:
            shutil.rmtree(os.path.join(root, d)) 
    for root, dirs, files in os.walk(target_img_mask_path):
        for f in files:
            os.unlink(os.path.join(root, f))
        for d in dirs:
            shutil.rmtree(os.path.join(root, d)) 
            
# Create directories that will store the masks on which to train the classification network
if not os.path.exists(target_resized_msk_path_binary):
    os.makedirs(target_resized_msk_path_binary)
    
if not os.path.exists(target_resized_msk_path_float):
    os.makedirs(target_resized_msk_path_float) 
    
if not os.path.exists(target_img_mask_path):
    os.makedirs(target_img_mask_path)

# get CXR image names from source directory                
source_img_names = [f for f in listdir(source_resized_img_path) if isfile(join(source_resized_img_path, f))]

for name in source_img_names:
    input_img = cv2.imread(source_resized_img_path + name, cv2.IMREAD_GRAYSCALE)
    scaled_img = input_img/255
    scaled_img = np.expand_dims(scaled_img,axis = [0,-1])
    mask = model(scaled_img).numpy()
    mask_float = np.squeeze(mask[0,:,:,0])    
    mask_binary = (mask_float > 0.5)*1
    
    mask_float *=255    
    mask_binary *=255
    cv2.imwrite(target_resized_msk_path_float + name, mask_float)
    cv2.imwrite(target_resized_msk_path_binary + name, mask_binary)
    
    fig = plt.figure(figsize=(20,10))
    fig.subplots_adjust(hspace=0.4, wspace=0.2)

    ax = fig.add_subplot(1, 2, 1)
    ax.imshow(np.squeeze(input_img), cmap="gray")
    ax = fig.add_subplot(1, 2, 2)
    ax.imshow(np.squeeze(mask_binary), cmap="gray")       

    plt.savefig(target_img_mask_path + name + '_img_and_pred_mask.png') 
    plt.close()
    

['lungs', 'non_lungs']


In [11]:
print(input_img.shape,mask.shape,mask_float.shape,mask_binary.shape)

(300, 340) (1, 300, 340, 2) (300, 340) (300, 340)


In [None]:
# In[6]

# ### PLOT ONE EXAMPLE OF CONTINUOUS MASK

fig = plt.figure(figsize=(20,10))
fig.subplots_adjust(hspace=0.4, wspace=0.2)

ax = fig.add_subplot(1, 2, 1)

ax.imshow(np.squeeze(input_img), cmap="gray")
ax = fig.add_subplot(1, 2, 2)
ax.imshow(np.squeeze(mask_float[:,:]), cmap="gray")    


In [None]:
# In[7]

# ### PLOT ONE EXAMPLE OF BINARY MASK

fig = plt.figure(figsize=(20,10))
fig.subplots_adjust(hspace=0.4, wspace=0.2)

ax = fig.add_subplot(1, 2, 1)

ax.imshow(np.squeeze(input_img), cmap="gray")
ax = fig.add_subplot(1, 2, 2)
ax.imshow(np.squeeze(mask_binary[:,:]), cmap="gray")    
