# Inference example for trained 2D U-Net model on BraTS.
Takes a trained model and performs inference on a few validation examples.

In [None]:
import sys
import platform
print("Python version:\n", sys.version)
print ("Path to the python executable:\n", sys.executable)

In [None]:
%matplotlib inline
import os
import psutil
import numpy as np
import keras as K
import h5py
import time
import tensorflow as tf 
import sys; sys.argv=['']; del sys

from argparser import args
from model import unet

if args.keras_api:
    import keras as K
else:
    from tensorflow import keras as K

import matplotlib.pyplot as plt

onnx=False
plot_result = True
#TODO - Enable nGraph Bridge - Switch to (decathlon) venv!
# import ngraph_bridge

if onnx:
    #TODO - Include ngraph onnx backend
    import onnx
    from ngraph_onnx.onnx_importer.importer import import_onnx_model
    import ngraph as ng

print ("We are using Tensorflow version", tf.__version__,\
       "with Intel(R) MKL", "enabled" if tf.pywrap_tensorflow.IsMklEnabled() else "disabled",)

In [None]:
# Create output directory for images
png_directory = "inference_examples"
if not os.path.exists(png_directory):
    os.makedirs(png_directory)
    
data_filename = os.path.join(args.data_path, args.data_filename)
model_filename = os.path.join(args.output_path, args.inference_filename)

#### Define the DICE coefficient and loss function

The Sørensen–Dice coefficient is a statistic used for comparing the similarity of two samples. Given two sets, X and Y, it is defined as

\begin{equation}
dice = \frac{2|X\cap Y|}{|X|+|Y|}
\end{equation}

In [None]:
def calc_dice(target, prediction, smooth=0.01):
    """
    Sorensen Dice coefficient
    """
    prediction = np.round(prediction)

    numerator = 2.0 * np.sum(target * prediction) + smooth
    denominator = np.sum(target) + np.sum(prediction) + smooth
    coef = numerator / denominator

    return coef

def calc_soft_dice(target, prediction, smooth=0.01):
    """
    Sorensen (Soft) Dice coefficient - Don't round preictions
    """
    numerator = 2.0 * np.sum(target * prediction) + smooth
    denominator = np.sum(target) + np.sum(prediction) + smooth
    coef = numerator / denominator

    return coef

## Inference Time!

Inferencing in this example can be done in 3 simple steps:
1. Load the data
1. Load the Keras model 
1. Perform a `model.predict` on an input image (or set of images)

#### Step 1 : Load data

In [None]:
df = h5py.File(data_filename, "r")
imgs_testing = df["imgs_testing"]
msks_testing = df["msks_testing"]
files_testing = df["testing_input_files"]

#### Step 2 : Load the model

In [None]:
unet_model = unet()
model = unet_model. ... # TODO: Check load_model definition, i.e explore model.py

#### Step 3 (final step): Perform prediction with `model.predict`

We first define a convenience function that performs a prediction on one image. This function takes an image ID and a flag `plot_result` which we can use to toggle displaying the results graphically. 

After we have performed a prediction, we calculate the dice score to analyze how good out prediction was compared to the ground truth. 

In [None]:
def plot_results(model, imgs_validation, msks_validation, img_no, png_directory):

    img = imgs_validation[[img_no], ]
    msk = msks_validation[[img_no], ]
    
    #TODO load onnx model in ngraph
    if onnx:
        onnx_protobuf = onnx.load('./output/unet_model_for_decathlon_100_iter.onnx')
        ng_models = import_onnx_model(onnx_protobuf)
        ng_model = ng_models[0]
        runtime = ng.runtime(backend_name='CPU')
        unet = runtime.computation(ng_model['output'], *ng_model['inputs'])
        
        start_time = time.time()
        pred_mask= unet(img)[0]
        print ("Time for prediction ngraph: ", '%.0f'%((time.time()-start_time)*1000),"ms")

    else:
        start_time = time.time()
        pred_mask = ... # TODO: How do I use Keras to perform predicitons/inferencing? Params: (img, verbose=0, steps=None)
        print ("Time for prediction TF: ", '%.0f'%((time.time()-start_time)*1000),"ms")
        
    if args.plot:
        plt.figure(figsize=(10, 10))
        plt.subplot(1, 3, 1)
        plt.imshow(img[0, :, :, 0], cmap="bone", origin="lower")
        plt.title("MRI")
        plt.axis("off")
        plt.subplot(1, 3, 2)
        plt.imshow(msk[0, :, :, 0], origin="lower")
        plt.title("Ground Truth")
        plt.axis("off")
        plt.subplot(1, 3, 3)
        plt.imshow(pred_mask[0, :, :, 0], origin="lower")
        plt.title("Prediction\n(Dice = {:.4f})".format(calc_dice(msk, pred_mask)))
        plt.axis("off")

        png_filename = os.path.join(png_directory, "pred_{}.png".format(img_no))
        plt.savefig(png_filename, bbox_inches="tight", pad_inches=0)
        print("Dice {:.4f}, Soft Dice {:.4f}, Saved png file to: {}".format(
            calc_dice(msk, pred_mask), calc_soft_dice(msk, pred_mask), png_filename))

#### Step 3 (continued) : Perform prediction on some images. 
Use `plot_result=True/False` to toggle plotting the results. If `plot_result=True`, the prediction results will be saved in the output directory for images, which is defined by the `png_directory` variable.

In [None]:
indicies_validation = [40, 63, 43, 55, 99, 101, 19, 46, ] #[40]

for idx in indicies_validation:
    plot_results(model, imgs_testing, msks_testing,
                 idx, args.output_pngs)

# Can we perform inference even faster? Hmm..

Let's find out. Move on the the next tutorial section.

`Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. SPDX-License-Identifier: EPL-2.0`

`Copyright (c) 2019 Intel Corporation`