In [22]:
import os
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import random
import glob
import nibabel as nib
import numpy as np
from scipy.ndimage import rotate
import csv
import SimpleITK as sitk
#from lungtumormask import mask as tumormask
from lungmask import mask as lungmask_fun
from skimage.measure import label, regionprops,shannon_entropy
from skimage.morphology import dilation,ball,erosion,remove_small_objects

from monai.utils import first, set_determinism
from monai.transforms import (
    SaveImage,
    ResizeWithPadOrCropd,
    MaskIntensityd,
    ScaleIntensityd,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    ResizeWithPadOrCrop,
    LoadImaged,
    Orientationd,
    FillHoles,
    RemoveSmallObjects,
    KeepLargestConnectedComponent,
    RandCropByPosNegLabeld,
    SaveImaged,
    CenterSpatialCropd,
    SpatialCropd,
    ScaleIntensityRanged,
    Spacingd,
    AsDiscrete,
    SpatialCrop,
    RandSpatialCropd,
    SpatialPadd,
    EnsureTyped,
    EnsureType,
    Invertd,
    DivisiblePadd,
    MapTransform,
    RandWeightedCropd,
    ToTensord,
    Transpose,
    ScaleIntensity,
)
from monai.networks.nets import UNet,VNet,SwinUNETR,UNETR,DynUNet
from monai.metrics import DiceMetric,SurfaceDiceMetric,HausdorffDistanceMetric,compute_surface_dice
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch,pad_list_data_collate

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)

if True:
    get_ipython().run_line_magic('matplotlib', 'inline')
    print('plot in line')

device: cuda
plot in line


In [3]:
# class to transpose lung mask
class Create_sequences(MapTransform):
    def __init__(self, keys):
        super().__init__(keys)
        print(f"keys to transpose: {self.keys}")

    def __call__(self, dictionary):
        dictionary = dict(dictionary)
        for key in self.keys:
            data = dictionary[key]
            if key == 'lung':
                data = np.transpose(data, (0, 2, 3, 1))
                data = rotate(data, 270, axes=(1, 2), reshape=False)
                data = np.flip(data, 1)
                data[data == 2] = int(1)
                data[data != 1] = int(0)
            dictionary[key] = data

        return dictionary


In [4]:
def get_kernels_strides(patch_size, spacing):
    sizes, spacings = patch_size, spacing
    input_size = sizes
    strides, kernels = [], []
    while True:
        spacing_ratio = [sp / min(spacings) for sp in spacings]
        stride = [
            2 if ratio <= 2 and size >= 8 else 1
            for (ratio, size) in zip(spacing_ratio, sizes)
        ]
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
        if all(s == 1 for s in stride):
            break
        for idx, (i, j) in enumerate(zip(sizes, stride)):
            if i % j != 0:
                raise ValueError(
                    f"Patch size is not supported, please try to modify the size {input_size[idx]} in the spatial dimension {idx}."
                )
        sizes = [i / j for i, j in zip(sizes, stride)]
        spacings = [i * j for i, j in zip(spacings, stride)]
        kernels.append(kernel)
        strides.append(stride)

    strides.insert(0, len(spacings) * [1])
    kernels.append(len(spacings) * [3])
    return kernels, strides

In [13]:

# root_path = '/data/p308104/Nifti_Imgs_V0/' #UMCG data on peregrine
#root_path = '/data/p308104/MultipleBP/'
#root_path = '/home/umcg/OneDrive/MultipleBreathingP/'
#root_path = '/home/umcg/Desktop/AutomaticITV_code/SABR1322_Nifti/'
root_path = '/home/umcg/Desktop/AutomaticITV_code/SABR1322_Nifti_AllBP_V2/'

#preweight_path = '/data/p308104/weights/v9/'
preweight_path = '/home/umcg/Desktop/AutomaticITV_code/weights/v11/'
pretrained_path_Swin = preweight_path + 'best_SwinUnet_V11_UMCG_TestSet3.pth'
pretrained_path_Dyn = preweight_path + 'best_DynUnet_V11_UMCG_TestSet3.pth'
    
figures_path = '/home/umcg/Desktop/AutomaticITV_code/figures_folder_i/'

cache = False
lf_select = None # NOT Needed for testing

