# Inference example for trained 2D U-Net model on BraTS.
Takes a trained model and performs inference on a few validation examples.

In [None]:
import sys
import platform
import os

print("Python version: {}".format(sys.version))
print("{}".format(platform.platform()))

In [None]:
def test_intel_tensorflow():
    """
    Check if Intel version of TensorFlow is installed
    """
    import tensorflow as tf
    
    print("We are using Tensorflow version {}".format(tf.__version__))
           
    major_version = int(tf.__version__.split(".")[0])
    if major_version >= 2:
       from tensorflow.python import _pywrap_util_port
       print("Intel-optimizations (DNNL) enabled:", _pywrap_util_port.IsMklEnabled())
    else:
       print("Intel-optimizations (DNNL) enabled:", tf.pywrap_tensorflow.IsMklEnabled()) 

test_intel_tensorflow()

In [None]:
saved_model_dir = "./output/2d_unet_decathlon"

In [None]:
# Create output directory for images
png_directory = "inference_examples"
if not os.path.exists(png_directory):
    os.makedirs(png_directory)
    
model_filename = os.path.join(saved_model_dir)

#### Define the DICE coefficient and loss function

The Sørensen–Dice coefficient is a statistic used for comparing the similarity of two samples. Given two sets, X and Y, it is defined as

\begin{equation}
dice = \frac{2|X\cap Y|}{|X|+|Y|}
\end{equation}

In [None]:
import numpy as np

def calc_dice(target, prediction, smooth=0.0001):
    """
    Sorensen Dice coefficient
    """
    prediction = np.round(prediction)

    numerator = 2.0 * np.sum(target * prediction) + smooth
    denominator = np.sum(target) + np.sum(prediction) + smooth
    coef = numerator / denominator

    return coef

def calc_soft_dice(target, prediction, smooth=0.0001):
    """
    Sorensen (Soft) Dice coefficient - Don't round predictions
    """
    numerator = 2.0 * np.sum(target * prediction) + smooth
    denominator = np.sum(target) + np.sum(prediction) + smooth
    coef = numerator / denominator

    return coef

## Inference Time!

Inferencing in this example can be done in 3 simple steps:
1. Load the data
1. Load the Keras model 
1. Perform a `model.predict` on an input image (or set of images)

#### Step 1 : Load data

In [None]:
data_path = "../data/decathlon/Task01_BrainTumour/2D_model/"
saved_model_dir = "./output/2d_unet_decathlon"

crop_dim=128  # -1 = Original resolution (240)
batch_size = 128
seed=816  # Change this to see different examples in the test dataset

In [None]:
from dataloader import DatasetGenerator

ds_testing = DatasetGenerator(os.path.join(data_path, "testing/*.npz"), 
                              crop_dim=crop_dim, 
                              batch_size=batch_size, 
                              augment=False, 
                              seed=seed)

#### Step 2 : Load the model

In [None]:
from model import unet

from tensorflow import keras as K
model = K.models.load_model(saved_model_dir, compile=False, custom_objects=unet().custom_objects)

#### Step 3: Perform prediction on some images. 
The prediction results will be saved in the output directory for images, which is defined by the `png_directory` variable.

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

import time

def plot_results(ds, idx):
    
    dt = ds.get_dataset().take(1).as_numpy_iterator()  # Get some examples (use different seed for different samples)

    plt.figure(figsize=(10,10))

    for img, msk in dt:

        plt.subplot(1, 3, 1)
        plt.imshow(img[idx, :, :, 0], cmap="bone", origin="lower")
        plt.title("MRI {}".format(idx), fontsize=20)

        plt.subplot(1, 3, 2)
        plt.imshow(msk[idx, :, :], cmap="bone", origin="lower")
        plt.title("Ground truth", fontsize=20)

        plt.subplot(1, 3, 3)

        # Predict using the TensorFlow model
        start_time = time.time()
        prediction = model.predict(img[[idx]])
        print("Elapsed time = {:.4f} msecs".format(1000.0*(time.time()-start_time)))
        
        plt.imshow(prediction[0,:,:,0], cmap="bone", origin="lower")
        plt.title("Prediction\nDice = {:.4f}".format(calc_dice(msk[idx, :, :], prediction)), fontsize=20)
        
        plt.savefig(os.path.join(png_directory, "prediction_tf_{}.png".format(idx)))

In [None]:
plot_results(ds_testing, 11)

In [None]:
plot_results(ds_testing, 17)

In [None]:
plot_results(ds_testing, 25)

In [None]:
plot_results(ds_testing, 56)

In [None]:
plot_results(ds_testing, 89)

In [None]:
plot_results(ds_testing, 101)

In [None]:
plot_results(ds_testing, 119)

# Can we perform inference even faster? Hmm..

Let's find out. Move on the the next tutorial section.

*Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. SPDX-License-Identifier: EPL-2.0*

*Copyright (c) 2019-2020 Intel Corporation*