In [None]:
import numpy as np
import pydicom
import os
import cv2
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from tensorflow import keras

In [34]:
#define all the functions

#a function for loading dicom images
#input: a path to the folder containing dicom slices of a dynamic pet image with dimensions 220*220*380*24
#output: a 3d image of 220*220*380
def read_pet_image(path_to_dcm_folder):
    #the input is the path to the folder containing all the DICOM slices
    #first, create a list of the names of all the files in this folders
    filenames=os.listdir(path_to_dcm_folder)
    #then create a list of filenames which are DICOM (i.e. have .dcm ending)
    dcm_filenames=[]
    for i in range(len(filenames)):
        if filenames[i][-4:]=='.dcm':
            dcm_filenames.append(filenames[i])
    #initialise lists for slices, their locations, and times
    slice_list=[]
    locations=[]
    times=[]
    #fill the lists by going through all the DICOM slices
    for i in range(len(dcm_filenames)):
        file='{}/{}'.format(path_to_dcm_folder,dcm_filenames[i])
        ds=pydicom.dcmread(file)
        locations.append(ds.ImagePositionPatient[2])
        times.append(ds.AcquisitionTime)
        if 'RescaleSlope' in ds and 'RescaleIntercept' in ds:
            slice_list.append(ds.pixel_array*ds.RescaleSlope+ ds.RescaleIntercept)
        else:
            slice_list.append(ds.pixel_array)
    #create new lists of the unique values of the acquisation times
    unique_times=np.sort(np.unique(np.array(times)))
    #initialize the image
    img=np.zeros((220,220,380))
    #find the 3D time-frames for all unique time values
    for i in range(len(unique_times)):
        locations_1=list(np.array(locations)[np.array(times)==unique_times[i]])
        slice_list_1=list(np.array(slice_list)[np.array(times)==unique_times[i]])
        merged=dict(zip(locations_1,slice_list_1))
        ordered_slice_list=[merged[loc] for loc in sorted(locations_1)]
        img_3d=np.moveaxis(np.array(ordered_slice_list),0,2)
        img_3d=np.moveaxis(img_3d,0,1)
        img_3d=np.flip(img_3d,axis=2)
        if i>0:
            k=(float(unique_times[i])-float(unique_times[i-1]))/(float(unique_times[-1])-float(unique_times[0]))
        else:
            k=(float(unique_times[i+1])-float(unique_times[i]))/(float(unique_times[-1])-float(unique_times[0]))
        img+=k*img_3d
    return(img)

def preprocess_slice(img,img_height):
    img=cv2.resize(img,(img_height,img_height))
    return img.astype('float16')

#function for post-processing predictions by taking means of five consecutive predictions
def postProcess_new(predictions,patientIndexes_test):
    predictions_1=[]
    for i in range(len(predictions)):
        current=predictions[i]
        if i>0:
            if patientIndexes_test[i-1]==patientIndexes_test[i]:
                prev=predictions[i-1]
            else:
                prev=0
            if i>1:
                if patientIndexes_test[i-2]==patientIndexes_test[i]:
                    prev1=predictions[i-2]
                else:
                    prev1=0
            else:
                    prev1=0
        else:
            prev=0
            prev1=0
        if i<len(predictions)-1:
            if patientIndexes_test[i+1]==patientIndexes_test[i]:
                next=predictions[i+1]
            else:
                next=0
            if i<len(predictions)-2:
                if patientIndexes_test[i+2]==patientIndexes_test[i]:
                    next1=predictions[i+2]
                else:
                    next1=0
            else:
                next1=0
        else:
            next=0
            next1=0
        predictions_1.append(np.mean([prev1,prev,current,next,next1]))
    predictions_1=np.array(predictions_1)
    return(predictions_1)

