In [None]:
import os, glob
from pathlib import Path
import numpy as np
import nibabel as nb
from skimage.measure import regionprops, label
from matplotlib import pyplot as plt
import matplotlib.colors as colors
from matplotlib.lines import Line2D
import math
from skimage import filters

METHOD = 'nnunet' # reg, dl, manual, nnunet

# Paths - segmentation results
if METHOD=='reg':
    main_path = '/mnt/sda1/Repos/a-eye/Data/SHIP_dataset/'
    input_path = main_path + 'non_labeled_dataset_nifti_reg_2/'
    input_path_ima = main_path + 'non_labeled_dataset_nifti_cropped/'
    output_path = '/mnt/sda1/Repos/a-eye/Output/axial_length/atlas/'
elif METHOD=='dl':
    main_path = '/mnt/sda1/Repos/a-eye/a-eye_segmentation/3D_multilabel/experiment_0/'
    input_path = main_path + 'test_orig_reg-cropped_non-labeled/'
elif METHOD=='manual':
    main_path = '/mnt/sda1/Repos/a-eye/a-eye_preprocessing/ANTs/'
    input_path = main_path + 'a123/'
elif METHOD=='nnunet': 
    input_path = '/mnt/sda1/Repos/a-eye/a-eye_segmentation/deep_learning/nnUNet/nnUNet/nnUNet_inference/non_labeled_dataset_nifti_nnunet/'
    lab_path = '/mnt/sda1/Repos/a-eye/a-eye_segmentation/deep_learning/nnUNet/nnUNet/nnUNet_inference/no_postprocessing/'
    output_path = '/mnt/sda1/Repos/a-eye/Output/axial_length/nnunet/'

# List of volumes for each individual label and subject
len_path = len([elem for elem in os.listdir(input_path)])
name_subject = [None]*len_path
axial_length_total = np.zeros(len_path)
outliers_list_1 = list() # condition 1
outliers_list_2 = list() # condition 2
outliers_list_3 = list() # condition 3
outliers_list_4 = list() # condition 4
outliers_dict_2 = {}
outliers_dict_3 = {}
outliers_dict_4 = {}
list_results = sorted(os.listdir(input_path))
arr_diff = []

