## Inference on DCM images

In this Notebook the best model (ensemble of five models) is used for prediction on DCM images.
Code is provided to check several metadata: DCM Modality, Patient's Position View and Body Part.

In [1]:
!pip install efficientnet >> /dev/null

In [1]:
# restart the kernel after the above command

import pydicom
import numpy as np
import matplotlib.pyplot as plt
import glob
import time

import tensorflow as tf
import efficientnet.tfkeras as efn

In [2]:
import logging
logger = tf.get_logger()
logger.setLevel(logging.ERROR)

In [3]:
# global constants

SIZE = 512
IMAGE_SIZE = [SIZE, SIZE]
FOLDS = 5

# EfficientNet B$
EFF_NET = 4

THRESHOLD = 0.54

DIR_MODELS = './tpu-models/'
DIR_DCM = './dcm/'

In [4]:
# check image and metadata

#
# check age within boundaries (10, 80)
# if NOT OK a warning is issued
#
def is_age_ok(dcm):
    # boundaries for age
    LOW = 10
    HIGH = 80
    
    age = int(dcm.PatientAge)
    
    if (age <= HIGH) and (age >= LOW):
        return True
    else:
        return False

#
# check image
# if NOT OK image is rejected
#
def check_image(dcm):
    # 1. check resolution: must be >= 512
    # This check is needed for the resize to 512x512.
    # 2. check Image modality: must be DX
    # 3. check Body Part: must be CHEST
    # 4. check Patient Position: must be in ['AP', 'PA']
    
    # return: True if image is OK
    isOK = True
    
    # resolution must be higher than 512x512
    # otherwise resize is not OK
    SIZE = 512
    
    if (dcm.Rows < SIZE) or (dcm.Columns < SIZE):
        print('Check on resolution NOT passed.')
        print('Image resolution is:', (dcm.rows, dcm.colums))
        print('Expected greater than 512x512')
        isOK = False
    
    # check Modality (Image Type)
    if dcm.Modality != 'DX':
        print('Check on Modality NOT passed.')
        print('Image modality is:', dcm.Modality, 'expected: DX.')
        isOK = False
    
    # check BODY part: must be CHEST
    if dcm.BodyPartExamined != 'CHEST':
        print('Check on Body Part NOT passed.')
        print('Body part is:', dcm.BodyPartExamined, 'expected: CHEST.')
        isOK = False
    
    # check Patient Position
    pos_ok = ['AP','PA']
    
    if dcm.PatientPosition not in pos_ok:
        print('Check on Patient position NOT passed.')
        print('Patient position is:', dcm.PatientPosition, 'expected: AP or PA.')
        isOK = False
    
    if isOK == False:
        print('Image rejected.')
        print('')
    
    return isOK

In [5]:
# here we define the DNN Model

EFNS = [efn.EfficientNetB0, efn.EfficientNetB1, efn.EfficientNetB2, efn.EfficientNetB3, 
        efn.EfficientNetB4, efn.EfficientNetB5, efn.EfficientNetB6, efn.EfficientNetB7]

# as default it used B0

def build_model(dim = SIZE, ef = 0):
    inp = tf.keras.layers.Input(shape=(*IMAGE_SIZE, 3))
    
    base = EFNS[ef](input_shape=(*IMAGE_SIZE, 3), weights='imagenet', include_top = False)
    
    x = base(inp)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dense(512, activation='relu')(x)
    x = tf.keras.layers.Dense(1, activation='sigmoid')(x)
    
    model = tf.keras.Model(inputs = inp,outputs = x)
    
    opt = tf.keras.optimizers.Adam(learning_rate = 0.001)
    
    fn_loss = tf.keras.losses.BinaryCrossentropy() 
    
    # loss = [focal_loss]
    model.compile(optimizer = opt, loss = [fn_loss], metrics=['AUC', 'accuracy'])
    
    return model

In [6]:
# load all the models from file
models = []

print('Loading models...')

# load all the 5 models
for fold in range(1, FOLDS + 1):
    print('Loading model n.', fold)
    model = build_model(dim=SIZE, ef=EFF_NET)
    model.load_weights(DIR_MODELS + 'fold-%i.h5'%fold)
    
    models.append(model)