SelectModel = 1  # 0 Swin  - 1 Dyn
figures_folder_i = figures_path+'figures_SwinDyn_V11_ITV'
name_run = "TestRun" + str(SelectModel) + "LF" + str(lf_select) + "run1"
print(name_run)



TestRun1LFNonerun1


In [6]:
def CreateLungMasks(root_path, CT_fpaths):
    # Get Lung mask and save it
    CT_path0 = CT_fpaths[0]
    CT_nii = nib.load(CT_path0)
    for ct in CT_fpaths:
        empty_header = nib.Nifti1Header()
        lung_path = ct[:-10] + '_LungMask.nii.gz'
        print('Creating Lung Mask: ', lung_path)
        input_image = sitk.ReadImage(ct, imageIO='NiftiImageIO')
        lungmask = lungmask_fun.apply(input_image)  # default model is U-net(R231)
        lungmask_ni = nib.Nifti1Image(lungmask, CT_nii.affine, empty_header)
        nib.save(lungmask_ni, lung_path)
    return 0

In [7]:
def LookSortFiles(root_path, all_patientdir):
    CTALL_fpaths =[]
    lungALL_fpaths =[]
    itv_fpaths = []
    gtv_fpaths = []    
    for patient_path in all_patientdir:
        ct_miss = 0
        gtv_miss = 0
        itv_miss = 0
        lung_miss = 0
        print(patient_path)
        for root, dirs, files in os.walk(root_path + patient_path, topdown=False):
            for f in files:
                if "_gtv" in f.lower():
                    gtv_fpaths.append(os.path.join(root_path, patient_path, f))
                    gtv_miss +=1
                if "_igtv" in f.lower() or "itv" in f.lower():
                    itv_fpaths.append(os.path.join(root_path, patient_path, f))
                    itv_miss +=1
                if "0%" in f.lower() and not("ave" in f.lower()):
                    if "ct" in f.lower():                            
                        CTALL_fpaths.append(os.path.join(root_path, patient_path, f))
                        ct_miss +=1
                    if "lung" in f.lower():
                        lungALL_fpaths.append(os.path.join(root_path, patient_path, f))
                        lung_miss +=1
            for i in range(len(CTALL_fpaths)-1):
                
                gtv_fpaths.append(itv_fpaths[-1])
                itv_fpaths.append(itv_fpaths[-1])

    print('ct: ',ct_miss,"Lungs: ",lung_miss,"GTV Miss: ", gtv_miss, "ITV Miss: ",itv_miss)
    if False:
        CreateLungMasks(root_path,CTALL_fpaths)
    
    CTALL_fpaths = np.sort(CTALL_fpaths)
    lungALL_fpaths = np.sort(lungALL_fpaths)
    itv_fpaths  = np.sort(itv_fpaths)
    gtv_fpaths = np.sort(gtv_fpaths)
    return CTALL_fpaths, itv_fpaths,gtv_fpaths, lungALL_fpaths

In [8]:
##MAIN
px_ = '0070683/'
all_patientdir = []
all_patientdir.append(px_)
#all_patientdir = os.listdir(root_path)
all_patientdir.sort()
print(len(all_patientdir),'in',name_run,all_patientdir)
CTALL_fpaths, itv_fpaths,gtv_fpaths, lungALL_fpaths = LookSortFiles(root_path, all_patientdir)
    
#Create data dictionat
data_dicts = [
    {"image": image_name,"lung":lung_name,"GTV": gtv_name,"ITV":itv_name}
    for image_name,lung_name,gtv_name,itv_name in zip(CTALL_fpaths,lungALL_fpaths,gtv_fpaths,itv_fpaths)
]
val_files =data_dicts[:]
print('CT val len:',len(val_files))


1 in TestRun1LFNonerun1 ['0070683/']
0070683/
ct:  10 Lungs:  10 GTV Miss:  0 ITV Miss:  2
CT val len: 9


In [14]:
num_workers=0
# HU are -1000 air , 0 water , usually normal tissue are around 0, top values should be around 100, bones are around 1000
minmin_CT = -1024
maxmax_CT = 200 
#Create Compose functions for preprocessing of train and validation
set_determinism(seed=0)
image_keys = ["image","lung","GTV","ITV"]
p = .5 #Data aug transform probability
size = 96
image_size = (size,size,size)
pin_memory = True if num_workers > 0 else False  # Do not change

