In [1]:
# show images inline
%matplotlib inline

# automatically reload modules when they have changed
%load_ext autoreload
%autoreload 2

# import keras
import keras
import os

# import keras_retinanet
from keras_maskrcnn import models
from keras_maskrcnn.utils.visualization import draw_mask
from keras_retinanet.utils.visualization import draw_box, draw_caption, draw_annotations
from keras_retinanet.utils.image import read_image_bgr, preprocess_image, resize_image
from keras_retinanet.utils.colors import label_color
from keras_retinanet.utils.gpu import setup_gpu

from keras_maskrcnn.bin.train import create_models

import argparse
import sys
import keras_retinanet

from keras_maskrcnn.preprocessing import csv_generator
import pandas as pd

from keras_retinanet.utils.transform import random_transform_generator
from keras_retinanet.utils.image import random_visual_effect_generator

from keras_maskrcnn.bin.train import create_callbacks
from keras_maskrcnn.bin.train import parse_args
from keras_retinanet.callbacks import RedirectModel
from keras_maskrcnn.callbacks.eval import Evaluate
from keras_retinanet.models.retinanet import retinanet_bbox

# import miscellaneous modules
import matplotlib.pyplot as plt
import cv2
import numpy as np
import time
import pyximport
pyximport.install()
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
# set tf backend to allow memory to grow, instead of claiming everything
import tensorflow as tf
keras.backend.tensorflow_backend._get_available_gpus()
import keras_maskrcnn


Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
import os,re,random,time,sys
import numpy as np
import SimpleITK as sitk
from matplotlib import pyplot as plt
import pandas as pd
from skimage.measure import regionprops,label
from time import time
from tqdm import tqdm
from math import sqrt
from PIL import Image
from skimage import exposure
import random 


## define example paths

In [17]:
path_input_low_energy= r"example input\example_img_low_energy.mha"
path_input_recombined = r"example input\example_img_recombined.mha"
path_weights = r"resnet101_16.h5"

## Preprocessing

