In [1]:
from BIDS import BIDS_Global_info, POI, NII
import numpy as np
from BIDS.POI_plotter import visualize_pois
from scipy.ndimage import binary_erosion
import pandas as pd

In [2]:
bids_surgery_poi = BIDS_Global_info(['/home/daniel/Data/Implants/dataset-implants'], additional_key='seq')

[!] seq is not in list of legal keys. This name 'sub-005_ses-ANONH1JI2F1R6_seq-204_ct.nii.gz' is invalid. Legal keys are: ['sub', 'ses', 'sequ', 'acq', 'task', 'chunk', 'hemi', 'sample', 'ce', 'trc', 'stain', 'rec', 'proc', 'mod', 'recording', 'res', 'dir', 'echo', 'flip', 'inv', 'mt', 'part', 'space', 'seg', 'source', 'ovl', 'run', 'label', 'split', 'den', 'desc', 's', 'e', 'q']. 
For use see https://bids-specification.readthedocs.io/en/stable/99-appendices/09-entities.html


In [3]:
def get_poi(container) -> POI:
    poi_query = container.new_query(flatten=True)
    poi_query.filter_format('poi')
    poi_query.filter('desc', 'local')
    poi_candidate = poi_query.candidates[0]
    
    poi = poi_candidate.open_ctd()
    return poi

def get_ct(container):
    ct_query = container.new_query(flatten=True)
    ct_query.filter_format('ct')
    ct_query.filter_filetype('nii.gz') # only nifti files
    ct_candidate = ct_query.candidates[0]

    ct = ct_candidate.open_nii()
    return ct

def get_subreg(container):
    subreg_query = container.new_query(flatten=True)
    subreg_query.filter_format('msk')
    subreg_query.filter_filetype('nii.gz') # only nifti files
    subreg_query.filter('seg', 'subreg')
    subreg_candidate = subreg_query.candidates[0]
    
    subreg = subreg_candidate.open_nii()
    return subreg

def get_vertseg(container):
    vertseg_query = container.new_query(flatten=True)
    vertseg_query.filter_format('msk')
    vertseg_query.filter_filetype('nii.gz') # only nifti files
    vertseg_query.filter('seg', 'vert')
    vertseg_candidate = vertseg_query.candidates[0]

    vertseg = vertseg_candidate.open_nii()
    return vertseg

def get_files(container):
    return get_poi(container), get_ct(container), get_subreg(container), get_vertseg(container)

In [4]:
def get_bounding_box(mask, vert, margin=5):
        """
        Get the bounding box of a given vertebra in a mask.

        Args:
            mask (numpy.ndarray): The mask to search for the vertex.
            vert (int): The vertebra to search for in the mask.
            margin (int, optional): The margin to add to the bounding box. Defaults to 2.

        Returns:
            tuple: A tuple containing the minimum and maximum values for the x, y, and z axes of the bounding box.
        """
        indices = np.where(mask == vert)
        x_min = np.min(indices[0]) - margin
        x_max = np.max(indices[0]) + margin
        y_min = np.min(indices[1]) - margin
        y_max = np.max(indices[1]) + margin
        z_min = np.min(indices[2]) - margin
        z_max = np.max(indices[2]) + margin

        # Make sure the bounding box is within the mask
        x_min = max(0, x_min)
        x_max = min(mask.shape[0], x_max)
        y_min = max(0, y_min)
        y_max = min(mask.shape[1], y_max)
        z_min = max(0, z_min)
        z_max = min(mask.shape[2], z_max)

        return x_min, x_max, y_min, y_max, z_min, z_max

In [5]:
def make_cutouts(subject_name: str, poi: POI, ct: NII, subreg: NII, vertseg: NII) -> list[dict]:
    cutouts = []
    vertebrae = set([key[0] for key in poi.keys()])
    vertseg_arr = vertseg.get_array()
    for vert in vertebrae:
        x_min, x_max, y_min, y_max, z_min, z_max = get_bounding_box(vertseg_arr, vert)
        cutouts.append(
            {   
                'subject': subject_name,
                'vertebra': vert,
                'ct_nii': ct.apply_crop_slice(ex_slice = (slice(x_min, x_max), slice(y_min, y_max), slice(z_min, z_max))),
                'subreg_nii': subreg.apply_crop_slice(ex_slice = (slice(x_min, x_max), slice(y_min, y_max), slice(z_min, z_max))),
                'vertseg_nii': vertseg.apply_crop_slice(ex_slice = (slice(x_min, x_max), slice(y_min, y_max), slice(z_min, z_max))),
                'poi': poi.crop_centroids(o_shift = (slice(x_min, x_max), slice(y_min, y_max), slice(z_min, z_max)))
            }
        )
    return cutouts