val_transforms = Compose(
    [
        LoadImaged(keys=image_keys),
        EnsureChannelFirstd(keys=image_keys),
        Orientationd(keys=["image","GTV","ITV"], axcodes="RAS"),
        #Spacingd(keys=["image","label"], pixdim=(1,1,1),mode=("bilinear","nearest")),
        ScaleIntensityRanged(keys=["image"], a_min=minmin_CT, a_max=maxmax_CT,b_min=0.0, b_max=1.0, clip=True,),
        Create_sequences(keys=image_keys),
        CropForegroundd(keys=image_keys, source_key="lung",k_divisible = size),
        MaskIntensityd(keys=["image"], mask_key="lung"),
        ToTensord(keys=image_keys),
    ]
)

# Check the images after the preprocessing
if cache:  # Cache
    #train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=num_workers)
    #train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=num_workers)
    val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0,
                          num_workers=int(num_workers // 2))
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=int(num_workers // 2), pin_memory=pin_memory)
else:
    #train_ds = Dataset(data=train_files, transform=train_transforms)
    #train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=0)
    val_ds = Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=0)  # ,collate_fn=pad_list_data_collate)


keys to transpose: ('image', 'lung', 'GTV', 'ITV')


In [10]:
#Check the images after the preprocessing
if False:
    figsize = (18, 9)
    check_ds =Dataset(data=val_files, transform=val_transforms)
    check_loader = DataLoader(check_ds, batch_size=1,num_workers=0)
    for batch_data in check_loader:
        #batch_data = first(check_loader)
        image,lung, label_ITV = (batch_data["image"][0][0],batch_data["lung"][0][0],batch_data["ITV"][0][0])
        px = batch_data["image"].to('cpu').meta["filename_or_obj"][0].split('/')[-2]
        print("Px:", px)
        print(f"image shape: {image.shape},lung shape: {lung.shape}, label shape: {label_ITV.shape}")
        count = 0
        for i in range(label_ITV.shape[-1]):
            if torch.sum(label_ITV[:,:,i])>0:
                count+=1
                fig = plt.figure('Instance = {}'.format(0), figsize=figsize)
                plt.subplot(1,2,1),plt.imshow(np.rot90(image[:,:,i]),cmap='gray'),plt.axis('off')
                plt.subplot(1,2,2),plt.imshow(np.rot90(image[:,:,i]),cmap='gray'),plt.axis('off')
                plt.contour(np.rot90(lung[:, :,i]),colors='yellow')
                plt.contour(np.rot90(label_ITV[:,:,i]),colors='red')
                plt.tight_layout(),plt.show()
            if count >0: 
                break


In [15]:
# Create the model
spatial_dims = 3
max_epochs = 250
in_channels = 1
out_channels = 2  # including background
lr = 1e-3  # 1e-4
weight_decay = 1e-5
T_0 = 40  # Cosine scheduler

task_id = "06"
deep_supr_num = 1  # when is 3 shape of outputs/labels dont match
patch_size = image_size
spacing = [1, 1, 1]
kernels, strides = get_kernels_strides(patch_size, spacing)

print("MODEL SwinDyn")
modelSwin = SwinUNETR(
    image_size,
    in_channels, out_channels,
    use_checkpoint=True,
    feature_size=48,
    # spatial_dims=spatial_dims
).to(device)
task_id = "06"
modelDyn = DynUNet(
    spatial_dims=3,
    in_channels=in_channels,
    out_channels=out_channels,
    kernel_size=kernels,
    strides=strides,
    upsample_kernel_size=strides[1:],
    norm_name="instance",
    deep_supervision=False,  # when is 3 shape of outputs/labels dont match
    deep_supr_num=deep_supr_num,
).to(device)


#metrics, no definition of :
#NO Loss Function
#NO Optimizer
# Load pretrained model
if pretrained_path_Swin is not(None):
        modelSwin.load_state_dict(torch.load(pretrained_path_Swin, map_location=torch.device(device)))
        print('Using Swin pretrained weights!')
if pretrained_path_Dyn is not(None):
    modelDyn.load_state_dict(torch.load(pretrained_path_Dyn, map_location=torch.device(device)))
    print('Using Dyn pretrained weights!')


MODEL SwinDyn
Using Swin pretrained weights!
Using Dyn pretrained weights!


In [16]:
##TESTING
#Define PostTranforms
out_channels = 2  # including background
post_transforms = Compose(
    [
        EnsureType(),
        AsDiscrete(argmax=True, threshold=0.9),
        #FillHoles(applied_labels=1, connectivity=0),
        #RemoveSmallObjects(min_size=64, connectivity=3, independent_channels=True),
        ScaleIntensity(minv=0.0, maxv=1.0),
        KeepLargestConnectedComponent(applied_labels=None,is_onehot=False,connectivity=2,num_components=1),
    ]
)
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, threshold=0.9),
                     ScaleIntensity(minv=0.0, maxv=1.0)])