In [4]:
def crop_img(img_read_re, img_read_le):
    otsu = sitk.OtsuThresholdImageFilter()
    otsu_image = otsu.Execute(img_read_re)
    otsu_array = sitk.GetArrayFromImage(otsu_image)
    
    invert_otsu = (np.ones(otsu_array.shape)-otsu_array).astype(np.uint8)
    (contours,_) = cv2.findContours(invert_otsu, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
    max_ctr = contours[0]
    for ctr in contours:
        if cv2.contourArea(ctr) > cv2.contourArea(max_ctr):
            max_ctr= ctr
    img_temp = np.zeros(invert_otsu.shape)
    polygon = ctr
    cv2.fillPoly( img_temp, [max_ctr], [1] )
    otsu_array = np.ones(invert_otsu.shape)-img_temp
        
    temp_img_re = sitk.GetArrayFromImage(img_read_re)
    
    temp_img_le_original = sitk.GetArrayFromImage(img_read_le)
        
    temp_img_re =(np.ones((img_read_re.GetSize()[1], img_read_re.GetSize()[0]))-otsu_array)*temp_img_re
    temp_img_le =(np.ones((img_read_le.GetSize()[1], img_read_le.GetSize()[0]))-otsu_array)*temp_img_le_original
    props = regionprops(np.array(temp_img_re>0,np.uint8))
    r0, c0, r1, c1 = props[0].bbox
    temp_img_re = temp_img_re[r0:r1, c0:c1]
    
    temp_img_re = pre_processing_for_img(temp_img_re)
    
    temp_img_le = temp_img_le[r0:r1, c0:c1]
    
    temp_img_le = pre_processing_for_img(temp_img_le)
    
    return temp_img_re, temp_img_le

In [5]:
def resample_intensities(orig_img,bin_nr=256):      
    v_count=0
    img_list=[]
    filtered = orig_img.copy()
    if np.min(orig_img.flatten())<0:
        filtered+=np.min(orig_img.flatten())
    resampled = np.zeros_like(filtered)
    max_val_img = np.max(filtered.flatten())
    min_val_img = np.min(filtered.flatten())
    step = (max_val_img-min_val_img)/bin_nr

    for st in np.arange(step+min_val_img,max_val_img+step,step):
        resampled[(filtered<=st)&(filtered>=st-step)] = v_count
        v_count+=1
    
    return np.array(resampled,dtype=np.uint16)

In [6]:
def pre_processing_for_img(img):
    temp_img = img.copy()
    low_thr = np.quantile(temp_img[temp_img>0], 0.01)
    high_thr = np.quantile(temp_img[temp_img>0], 0.99)
    temp_img[temp_img<low_thr] = low_thr
    temp_img[temp_img>high_thr] = high_thr
    if len(np.unique(temp_img[temp_img>0])) > 256:
            temp_img_sampled = resample_intensities(temp_img[temp_img>0])
            temp_img[temp_img>0] = temp_img_sampled
    else:
        new_img =(temp_img-np.min(temp_img))/(np.max(temp_img)-np.min(temp_img)) 
        temp_img = (new_img*255).astype(np.uint8)
    return temp_img

In [7]:
def load_patient(path_input_recombined,path_input_low_energy):
    img_read_re = sitk.ReadImage(path_input_recombined)
    img_read_le = sitk.ReadImage(path_input_low_energy)
    img_re, img_le = crop_img(img_read_re, img_read_le)
    return img_le,img_re

In [22]:
def preprocessing(path_input_low_energy,path_input_recombined):
    clahe = cv2.createCLAHE(clipLimit =2.5, tileGridSize=(16,16))
    clahe_recombined_2 = cv2.createCLAHE(clipLimit =1.0, tileGridSize=(16,16)) 

    img_le, img_re = load_patient(path_input_recombined,path_input_low_energy)
    new_img_le =(img_le-np.min(img_le))/(np.max(img_le)-np.min(img_le)) 
    im_le = (new_img_le*255).astype(np.uint8)

    x_le = clahe.apply(im_le).astype(np.uint8)
    new_img_re =(img_re-np.min(img_re))/(np.max(img_re)-np.min(img_re)) 
    new_img_re = (new_img_re*255).astype(np.uint8)

    x_re = clahe.apply(new_img_re).astype(np.uint8)
    x_re_2 = clahe_recombined_2.apply(new_img_re).astype(np.uint8)

    temp_im_le = Image.fromarray(x_le)
    temp_im_re = Image.fromarray(x_re)
    temp_im_re_2 = Image.fromarray(x_re_2)
    merged_img =Image.fromarray( cv2.merge((x_le,x_re,x_re_2)) )
    img_rgb = merged_img.convert("RGB")
    img_bgr = np.asarray(img_rgb)[:, :, ::-1].copy()
    return img_bgr

## Prediction

In [27]:
def detect_and_classify_lesions(path_input_recombined,path_input_low_energy,path_weights):
    
    img_bgr = preprocessing(path_input_low_energy,path_input_recombined)
    
    backbone = models.backbone("resnet101")
    model, training_model, prediction_model = create_models(backbone_retinanet=backbone.maskrcnn, num_classes=2, weights=path_weights , freeze_backbone=False)   
    
    mask_prediction = []
    dict_prediction = {}
    labels_to_names = {0:'benign',1:'malignant'}
    image = img_bgr.copy()
    image = preprocess_image(image)
    image, scale = resize_image(image)

    outputs = model.predict_on_batch(np.expand_dims(image, axis=0))
    boxes  = outputs[-4][0]
    scores = outputs[-3][0]
    labels = outputs[-2][0]
    masks  = outputs[-1][0]
    boxes /= scale

    selected_indices = tf.image.non_max_suppression(tf.constant(boxes.reshape(-1,4)), tf.constant(scores.flatten()), max_output_size=tf.constant(5), iou_threshold=0.01)
    selected_boxes = tf.gather(tf.constant(boxes.reshape(-1,4)), selected_indices)
    session = tf.Session()
    with session.as_default():
        boxes = selected_boxes.eval() 

    for box, score, label, mask in zip(boxes, scores, labels, masks):
        if score > 0.1 and box[0] > 0 and box[1]>0 and box[2] > 0 and box[3]>0: 
            b = box.astype(int)
            dict_prediction[str(b)]=labels_to_names[label]
            
            temp_mask = mask[:, :, label] ##warning: the mask return is within the bounding box and has a fix size of 28x28
            mask_prediction.append(temp_mask)
            
    return np.array(mask_prediction),dict_prediction