#a function for predicting the 3d bounding box
#inputs: 3d pet image and three classification models
#output: 3d binary mask of the 3d bounding box
def predict_3d_boundingbox(img,classifier_sag,classifier_cor,classifier_trans):
    mask_for_3dbb=np.zeros(img.shape)
    for view in ['sagittal','coronal','transaxial']:
        slices=[]
        if view=='sagittal':
            for j in range(img.shape[0]):
                slices.append(preprocess_slice(img[j,:,:],128))
        if view=='coronal':
            for j in range(img.shape[1]):
                slices.append(preprocess_slice(img[:,j,:],128))
        if view=='transaxial':
            for j in range(img.shape[2]):
                slices.append(preprocess_slice(img[:,:,j],128))
        slices=np.array(slices)
        if view=='sagittal':
            predicted_labels=classifier_sag.predict(slices,verbose=0)[:,0]
        if view=='coronal':
            predicted_labels=classifier_cor.predict(slices,verbose=0)[:,0]
        if view=='transaxial':
            predicted_labels=classifier_trans.predict(slices,verbose=0)[:,0]
        predicted_labels=postProcess_new(predicted_labels,[0]*len(predicted_labels))
        if view=='sagittal':
            for j in range(img.shape[0]):
                if predicted_labels[j]>=0.5:
                    mask_for_3dbb[j,:,:]+=1
        if view=='coronal':
            for j in range(img.shape[1]):
                if predicted_labels[j]>=0.5:
                    mask_for_3dbb[:,j,:]+=1
        if view=='transaxial':
            for j in range(img.shape[2]):
                if predicted_labels[j]>=0.5:
                    mask_for_3dbb[:,:,j]+=1
    mask_for_3dbb=np.array(mask_for_3dbb>2.5,dtype=int)
    return(mask_for_3dbb)

def chooseIndices(i_1):
    i_1s=np.array(range(len(i_1)))[i_1==1]
    if i_1s[-1]-i_1s[0]<64:
        return([max(0,round((i_1s[-1]+i_1s[0])/2-32))])
    elif i_1s[-1]-i_1s[0]<128:
        return([i_1s[0],i_1s[-1]-64])
    elif i_1s[-1]-i_1s[0]<192:
        return([i_1s[0],round((i_1s[-1]+i_1s[0])/2),i_1s[-1]-64])
    else:
        if len(i_1)==220:
            return([0,55,101,156])
        if len(i_1)==380:
            return([0,63,126,189,252,315])

#a function for predicting the segmentation mask
#inputs: 3d pet image, mask of the 3d bounding box, and three segmentation models
#output: 3d binary mask        
def predict_segmentation_mask(img,mask_for_3dbb,segmentator_sag,segmentator_cor,segmentator_trans):
    newMask_2_5d=np.zeros(img.shape)
    for view in ['coronal','sagittal','transaxial']:
        newMask=np.zeros(img.shape)
        denominators=np.zeros(img.shape)
        if view=='sagittal':
            for j in range(mask_for_3dbb.shape[0]):
                if np.max(mask_for_3dbb[j,:,:])==1:
                    for i1 in chooseIndices(np.max(mask_for_3dbb[j,:,:],axis=1)):
                        for j1 in chooseIndices(np.max(mask_for_3dbb[j,:,:],axis=0)):
                            if np.max(mask_for_3dbb[j,i1:(i1+64),j1:(j1+64)])==1:
                                slice=img[j,i1:(i1+64),j1:(j1+64)].astype('float16')
                                pred=segmentator_sag.predict(np.array([slice]),verbose=0)[0][:,:,0]
                                newMask[j,i1:(i1+64),j1:(j1+64)]+=pred
                                denominators[j,i1:(i1+64),j1:(j1+64)]+=1
        if view=='coronal':
            for j in range(mask_for_3dbb.shape[1]):
                if np.max(mask_for_3dbb[:,j,:])==1:
                    for i1 in chooseIndices(np.max(mask_for_3dbb[:,j,:],axis=1)):
                        for j1 in chooseIndices(np.max(mask_for_3dbb[:,j,:],axis=0)):
                            if np.max(mask_for_3dbb[i1:(i1+64),j,j1:(j1+64)])==1:
                                slice=img[i1:(i1+64),j,j1:(j1+64)].astype('float16')
                                pred=segmentator_cor.predict(np.array([slice]),verbose=0)[0][:,:,0]
                                newMask[i1:(i1+64),j,j1:(j1+64)]+=pred
                                denominators[i1:(i1+64),j,j1:(j1+64)]+=1
        if view=='transaxial':
            for j in range(mask_for_3dbb.shape[2]):
                if np.max(mask_for_3dbb[:,:,j])==1:
                    for i1 in chooseIndices(np.max(mask_for_3dbb[:,:,j],axis=1)):
                        for j1 in chooseIndices(np.max(mask_for_3dbb[:,:,j],axis=0)):
                            if np.max(mask_for_3dbb[i1:(i1+64),j1:(j1+64),j])==1:
                                slice=img[i1:(i1+64),j1:(j1+64),j].astype('float16')
                                pred=segmentator_trans.predict(np.array([slice]),verbose=0)[0][:,:,0]
                                newMask[i1:(i1+64),j1:(j1+64),j]+=pred
                                denominators[i1:(i1+64),j1:(j1+64),j]+=1
        denominators[denominators==0]=1
        newMask=newMask/denominators
        newMask=np.array(newMask>0.5,dtype=int)
        newMask_2_5d+=newMask
    newMask_2_5d=np.array(newMask_2_5d>1.5,dtype=int)
    return(newMask_2_5d)