# Testing the model
print(len(val_loader))

predicted_ITV = []
predicted_ITV_nopost = []
best_blob = True
count=0
modelSwin.eval()
modelDyn.eval()
with torch.no_grad():
    for val_data in val_loader:
    
        #val_inputs, val_labels = (
        #    val_data["image"].to(device),
        #    val_data["label"].to(device),)
        val_inputs = val_data["image"].to(device)
        roi_size = image_size
        sw_batch_size = 1
        px = val_data["image"].to('cpu').meta["filename_or_obj"][0].split('/')[-2]
        bp = val_data["image"].to('cpu').meta["filename_or_obj"][0].split('/')[-1].split('=')[-1]
        print('Px: ', 'BP: ',bp," Count: ",px,count+1)
        count+=1

        val_outputs_Swin = sliding_window_inference(val_inputs, roi_size, sw_batch_size, modelSwin)
        val_outPost_Swin = [post_transforms(i) for i in decollate_batch(val_outputs_Swin)]
        val_outputs_Swin = [post_pred(i) for i in decollate_batch(val_outputs_Swin)]

        val_outputs_Dyn = sliding_window_inference(val_inputs, roi_size, sw_batch_size, modelDyn)
        val_outPost_Dyn = [post_transforms(i) for i in decollate_batch(val_outputs_Dyn)]
        val_outputs_Dyn = [post_pred(i) for i in decollate_batch(val_outputs_Dyn)]

        val_outPost = torch.logical_or(val_outPost_Swin[0], val_outPost_Dyn[0])        
        #val_outPost = torch.add(val_outPost_Swin[0], val_outPost_Dyn[0])        
        
        if best_blob: 

            out_3dnp = val_outPost[0].detach().cpu().numpy()
            out_3dnp = out_3dnp.squeeze()
            out_3dnp = dilation(out_3dnp, ball(2))
            label_out_3dnp = label(out_3dnp)
            props = regionprops(label_out_3dnp,val_inputs[0].detach().cpu().numpy().squeeze())

            print("num de blobs predicted: ",len(props))
            for n in range(len(props)):
                r = props[n]
                #print('prediction bbox',r.bbox,"size",len(r.coords))
                patch = np.zeros(out_3dnp.shape)
                for j in range(len(r.coords)): 
                    patch[r.coords[j][0],r.coords[j][1],r.coords[j][2]]=1
                #Create different matrixes, one for each blob to send to metrics
                predicted_blobn = np.zeros(out_3dnp.shape)  
                predicted_blobn[label_out_3dnp==n+1]=1
                #print("Bounding Box: ",r.bbox)
                #print("Values: ",r.axis_major_length/r.axis_minor_length,"Feret: ",r.feret_diameter_max)
                #print("Intensity:",r.intensity_max,r.intensity_min,r.intensity_mean)
                #print("Entropy:", shannon_entropy(predicted_blobn))

                if len(props)==1:
                    bestBlob = np.expand_dims(predicted_blobn, 0)
                    tensor_blobn = torch.from_numpy(bestBlob)
                elif n==0:
                    ratio_zero = r.axis_major_length/r.axis_minor_length
                    minint_zero = r.intensity_min
                    entr_zero = shannon_entropy(predicted_blobn)
                    blob_zero = predicted_blobn
                elif n>0: 
                    ratio_curr = r.axis_major_length/r.axis_minor_length
                    minint_curr = r.intensity_min
                    entr_curr = shannon_entropy(predicted_blobn)
                    if (ratio_curr>3) or (ratio_zero>3):
                        print("Selected by ratio")
                        if ratio_curr<ratio_zero:
                            bestBlob = np.expand_dims(predicted_blobn, 0)
                            tensor_blobn = torch.from_numpy(bestBlob)
                        else:
                            bestBlob = np.expand_dims(blob_zero, 0)
                            tensor_blobn = torch.from_numpy(bestBlob)
                    else:
                        if minint_zero<0.0001: 
                            bestBlob = np.expand_dims(predicted_blobn, 0)
                            tensor_blobn = torch.from_numpy(bestBlob)
                        elif minint_curr<0.0001: 
                            bestBlob = np.expand_dims(blob_zero, 0)
                            tensor_blobn = torch.from_numpy(bestBlob)
                        elif entr_curr<entr_zero:
                            bestBlob = np.expand_dims(predicted_blobn, 0)
                            tensor_blobn = torch.from_numpy(bestBlob)
                        else:
                            bestBlob = np.expand_dims(blob_zero, 0)
                            tensor_blobn = torch.from_numpy(bestBlob)            
            
        if best_blob:
            predicted_ITV.append(tensor_blobn)
        else:
            predicted_ITV.append(val_outPost)
        predicted_ITV_nopost.append(val_outPost)



