In [1]:
import os
import sys

current_directory = os.getcwd()
path = os.path.dirname(current_directory)
sys.path.append(path)
from Utils import show_image

import glob
import json
import SimpleITK as sitk 
import numpy as np
import math

from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd, Orientationd, Spacingd
from sitkIO import PushSitkImage

%matplotlib widget


In [2]:
def generate_json_list(data_dir):
    # Get dataset list
    dataset_list_filename = os.path.join(data_dir, 'imagelist_test_images.json')
    with open(dataset_list_filename) as file_images:
        list_images = json.load(file_images)
    cols = list_images['columns']
    # Obtain a numpy list of images formatted as col:
    dataset_image_list = np.array(list_images['images'], dtype=object)
    # image_tip = dataset_image_list[:,cols.index('tip')]
    # image_base = dataset_image_list[:,cols.index('base')]
    return (dataset_image_list, cols)

def get_physical_coordinates(name, dataset_image_list, cols):
    image_filename = dataset_image_list[:,cols.index('filename')]
    iminfo_needle_label = dataset_image_list[(image_filename==name),:]
    image_tip = iminfo_needle_label[0, cols.index('tip')]
    image_base = iminfo_needle_label[0, cols.index('base')]
    return (np.array(image_tip), np.array(image_base))

# Returns the unit vector of the vector
def unit_vector(vector):
    return vector / np.linalg.norm(vector)

# Calculate the projection of a vector onto a plane (n = normal vector)
def vector_projection_onto_plane(v, n):
    # Calculate the dot product of v and n
    dot_product = np.dot(v, n)
    # Calculate the projection of v onto the plane
    projection = v - dot_product * n
    return unit_vector(projection)

# Returns the signed angle between two vectors
# Source: https://stackoverflow.com/a/70789545/19347752
# Based in: https://people.eecs.berkeley.edu/%7Ewkahan/MathH110/Cross.pdf (page 15)
def angle_between_vectors(v1, v2):
    v1_u = unit_vector(v1)
    v2_u = unit_vector(v2)
    y = v1_u - v2_u
    x = v1_u + v2_u
    a0 = 2 * np.arctan(np.linalg.norm(y) / np.linalg.norm(x))
    if (not np.signbit(a0)) or np.signbit(np.pi - a0):
        return np.rad2deg(a0)
    elif np.signbit(a0):
        return 0.0
    else:
        return np.rad2deg(np.pi)
    
# Calculate the euclidean distance between two tensors = torch(D, H, W)
# Where D = depth, H = height and W = width
def euclidean_distance_3d(X, Y):
    # Compute squared differences
    squared_diff = (X - Y) ** 2
    # Sum along dimension
    sum_squared_diff = squared_diff.sum(dim=0)
    # Take the square root
    distance = np.sqrt(sum_squared_diff)
    return distance

def get_direction(sitk_output, label_value):
    stats = sitk.LabelShapeStatisticsImageFilter()    
    # Separate labels
    sitk_label = (sitk_output==int(label_value))
    sitk_shaft = sitk.ConnectedComponent(sitk_label)
    stats.Execute(sitk_shaft)
    # Select centroid from segmentation
    if sitk.GetArrayFromImage(sitk_label).sum() > 0:
        # Get labels from segmentation
        stats.SetComputeOrientedBoundingBox(True)
        stats.Execute(sitk.ConnectedComponent(sitk_shaft))
        # Get labels sizes and centroid physical coordinates
        labels = stats.GetLabels()
        labels_size = []
        labels_obb_dir = []
        labels_obb_size = []
        for l in stats.GetLabels():
            number_pixels = stats.GetNumberOfPixels(l)
            labels_size.append(number_pixels)
            labels_obb_dir.append(stats.GetOrientedBoundingBoxDirection(l))
            labels_obb_size.append(stats.GetOrientedBoundingBoxSize(l))       
        # Get the main insertion axis from the bounding box
        index_largest = labels_size.index(max(labels_size)) # Find index of largest centroid
        obb_size = labels_obb_size[index_largest]
        i_axis = obb_size.index(max(obb_size))
        obb_vec = unit_vector(labels_obb_dir[index_largest][3*i_axis:(3*i_axis+3)]) # Choose the vector of the longer axis
        obb_vec = obb_vec*math.copysign(1, obb_vec[2]) # Always choose dir that is positive in the direction of S
        return (unit_vector(obb_vec), obb_size[i_axis])
                
        # # Get largest segmentation as the shaft
        # index_largest = labels_size.index(max(labels_size)) # Find index of largest centroid
        # label = labels[index_largest]
        # # Create a binary mask for the specific label
        # sitk_selected_shaft = sitk.BinaryThreshold(sitk_shaft, lowerThreshold=label, upperThreshold=label)
        # # Create a skeleton
        # sitk_skeleton = sitk.BinaryThinning(sitk_selected_shaft)
        # (tip, base) = getShaftCoordinates(sitk_skeleton)    
        # dir_vec = unit_vector(np.array(tip) - np.array(base))
        # return (dir_vec, obb_vec)
    else:
        return None