i=0
for folder1 in list_results:
    # if folder1 == '2022160100484': # uncomment for concrete subject
    # if folder1 == 'AEye_2022160102320_0000.nii.gz': # for nnunet

        # Subject's name
        if METHOD=='dl': name_subject[i] = folder1.split('_')[0]
        elif METHOD=='nnunet': name_subject[i] = folder1.split('_')[1]
        else: name_subject[i] = str(folder1) # reg, manual
        print(f'subject: {name_subject[i]}')

        # Load image in array format
        if METHOD=='reg':
            lab = nb.load(f'{input_path}{folder1}/labels.nii.gz')
            lab_arr = lab.get_fdata() # labels matrix
            voxel_size = lab.header.get_zooms() # voxel size
            ima = nb.load(f'{input_path_ima}{folder1}_cropped.nii.gz')
            ima_arr = ima.get_fdata() # image matrix
        elif METHOD=='dl':
            lab = nb.load(f'{input_path}{folder1}')
            lab_arr = lab.get_fdata() # labels matrix
            voxel_size = lab.header.get_zooms() # voxel size
        elif METHOD=='manual':
            lab = nb.load(f'{input_path}{folder1}/input/{folder1}_labels_cropped.nii.gz')
            lab_arr = lab.get_fdata() # labels matrix
            voxel_size = lab.header.get_zooms() # voxel size
        elif METHOD=='nnunet':
            lab = nb.load(f'{lab_path}AEye_{name_subject[i]}.nii.gz')
            lab_arr = lab.get_fdata() # labels matrix
            lab_arr = lab_arr[int(np.around(lab_arr.shape[0]/2)):lab_arr.shape[0], int(np.around(lab_arr.shape[1]/2)):lab_arr.shape[1], 0:int(np.around(lab_arr.shape[2]/2))]
            voxel_size = lab.header.get_zooms() # voxel size
            ima = nb.load(f'{input_path}{folder1}')
            ima_arr = ima.get_fdata() # image matrix
            ima_arr = ima_arr[int(np.around(ima_arr.shape[0]/2)):ima_arr.shape[0], int(np.around(ima_arr.shape[1]/2)):ima_arr.shape[1], 0:int(np.around(ima_arr.shape[2]/2))]
        
        # Label masks
        lens = (lab_arr==1)*1
        globe = (lab_arr==2)*1
        eyeball = (np.logical_or(lab_arr==1, lab_arr==2))*1 # *1 to convert it to int instead of bool
        nerve = (lab_arr==3)*1
        # Condition 1: there must be lens, nerve or globe segmentation
        if np.count_nonzero(lens)==0 or np.count_nonzero(nerve)==0 or np.count_nonzero(globe)==0: # condition 1
            axial_length = 0
            outliers_list_1.append(name_subject[i]) # outliers with no lens segmented
            outliers_list1_clean = list(dict.fromkeys(outliers_list_2)) # to remove duplicates from list
            continue
        else:
            # LENS CENTROID
            properties_lens = regionprops(lens) # properties
            com_lens = (np.rint(properties_lens[0].centroid)).astype(int) # centroid
            print(f'com_lens: {np.around(com_lens)}')

            # EYEBALL CENTROID
            properties_eyeball = regionprops(eyeball) # properties
            com_eyeball = (np.rint(properties_eyeball[0].centroid)).astype(int) # centroid
            print(f'com_eyeball: {np.around(com_eyeball)}')
            
            # OPTIC NERVE CENTROID
            properties_nerve = regionprops(nerve) # properties
            com_nerve = (np.rint(properties_nerve[0].centroid)).astype(int) # centroid
            print(f'com_nerve: {np.around(com_nerve)}')

            # Compute best centroid
            if com_lens[2] == com_nerve[2]:
                slices = np.array([com_lens[2]])
            else:
                if com_lens[2] < com_nerve[2]:
                    slices = np.arange(com_lens[2], com_nerve[2]+1, 1, dtype=int)
                else: 
                    slices = np.arange(com_lens[2], com_nerve[2]-1, -1, dtype=int)
            print(f'slices: {slices} \n')

            # Save variables per slice
            axial_length_slices = np.zeros(len(slices))
            lens_vox = np.zeros(len(slices))
            globe_vox = np.zeros(len(slices))
            nerve_vox = np.zeros(len(slices))
            int_fat_vox = np.zeros(len(slices))
            l2on_ratio = np.zeros(len(slices)) # lens to optic nerve ratio (voxels)

            # Compute axial length per slice
            p = -1
            for s in slices:
                p+=1
                lens_vox[p] = np.count_nonzero(lab_arr[:, :, s]==1)
                globe_vox[p] = np.count_nonzero(lab_arr[:, :, s]==2)
                nerve_vox[p] = np.count_nonzero(lab_arr[:, :, s]==3)
                int_fat_vox[p] = np.count_nonzero(lab_arr[:, :, s]==4)
                print(f'slice {s} -- voxels -- lens: {int(lens_vox[p])} | nerve: {int(nerve_vox[p])} | int fat: {int(int_fat_vox[p])} | globe: {int(globe_vox[p])}')

                # Label masks
                lens_2d = (lab_arr[:, :, s]==1)*1
                globe_2d = (lab_arr[:, :, s]==2)*1
                nerve_2d = (lab_arr[:, :, s]==3)*1
                int_fat_2d = (lab_arr[:, :, s]==4)*1
                lens_vox_2d = np.count_nonzero(lens_2d)
                globe_vox_2d = np.count_nonzero(globe_2d)
                nerve_vox_2d = np.count_nonzero(nerve_2d)
                int_fat_vox_2d = np.count_nonzero(int_fat_2d)
                # if METHOD=='nnunet':
                #     lens_2d_q = lens_2d[int(np.around(lens_2d.shape[0]/2)):lens_2d.shape[0], int(np.around(lens_2d.shape[1]/2)):lens_2d.shape[1]] # quadrant right eye
                #     globe_2d_q = globe_2d[int(np.around(globe_2d.shape[0]/2)):globe_2d.shape[0], int(np.around(globe_2d.shape[1]/2)):globe_2d.shape[1]] # quadrant right eye
                #     nerve_2d_q = nerve_2d[int(np.around(nerve_2d.shape[0]/2)):nerve_2d.shape[0], int(np.around(nerve_2d.shape[1]/2)):nerve_2d.shape[1]] # quadrant right eye
                #     int_fat_2d_q = int_fat_2d[int(np.around(int_fat_2d.shape[0]/2)):int_fat_2d.shape[0], int(np.around(int_fat_2d.shape[1]/2)):int_fat_2d.shape[1]] # quadrant right eye
                #     lens_vox_2d = np.count_nonzero(lens_2d_q)
                #     globe_vox_2d = np.count_nonzero(globe_2d_q)
                #     nerve_vox_2d = np.count_nonzero(nerve_2d_q)
                #     int_fat_vox_2d = np.count_nonzero(int_fat_2d_q)
                #     print(f'```````````right eye: lens: {lens_vox_2d} | nerve: {nerve_vox_2d} | int fat: {int_fat_vox_2d} | globe: {globe_vox_2d}')

                # Centroid of the lens in the 2D slice in axial view
                # Check first if there are voxels of all the structures below in the upper right quadrant (only for nnunet)
                if lens_vox_2d > 0 and nerve_vox_2d > 0 and int_fat_vox_2d > 0 and globe_vox_2d > 0:
                    properties_lens = regionprops(lens_2d)
                    com_lens = properties_lens[0].centroid # centroid
                    # if METHOD=='nnunet':
                    #     properties_lens = regionprops(lens_2d_q) # upper right quadrant
                    #     com_lens = properties_lens[0].centroid # centroid
                        # com_lens = [com_lens[0]+lens_2d.shape[0]/2, com_lens[1]+lens_2d.shape[1]/2] # centroid amplified to entire image
                    print(f'com_lens: {np.around(com_lens)}')
                    axis_minor_lens = properties_lens[0].axis_minor_length # minor axis
                    axis_major_lens = properties_lens[0].axis_major_length # major axis
                    print(f'axis minor lens = {axis_minor_lens} \naxis major lens = {axis_major_lens}')

                    # Get the 2nd point of the orthogonal in the lens
                    x0_lens = com_lens[0]
                    y0_lens = com_lens[1]
                    x1_lens = x0_lens
                    y1_lens = y0_lens - 1 # orthogonal to the globe (following y axis: orientation y=1, z=slice)

                    # Parametric equation of the line between (x0, y0) and (x1, y1)
                    n_points = int(np.ceil(math.dist([0,0,0], [lab_arr.shape[0], lab_arr.shape[1], lab_arr.shape[2]]))) # max number of points of a line in the image square
                    t = np.linspace(-int(n_points), int(n_points), n_points*10) # resolution of the line
                    line_x = (x0_lens - x1_lens)*t + x0_lens
                    line_y = (y0_lens - y1_lens)*t + y0_lens

                    # Line in image space (square)
                    line = np.zeros([lab_arr.shape[0], lab_arr.shape[1]])
                    for j in range(len(t)):
                        if 0<=round(line_x[j])<lab_arr.shape[0] and 0<=round(line_y[j])<lab_arr.shape[1] :
                            line[int(np.around(line_x[j])), int(np.around(line_y[j]))] = 1
                    print(f'Number of points of the line in the image space: {np.count_nonzero(line)}')

                    # Intersections and extreme points
                    # Lens
                    # inter_lens = np.logical_and(lens[:, :, int(np.around(com_eyeball[2]))], line)*1 # int format, intersection points in lens with line
                    inter_lens = np.logical_and(lens_2d, line)*1 # int format, intersection points in lens with line
                    print(f'Number of intersection points in lens: {np.count_nonzero(inter_lens)}')
                    # Condition 3: there must be intersection points between the line and the lens
                    if np.count_nonzero(inter_lens) > 0: # condition 3
                        inter_coord_lens = np.argwhere(inter_lens==1)
                        extreme_inter_lens = inter_coord_lens[np.argmax(inter_coord_lens[:,1])]
                        print(f'Lens extreme intersection point: {extreme_inter_lens}')
                        # Globe
                        # inter_globe = np.logical_and(globe[:, :, int(np.around(com_eyeball[2]))], line)*1 # int format, intersection points in globe with line
                        inter_globe = np.logical_and(globe_2d, line)*1 # int format, intersection points in globe with line
                        print(f'Number of instersection points in globe: {np.count_nonzero(inter_globe)}')
                        inter_coord_globe = np.argwhere(inter_globe==1)
                        extreme_inter_globe = inter_coord_globe[np.argmin(inter_coord_globe[:,1])]
                        print(f'Globe extreme intersection point: {extreme_inter_globe}')

                        # Line to compute the intensities gradient
                        sobel = True
                        if sobel:
                            edges = filters.sobel(ima_arr[:, :, s]) # Sobel
                            grad_arr = edges[extreme_inter_lens[0], extreme_inter_lens[1]:lab_arr.shape[1]-1] # Sobel
                        else:
                            grad_arr = ima_arr[extreme_inter_lens[0], extreme_inter_lens[1]:lab_arr.shape[1]-1, s]
                        print(f'grad_arr = {grad_arr}')
                        vox = 0
                        n_drops = 0
                        if len(grad_arr) >= 1:
                            for v in range(len(grad_arr)):
                                next_val = grad_arr[v+1] if v != len(grad_arr)-1 else 3000
                                diff = next_val - grad_arr[v]
                                if sobel:
                                    threshold = 153.57 # np.mean(arr_diff) for 100 cases
                                    if diff < threshold: # Sobel
                                        if n_drops > 0: # 2nd drop
                                            arr_diff.append(diff)
                                            break
                                        else:
                                            vox+=1
                                            n_drops+=1
                                    else:
                                        vox+=1
                                else:
                                    threshold = 179.92 # np.mean(arr_diff) for 100 cases
                                    if (diff > 0 and next_val >= 100) or (diff <= threshold and diff >= 0):
                                        vox += 1
                                    else:
                                        arr_diff.append(diff)
                                        break
                        print(f'Number of mm to add to the axial length (due to the cornea) = {vox}')

                        # Extra distance 2 (towards globe direction until reaching intraconal fat)
                        grad_arr2 = lab_arr[extreme_inter_globe[0], np.arange(extreme_inter_globe[1]-1, 0, -1), s]
                        print(f'grad_arr2 = {grad_arr2}')
                        vox2 = 0
                        if len(grad_arr2) > 0 and np.count_nonzero(grad_arr2 == 4) > 0 :
                            for w in range(len(grad_arr2)):
                                if grad_arr2[w] != 4:
                                    vox2 +=1
                                else:
                                    break
                        print(f'Number of mm to add to the axial length (due to intraconal fat) = {vox2}')

                        # Axial length
                        axial_length = (extreme_inter_lens[1] + voxel_size[1]/2) - (extreme_inter_globe[1] - voxel_size[1]/2) + vox + vox2 # distance between the center of the two points, we have to add +0.5mm for each point to be the actual extremes
                        print(f'Axial length = {axial_length} mm \n')

                        if axial_length < 30 and axial_length > 15: # condition 4
                            axial_length_total[i] = axial_length
                            # axial_length_slices[p] = axial_length # if best slice

                            # PLOT
                            plot_bool = True
                            if plot_bool:
                                k = 1 # aspect ratio
                                fig, ax = plt.subplots(1, 3, figsize=(16*k, 9*k))
                                fig.patch.set_facecolor('white')
                                fig.suptitle(f'Automatic axial length extraction. Subject: {name_subject[i]}. Slice: {s}. Axial length: {axial_length} mm')

                                # Legend
                                legend_elements = [Line2D([0], [0], color='y', lw=2, label=f'Axial length'),
                                    Line2D([], [], color='y', label='Extreme points', marker='+', markersize=5, linestyle='None'),
                                    Line2D([], [], color='b', label='Lens centroid', marker='+', markersize=5, linestyle='None'),
                                    Line2D([], [], color='c', label='Added distance', marker='+', markersize=5, linestyle='None')]
                                fig.legend(handles=legend_elements, loc='lower right')
                                fig.tight_layout()

                                # Lens centroid and line
                                ax[0].set_title('Original image')
                                ax[0].imshow(ima_arr[:, :, s].T, origin='lower', cmap='gist_gray', interpolation='none')
                                # Note the inverted coordinates because plt uses (x, y) while NumPy uses (row, column)
                                ax[0].plot(int(np.around(com_lens[0])), int(np.around(com_lens[1])), '+b', markersize=10)
                                ax[0].plot(extreme_inter_lens[0], extreme_inter_lens[1]+voxel_size[1]/2, '+y', markersize=10)
                                ax[0].plot(extreme_inter_globe[0], extreme_inter_globe[1]-voxel_size[1]/2, '+y', markersize=10)
                                ax[0].plot((extreme_inter_lens[0], extreme_inter_globe[0]), (extreme_inter_lens[1]+voxel_size[1]/2, extreme_inter_globe[1]-voxel_size[1]/2), '-y', linewidth=1)
                                # Extra distance
                                ax[0].plot(extreme_inter_lens[0], extreme_inter_lens[1]+voxel_size[1]/2+vox, '+c', markersize=10)
                                ax[0].plot((extreme_inter_lens[0], extreme_inter_lens[0]), (extreme_inter_lens[1]+voxel_size[1]/2+vox, extreme_inter_lens[1]), '-c', linewidth=1)
                                # Extra distance 2
                                ax[0].plot(extreme_inter_globe[0], extreme_inter_globe[1]-voxel_size[1]/2-vox2, '+c', markersize=10)
                                ax[0].plot((extreme_inter_globe[0], extreme_inter_globe[0]), (extreme_inter_globe[1]-voxel_size[1]/2-vox2, extreme_inter_globe[1]), '-c', linewidth=1)

                                # Eyeball
                                ax[1].set_title('Original image + labels')
                                ax[1].imshow(ima_arr[:, :, s].T, origin='lower', cmap='gist_gray', interpolation='none')
                                lens_mask = np.ma.masked_where(lens[:, :, s] == 0, lens[:, :, s])
                                globe_mask = np.ma.masked_where(globe[:, :, s] == 0, globe[:, :, s])
                                nerve_mask = np.ma.masked_where(nerve[:, :, s] == 0, nerve[:, :, s])
                                int_fat_mask = np.ma.masked_where(int_fat_2d[:, :] == 0, int_fat_2d[:, :])
                                palette_lens = colors.ListedColormap(['red'])
                                palette_globe = colors.ListedColormap(['lime'])
                                palette_nerve = colors.ListedColormap(['blue'])
                                palette_int_fat = colors.ListedColormap(['yellow'])
                                ax[1].imshow(lens_mask.T, origin='lower', interpolation='none', alpha=0.4, cmap=palette_lens)
                                ax[1].imshow(globe_mask.T, origin='lower', interpolation='none', alpha=0.4, cmap=palette_globe)
                                ax[1].imshow(nerve_mask.T, origin='lower', interpolation='none', alpha=0.4, cmap=palette_nerve)
                                ax[1].imshow(int_fat_mask.T, origin='lower', interpolation='none', alpha=0.4, cmap=palette_int_fat)
                                # Note the inverted coordinates because plt uses (x, y) while NumPy uses (row, column)
                                ax[1].plot(int(np.around(com_lens[0])), int(np.around(com_lens[1])), '+b', markersize=10)
                                ax[1].plot(extreme_inter_lens[0], extreme_inter_lens[1]+voxel_size[1]/2, '+y', markersize=10)
                                ax[1].plot(extreme_inter_globe[0], extreme_inter_globe[1]-voxel_size[1]/2, '+y', markersize=10)
                                ax[1].plot((extreme_inter_lens[0], extreme_inter_globe[0]), (extreme_inter_lens[1]+voxel_size[1]/2, extreme_inter_globe[1]-voxel_size[1]/2), '-y', linewidth=1)
                                # Extra distance
                                ax[1].plot(extreme_inter_lens[0], extreme_inter_lens[1]+voxel_size[1]/2+vox, '+c', markersize=10)
                                ax[1].plot((extreme_inter_lens[0], extreme_inter_lens[0]), (extreme_inter_lens[1]+voxel_size[1]/2+vox, extreme_inter_lens[1]), '-c', linewidth=1)
                                # Extra distance 2
                                ax[1].plot(extreme_inter_globe[0], extreme_inter_globe[1]-voxel_size[1]/2-vox2, '+c', markersize=10)
                                ax[1].plot((extreme_inter_globe[0], extreme_inter_globe[0]), (extreme_inter_globe[1]-voxel_size[1]/2-vox2, extreme_inter_globe[1]), '-c', linewidth=1)

                                # Sobel
                                edges = filters.sobel(ima_arr[:, :, s].T)
                                ax[2].set_title('Sobel filter')
                                ax[2].imshow(edges, origin='lower', cmap='gist_gray', interpolation='none')
                                lens_mask = np.ma.masked_where(lens[:, :, s] == 0, lens[:, :, s])
                                globe_mask = np.ma.masked_where(globe[:, :, s] == 0, globe[:, :, s])
                                nerve_mask = np.ma.masked_where(nerve[:, :, s] == 0, nerve[:, :, s])
                                palette_lens = colors.ListedColormap(['red'])
                                palette_globe = colors.ListedColormap(['lime'])
                                palette_nerve = colors.ListedColormap(['blue'])
                                # ax[2].imshow(lens_mask.T, origin='lower', interpolation='none', alpha=0.4, cmap=palette_lens)
                                # ax[2].imshow(globe_mask.T, origin='lower', interpolation='none', alpha=0.4, cmap=palette_globe)
                                # ax[2].imshow(nerve_mask.T, origin='lower', interpolation='none', alpha=0.4, cmap=palette_nerve)
                                # Note the inverted coordinates because plt uses (x, y) while NumPy uses (row, column)
                                # ax[2].plot(int(np.around(com_lens[0])), int(np.around(com_lens[1])), '+b', markersize=10)
                                # ax[2].plot(extreme_inter_lens[0], extreme_inter_lens[1]+voxel_size[1]/2, '+y', markersize=10)
                                # ax[2].plot(extreme_inter_globe[0], extreme_inter_globe[1]-voxel_size[1]/2, '+y', markersize=10)
                                ax[2].plot((extreme_inter_lens[0], extreme_inter_globe[0]), (extreme_inter_lens[1]+voxel_size[1]/2, extreme_inter_globe[1]-voxel_size[1]/2), '-y', linewidth=1)
                                # Extra distance
                                ax[2].plot(extreme_inter_lens[0], extreme_inter_lens[1]+voxel_size[1]/2+vox, '+c', markersize=10)
                                ax[2].plot((extreme_inter_lens[0], extreme_inter_lens[0]), (extreme_inter_lens[1]+voxel_size[1]/2+vox, extreme_inter_lens[1]), '-c', linewidth=1)
                                # Extra distance 2
                                ax[2].plot((extreme_inter_globe[0], extreme_inter_globe[0]), (extreme_inter_globe[1]-voxel_size[1]/2-vox2, extreme_inter_globe[1]), '-c', linewidth=1)
                                ax[2].plot(extreme_inter_globe[0], extreme_inter_globe[1]-voxel_size[1]/2-vox2, '+c', markersize=10)

                                plt.show

                                # plt.savefig(f'{output_path}examples/{name_subject[i]}_{s}.png')

                            # break # comment to extract all the slices in the range lens centroid - optic nerve centroid

                        else:
                            axial_length_slices[p] = 0
                            outliers_list_4.append(name_subject[i])
                            outliers_list_4_clean = list(dict.fromkeys(outliers_list_4)) # to remove duplicates from list
                            outliers_dict_4[str(name_subject[i])] = s

                    else:
                        axial_length_slices[p] = 0
                        outliers_list_3.append(name_subject[i])
                        outliers_list_3_clean = list(dict.fromkeys(outliers_list_3)) # to remove duplicates from list
                        outliers_dict_3[str(name_subject[i])] = s
                        # break

                else:
                    axial_length_slices[p] = 0
                    outliers_list_2.append(name_subject[i])
                    outliers_list2_clean = list(dict.fromkeys(outliers_list_2)) # to remove duplicates from list
                    outliers_dict_2[str(name_subject[i])] = s

            # Selecting best slice (not really good - better to allow clinician to select by him/herself)
            # for k in range(len(axial_length_slices)):
            #     l2on_ratio[k] = lens_vox[k] / nerve_vox[k] if lens_vox[k] < nerve_vox[k] else 0
            # best_slice = np.argmax(l2on_ratio)
            # print(f'--------- BEST SLICE: {slices[best_slice]} ---------')
            # axial_length_total[i] = axial_length_slices[best_slice]
            # print(f'\nFinal axial length: {axial_length_total[i]}mm')

            print(f'\n')

            i+=1
            if i==10:
                break

In [None]:
np.abs(np.mean(arr_diff))

In [None]:
print(np.count_nonzero(lab_arr))