9
Px:  BP:  4D thorax 2.0  2.0  Br38  3  0% iMAR_ct.nii.gz  Count:  0070683 1
num de blobs predicted:  1
Px:  BP:  4D thorax 2.0  2.0  Br38  3  10% iMAR_ct.nii.gz  Count:  0070683 2
num de blobs predicted:  1
Px:  BP:  4D thorax 2.0  2.0  Br38  3  20% iMAR_ct.nii.gz  Count:  0070683 3
num de blobs predicted:  1
Px:  BP:  4D thorax 2.0  2.0  Br38  3  30% iMAR_ct.nii.gz  Count:  0070683 4
num de blobs predicted:  1
Px:  BP:  4D thorax 2.0  2.0  Br38  3  40% iMAR_ct.nii.gz  Count:  0070683 5
num de blobs predicted:  1
Px:  BP:  4D thorax 2.0  2.0  Br38  3  50% iMAR_ct.nii.gz  Count:  0070683 6
num de blobs predicted:  1
Px:  BP:  4D thorax 2.0  2.0  Br38  3  60% iMAR_ct.nii.gz  Count:  0070683 7
num de blobs predicted:  1
Px:  BP:  4D thorax 2.0  2.0  Br38  3  70% iMAR_ct.nii.gz  Count:  0070683 8
num de blobs predicted:  1
Px:  BP:  4D thorax 2.0  2.0  Br38  3  80% iMAR_ct.nii.gz  Count:  0070683 9
num de blobs predicted:  1


In [19]:
def rescaleITV(predicted_ITV):

    predicted_ITV_rescale = []

    post_rescale = Compose([ResizeWithPadOrCrop(spatial_size=(384,384,192),method="symmetric"),AsDiscrete(threshold=0.1),ScaleIntensity(minv=0.0, maxv=1.0)])
    for k in range(len(predicted_ITV)):
        temp_ITV_tensor = post_rescale(predicted_ITV[k])
        predicted_ITV_rescale.append(temp_ITV_tensor)

        if k ==0: 
            ITV_tensor_10BP  = temp_ITV_tensor
            ITV_tensor_2BP   = temp_ITV_tensor
        elif k==4:
            ITV_tensor_2BP   = torch.add(ITV_tensor_2BP, temp_ITV_tensor)
        elif k!=0 and k!=5:
            ITV_tensor_10BP  = torch.add(ITV_tensor_10BP, temp_ITV_tensor)


        print(ITV_tensor_10BP.shape,ITV_tensor_2BP.shape,temp_ITV_tensor.shape)
        return ITV_tensor_10BP,ITV_tensor_2BP

ITV_tensor_10BP,ITV_tensor_2BP = rescaleITV(predicted_ITV)
ITV_tensor_10BP_noPost,ITV_tensor_2BPnoPost = rescaleITV(predicted_ITV_nopost)
    

