# DeepEdit Inference Tutorial - Edited by Elan

DeepEdit is an algorithm that combines the power of two models in one single architecture. It allows the user to perform inference, as a standard segmentation method (i.e. UNet), and also to interactively segment part of an image using clicks (Sakinis et al.). DeepEdit aims to facilitate the user experience and at the same time the development of new active learning techniques.


This Notebooks shows the performance of a model trained to segment the spleen. 

**Once the model is trained, we recommend importing the pretrained model into the [DeepEdit App in MONAI Label](https://github.com/Project-MONAI/MONAILabel/tree/main/sample-apps/radiology#deepedit) for full experience.**

Sakinis et al., Interactive segmentation of medical images through fully convolutional neural networks. (2019) https://arxiv.org/abs/1903.08205

In [None]:
# !python -c "import monai" || pip install -q "monai-weekly[nibabel tqdm]"
# !python -c "import matplotlib" || pip install -q matplotlib==3.5.2
# !pip install -q pytorch-ignite==0.4.8

# %matplotlib inline

#### Library versions used:

monai-weekly==0.9.dev2219 itk==5.2.1.post1 matplotlib==3.5.2 nibabel==3.2.2 numpy==1.22.3 pytorch-ignite==0.4.8 scikit-image==0.19.2 scipy==1.8.0 tensorboard==2.8.0 torch==1.11.0 tqdm==4.64.0

In [None]:
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import jit

import monai
from monai.config import print_config

from monai.apps.deepedit.transforms import (
    AddGuidanceSignalDeepEditd,
    AddGuidanceFromPointsDeepEditd,
    ResizeGuidanceMultipleLabelDeepEditd,
)


from monai.transforms import (
    Activationsd,
    AsDiscreted,
    EnsureChannelFirstd,
    EnsureTyped,
    LoadImaged,
    Orientationd,
    Resized,
    ScaleIntensityRanged,
    SqueezeDimd,
    ToNumpyd,
    ToTensord,SaveImaged
)

print_config()

### Plotting functions

In [None]:
def draw_points(guidance, slice_idx):
    if guidance is None:
        return
    for p in guidance:
        p1 = p[1]
        p2 = p[0]
        plt.plot(p1, p2, "r+", 'MarkerSize', 30)


def show_image(image, label, guidance=None, slice_idx=None):
    plt.figure("check", (12, 6))
    plt.subplot(1, 2, 1)
    plt.title("image")
    plt.imshow(image, cmap="gray")

    if label is not None:
        masked = np.ma.masked_where(label == 0, label)
        plt.imshow(masked, 'jet', interpolation='none', alpha=0.7)

    draw_points(guidance, slice_idx)
    plt.colorbar()

    if label is not None:
        plt.subplot(1, 2, 2)
        plt.title("label")
        plt.imshow(label)
        plt.colorbar()
        # draw_points(guidance, slice_idx)
    plt.show()


def print_data(data):
    for k in data:
        v = data[k]

        d = type(v)
        if type(v) in (int, float, bool, str, dict, tuple):
            d = v
        elif hasattr(v, 'shape'):
            d = v.shape

        if k in ('image_meta_dict', 'label_meta_dict'):
            for m in data[k]:
                print('{} Meta:: {} => {}'.format(k, m, data[k][m]))
        else:
            print('Data key: {} = {}'.format(k, d))

### Set working directory

In [None]:
os.chdir('/workspace/abdominal-segmentation')

### Download data if not available

In [None]:
# Download data and model

# resource = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/_image.nii.gz"
# dst = "_image.nii.gz"

# if not os.path.exists(dst):
#     monai.apps.download_url(resource, dst)

# resource = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/\
# download/0.8.1/pretrained_deepedit_dynunet-final.ts"
# dst = "pretrained_deepedit_dynunet-final.ts"
# if not os.path.exists(dst):
#     monai.apps.download_url(resource, dst)

### Define preprocessing transforms used during training to transform inference images

In [None]:
input_image_path = 'datasets/Task09_Spleen/imagesTs/spleen_1.nii.gz'

In [None]:
# labels
labels = {'spleen': 1,
          'background': 0
          }

# Pre Processing
spatial_size = [128, 128, 128]

input_image_path = 'datasets/Task09_Spleen/imagesTs/spleen_1.nii.gz'

data = {
    'image': input_image_path,
    'guidance': {'spleen': [[66, 180, 105], [66, 180, 145]], 'background': []},
}

#slice_idx = original_slice_idx = data['guidance']['spleen'][0][2]
slice_idx = 15

pre_transforms = [
    # Loading the image
    LoadImaged(keys="image", reader="ITKReader"),
    # Ensure channel first
    EnsureChannelFirstd(keys="image"),
    # Change image orientation
    Orientationd(keys="image", axcodes="RAS"),
    # Scaling image intensity - works well for CT images
    ScaleIntensityRanged(keys="image", a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
    # DeepEdit Tranforms for Inference #
    # Add guidance (points) in the form of tensors based on the user input
    AddGuidanceFromPointsDeepEditd(ref_image="image", guidance="guidance", label_names=labels),
    # Resize the image
    Resized(keys="image", spatial_size=spatial_size, mode="area"),
    # Resize the guidance based on the image resizing
    ResizeGuidanceMultipleLabelDeepEditd(guidance="guidance", ref_image="image"),
    # Add the guidance to the input image
    AddGuidanceSignalDeepEditd(keys="image", guidance="guidance"),
    # Convert image to tensor
    ToTensord(keys="image"),
]

original_image = None

# Going through each of the pre_transforms
original_shape = None

for t in pre_transforms:
    tname = type(t).__name__
    data = t(data)
    image = data['image']
    label = data.get('label')
    guidance = data.get('guidance')

    print("{} => image shape: {}".format(tname, image.shape))

    if tname == 'LoadImaged':
        original_image = data['image']
        original_shape = original_image.shape
        label = None
        tmp_image = image[:, :, slice_idx]
        show_image(tmp_image, label, [guidance['spleen'][0]], slice_idx)

transformed_image = data['image']
guidance = data.get('guidance')

### Load model from TS file, skip next section if using TS model!

In [None]:
# Evaluation
# Using TS model
# model_path = 'models/Task09_Spleen/imagesTr/pretrained_deepedit_dynunet-final.ts'
# model = jit.load(model_path)
# model.cuda()
# model.eval()

### Load model from PT File

In [None]:
def get_network(network, labels, spatial_size):
    # Network
    if network == "unetr":
        network = UNETR(
            spatial_dims=3,
            in_channels=len(labels) + 1,
            out_channels=len(labels),
            img_size=spatial_size,
            feature_size=64,
            hidden_size=1536,
            mlp_dim=3072,
            num_heads=48,
            pos_embed="conv",
            norm_name="instance",
            res_block=True,
        )
    else:
        network = DynUNet(
            spatial_dims=3,
            in_channels=len(labels) + 1,
            out_channels=len(labels),
            kernel_size=[3, 3, 3, 3, 3, 3],
            strides=[1, 2, 2, 2, 2, [2, 2, 1]],
            upsample_kernel_size=[2, 2, 2, 2, [2, 2, 1]],
            norm_name="instance",
            deep_supervision=False,
            res_block=True,
        )
    return network

In [None]:
# Using PT model
import distutils.util
from monai.networks.nets import DynUNet, UNETR
import torch.distributed as dist

model_path = 'models/Task09_Spleen/imagesTr/pretrained_deepedit_dynunet-final.pt'
#CheckPoint = torch.load(model_path)

network = 'dynunet'
labels = {"spleen": 1,"background": 0,}
spatial_size = [128, 128, 128]
use_gpu=True
device = torch.device("cuda" if use_gpu else "cpu")
network = get_network(network, labels, spatial_size).to(device)
network.load_state_dict(torch.load(model_path))
network.cuda()
model = network.eval()
#model.load_state_dict(torch.load(PATH))
# optimizers.load_state_dict(CheckPoint['optimizer_state_dict'])
# Epoch = CheckPoint['epoch']
# Loss = CheckPoint['loss']

### Predict and display label and image overlay at model resolution

In [None]:
inputs = data['image'][None].cuda()
with torch.no_grad():
    outputs = model(inputs)
outputs = outputs[0]
data['pred'] = outputs

post_transforms = [
    EnsureTyped(keys="pred"),
    Activationsd(keys="pred", softmax=True),
    AsDiscreted(keys="pred", argmax=True),
    SqueezeDimd(keys="pred", dim=0),
    ToNumpyd(keys="pred"),
]

data_orig = data.copy()

pred = None
for t in post_transforms:
    tname = type(t).__name__
    data = t(data)
    image = data['image']
    label = data['pred']
    print("{} => image shape: {}, pred shape: {}".format(tname, image.shape, label.shape))

for i in range(10, 110, 40):
    image = transformed_image[0, :, :, i]  # Taking the first channel which is the main image
    label = data['pred'][:, :, i]
    if np.sum(label) == 0:
        continue

    print("Final PLOT:: {} => image shape: {}, pred shape: {}; min: {}, max: {}, sum: {}".format(
        i, image.shape, label.shape, np.min(label), np.max(label), np.sum(label)))
    show_image(image, label)

### Convert prediction back to original shape as scan image 

In [None]:
print('Converting prediction image to original shape', original_shape)    
post_transforms_resize =post_transforms[:-2] + [Resized(keys=["image","pred"], spatial_size=original_shape, mode=["area","nearest"])] + post_transforms[-2:]
for t in post_transforms_resize:
    tname = type(t).__name__
    data_orig = t(data_orig)
    image = data_orig['image']
    label = data_orig['pred']
    print("{} => image shape: {}, pred shape: {}".format(tname, image.shape, label.shape))# print("Resized Final PLOT:: {} => image shape: {}, pred shape: {}; min: {}, max: {}, sum: {}".format(

In [None]:
# Display Images
for i in range(5, original_shape[2], 5):
    image = data_orig['image'][0, :, :, i]  # Taking the first channel which is the main image
    label = data_orig['pred'][:, :, i]
    if np.sum(label) == 0:
        continue
       
    print("Final PLOT:: {} => image shape: {}, pred shape: {}; min: {}, max: {}, sum: {}".format(
        i, image.shape, label.shape, np.min(label), np.max(label), np.sum(label)))
    show_image(image, label)

### Saving using NiBabel - Does not work with Slicer3D, looks fine on ImageJ

In [None]:
#pip install nibabel

In [None]:
import nibabel as nib
output_dir = 'output/Task09_Spleen/imagesTs'
ext = 'NIB.nii.gz'
os.makedirs(output_dir,exist_ok=True)
original_affine = data_orig['image_meta_dict']['affine']
print(original_affine)
nib.save(nib.Nifti1Image(data_orig['pred'].astype(np.uint8), original_affine),
     os.path.join(output_dir,'spleen_1_nib.nii.gz'))


### Saving using SaveImage Transform in MONAI - does not work with slicer3D, looks fine on ImageJ

In [None]:
data_orig['pred_meta_dict'] = data_orig['image_meta_dict']
sd = SaveImaged(keys=["pred"],output_dir=output_dir, output_postfix='SD', output_ext='.nii.gz',
           squeeze_end_dims=True, 
           data_root_dir='datasets/Task09_Spleen/imagesTs',
           writer=None)
sd(data_orig)

### Saving using SITK - Works with slicer3D for the Task09_Spleen dataset, needs path of original input image

In [None]:
def fix_3D_shape(nparray,option = 'channel_smallest'): # converts (H,W,C) to (C,H,W) arrays
    print('Fixing 3D shape...')
    if len(nparray.shape) == 3:
        if option == 'channel_smallest': # channel is the smallest dimension
            print('Input array shape: ', nparray.shape)
            if np.argmin(nparray.shape) == 2:
                nparray = np.moveaxis(nparray, 2, 0)
        elif option == 'square_image':
            if nparray.shape[0] == nparray.shape[1]:
                nparray = np.moveaxis(nparray, 2, 0)
        print('Output array shape: ', nparray.shape)
        return nparray.astype('int16')
    else:
        #print('Not a 3D Array, shape: ')
        raise Exception('Not a 3D array', nparray.shape)

In [None]:
import SimpleITK as sitk
origImg = sitk.ReadImage(input_image_path)
print('Orig Imag shape: ', origImg.GetSize())
print('Label array shape: ', data_orig['pred'].astype(np.uint8).shape)
sitkImg = sitk.GetImageFromArray(np.rot90(fix_3D_shape(data_orig['pred'].astype(np.uint8)), k=1, axes=(1,2)))
sitkImg = sitk.Flip(sitkImg, [False, True, False])
print('Pred Image shape: ',sitkImg.GetSize())
sitkImg.CopyInformation(origImg)
scan_out =os.path.join(output_dir,'spleen_1_sitk.nii.gz')
sitk.WriteImage(sitkImg, scan_out)