# Given a binary skeleton image (single pixel-wide), return the physical coordinates of the tip and base
# extremity closer to the image center 
def getShaftCoordinates(sitk_line):
    # Get the coordinates of all non-zero pixels in the binary image
    nonzero_coords = np.argwhere(sitk.GetArrayFromImage(sitk_line) == 1)
    # Calculate the distance of each non-zero pixel to all others
    distances = np.linalg.norm(nonzero_coords[:, None, :] - nonzero_coords[None, :, :], axis=-1)
    # Find the two points with the maximum distance; these are the extremity points
    extremity_indices = np.unravel_index(np.argmax(distances), distances.shape)
    extremitys_numpy = [nonzero_coords[index] for index in extremity_indices]
    # Conver to sitk array order
    extremity1 = (int(extremitys_numpy[0][2]), int(extremitys_numpy[0][1]), int(extremitys_numpy[0][0]))
    extremity2 = (int(extremitys_numpy[1][2]), int(extremitys_numpy[1][1]), int(extremitys_numpy[1][0]))
    # Calculate the center coordinates of the image volume
    image_shape = sitk_line.GetSize()
    center_coordinates = np.array(image_shape) / 2.0
    # Calculate the distances from each extremity point to the center
    distance1 = np.linalg.norm(np.array(extremity1) - center_coordinates)
    distance2 = np.linalg.norm(np.array(extremity2) - center_coordinates)
    # Determine which extremity is closer to the center and return physical coordinates
    if distance1 < distance2:
        tip = sitk_line.TransformIndexToPhysicalPoint(extremity1)
        base = sitk_line.TransformIndexToPhysicalPoint(extremity2)
    else:
        tip = sitk_line.TransformIndexToPhysicalPoint(extremity2)
        base = sitk_line.TransformIndexToPhysicalPoint(extremity1)
    return (tip, base)
    
# Get the centroid coordinates in RAS coordinates
def get_centroid(sitk_output, label_value):
    # Separate labels
    sitk_label = (sitk_output==int(label_value))
    # Select centroid from segmentation
    if sitk.GetArrayFromImage(sitk_label).sum() > 0:
        # Get labels from segmentation
        stats = sitk.LabelShapeStatisticsImageFilter()
        stats.Execute(sitk.ConnectedComponent(sitk_label))
        # Get labels sizes and centroid physical coordinates
        labels_size = []
        labels_centroid = []
        for l in stats.GetLabels():
            number_pixels = stats.GetNumberOfPixels(l)
            centroid = stats.GetCentroid(l)
            labels_size.append(number_pixels)
            labels_centroid.append(centroid)    
        # Get tip estimate position
        index_largest = labels_size.index(max(labels_size)) # Find index of largest centroid
        # print('Selected tip = %s' %str(index_largest+1))
        # print('Tip: -> Size: %s, Center: %s' %(labels_size[index_largest] , labels_centroid[index_largest] ))
        centerLPS = labels_centroid[index_largest]             # Get the largest centroid center
        return np.array([centerLPS[0], centerLPS[1], centerLPS[2]]) 
        ## Convert to 3D Slicer coordinates (RAS)
        # centerRAS = torch.tensor([-centerLPS[0], -centerLPS[1], centerLPS[2]])   
    else:
        return None


In [3]:
data_dir = os.getcwd()
prefix = 'test'
caseNumber = 713
label_path = os.path.join(data_dir, 'test_labels')
file_prefix = 'SyntheticImage_'
label_type1 = 'multi_label'
label_type2 = 'M_seg'