(1, 384, 384, 192) (1, 384, 384, 192) (1, 384, 384, 192)
(1, 384, 384, 192) (1, 384, 384, 192) (1, 384, 384, 192)


In [20]:
post_label = Compose([EnsureType(),ResizeWithPadOrCrop(spatial_size=(384,384,192),method="symmetric"),AsDiscrete(threshold=0.5),ScaleIntensity(minv=0.0, maxv=1.0)])
#Create Label tensors
#val_GTV,val_ITV = val_data["GTV"].to(device),val_data["ITV"].to(device)
#val_GTV = [post_label(i) for i in decollate_batch(val_GTV)]
val_ITV = val_data["ITV"].to(device)
val_ITV = [post_label(i) for i in decollate_batch(val_ITV)]

lbl_3dnp = val_ITV[0].detach().cpu().numpy()
lbl_3dnp = lbl_3dnp.squeeze()
lbl_4d = np.expand_dims(lbl_3dnp, 0)
tensor_label = torch.from_numpy(lbl_4d)
label_img = label(lbl_3dnp)
regions = regionprops(label_img)
print("num de blobs in label: ",len(regions))
for i in range(len(regions)):
    r = regions[i]
    print('label bbox ',r.bbox,"size",len(r.coords))
    lbl_bbox = r.bbox
    
print(tensor_label.shape)


num de blobs in label:  1
label bbox  (211, 133, 102, 236, 156, 115) size 3273
torch.Size([1, 384, 384, 192])


In [36]:
import nibabel as nib

def saveNifti(np_image,path_to_save,filename):
    if not os.path.exists(path_to_save):
        os.makedirs(path_to_save)
    converted_array = np.array(np_image, dtype=np.float32)
    affine = np.eye(4)
    nifti_file = nib.Nifti1Image(converted_array, affine)
    nib.save(nifti_file, path_to_save+filename) # Here you put the path + the extionsion 'nii' or 'nii.gz'
    return 0
    
#path_to_save = os.path.join(figures_folder_i, px)
path_to_save = os.path.join(root_path, px)
saveNifti(ITV_tensor_10BP, path_to_save,'/predictedITV_AllBP.Nii')
saveNifti(ITV_tensor_10BP_noPost.detach().cpu().numpy(),path_to_save , '/predictedITV_AllBP_NoPost.Nii')
saveNifti(tensor_label, path_to_save , '/ITV_Label.Nii')



0

In [None]:
def TumorTrajectory(predicted_ITV_interp):
    xx = np.zeros(10,int)
    yy = np.zeros(10,int)
    zz = np.zeros(10,int)
    for l in range(len(predicted_ITV_interp)):
        labeledGTV = label(predicted_ITV_interp[l].detach().cpu().numpy().squeeze())
        propsGTV = regionprops(labeledGTV)
        for m in range(len(propsGTV)):
            #print("GTV #",l," -Centroid: ",propsGTV[m].centroid)
            xx[l],yy[l],zz[l] = propsGTV[m].centroid_local
    ax = plt.figure().add_subplot(projection='3d')
    plt.subplot(1,1,1),plt.plot(xx,yy,zz,label='x,y,z')
    ax.legend()
    plt.show()
    plt.subplot(1,3,1),plt.plot(xx)
    plt.subplot(1,3,2),plt.plot(yy)
    plt.subplot(1,3,3),plt.plot(zz)
    plt.show()
    print('Limits: ',xx.max()-xx.min(),yy.max()-yy.min(),zz.max()-zz.min())
    return
TumorTrajectory(predicted_ITV_rescale)