In [6]:
for subject, container in bids_surgery_poi.enumerate_subjects():
    poi, ct, subreg, vertseg = get_files(container)
    subject = container.name
    cutouts = make_cutouts(subject, poi, ct, subreg, vertseg)

(47.8593738, -199.9874866, -86.0299643)
(56.0624989, -214.8312368, -102.0299488)
(59.968749, -224.987487, -118.0299333)
(67.7812491, -234.7531122, -140.8299112)
(74.0312493, -241.7843623, -168.8298841)
(74.8124993, -248.8156124, -192.8298609)
(72.8593742, -249.9874875, -214.0298403)
(72.0781242, -250.7687375, -232.0298229)
(63.8749991, -250.7687375, -252.0298035)
(56.0624989, -248.0343624, -288.0297687)
(65.0468741, -246.8624874, -320.8297369)
(65.8281241, -248.0343624, -334.0297241)
(39.9999991, -237.6156303, -195.9589257)
(39.2187491, -229.4125052, -227.958437)
(41.1718742, -222.381255, -260.757936)
(31.015624, -218.475005, -287.9575206)
(27.1093739, -216.5218799, -314.7571112)
(50.6093745, -271.2499996, 359.5)
(49.6835933, -277.113281, 334.5)
(47.523437, -281.1249998, 310.5)
(45.6718744, -282.3593748, 285.5)
(39.4999993, -279.2734373, 263.5)
(32.402343, -266.3124995, 237.5)
(29.6249992, -257.3632806, 214.5)
(38.5742181, -249.339843, 188.5)
(41.6601556, -244.0937492, 163.5)
(40.42578

In [10]:
cutout = cutouts[5]
cutout

{'subject': '007',
 'vertebra': 7,
 'ct_nii': <BIDS.nii_wrapper.NII at 0x7f6a53cba680>,
 'subreg_nii': <BIDS.nii_wrapper.NII at 0x7f6a53cba920>,
 'vertseg_nii': <BIDS.nii_wrapper.NII at 0x7f6a53cb9c30>,
 'poi': POI(centroids={2: {90: (68.245, -145.401, 28.251), 91: (75.912, -153.898, 181.001), 92: (2.855, -176.184, 74.401), 93: (2.058, -170.762, 137.877)}, 3: {90: (69.61, -111.171, 30.185), 91: (70.883, -116.598, 172.032), 92: (2.657, -128.19299999999998, 83.327), 93: (1.17, -118.733, 118.765)}, 4: {90: (61.717, -91.432, 30.763), 91: (67.54, -84.103, 178.826), 92: (2.049, -70.617, 90.344), 93: (4.151, -75.353, 118.988)}, 5: {90: (74.16, -58.516, 40.451), 91: (75.749, -60.437, 173.471), 92: (8.176, -23.13, 80.883), 93: (14.072, -26.806, 124.99)}, 6: {90: (88.05, -17.962, 38.066), 91: (90.601, -13.384, 178.078), 92: (19.623, 15.136, 77.066), 93: (23.339, 8.707, 127.969)}, 7: {90: (101.411, 15.396, 39.948), 91: (93.574, 13.269, 173.652), 92: (32.407, 48.132, 78.346), 93: (33.605, 51.272, 

In [8]:
cutout['ct_nii'].save('ct.nii.gz')
cutout['subreg_nii'].save('subreg.nii.gz')
cutout['vertseg_nii'].save('vertseg.nii.gz')
cutout['poi'].save('poi.json')

[96m[*] Save ct.nii.gz as int16[0m[0m
[96m[*] Save subreg.nii.gz as uint8[0m[0m
[96m[*] Save vertseg.nii.gz as uint8[0m[0m
[96m[*] Centroids saved: poi.json in format POI[0m[0m


In [9]:
visualize_pois(ctd = cutout['poi'], seg_vert = cutout['vertseg_nii'], vert_idx_list=[cutout['vertebra']])

[0m[ ] Image reoriented from ('L', 'A', 'S') to ('P', 'I', 'R')[0m[0m
[0m[*] Centroids reoriented from ('L', 'A', 'S') to ('P', 'I', 'R')[0m[0m


100%|██████████| 1/1 [00:00<00:00, 10.81it/s]


Widget(value="<iframe src='http://localhost:34411/index.html?ui=P_0x7f6a541ef040_0&reconnect=auto' style='widt…

In [23]:
def calc_sfc_distances(poi: POI, vertseg: NII, vert: int) -> list:
    """
    Calculate the shortest distance between the surface of a vertebra and the points of interest.

    Args:
        poi (POI): The points of interest.
        vertseg (NII): The segmentation of the vertebrae.
        vert (int): The vertebra to calculate the distances for.

    Returns:
        list: A list of distances for each point of interest.
    """
    sfc_distances = {}
    vertseg_arr = vertseg.rescale((1,1,1)).get_array().copy()
    vert_msk = vertseg_arr == vert
    vertseg_arr[vert_msk] = 1
    vertseg_arr[~vert_msk] = 0

    eroded = binary_erosion(vertseg_arr, iterations=2)
    vertseg_arr[eroded] = 0
    for v_idx, p_idx, coords in poi.rescale((1,1,1)).items():
        x, y, z = coords
        if v_idx != vert:
            continue
        sfc_coords = np.where(vertseg_arr == 1)
        distances = np.sqrt((sfc_coords[0] - x)**2 + (sfc_coords[1] - y)**2 + (sfc_coords[2] - z)**2)
        sfc_distances[p_idx] = np.min(distances)

    return sfc_distances

In [24]:
calc_sfc_distances(cutout['poi'], cutout['vertseg_nii'], cutout['vertebra'])

[0m[*] Rescaled centroid coordinates to spacing (x, y, z) = (1, 1, 1) mm[0m[0m


{90: 4.838407860422128,
 91: 4.100942145656975,
 92: 0.3146065300604878,
 93: 1.1497855588658308}

In [26]:
for cutout in cutouts:
    sfc_distances = calc_sfc_distances(cutout['poi'], cutout['vertseg_nii'], cutout['vertebra'])
    for p_idx, distance in sfc_distances.items():
        cutout['d_sfc_' + str(p_idx)] = distance

[0m[*] Rescaled centroid coordinates to spacing (x, y, z) = (1, 1, 1) mm[0m[0m
[0m[*] Rescaled centroid coordinates to spacing (x, y, z) = (1, 1, 1) mm[0m[0m
[0m[*] Rescaled centroid coordinates to spacing (x, y, z) = (1, 1, 1) mm[0m[0m
[0m[*] Rescaled centroid coordinates to spacing (x, y, z) = (1, 1, 1) mm[0m[0m
[0m[*] Rescaled centroid coordinates to spacing (x, y, z) = (1, 1, 1) mm[0m[0m
[0m[*] Rescaled centroid coordinates to spacing (x, y, z) = (1, 1, 1) mm[0m[0m
[0m[*] Rescaled centroid coordinates to spacing (x, y, z) = (1, 1, 1) mm[0m[0m


In [32]:
cutouts_df = pd.DataFrame(cutouts)
cutouts_df.describe()

Unnamed: 0,vertebra,d_sfc_90,d_sfc_91,d_sfc_92,d_sfc_93
count,7.0,7.0,7.0,7.0,7.0
mean,5.0,4.010787,4.755451,1.662263,1.854666
std,2.160247,0.926042,1.25047,1.337239,0.855068
min,2.0,2.908072,3.826778,0.262286,0.945657
25%,3.5,3.198414,4.115813,0.708831,1.325927
50%,5.0,4.19437,4.230253,1.148416,1.627674
75%,6.5,4.637863,4.762888,2.435806,2.149505
max,8.0,5.300514,7.473724,3.935867,3.458464


In [12]:
# analyze all pois for all subjects
poi_iso_pir = []
for subject, container in bids_surgery_poi.enumerate_subjects():
    poi = get_poi(container)
    poi.rescale_().reorient_()

    poi_iso_pir.append({
        'subject': container.name,
        'poi': poi
    })

In [15]:
def sanity_check_pois(poi: POI):
    """
    Check the points of interest for sanity.

    Args:
        poi (POI): The points of interest to check.

    Returns:
        bool: True if the points of interest are sane, False otherwise.
    """
    unique_vertebrae = set([key[0] for key in poi.keys()])

    for v_idx in unique_vertebrae:
        # 90 is always left of 91, in PIR space that means a smaller R value
        assert poi[v_idx, 90][2] < poi[v_idx, 91][2], "POI 90 has coordinates " + str(poi[v_idx, 90]) + " and POI 91 has coordinates " + str(poi[v_idx, 91])

        # 92 is always left of 93, in PIR space that means a smaller R value
        assert poi[v_idx, 92][2] < poi[v_idx, 93][2], "POI 92 has coordinates " + str(poi[v_idx, 92]) + " and POI 93 has coordinates " + str(poi[v_idx, 93])
                                                                                                                                             
        # 92 is always anterior of 90, in PIR space that means a smaller P value
        assert poi[v_idx, 92][0] < poi[v_idx, 90][0], "POI 92 has coordinates " + str(poi[v_idx, 92]) + " and POI 90 has coordinates " + str(poi[v_idx, 90])

        # 93 is always anterior of 91, in PIR space that means a smaller P value
        assert poi[v_idx, 93][0] < poi[v_idx, 91][0], "POI 93 has coordinates " + str(poi[v_idx, 93]) + " and POI 91 has coordinates " + str(poi[v_idx, 91])

In [16]:
for dict in poi_iso_pir:
    sanity_check_pois(dict['poi'])

The most basic sanity checks are successful, however the surface distance is really worrying. 