# Notebook to predict diseases in chest X-ray

In this notebook we show we can predict (inference) diseases from chest X-ray images. Refer to [ChestXray-Training.ipynb](./ChestXray-Training.ipynb) notebook for information on training the model.

## Import required modules

In [7]:
%pylab inline

import warnings
# Ignoring the warnings to improve readability of the notebook
warnings.filterwarnings("ignore", message="numpy.dtype size changed")

from bigdl.nn.layer import Model
from bigdl.nn.criterion import *
from bigdl.optim.optimizer import *
from pyspark.sql import SparkSession
from pyspark.sql import SQLContext
from pyspark.sql.functions import col, udf
from pyspark.sql.types import *
from pyspark.sql.types import DoubleType
from pyspark.sql.types import StringType, ArrayType

from zoo.common.nncontext import *
from zoo.feature.image import *
from zoo.models.image.imageclassification import *
from zoo.pipeline.nnframes import *
from zoo.pipeline.api.net import Net
from zoo.pipeline.api.keras.models import Sequential
from zoo.pipeline.api.keras.layers import *
from zoo.pipeline.api.keras.metrics import AUC
from zoo.pipeline.nnframes import NNEstimator
from zoo.pipeline.api.keras.objectives import BinaryCrossEntropy

import time
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

Populating the interactive namespace from numpy and matplotlib


## Get Spark Content and load the trained model

In [8]:
# Function to get the Spark content and load the image
def load_spark_model(model_path, bin_path):
    # Get Spark Content
    sparkConf = create_spark_conf().setAppName("ChestXray_Inference")
    sc = init_nncontext(sparkConf)
    spark = SparkSession.builder.config(conf=sparkConf).getOrCreate()
    sqlContext = SQLContext(sc)

    # Load the model
    trained_model = Net.load(model_path, bin_path)
    return trained_model

In [9]:
# Function to display the Xray in 
def display_xray(image_path):
    img=mpimg.imread(image_path)
    imgplot = plt.imshow(img)
    plt.show()

In [10]:
def text_to_label(text):
    label_texts = ["Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass", "Nodule", "Pneumonia",
               "Pneumothorax", "Consolidation", "Edema", "Emphysema", "Fibrosis", "Pleural_Thickening", "Hernia"]
    label_map = {k: v for v, k in enumerate(label_texts)}

    arr = [0.0] * len(label_texts)
    for l in text.split("|"):
        if l != "No Finding":
            arr[label_map[l]] = 1.0
    return arr

def predict(model, label_path, image_path):
    
    label_texts = ["Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass", "Nodule", "Pneumonia",
               "Pneumothorax", "Consolidation", "Edema", "Emphysema", "Fibrosis", "Pleural_Thickening", "Hernia"]

    label_map = {k: v for v, k in enumerate(label_texts)}

    label_length = len(label_texts)
    
    # load the image
    getLabel = udf(lambda x: text_to_label(x), ArrayType(DoubleType()))
    getName = udf(lambda row: os.path.basename(row[0]), StringType())
    test_imageDF = NNImageReader.readImages(image_path, sc, resizeH=256, resizeW=256, image_codec=1)\
                    .withColumn("Image Index", getName(col('image')))
    imageDF = test_imageDF.withColumnRenamed('Image Index', 'Image_Index')
    labelDF = sqlContext.read.option('timestampFormat', 'yyyy/MM/dd HH:mm:ss ZZ')\
                .load(label_path, format="csv", sep=",", inferSchema="true", header="true")\
                .select("Image_Index", "Finding_Labels")\
                .withColumn("label", getLabel(col('Finding_Labels')))\
                .withColumnRenamed('Image Index', 'Image_Index')
    labelDF1 = labelDF.withColumnRenamed('Image Index', 'Image_Index')\
                .withColumnRenamed('Finding Labels', 'Finding_Labels')
    inferDF = imageDF.join(labelDF1, on="Image_Index", how="inner")    
    
    # Predict output of when inputdf is passed through model
    transformer = ChainedPreprocessing([
        RowToImageFeature(),
        ImageCenterCrop(224, 224),
        ImageChannelNormalize(123.68, 116.779, 103.939),
        ImageMatToTensor(),
        ImageFeatureToTensor()])
    classifier_model = NNModel(model, transformer).setFeaturesCol("image")\
                        .setBatchSize(1)
    output = classifier_model.transform(inferDF)
    return output

def print_prediction_output(predDF):
    print("\n\n")  
    label_texts = ["Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass", "Nodule", "Pneumonia",
               "Pneumothorax", "Consolidation", "Edema", "Emphysema", "Fibrosis", "Pleural_Thickening", "Hernia"]

#    print(predDF.show(1))
#    print("length of prediction array : ", len(predDF.collect()[0].prediction))
    predictions_list = predDF.collect()[0].prediction
    labelList = predDF.collect()[0].label
    print("{:<25} - {:<25} - {:<15}".format('Finding_Labels', 'Prediction', 'Label'))
    print("{:<25} - {:<25} - {:<15}".format('-'*len('Finding_Labels'), '-'*len('Prediction'), '-'*len('Label')))
    for indx in range(0, len(predictions_list)):
        print("{:<25} - {:<25} - {:<15}".format(label_texts[indx], predictions_list[indx], labelList[indx]))
    print("\n\n")

In [11]:
%%time
if __name__== "__main__":
    
    # Path to the NIH label file, the model and the labels from NIH
    model_path = "file:///home/bala/xray-inference/BigDL-ImageProcessing-Examples/trained-model/xray_model_2019_04_04_05_31_10.bigdl"
    bin_path = "file:///home/bala/xray-inference/BigDL-ImageProcessing-Examples/trained-model/xray_model_2019_04_04_05_31_10.bin"
    label_path = "file:///home/bala/xray-inference/BigDL-ImageProcessing-Examples/trained-model/Data_Entry_2017.csv"
   
    trained_model = load_spark_model(model_path,bin_path)

CPU times: user 16.2 ms, sys: 2.22 ms, total: 18.4 ms
Wall time: 4.68 s


In [12]:
%%time
    # Path to the image for inference
    image_path = "file:///home/bala/xray-inference/BigDL-ImageProcessing-Examples/images/00000001_002.png"
    predictionDF = predict(trained_model, label_path, image_path)
    print_prediction_output(predictionDF)

creating: createRowToImageFeature
creating: createImageCenterCrop
creating: createImageChannelNormalize
creating: createImageMatToTensor
creating: createImageFeatureToTensor
creating: createChainedPreprocessing
creating: createTensorToSample
creating: createChainedPreprocessing
creating: createNNModel



Finding_Labels            - Prediction                - Label          
--------------            - ----------                - -----          
Atelectasis               - 0.0513895861804           - 0.0            
Cardiomegaly              - 0.999990344048            - 1.0            
Effusion                  - 0.772566437721            - 1.0            
Infiltration              - 0.104134827852            - 0.0            
Mass                      - 0.00671234633774          - 0.0            
Nodule                    - 0.00529360491782          - 0.0            
Pneumonia                 - 0.00492261582986          - 0.0            
Pneumothorax              - 0.00217330595478  