In [None]:
#Metrics
def metrics(Metrics_tensor,tensor_label):
    dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
    hausdorff_metric = HausdorffDistanceMetric(include_background=False, reduction="mean", percentile=95)
    surfDice_metric191 = SurfaceDiceMetric(class_thresholds=np.linspace(3, 3, 191), include_background=False)
    surfDice_metric287 = SurfaceDiceMetric(class_thresholds=np.linspace(3,3,287), include_background=False)
    surfDice_metric383 = SurfaceDiceMetric(class_thresholds=np.linspace(3, 3, 383), include_background=False)
    
    voxTP = 0
    voxFN = 0
    voxFP = 0

    #DICE
    dice_metric(y_pred=Metrics_tensor, y=tensor_label) #[0][1:2, :, :, :] [0]
    dice1 = dice_metric.aggregate().item()
    print('Dice :', dice1)
    dice_metric.reset()

    hausdorff_metric(y_pred=Metrics_tensor, y=tensor_label)
    hausd1 = hausdorff_metric.aggregate().item()
    hausdorff_metric.reset()
    print('Hausdorff:', hausd1)

    #print("Shape:",tensor_blobn.shape,tensor_blobn[0].shape[1])
    if Metrics_tensor.shape[1]==288:
        surfDice_metric287(y_pred=Metrics_tensor, y=tensor_label)
        sdice1 = surfDice_metric287.aggregate().item()
        surfDice_metric287.reset()
    elif Metrics_tensor.shape[1]==192:
        surfDice_metric191(y_pred=Metrics_tensor, y=tensor_label)
        sdice1 = surfDice_metric191.aggregate().item()
        surfDice_metric191.reset()
    else:
        surfDice_metric383(y_pred=Metrics_tensor, y=tensor_label)
        sdice1 = surfDice_metric383.aggregate().item()
        surfDice_metric383.reset()
    print('Surface dice:', sdice1)

    out_3dnp = Metrics_tensor.detach().cpu().numpy()
    out_3dnp = out_3dnp.squeeze()
    label_out_3dnp = label(out_3dnp)
    props = regionprops(label_out_3dnp)

    lbl_3dnp = tensor_label.detach().cpu().numpy()
    lbl_3dnp = lbl_3dnp.squeeze()
    label_lbl = label(lbl_3dnp)
    regions = regionprops(label_lbl)
    print("num de blobs in label: ",len(regions))
    for i in range(len(regions)):
        r = regions[i]
        lbl_bbox = r.bbox

    print("num de blobs predicted: ",len(props))
    for n in range(len(props)):
        r = props[n]
        print("Blob Area:",r.area)
        TP = False
        for j in range(len(r.coords)):
            if (r.coords[j][0]>lbl_bbox[0] and r.coords[j][0]<lbl_bbox[3]):
                if (r.coords[j][1]>lbl_bbox[1] and r.coords[j][1]<lbl_bbox[4]):
                    if (r.coords[j][2]>lbl_bbox[2] and r.coords[j][2]<lbl_bbox[5]):
                        TP=True
                        voxTP+=1
        if TP:
            sumPredicted =np.sum(out_3dnp)
            sumGroundT=np.sum(lbl_3dnp)
            voxFP=abs(sumPredicted-voxTP)
            voxFN=abs(sumGroundT-voxTP)

            print("Sensitivity: ",voxTP/(voxTP+voxFP))
            print("Precision: ",voxTP/(voxTP+voxFN))
            
        
        return out_3dnp 

In [None]:
def postITV(ITV_tensor,binarize):
    tensor_blobn=[]
    #Delete outliers of slices
    out_3dnp = ITV_tensor.detach().cpu().numpy()
    out_3dnp = out_3dnp.squeeze()
    out_3dnp[:,:,:2] = 0
    out_3dnp[:,:,180:] = 0
    out_3dnp[out_3dnp>=1] = 1
    out_3dnp[out_3dnp<1] = 0
    label_out_3dnp = label(out_3dnp)
    props = regionprops(label_out_3dnp)
    for n in range(len(props)):
        r = props[n]        
        predicted_blobn = np.zeros(out_3dnp.shape)  
        predicted_blobn[label_out_3dnp==n+1]=1
        #predicted_blobn = dilation(predicted_blobn, ball(3))
        predicted_blobn = np.expand_dims(predicted_blobn, 0)
        tensor_blobn.append(torch.from_numpy(predicted_blobn))
    return tensor_blobn

print("All breathing phases:")
ITV_tensor_post = postITV(ITV_tensor_10BP,binarize=True)
print(len(ITV_tensor_post)," Blobs found")
for p in range(len(ITV_tensor_post)):
    print("Stats for blob #",p+1)
    out_3dnp = metrics(ITV_tensor_post[p],tensor_label)
    path_to_save = os.path.join(figures_folder_i, px)
    saveNifti(np_image=out_3dnp, path_to_save=path_to_save +'/predictedITV_AllBP_'+str(p)+'.Nii')