#functions for visualisation

def get_all_edges(bool_img):
    """
    Get a list of all edges (where the value changes from True to False) in the 2D boolean image.
    The returned array edges has he dimension (n, 2, 2).
    Edge i connects the pixels edges[i, 0, :] and edges[i, 1, :].
    Note that the indices of a pixel also denote the coordinates of its lower left corner.
    """
    edges = []
    ii, jj = np.nonzero(bool_img)
    for i, j in zip(ii, jj):
        # North
        if j == bool_img.shape[1]-1 or not bool_img[i, j+1]:
            edges.append(np.array([[i, j+1],
                                   [i+1, j+1]]))
        # East
        if i == bool_img.shape[0]-1 or not bool_img[i+1, j]:
            edges.append(np.array([[i+1, j],
                                   [i+1, j+1]]))
        # South
        if j == 0 or not bool_img[i, j-1]:
            edges.append(np.array([[i, j],
                                   [i+1, j]]))
        # West
        if i == 0 or not bool_img[i-1, j]:
            edges.append(np.array([[i, j],
                                   [i, j+1]]))

    if not edges:
        return np.zeros((0, 2, 2))
    else:
        return np.array(edges)


def close_loop_edges(edges):
    """
    Combine the edges defined by 'get_all_edges' to closed loops around objects.
    If there are multiple disconnected objects a list of closed loops is returned.
    Note that it's expected that all the edges are part of exactly one loop (but not necessarily the same one).
    """

    loop_list = []
    while edges.size != 0:

        loop = [edges[0, 0], edges[0, 1]]  # Start with first edge
        edges = np.delete(edges, 0, axis=0)

        while edges.size != 0:
            # Get next edge (=edge with common node)
            ij = np.nonzero((edges == loop[-1]).all(axis=2))
            if ij[0].size > 0:
                i = ij[0][0]
                j = ij[1][0]
            else:
                loop.append(loop[0])
                # Uncomment to to make the start of the loop invisible when plotting
                # loop.append(loop[1])
                break

            loop.append(edges[i, (j + 1) % 2, :])
            edges = np.delete(edges, i, axis=0)

        loop_list.append(np.array(loop))

    return loop_list


def plot_outlines(bool_img, ax=None, **kwargs):
    if ax is None:
        ax = plt.gca()
    edges = get_all_edges(bool_img=bool_img)
    edges = edges - 0.5  # convert indices to coordinates; TODO adjust according to image extent
    outlines = close_loop_edges(edges=edges)
    cl = LineCollection(outlines, **kwargs)
    ax.add_collection(cl)

In [None]:
#load an image by defining the path to the correct dicom folder
path_to_dcm_folder='D:/img/koveri/Data/koveri0001/PET_DYN_220_Rest_cardiac/1.2.246.10.8282559.10.102300.1.2.118819549558574.716470114599'
img=read_pet_image(path_to_dcm_folder)

