In [None]:
#default_exp batch

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [None]:
from fastai.vision import *
import tensorflow.compat.v1 as tf
import pydicom
from tempfile import mkstemp
import warnings
warnings.filterwarnings('ignore')

In [None]:
#exporti
import os
import pandas as pd
from misas.core import *
from misas.mri import *
import altair as alt
from tqdm.notebook import tqdm
from tempfile import mkdtemp
from functools import partial

## Aim: Getting a first impression on the models performance on the given data before going into detailed evaluation

In this case study we demonstrate how `misas` can give you an overview of the performance of the model by creating plots that show the average dice score of the whole batch of images over different parameters:
 - **Model**: `ukbb_cardiac` [network](https://github.com/baiwenjia/ukbb_cardiac) by [Bai et al. 2018 [1]](https://doi.org/10.1186/s12968-018-0471-x), trained on [UK Biobank](https://www.ukbiobank.ac.uk/) cardiac MRI images
 - **Data**: Kaggle [Data Science Bowl Cardiac Challenge Data](https://www.kaggle.com/c/second-annual-data-science-bowl) MRI images

# Prepare Model for Misas

The used model was trained on UK Biobank cardiac imaging data to segment short-axis images of the heart into left ventricle (LV), right ventricle (RV) and myocardium (MY). For details about the model please read [the paper (Bai et al. 2018)](https://doi.org/10.1186/s12968-018-0471-x) and cite it if you use it. For implementation, training and usage see the [GitHub repository](https://github.com/baiwenjia/ukbb_cardiac). We downloaded the pre-trained model for short-axis images from https://www.doc.ic.ac.uk/~wbai/data/ukbb_cardiac/trained_model/ (local copy in `example/kaggle/FCN_sa`). In order to use it with `misas` we need to wrap it in a class that implements the desired interface (`prepareSize` and `predict` taking `Image` as input, see the main docu for more details).

`ukbb_cardiac` is written in `tensorflow` v1. With `tensorflow` v2 make sure to import the compat module.

The model requires images to be a multiple of 16 in each dimension. We pad images accordingly in `prepareSize`. Additionally, code in `image_to_input` takes care of the specifics of transforming a three-channel image into a single-item batch of single-channel images. In `predict` the output is converted to `ImageSegment` class.

In [None]:
class ukbb_model:
    def __init__(self, model_path):
        tf.disable_eager_execution()
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())
        saver = tf.train.import_meta_graph(f'{model_path}.meta')
        saver.restore(self.sess, model_path)
        
    def prepareSize(self, image):
        _, X, Y = image.shape
        image.crop_pad((int(math.ceil(X / 16.0)) * 16, int(math.ceil(Y / 16.0)) * 16), padding_mode="zeros")
        return image
    
    def image_to_input(self, image):
        img = image.clone()
        self.prepareSize(img)
        img_data = img.data[0]
        img_data = np.expand_dims(img_data, 0)
        img_data = np.expand_dims(img_data, -1)
        return img_data
    
    def predict(self, image):
        image_data = self.image_to_input(image)
        preds, classes = self.sess.run(['prob:0', 'pred:0'],
                   feed_dict={'image:0': image_data, 'training:0': False})
        preds = np.squeeze(preds, 0)
        classes = ImageSegment(ByteTensor(classes))
        return classes, preds

In [None]:
model = ukbb_model('example/kaggle/FCN_sa')
import warnings
warnings.filterwarnings('ignore')

INFO:tensorflow:Restoring parameters from example/kaggle/FCN_sa


# Prepare Images

In [None]:
def windowed(tensor, width, center):
    '''Scale pixel intensity by window width and window center'''
    px = tensor.clone()
    px_min = center - width//2
    px_max = center + width//2
    px[px<px_min] = px_min
    px[px>px_max] = px_max
    return (px-px_min) / (px_max-px_min)

In [None]:
#exporti
def dicom_to_Image(file):
    '''Reading the Dicom file and bringing the image into correct format and orientation for downstream evaluations'''
    dcm = pydicom.dcmread(file)
    img = Tensor(dcm.pixel_array.astype(np.int16))
    img = windowed(img, dcm.WindowWidth, dcm.WindowCenter)
    img = Image(torch.stack([img, img, img]))    # to convert Tensor into Image object, 3 dimensions are needed
    img = img.flip_lr()
    img = img.rotate(90)
    return img

In [None]:
files = ["example/kaggle/sample_images/IM-13717-0026.dcm",
        "example/kaggle/sample_images/IM-7453-0024.dcm",
        "example/kaggle/sample_images/IM-4718-0021.dcm",
        #"example/kaggle/sample_images/IM-5022-0015.dcm",
        #"example/kaggle/sample_images/IM-14141-0011.dcm",
        #"example/kaggle/sample_images/IM-13811-0003.dcm",
        #"example/kaggle/sample_images/IM-5382-0008.dcm"
        ]

# Evaluation

If true_masks is "None", the model will predict a mask for every image, save it as png in a newly created directory and use this mask as truth for the evaluation and afterwards delete the directory and the contained predicted thruths. Thus, it is important to hand in the images in an orientation in which the model makes good predictions, in this case this is handled, by the "prepareImage" function, which reads the dicom file, converts it to a fastai Image object, and rotates and flips the image into correct position

In [None]:
#export
def batch_results(images, model, eval_functions, true_masks=None, components=['bg','LV','MY','RV']):
    ''' Evaluation of the models performance across multiple images and transformations
    
    -images (list): paths for dicom files with which the model should be evaluated
    -model: model to be evaluated
    -eval_functions (list): names of the eval functions from misas.core that should be evaluated
    -true_masks (list, optional): paths of png files with true masks for dicoms in 'images' in the same order as 'images'
    -components (list, optional): classes that will be evaluated by the eval functions
    
    Returns: list of Pandas dataFrames, that contains one dataFrame for each image with the columns: 'parameter', 'bg', 'LV', 'MY', 'RV', 'File'
    '''
    results = []
    for x in tqdm(eval_functions, leave=False):
        trfm_result = []
        for index, i in enumerate(images):
            img = lambda: dicom_to_Image(i)
            tmp = tempfile.mkdtemp()
            if true_masks == None:
                truth_path = os.path.join(tmp, "current_truth.png")
                truth = model.predict(img())[0]
                truth.save(truth_path)
                true_mask = lambda: open_mask(truth_path)
            elif true_masks != None:
                true_mask = lambda: open_mask(true_masks[index])
            df = x(img(), true_mask(), model, components=components)
            df["File"] = i
            trfm_result.append(df)
            shutil.rmtree(tmp)
        trfm_result = pd.concat(trfm_result)
        results.append(trfm_result)
    return results

In this example we are passing a list of true masks to the function that will be used for evaluation. Since this is only for demonstration purposes, we will manually generate the list of true masks by predicting the truth for every image and saving it to a temporary directory and storing the path names in a list, from where `batch_results` can access it.

In [None]:
truths = []
for i in files:
    img = dicom_to_Image(i)
    truth = model.predict(img)[0]
    tmpfile = mkstemp()
    truth.save(tmpfile[1] + ".png")
    truths.append(tmpfile[1] + ".png")

In [None]:
results_with_truths = batch_results(files, model, [
                                    eval_rotation_series,
                                    #eval_bright_series,
                                    #eval_crop_series,
                                    #eval_contrast_series,
                                    eval_dihedral_series,
                                    #eval_resize_series,
                                    #eval_spike_series,
                                    #eval_spike_pos_series,
                                    #eval_zoom_series
                                    ], components=['bg','LV','MY','RV'], true_masks=truths)

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

To demonstrate that the function can also evaluate the models performance without passing a list of true masks on to the function, in this example we do not pass a list of PNGs with true masks to the function, so the model predicts one for every image, saves it as png and uses it for evaluation. In this case it is important that the evaluated images are passed to the function in an orientation in which the model can make a good prediction, since this prediction is used as the true mask for all the following transformations. In this case this is handled by the `dicom_to_Image` function, which converts the image into the desiered orientation.

In [None]:
results_without_truths = batch_results(files, model, [
                                       #eval_rotation_series,
                                       #eval_bright_series,
                                       #eval_crop_series,
                                       #eval_contrast_series,
                                       eval_dihedral_series,
                                       #eval_resize_series,
                                       #eval_spike_series,
                                       #eval_spike_pos_series,
                                        #eval_zoom_series
                                      ], components=['bg','LV','MY','RV'])

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

# Plotting the results

In [None]:
#export
def plot_batch(df_results, plot_function=plot_avg_and_dots):
    '''
    Creates and displays the plots with the data as returned by the batch_results functions. 
    
    Positional arguments:
    -df_results (list): dataframes which contains one dataFrame for each transformation in a format as it is returned by `batch_results`
    Keyword arguments:
    -plot_function:
        - plot_avg_and_errorbars: plots the average of the dice score of all images across the parameters and shows the standarddeviation as errorbars
        - plot_avg_and_dots: plots the average of the dice score and additionally shows the single datapoints instead of errorbars
        - plot_boxplot
    
    Returns: List with altair.FacetChart objects
    '''
    plots = []
    for i in df_results:
        plot = plot_function(i)
        plots.append(plot)
    for p in plots:
        p.display()
    return plots

In [None]:
#export
def plot_avg_and_dots(df, draw_line=True):
    '''
    Plots the average dice score and shows the single data points
    
    Positional arguments:
    -df (pd.DataFrame object): columns: 'parameter', 'bg', 'LV', 'MY', 'RV', 'File'
    
    Returns: altair.FacetChart object
    '''
    melt_results = df.melt(id_vars=df.columns[0], value_vars=df.columns[2:5])
    dot_plot = alt.Chart(melt_results
                ).mark_point(
                ).encode(x=melt_results.columns[0], y="value", color=alt.Color("variable")
                ).properties(width=400, height=200
                ).interactive()
    if draw_line == True:
        avg_line_plot = alt.Chart(melt_results
                    ).mark_line(
                    ).encode(x=melt_results.columns[0], y="average(value)", color=alt.Color("variable")
                    ).properties(width=400, height=200
                    ).interactive()
        plot = alt.layer(dot_plot, avg_line_plot).facet(column="variable")
    else:
        plot = dot_plot.facet(column="variable")
    return plot

In [None]:
#export
def plot_avg_and_errorbars(df):
    '''
    Plots the average dice score and shows the stdev as errorbars.
    
    Positional arguments:
    -df (pd.DataFrame object): columns: 'parameter', 'bg', 'LV', 'MY', 'RV', 'File'
    
    Returns: altair.FacetChart object
    '''
    melt_results = df.melt(id_vars=df.columns[0], value_vars=df.columns[2:5])
    avg_line_plot = alt.Chart(melt_results
                ).mark_line(
                ).encode(x=melt_results.columns[0], y="average(value)", color=alt.Color("variable")
                ).properties(width=400, height=200
                ).interactive()
    error_bars = alt.Chart(melt_results
                ).mark_errorbar(extent='stdev'
                ).encode(x=melt_results.columns[0], y="value", color=alt.Color("variable"))
    plot = alt.layer(avg_line_plot, error_bars).facet(column="variable")
    return plot

In [None]:
#export
def plot_boxplot(df):
    '''
    Plots the average dice score as boxplots
    
    Positional arguments:
    -df (pd.DataFrame object): columns: 'parameter', 'bg', 'LV', 'MY', 'RV', 'File'
    
    Returns: altair.FacetChart object
    '''
    melt_results = df.melt(id_vars=df.columns[0], value_vars=df.columns[2:5])
    plot = alt.Chart(melt_results
                ).mark_boxplot(extent="min-max", size=5
                ).encode(x=alt.X(melt_results.columns[0]), y=alt.Y("value"), color=alt.Color("variable")
                ).properties(width=400, height=200
                ).facet(column="variable"
                ).interactive()
    return plot

In [None]:
plots = plot_batch(results_with_truths, partial(plot_avg_and_dots, draw_line=False))