print('')
print("TWO breathing phases:")
BP2_ITV_tensor_post = postITV(ITV_tensor_2BP,binarize=True)
for p in range(len(BP2_ITV_tensor_post)):
    out_3dnp = metrics(BP2_ITV_tensor_post[p],tensor_label)
    path_to_save = os.path.join(figures_folder_i, px)
    saveNifti(np_image=out_3dnp, path_to_save=path_to_save +'/predicted_TwoBP_'+str(p)+'.Nii')


In [None]:
def plotOutput(ITV_Final_Tensor,predicted_ITV,tensor_label,name,px,minmin_CT_plot,maxmax_CT_plot):
    post_rescale_image = Compose(
        [
            LoadImaged(keys=image_keys),
            EnsureChannelFirstd(keys=image_keys),
            Orientationd(keys=["image","GTV","ITV"], axcodes="RAS"),
            #Spacingd(keys=["image","label"], pixdim=(1,1,1),mode=("bilinear","nearest")),
            ScaleIntensityRanged(keys=["image"], a_min=minmin_CT_plot, a_max=maxmax_CT_plot,b_min=0.0, b_max=1.0, clip=True,),
            Create_sequences(keys=image_keys),
            CropForegroundd(keys=image_keys, source_key="lung",k_divisible = size),
            ResizeWithPadOrCropd(keys = ["image"],spatial_size=(384,384,192),mode="symmetric"),
            ToTensord(keys=image_keys),
        ])
    figsize = (18, 8)
    #image
    val_ds = Dataset(data=val_files, transform=post_rescale_image)
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=0)  # ,collate_fn=pad_list_data_collate)
    batch_data = first(val_loader)
    image,lung, label_ITV = (batch_data["image"][0][0],batch_data["lung"][0][0],batch_data["ITV"][0][0])
    image = image.detach().cpu().numpy()
    image = image.squeeze()

    #ITV
    lblimg = post_rescale(tensor_label)
    lblimg = lblimg.detach().cpu().numpy()
    lblimg = lblimg.squeeze()


    #Prediction
    breathingphases = len(predicted_ITV)
    out_3dnp = ITV_Final_Tensor.detach().cpu().numpy()
    out_3dnp = out_3dnp.squeeze()


    for i in range(lbl_3dnp.shape[2]):
        if (np.sum(lbl_3dnp[:, :,i],)>0) or (np.sum(out_3dnp[:, :,i],)>0):
                fig = plt.figure('Instance = {}'.format(0), figsize=figsize)
                ax = fig.add_subplot(121)
                ax.imshow(np.rot90(image[96:288,96:288, i]),cmap='gray'),plt.axis('off')
                ax = fig.add_subplot(122)
                ax.imshow(np.rot90(image[96:288,96:288, i]),cmap='gray'),plt.axis('off')
                for bp in range(breathingphases):
                    ax.contour(np.rot90(predicted_ITV[bp].detach().cpu().numpy()[0,96:288,96:288,i]),colors='red')
                ax.contour(np.rot90(lblimg[96:288,96:288,i]),colors='yellow')
                ax.contour(np.rot90(out_3dnp[96:288,96:288,i]),colors='blue')
                ax.text(8, 10, 'Yellow ITV Label', style='normal',color='white',fontsize=15)
                ax.text(8, 25, 'Blue ITV Prediction', style='normal',color='white',fontsize=15)
                ax.text(8, 40, ['Red GTV Prediction',i], style='normal',color='white',fontsize=15)
                plt.show()
                #if not os.path.exists(os.path.join(figures_folder_i, px)):
                #                os.makedirs(os.path.join(figures_folder_i, px))
                #plt.savefig(os.path.join(figures_folder_i, px, 'FullV_final_V11{}.png'.format(i,name)))
                #plt.clf()
    return 0
minmin_CT_plot = -1024
maxmax_CT_plot = 2000
plotOutput(ITV_tensor_post[0],predicted_ITV_rescale,tensor_label,"FullBP",px,minmin_CT_plot,maxmax_CT_plot)

In [None]:
figures_folder_i