In [1]:
# Build a batch script around 'inference.py'
# Function of importance
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import os, argparse
import cv2
from data import process_image_file

class Inference(object):
    def __init__(self, weightspath, metaname, ckptname, input_size=480, n_classes=3, top_percent=0.08,
                 in_tensorname='input_1:0', out_tensorname='norm_dense_2/Softmax:0',is_severity_model=False):
        """
        weightspath: the path to the directory holding the pretrained model
        metaname: Name of ckpt meta file
        ckptname: Name of model ckpts
        imagepath: path to the image to run through the model
        """
        self.weightspath = weightspath
        self.metaname=metaname
        self.ckptname = ckptname
        self.input_size = input_size
        self.n_classes = n_classes
        self.top_percent = top_percent
        self.in_tensorname = in_tensorname
        self.out_tensorname = out_tensorname
        self.is_severity_model = is_severity_model
    
    def visualise_graph(self):
        tf.compat.v1.disable_eager_execution()
        with tf.compat.v1.Session() as sess:
            saver = tf.compat.v1.train.import_meta_graph(os.path.join(self.weightspath, self.metaname))
            saver.restore(sess, os.path.join(self.weightspath, self.ckptname))
            graph = tf.compat.v1.get_default_graph()
            tensors = [t.name for op in graph.get_operations() for t in op.values()]
            for t in tensors:
                print(t)
    
    def execute(self, imagepath_or_directory, verbose=False):
        # To remove TF Warnings
        tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
        tf.compat.v1.disable_eager_execution()
        
        if self.is_severity_model:
            # For COVIDNet CXR-S training with COVIDxSev level 1 and level 2 air space seveirty grading
            mapping = {'level2': 0, 'level1': 1}
            inv_mapping = {0: 'level2', 1: 'level1'}
        elif self.n_classes == 2:
            # For COVID-19 positive/negative detection
            mapping = {'negative': 0, 'positive': 1}
            inv_mapping = {0: 'negative', 1: 'positive'}
        elif self.n_classes == 3:
            # For detection of no pneumonia/non-COVID-19 pneumonia/COVID-19 pneumonia
            mapping = {'normal': 0, 'pneumonia': 1, 'COVID-19': 2}
            inv_mapping = {0: 'normal', 1: 'pneumonia', 2: 'COVID-19'}
        else:
            raise Exception('''COVID-Net currently only supports 2 class COVID-19 positive/negative detection
                or 3 class detection of no pneumonia/non-COVID-19 pneumonia/COVID-19 pneumonia''')
        
        # Tensorflow
        mapping_keys = list(mapping.keys())
        with tf.compat.v1.Session() as sess:
            saver = tf.compat.v1.train.import_meta_graph(os.path.join(self.weightspath, self.metaname))
            saver.restore(sess, os.path.join(self.weightspath, self.ckptname))
            graph = tf.compat.v1.get_default_graph()

            image_tensor = graph.get_tensor_by_name(self.in_tensorname)
            pred_tensor = graph.get_tensor_by_name(self.out_tensorname)
            
            # Images to be processed are input here
            if os.path.isdir(imagepath_or_directory):
                imagepaths_list = self.imagepaths(imagepath_or_directory)
            else:
                imagepaths_list = [imagepath_or_directory]
            
            # Predictionss
            preds_list = []
            for imagepath in imagepaths_list:
                x = process_image_file(imagepath, self.top_percent, self.input_size)
                x = x.astype('float32') / 255.0
                pred = sess.run(pred_tensor, feed_dict={image_tensor: np.expand_dims(x, axis=0)})
                preds_list.append(pred[0])
                if verbose:
                    print('Prediction: {}'.format(inv_mapping[pred.argmax(axis=1)[0]]))
                    print('Confidence')
                    print(' '.join('{}: {:.3f}'.format(cls.capitalize(), pred[0][i]) for cls, i in mapping.items()))
            preds_list = np.vstack(preds_list)
        return preds_list
    
    ## Set up imagepaths
    def imagepaths(self, directory):
        path_list = []
        for root, dirs, files in os.walk(directory):
            for name in files:
                if '.png' in name:
                    path_list.append(os.path.join(root, name))
        return path_list
        
    
    def reset_graph(self):
        tf.compat.v1.reset_default_graph()
    def visualise_image(self, imagepath):
        x = process_image_file(imagepath, self.top_percent, self.input_size)
        print("Image size: " + str(x.shape))
        plt.imshow(x)
        return

print('**DISCLAIMER**')
print('Do not use this prediction for self-diagnosis. You should check with your local authorities for the latest advice on seeking medical assistance.')

**DISCLAIMER**
Do not use this prediction for self-diagnosis. You should check with your local authorities for the latest advice on seeking medical assistance.


In [2]:
inf = Inference(weightspath="CXR4-A/", metaname="model.meta", ckptname="model-18540",
                in_tensorname='input_1:0', out_tensorname='norm_dense_1/Softmax:0')
# External PolyU
non_suppressed_directory_HK = "D:/data/POLYU_COVID19_CXR_CT_Cohort1/cxr/CXR_PNG"
suppressed_Rajaraman_HK = "../Rajaraman_ResNet_BS/bone_suppressed/external_POLYU/"
suppressed_Gusarev_HK = "../Deep-Learning-Models-for-bone-suppression-in-chest-radiographs/bone_suppressed/external_POLYU/"

# JSRT External
non_suppressed_directory_JSRT = "D:/data/JSRT/JSRT/"
suppressed_Rajaraman_JSRT =  "../Rajaraman_ResNet_BS/bone_suppressed/internal_original/"
suppressed_Gusarev_JSRT =  "../Deep-Learning-Models-for-bone-suppression-in-chest-radiographs/bone_suppressed/internal_original/"
# JSRT Non-nodule
non_suppressed_directory_JSRT_NN = "D:/data/JSRT/JSRT_NN/"
suppressed_Rajaraman_JSRT_NN =  "../Rajaraman_ResNet_BS/bone_suppressed/internal_NN/"

preds = inf.execute(non_suppressed_directory_JSRT_NN, verbose=False)

inf.reset_graph()


In [3]:
idx_normal = np.logical_and( preds[:,0] > preds[:,1] , preds[:,0] > preds[:,2])
idx_pneumonia = np.logical_and( preds[:,1] > preds[:,0] , preds[:,1] > preds[:,2])
idx_COVID = np.logical_and( preds[:,2] > preds[:,1] , preds[:,2] > preds[:,0])

print("Proportion classed as normal: " +str(np.sum(idx_normal)/len(idx_normal)))
print("Proportion classed as pneumonia: " +str(np.sum(idx_pneumonia)/len(idx_pneumonia)))
print("Proportion classed as COVID: " +str(np.sum(idx_COVID)/len(idx_COVID)))

Proportion classed as normal: 0.24731182795698925
Proportion classed as pneumonia: 0.053763440860215055
Proportion classed as COVID: 0.6989247311827957
