# Model Interpretation through Sensitivity Analysis for Segmentation

## Setup

In [None]:
#default_exp core

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

Force execution on CPU not GPU

In [None]:
#exporti
from fastai.vision import *
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

In [None]:
#exporti
def dice(rotatedPrediction, trueMask, component = 1):
    dice = 1
    pred = rotatedPrediction.data == component
    msk = trueMask.data == component
    intersect = pred&msk
    total = pred.sum() + msk.sum() 
    if total > 0:
        dice = 2 * intersect.sum().float() / total
    return dice.item()

## Sensitivity Analysis

### Rotation

In [None]:
#export
def plot_rotation_series(image_function, model, start=0, end=180, num=5):
    fig, axs = plt.subplots(1,num,figsize=(16,6))
    for deg, ax in zip(np.linspace(start,end,num), axs):
        img = image_function().resize(256).rotate(degrees=int(deg))
        img.show(ax=ax, title=f'degrees={deg}', y =  model.predict(img)[0])

In [None]:
#export
def rotation_series(image_function, mask_function, model, step_size=5):
    trueMask = mask_function().resize(256)
    results = list()
    for deg in tqdm(range(0, 360, step_size)):
        image = image_function().resize(256)
        rotatedImage = image.rotate(degrees=deg)
        prediction = model.predict(rotatedImage)[0]
        prediction._px = prediction._px.float()
        rotatedPrediction = prediction.rotate(degrees=-deg)
        diceLV = dice(rotatedPrediction, trueMask, component = 1)
        diceMY = dice(rotatedPrediction, trueMask, component = 2)
        results.append([deg, diceLV, diceMY])

    results = pd.DataFrame(results,columns = ['deg', 'diceLV', 'diceMY'])
    return results

### Cropping

In [None]:
#export
def plot_crop_series(image_function, model, start=256, end=56, num=5):
    fig, axs = plt.subplots(1,num,figsize=(16,6))
    for pxls, ax in zip(np.linspace(start,end,num), axs):
        croppedImage = image_function()
        croppedImage.resize(256)
        croppedImage.crop(int(pxls))
        croppedImage.crop_pad(256, padding_mode = 'zeros')
        croppedImage.show(ax=ax, title=f'pixels={int(pxls)}', y =  model.predict(croppedImage)[0])

In [None]:
#export
def crop_series(image_function, mask_function, model, step_size=5):
    results = list()
    for pxls in tqdm(range(256, 32, -step_size)):
        image = image_function()
        image.resize(256)
        trueMask = mask_function()
        trueMask.resize(256)

        croppedImage = image.crop(pxls).crop_pad(256, padding_mode = 'zeros')
        prediction = model.predict(croppedImage)[0]
        prediction._px = prediction._px.float()

        diceLV = dice(prediction, trueMask, component = 1)
        diceMY = dice(prediction, trueMask, component = 2)
        results.append([pxls, diceLV, diceMY])

    results = pd.DataFrame(results,columns = ['pxls', 'diceLV', 'diceMY'])
    return results

In [None]:
from nbdev.export import notebook2script
notebook2script()

Converted 01_local_interpret.ipynb.
Converted index.ipynb.