#pre-process the image by forcing the values over 50000 to be 50000 and then mapping the interval [0, 50000] onto [0,1]
img[img>50000]=50000
img=img/50000

In [36]:
#load and compile models for an organ of interest
organ='brain'
classifier_sag=keras.models.load_model('classifier_{}_{}.keras'.format(organ,'sagittal'))
classifier_sag.compile(optimizer='adam',loss='BinaryCrossentropy')
classifier_cor=keras.models.load_model('classifier_{}_{}.keras'.format(organ,'coronal'))
classifier_cor.compile(optimizer='adam',loss='BinaryCrossentropy')
classifier_trans=keras.models.load_model('classifier_{}_{}.keras'.format(organ,'transaxial'))
classifier_trans.compile(optimizer='adam',loss='BinaryCrossentropy')
segmentator_sag=keras.models.load_model('segmentator_{}_{}.keras'.format(organ,'sagittal'))
segmentator_sag.compile(optimizer='adam',loss='Dice')
segmentator_cor=keras.models.load_model('segmentator_{}_{}.keras'.format(organ,'coronal'))
segmentator_cor.compile(optimizer='adam',loss='Dice')
segmentator_trans=keras.models.load_model('segmentator_{}_{}.keras'.format(organ,'transaxial'))
segmentator_trans.compile(optimizer='adam',loss='Dice')

In [37]:
#run the models
#step 1: the slice-wise classification to produce a 3d bounding box
mask_for_3dbb=predict_3d_boundingbox(img,classifier_sag,classifier_cor,classifier_trans)
#step 2: the multi-view segmentation with the help of the 3d bounding box
predicted_mask=predict_segmentation_mask(img,mask_for_3dbb,segmentator_sag,segmentator_cor,segmentator_trans)



In [None]:
#check how the mask looks like
plt.imshow(np.transpose(np.mean(predicted_mask,axis=0)),cmap='gray')
plt.show()
plt.imshow(np.transpose(np.mean(predicted_mask,axis=1)),cmap='gray')
plt.show()
plt.imshow(np.transpose(np.mean(predicted_mask,axis=2)),cmap='gray')
plt.show()

In [None]:
#plot the outlines of the predicted mask and the 3d bounding box onto the pet image
img_sag=np.transpose(np.mean(img,axis=0))
bb_sag=np.transpose(np.max(mask_for_3dbb,axis=0))
mask_sag=np.transpose(np.max(predicted_mask,axis=0))
img_sag=cv2.resize(np.array(img_sag),(220,round(2.80/1.65*380)))
bb_sag=cv2.resize(np.array(bb_sag,dtype='uint8'),(220,round(2.80/1.65*380)))
mask_sag=cv2.resize(np.array(mask_sag,dtype='uint8'),(220,round(2.80/1.65*380)))
plt.imshow(img_sag,cmap='gray')
plot_outlines(bb_sag.T,lw=1,color='white')
plot_outlines(mask_sag.T,lw=1,color='white')
plt.show()
img_cor=np.transpose(np.mean(img,axis=1))
bb_cor=np.transpose(np.max(mask_for_3dbb,axis=1))
mask_cor=np.transpose(np.max(predicted_mask,axis=1))
img_cor=cv2.resize(img_cor,(220,round(2.80/1.65*380)))
bb_cor=cv2.resize(np.array(bb_cor,dtype='uint8'),(220,round(2.80/1.65*380)))
mask_cor=cv2.resize(np.array(mask_cor,dtype='uint8'),(220,round(2.80/1.65*380)))
plt.imshow(img_cor,cmap='gray')
plot_outlines(bb_cor.T,lw=1,color='white')
plot_outlines(mask_cor.T,lw=1,color='white')
plt.show()
img_trans=np.transpose(np.mean(img,axis=2))
bb_trans=np.transpose(np.max(mask_for_3dbb,axis=2))
mask_trans=np.transpose(np.max(predicted_mask,axis=2))
plt.imshow(img_trans,cmap='gray')
plot_outlines(bb_trans.T,lw=1,color='white')
plot_outlines(mask_trans.T,lw=1,color='white')
plt.show()