In [None]:
from __future__ import division, print_function

import argparse
import csv
import glob
import os
import re
import shutil
import tempfile
import unittest
import gspread as gspread
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymia
import torch
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.inferers import sliding_window_inference
from monai.losses import DiceLoss
from monai.metrics import (DiceMetric, HausdorffDistanceMetric,
                           SurfaceDistanceMetric)
from monai.networks.layers import Norm
from monai.networks.nets import UNet
from monai.transforms import *
from monai.transforms import (AsDiscrete, AsDiscreted, Compose,
                              CropForegroundd, EnsureChannelFirstd, LoadImage,
                              LoadImaged, Orientationd, RandCropByPosNegLabeld,
                              Resized, ScaleIntensityRanged, Spacingd)
from monai.utils import first, set_determinism
from torch.utils import benchmark
from torch.utils.cpp_extension import load
import torchvision.transforms as T
from scipy import ndimage
import matplotlib.pyplot as plt
import einops
# import warp as wp


"""
loading data and extension
"""
data_dir = "/data/OrganSegmentations"
lltm_cuda = load('lltm_cuda', ['lltm_cuda.cpp', 'lltm_cuda_kernel.cu'], verbose=True)
#help(lltm_cuda)
device = torch.device("cuda")

%matplotlib inline


In [None]:

"""
set of functions that are used for simulating the output of arbitrary algorithm an preprocessing required in medical imaging pipeline
"""

def dilatatee_inner(arr, smallK):
    """
    binary dilatation that is applied either to all points or just to a subset of the border points
    """
    arr_new=ndimage.binary_dilation(arr,iterations=1)   
   
    if(smallK):
        inds=np.argwhere(np.logical_and(arr_new,np.logical_not(arr)))
        rng = np.random.default_rng()
        rng.shuffle(inds,axis=0)


        k = (inds.shape[0]//3)*2
        inds=inds[0:k,:]    
        arr_new[inds[:,0],inds[:,1],inds[:,2]]=False
    return arr_new

def dilatatee(arr,n_iter, smallK):
    """
    applies dilatatee_inner function n_iter times
    """
    for i in range(n_iter):
        arr=dilatatee_inner(arr, smallK)
    return arr

def execute_single_case():
    """
    get example files from a dataset preprocess them and return full per voxel Hausdorf distance
    """
    set_determinism(seed=0)
    
    #preprocessing
    val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Spacingd(keys=["image", "label"], pixdim=(
            1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        EnsureTyped(keys=["image", "label"]),
    ])

    #choosing files
    images = sorted(
        glob.glob(os.path.join(data_dir, "*.nii.gz")))
    train_images= list(filter(lambda p : "volume" in p, images))   
    train_labels= list(filter(lambda p : "label" in p, images))

    data_dicts = [
        {"image": image_name, "label": label_name}
        for image_name, label_name in zip(train_images, train_labels)
    ]   
    data_dicts= data_dicts[0:1]

    #get dataloader
    check_ds = Dataset(data=data_dicts, transform=val_transforms)
    check_loader = DataLoader(check_ds, batch_size=1)
    for dat in check_loader:   
        ii=1

        sizz = dat['image'].shape        
        #in this case we will display just the liver so we had chosen its label
        labelBoolTensorA =  torch.where( dat['label']==ii, 1, 0).bool()


        labelBoolTensorB=dilatatee(labelBoolTensorA[0,0,:,:,:].detach().cpu().numpy(),20,True)
        labelBoolTensorB=dilatatee(labelBoolTensorB,10,False)
        labelBoolTensorB= torch.tensor(labelBoolTensorB)

        WIDTH,  HEIGHT,  DEPTH= sizz[2], sizz[3],sizz[4]
        """
        invoking the housdorff distance function so we get the per voxel HD value that can be displayed
        """
        res=lltm_cuda.getHausdorffDistance_3Dres(labelBoolTensorA[0,0,:,:,:].to(device) , labelBoolTensorB.to(device),  WIDTH,  HEIGHT,  DEPTH,1.0, torch.ones(1, dtype =bool) )
        res= res.reshape(WIDTH,  HEIGHT,  DEPTH)

        #in this case we want just one case so we break the loop
        return res,labelBoolTensorA,labelBoolTensorB

In [None]:
"""
execution of main function
"""
res,labelBoolTensorA,labelBoolTensorB=execute_single_case()


In [None]:
sizz = labelBoolTensorA.shape  
WIDTH,  HEIGHT,  DEPTH= sizz[2], sizz[3],sizz[4]
# res=lltm_cuda.getHausdorffDistance_3Dres(labelBoolTensorA[0,0,:,:,:].to(device) , labelBoolTensorB.to(device),  WIDTH,  HEIGHT,  DEPTH,1.0, torch.ones(1, dtype =bool) )
robustness_percent=1.0
point=lltm_cuda.getHausdorffDistance(labelBoolTensorA[0,0,:,:,:].to(device) , labelBoolTensorB.to(device),  WIDTH,  HEIGHT,  DEPTH,robustness_percent, torch.ones(1, dtype =bool) )
arr=lltm_cuda.getHausdorffDistance_FullResList(labelBoolTensorA[0,0,:,:,:].to(device) , labelBoolTensorB.to(device),  WIDTH,  HEIGHT,  DEPTH,robustness_percent, torch.ones(1, dtype =bool) )
point

In [None]:
import matplotlib.pyplot as plt

plt.hist(arr.detach().cpu().numpy())
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.title('Histogram of arr')
plt.show()


In [None]:

res_now=res.detach().cpu().numpy()
res_now= np.rot90(res_now,1)
res_now= np.flip(res_now,1)
x=einops.reduce(res_now, 'x y z -> x', 'sum')
z=einops.reduce(res_now, 'x y z -> z', 'sum')
slice_x= np.argmax(x)
slice_z= np.argmax(z)
to_disp=res_now[:,:,slice_z-40]

plt.imshow(to_disp)

In [None]:
import SimpleITK as sitk

# Save the images as a Nifti files to enable inspection in 3D slicer
sitk.WriteImage(sitk.GetImageFromArray((res.detach().cpu().numpy())), "/workspaces/Hausdorff_morphological/data/output.nii.gz")
sitk.WriteImage(sitk.GetImageFromArray(labelBoolTensorA.detach().cpu().numpy().astype(np.uint8)), "/workspaces/Hausdorff_morphological/data/orig.nii.gz")
sitk.WriteImage(sitk.GetImageFromArray(labelBoolTensorB.detach().cpu().numpy().astype(np.uint8)), "/workspaces/Hausdorff_morphological/data/dilatated.nii.gz")