print('Loading completed.')

Loading models...
Loading model n. 1
Loading model n. 2
Loading model n. 3
Loading model n. 4
Loading model n. 5
Loading completed.


In [7]:
# have a look at the architecture and trainable parms #
models[0].summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 512, 512, 3)]     0         
_________________________________________________________________
efficientnet-b4 (Model)      (None, 16, 16, 1792)      17673816  
_________________________________________________________________
global_average_pooling2d (Gl (None, 1792)              0         
_________________________________________________________________
dense (Dense)                (None, 512)               918016    
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 513       
Total params: 18,592,345
Trainable params: 18,467,145
Non-trainable params: 125,200
_________________________________________________________________


In [None]:
# now in the vector models we have the five models, one for each FOLDS
# we do the prediction for each model and then we do the average

In [8]:
def predict_image(img, threshold):
    avg_prob = 0.

    for i in range(0, FOLDS):
        pred = models[i].predict(img)[0, 0]
    
        avg_prob += pred
    
    avg_prob = round(avg_prob/FOLDS, 3)

    if avg_prob >= threshold:
        pred_class = 1
    else:
        pred_class = 0
        
    return avg_prob, pred_class

In [9]:
#
# take as input a DCM, produces a resized tensor ready for prediction
#
def preprocess_image(dcm, size):
    img_array = dcm.pixel_array 
    
    img_t = tf.convert_to_tensor(img_array, dtype=tf.uint8)
    img_t = tf.expand_dims(img_t, -1)
    # to RGB
    img_t = tf.image.grayscale_to_rgb(img_t)
    # RESIZE
    img_t = tf.image.resize(img_t, [size, size])
    # NORMALIZE, as expected from the NN
    img_t = tf.cast(img_t, tf.float32)/255.
    # add the batch dimension
    img_t = tf.expand_dims(img_t, axis = 0)
    
    return img_t

In [11]:
dcm_list = sorted(glob.glob(DIR_DCM + '*.dcm'))

n_images = len(dcm_list)

tStart = time.time()

print('******************************************')
print('Report on processing a batch of ', n_images, 'images.')
print('')

for dcm_name in dcm_list:
    dcm = pydicom.dcmread(dcm_name)
    
    print('**************************')
    print('File name:', dcm_name)
    print('Patient ID:', dcm.PatientID)
    print('Study description:', dcm.StudyDescription)
    print('')
    
    if check_image(dcm):
        # check passed
        
        if not is_age_ok(dcm):
            # issue a warning
            print('Warning: age is not in the defined boundaries for the software: ', int(dcm.PatientAge))
            print('Prediction may not be accurate.')
            print('')
        
        # process image
        img = preprocess_image(dcm, SIZE)
    
        prob, pred_class = predict_image(img, THRESHOLD)
    
        print('Model prediction:')
        
        if pred_class == 1:
            print('The diagnose is Pneumonia. Probability is:', round(prob, 2))
        else:
            print('The diagnose is NOT Pneumonia. Probability is:', round(1. - prob, 2))
        
        print('')

tEla = time.time() - tStart
tSingle = tEla/n_images

print('Time(sec.) to process a single image:', round(tSingle, 3))

******************************************
Report on processing a batch of  6 images.

**************************
File name: ./dcm/test1.dcm
Patient ID: 2
Study description: No Finding

Prediction may not be accurate.

Model prediction:
The diagnose is NOT Pneumonia. Probability is: 0.55

**************************
File name: ./dcm/test2.dcm
Patient ID: 1
Study description: Cardiomegaly

Model prediction:
The diagnose is NOT Pneumonia. Probability is: 0.61

**************************
File name: ./dcm/test3.dcm
Patient ID: 61
Study description: Effusion

Model prediction:
The diagnose is Pneumonia. Probability is: 0.77

**************************
File name: ./dcm/test4.dcm
Patient ID: 2
Study description: No Finding

Check on Body Part NOT passed.
Body part is: RIBCAGE expected: CHEST.
Image rejected.

**************************
File name: ./dcm/test5.dcm
Patient ID: 2
Study description: No Finding

Check on Modality NOT passed.
Image modality is: CT expected: DX.
Image rejected.

*****