# Load Json data
(image_list, cols_list) = generate_json_list(label_path)

# Generate dictionary
# labels = sorted(glob.glob(os.path.join(label_prefix, "*_"+label_type+"_label.nii.gz"))) # Load filenames
# labels = sorted(glob.glob(os.path.join(label_prefix, file_prefix+ "*_"+label_type+"_label.nrrd"))) # Load filenames

label1 = glob.glob(os.path.join(label_path, file_prefix + str(caseNumber).zfill(3) + "_" + label_type1 + '.nii.gz')) # Load filenames
label2 = glob.glob(os.path.join(label_path, file_prefix + str(caseNumber).zfill(3) + "_" + label_type2 + '.nii.gz')) # Load filenames
labels = [label1, label2]

test_files = [
    {'label':label_name}
    for label_name in labels
]

# Define Transforms
load_array = [
        LoadImaged(keys=['label'], image_only=False),
        EnsureChannelFirstd(keys=['label'], channel_dim='no_channel'),
        Orientationd(keys=['label'], axcodes='LIP'),        
]

loadTest = Compose(load_array)
original = loadTest(test_files)
sitkTransform = PushSitkImage(resample=False, output_dtype=np.float32, print_log=False)

N = len(original)
tip_label=[]
dir_3d_label=[]
dir_cor_label=[]
dir_sag_label=[]

n_COR = np.array([0,1,0]) # P 
n_SAG = np.array([1,0,0]) # L
for i in range(N):
    original_dict = original[i]
    label = original_dict['label']
    filename = original_dict['label_meta_dict']['filename_or_obj'].removeprefix(label_path+'/') 
        
    # Get tip coordinates 
    sitk_label = sitkTransform(label)
    show_image(sitk_label, title= 'Label image')
    tip_label.append(get_centroid(sitk_label, 2))
      
    print('Tip coordinates - LABEL '+str(i))
    print(tip_label[i])

    # Get direction vector
    (dir_3d, size) = get_direction(sitk_label, 1)
    dir_cor_label.append(vector_projection_onto_plane(dir_3d, n_COR))
    dir_sag_label.append(vector_projection_onto_plane(dir_3d, n_SAG))

    dir_3d_label.append(dir_3d)
    print('Needle direction (3D/COR/SAG) - LABEL '+str(i))
    print(dir_3d_label[i])
    print(dir_cor_label[i])
    print(dir_sag_label[i])


# (tip_real, base_real) = get_physical_coordinates(filename, image_list, cols_list)   
# print('Tip coordinates - REAL')
# print(tip_real)

# v_real = unit_vector(tip_real-base_real)
# print('Needle direction - REAL')
# print(v_real)

err_angle_3d = angle_between_vectors(dir_3d_label[0], dir_3d_label[1])
err_angle_cor = angle_between_vectors(dir_cor_label[0], dir_cor_label[1])
err_angle_sag = angle_between_vectors(dir_sag_label[0], dir_sag_label[1])
print('err angle - 3D/COR/SAG')
print(err_angle_3d)
print(err_angle_cor)
print(err_angle_sag)

#--------------------------------------------------------------------------------
# Load json file with physical real positions
#--------------------------------------------------------------------------------

# 
# (image_list, cols_list) = generate_json_list(filename_prefix)
# label_prefix = data_dir.removeprefix('./')+'/test_labels/'
# print(label_prefix)

          
                
                

interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

Tip coordinates - LABEL 0
[ 98.4375      18.39999962 -26.953125  ]
proj
[0.12257973 0.         0.99245867]
proj
[0.         0.         0.99245867]
Needle direction (3D/COR/SAG) - LABEL 0
[0.12257973 0.         0.99245867]
[0.12257973 0.         0.99245867]
[0. 0. 1.]


interactive(children=(IntSlider(value=1, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

Tip coordinates - LABEL 1
[ 98.61111111  18.39999962 -27.734375  ]
proj
[0.70710678 0.         0.70710678]
proj
[0.         0.         0.70710678]
Needle direction (3D/COR/SAG) - LABEL 1
[0.70710678 0.         0.70710678]
[0.70710678 0.         0.70710678]
[0. 0. 1.]
err angle - 3D/COR/SAG
37.95899050156752
37.95899050156752
0.0
