import os import numpy as np from copy import copy import sys from spinalcordtoolbox import utils as sct from spinalcordtoolbox.types import Centerline from spinalcordtoolbox.centerline.core import ParamCenterline from spinalcordtoolbox.centerline.core import get_centerline from spinalcordtoolbox.image import Image # main # ======================================================================================================================= def main(last_disc): """ Pipeline for data processing. """ # filenames fname_image = 'data/sub-01/anat/sub-01_T2w.nii.gz' fname_image_seg = 'output/data_processed/sub-01/anat/sub-01_T2w_seg.nii.gz' fname_image_discs = 'output/data_processed/sub-01/anat/sub-01_T2w_seg_labeled_discs.nii.gz' # obtaining centerline im_seg = Image(fname_image_seg).change_orientation('RPI') param_centerline = ParamCenterline(algo_fitting = 'linear', smooth = 50, degree = None, minmax = None) # extracting intervertebral discs im_discs = Image(fname_image_discs).change_orientation('RPI') coord = im_discs.getNonZeroCoordinates(sorting = 'z', reverse_coord = True) coord_physical = [] for c in coord: if c.value <= last_disc or c.value in [48, 49, 50, 51, 52]: c_p = list(im_discs.transfo_pix2phys([[c.x, c.y, c.z]])[0]) c_p.append(c.value) coord_physical.append(c_p) # extracting centerline _, arr_ctl, arr_ctl_der, _ = get_centerline(im_seg, param = param_centerline, space = 'phys') centerline = Centerline(points_x = arr_ctl[0], points_y = arr_ctl[1], points_z = arr_ctl[2], deriv_x = arr_ctl_der[0], deriv_y = arr_ctl_der[1], deriv_z = arr_ctl_der[2]) centerline.compute_vertebral_distribution(coord_physical) list_centerline = [] list_centerline.append(centerline) # computing average template centerline and vertebral distribution points_average_centerline, position_template_discs = average_centerline(list_centerline, last_disc) # generating the initial template space generate_initial_template_space(points_average_centerline = points_average_centerline, position_template_discs = position_template_discs, last_disc = last_disc) # straighten centerline os.system('sct_straighten_spinalcord' + ' -i ' + fname_image + ' -s ' + fname_image_seg + ' -dest template-last_disc_' + str(last_disc) +'/template_label-centerline.nii.gz' + ' -ldisc-input ' + fname_image_discs + ' -ldisc-dest template-last_disc_' + str(last_disc) +'/template_labels-disc.nii.gz' + ' -ofolder ' + 'output-last_disc_' + str(last_disc) + ' -disable-straight2curved' + ' -param threshold_distance=1') def average_centerline(list_centerline, last_disc, use_label_ref = None): """ This function compute the average centerline and vertebral distribution, that will be used to create the final template space. :param list_centerline: list of Centerline objects, for all subjects :param dataset_info: dictionary containing dataset information :param lowest_disc: integer value containing the lowest disc until which the template will go :return: points_average_centerline: list of points (x, y, z) of the average spinal cord and brainstem centerline position_template_discs: index of intervertebral discs along the template centerline """ list_dist_discs = [] for centerline in list_centerline: list_dist_discs.append(centerline.distance_from_C1label) # generating custom list of average vertebral lengths new_vert_length = {} for dist_discs in list_dist_discs: for i, disc_label in enumerate(dist_discs): if (i + 1) <= last_disc: if disc_label == 'PMJ': length = abs(dist_discs[disc_label] - dist_discs['PMG']) elif disc_label == 'PMG': length = abs(dist_discs[disc_label] - dist_discs['C1']) else: index_current_label = Centerline.potential_list_labels.index(Centerline.labels_regions[disc_label]) next_label = Centerline.regions_labels[Centerline.potential_list_labels[index_current_label + 1]] if next_label in dist_discs: length = abs(dist_discs[disc_label] - dist_discs[next_label]) if disc_label in new_vert_length: new_vert_length[disc_label].append(length) else: new_vert_length[disc_label] = [length] average_vert_length = {} for disc_label in new_vert_length: average_vert_length[disc_label] = np.mean(new_vert_length[disc_label]) # computing length of each vertebral level length_vertebral_levels = {} for dist_discs in list_dist_discs: for disc_label in new_vert_length: if disc_label in dist_discs: if disc_label == 'PMJ': length = abs(dist_discs[disc_label] - dist_discs['PMG']) elif disc_label == 'PMG': length = abs(dist_discs[disc_label] - dist_discs['C1']) else: index_current_label = Centerline.potential_list_labels.index(Centerline.labels_regions[disc_label]) next_label = Centerline.regions_labels[Centerline.potential_list_labels[index_current_label + 1]] if next_label in dist_discs: length = abs(dist_discs[disc_label] - dist_discs[next_label]) else: length = average_vert_length[disc_label] else: length = average_vert_length[disc_label] if disc_label in length_vertebral_levels: length_vertebral_levels[disc_label].append(length) else: length_vertebral_levels[disc_label] = [length] # averaging the length of vertebral levels average_length = {} for disc_label in length_vertebral_levels: mean = np.mean(length_vertebral_levels[disc_label]) std = np.std(length_vertebral_levels[disc_label]) average_length[disc_label] = [disc_label, mean, std] # computing distances of discs from C1, based on average length distances_discs_from_C1 = {'C1': 0.0} if 'PMG' in average_length: distances_discs_from_C1['PMG'] = -average_length['PMG'][1] if 'PMJ' in average_length: distances_discs_from_C1['PMJ'] = -average_length['PMG'][1] - average_length['PMJ'][1] for disc_number in range(last_disc + 1): #Centerline.potential_list_labels: if disc_number not in [0, 1, 48, 50]: #and Centerline.regions_labels[disc_number] in average_length: distances_discs_from_C1[Centerline.regions_labels[disc_number]] = distances_discs_from_C1[Centerline.regions_labels[disc_number - 1]] + average_length[Centerline.regions_labels[disc_number - 1]][1] # calculating discs average distances from C1 average_distances = [] for disc_label in distances_discs_from_C1: mean = np.mean(distances_discs_from_C1[disc_label]) std = np.std(distances_discs_from_C1[disc_label]) average_distances.append([disc_label, mean, std]) # averaging distances for all subjects and calculating relative positions average_distances = sorted(average_distances, key = lambda x: x[1], reverse = False) number_of_points_between_levels = 100 disc_average_coordinates = {} points_average_centerline = [] label_points = [] average_positions_from_C1 = {} disc_position_in_centerline = {} # iterate over each disc level for i in range(len(average_distances)): disc_label = average_distances[i][0] average_positions_from_C1[disc_label] = average_distances[i][1] for j in range(number_of_points_between_levels): relative_position = float(j) / float(number_of_points_between_levels) if disc_label in ['PMJ', 'PMG']: relative_position = 1.0 - relative_position list_coordinates = [[]] * len(list_centerline) for k, centerline in enumerate(list_centerline): list_coordinates[k] = centerline.get_closest_to_relative_position(disc_label, relative_position) # average all coordinates get_avg = [] for item in list_coordinates: if item != None: get_avg.append(item) average_coord = np.array(get_avg).mean() # add it to averaged centerline list of points points_average_centerline.append(average_coord) label_points.append(disc_label) if j == 0: disc_average_coordinates[disc_label] = average_coord disc_position_in_centerline[disc_label] = i * number_of_points_between_levels # create final template space if use_label_ref is not None: label_ref = use_label_ref if label_ref not in length_vertebral_levels: raise Exception('ERROR: the reference label passed in argument ' + label_ref + ' should be present in the images.') else: if 'PMG' in length_vertebral_levels: label_ref = 'PMG' elif 'C1' in length_vertebral_levels: label_ref = 'C1' else: raise Exception('ERROR: the images should always have C1 label.') position_template_discs = {} coord_ref = np.array([0.0, 0.0, 0.0]) for disc in average_positions_from_C1: coord_disc = coord_ref.copy() coord_disc[2] -= average_positions_from_C1[disc] - average_positions_from_C1[label_ref] position_template_discs[disc] = coord_disc index_straight = 0 points_average_centerline_template = [] for i in range(0, len(points_average_centerline)): current_label = label_points[i] if current_label in average_length: length_current_label = average_length[current_label][1] relative_position_from_disc = float(i - disc_position_in_centerline[current_label]) / float(number_of_points_between_levels) temp_point = np.copy(coord_ref) if i >= index_straight: index_current_label = Centerline.potential_list_labels.index(Centerline.labels_regions[current_label]) next_label = Centerline.regions_labels[Centerline.potential_list_labels[index_current_label + 1]] if next_label not in average_positions_from_C1: temp_point[2] = coord_ref[2] - average_positions_from_C1[current_label] - relative_position_from_disc * length_current_label else: temp_point[2] = coord_ref[2] - average_positions_from_C1[current_label] - abs(relative_position_from_disc * (average_positions_from_C1[current_label] - average_positions_from_C1[next_label])) points_average_centerline_template.append(temp_point) points_average_centerline = points_average_centerline_template return points_average_centerline, position_template_discs def generate_initial_template_space(points_average_centerline, position_template_discs, last_disc, algo_fitting = 'linear', smooth = 50, degree = None, minmax = None): """ This function generates the initial template space, on which all images will be registered. :param points_average_centerline: list of points (x, y, z) of the average spinal cord and brainstem centerline :param position_template_discs: index of intervertebral discs along the template centerline :return: NIFTI files in RPI orientation (template space, template centerline, template disc positions) & .npz file of template Centerline object """ # initializing variables path_template = 'template-last_disc_' + str(last_disc) + '/' if not os.path.exists(path_template): os.makedirs(path_template) x_size_of_template_space, y_size_of_template_space = 201, 201 spacing = 0.5 # creating template space size_template_z = int(abs(points_average_centerline[0][2] - points_average_centerline[-1][2]) / spacing) + 15 template_space = Image([x_size_of_template_space, y_size_of_template_space, size_template_z]) template_space.data = np.zeros((x_size_of_template_space, y_size_of_template_space, size_template_z)) template_space.hdr.set_data_dtype('float32') origin = [points_average_centerline[-1][0] + x_size_of_template_space * spacing / 2.0, points_average_centerline[-1][1] - y_size_of_template_space * spacing / 2.0, (points_average_centerline[-1][2] - spacing)] template_space.hdr.as_analyze_map()['dim'] = [3.0, x_size_of_template_space, y_size_of_template_space, size_template_z, 1.0, 1.0, 1.0, 1.0] template_space.hdr.as_analyze_map()['qoffset_x'] = origin[0] template_space.hdr.as_analyze_map()['qoffset_y'] = origin[1] template_space.hdr.as_analyze_map()['qoffset_z'] = origin[2] template_space.hdr.as_analyze_map()['srow_x'][-1] = origin[0] template_space.hdr.as_analyze_map()['srow_y'][-1] = origin[1] template_space.hdr.as_analyze_map()['srow_z'][-1] = origin[2] template_space.hdr.as_analyze_map()['srow_x'][0] = -spacing template_space.hdr.as_analyze_map()['srow_y'][1] = spacing template_space.hdr.as_analyze_map()['srow_z'][2] = spacing template_space.hdr.set_sform(template_space.hdr.get_sform()) template_space.hdr.set_qform(template_space.hdr.get_sform()) template_space.save(path_template + 'template_space.nii.gz', dtype = 'uint8') print(f'\nSaving template space in {template_space.orientation} orientation as {path_template}template_space.nii.gz\n') # generate template centerline as an image image_centerline = template_space.copy() for coord in points_average_centerline: coord_pix = image_centerline.transfo_phys2pix([coord])[0] if 0 <= coord_pix[0] < image_centerline.data.shape[0] and 0 <= coord_pix[1] < image_centerline.data.shape[1] and 0 <= coord_pix[2] < image_centerline.data.shape[2]: image_centerline.data[int(coord_pix[0]), int(coord_pix[1]), int(coord_pix[2])] = 1 image_centerline.save(path_template + 'template_label-centerline.nii.gz', dtype = 'float32') print(f'\nSaving template centerline in {template_space.orientation} orientation as {path_template}template_label-centerline.nii.gz\n') # generate template discs position coord_physical = [] image_discs = template_space.copy() for disc in position_template_discs: label = Centerline.labels_regions[disc] coord = position_template_discs[disc] coord_pix = image_discs.transfo_phys2pix([coord])[0] coord = coord.tolist() coord.append(label) coord_physical.append(coord) if 0 <= coord_pix[0] < image_discs.data.shape[0] and 0 <= coord_pix[1] < image_discs.data.shape[1] and 0 <= coord_pix[2] < image_discs.data.shape[2]: image_discs.data[int(coord_pix[0]), int(coord_pix[1]), int(coord_pix[2])] = label else: sct.printv(str(coord_pix)) sct.printv('ERROR: the disc label ' + str(disc) + ' is not in the template image.') image_discs.save(path_template + 'template_labels-disc.nii.gz', dtype = 'uint8') print(f'\nSaving disc positions in {image_discs.orientation} orientation as {path_template}template_labels-disc.nii.gz\n') # generate template centerline as a npz file param_centerline = ParamCenterline(algo_fitting = algo_fitting, smooth = smooth, degree = degree, minmax = minmax) # centerline params of original template centerline had options that you cannot just provide `get_centerline` with anymroe (algo_fitting = 'nurbs', nurbs_pts_number = 4000, all_slices = False, phys_coordinates = True, remove_outliers = True) _, arr_ctl, arr_ctl_der, _ = get_centerline(image_centerline, param = param_centerline, space = 'phys') centerline_template = Centerline(points_x = arr_ctl[0], points_y = arr_ctl[1], points_z = arr_ctl[2], deriv_x = arr_ctl_der[0], deriv_y = arr_ctl_der[1], deriv_z = arr_ctl_der[2]) centerline_template.compute_vertebral_distribution(coord_physical) centerline_template.save_centerline(fname_output = path_template + 'template_label-centerline') print(f'\nSaving template centerline as .npz file (saves all Centerline object information, not just coordinates) as {path_template}template_label-centerline.npz\n') # ======================================================================================================================= # Start program # ======================================================================================================================= if __name__ == "__main__": main(int(sys.argv[1]))