Skip to content

Commit

Permalink
added option to have different source/target for oriented graph paths
Browse files Browse the repository at this point in the history
  • Loading branch information
mpaquette committed Nov 16, 2021
1 parent 0f3c688 commit 25477f9
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 36 deletions.
6 changes: 3 additions & 3 deletions compute_shortest_path_naive_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
Compute shortest paths and connectivity matrices for naive graph.
Needs a naive graph, a mask, a label map and a target_type.
The label map is expected to have integer value above 0 for each source/target.
The label map is expected to have integer value above 0 for each source.
If flag label_target isn't used, we use same label map as target as the sources.
We intersect the label map with the mask and discard what's outside.
For target_type 'COM', we compute the center-of-mass of each label (what remains after the intersection) and we find the closest voxel inside the mask to that center-of-mass.
Expand Down Expand Up @@ -77,7 +78,7 @@ def main():

source_roipath = args.label_source
target_roipath = args.label_target
if len(target_roipath) == 0:
if len(target_roipath) == 0: # not set
print('Using same labels for source and target')
target_roipath = source_roipath

Expand All @@ -93,7 +94,6 @@ def main():
mask = mask_img.get_fdata().astype(np.bool)



source_label_map = nib.load(source_roipath).get_fdata().astype(np.int)
print('Source label map has {:} greater than zero voxel'.format((source_label_map>0).sum()))
tmp_source_label_max = np.max(source_label_map)
Expand Down
113 changes: 80 additions & 33 deletions compute_shortest_path_oriented_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Needs a oriented graph, a mask, a label map and a target_type.
The label map is expected to have integer value above 0 for each source/target.
If flag label_target isn't used, we use same label map as target as the sources.
We intersect the label map with the mask and discard what's outside.
For target_type 'COM', we compute the center-of-mass of each label (what remains after the intersection) and we find the closest voxel inside the mask to that center-of-mass.
Expand Down Expand Up @@ -52,13 +53,15 @@ def buildArgsParser():

p.add_argument('graph', type=str, default=[],
help='Path of the naive graph (pickled).')
p.add_argument('label', type=str, default=[],
help='Path of the label map.')
p.add_argument('label_source', type=str, default=[],
help='Path of the sources label map. Will be used for targets too if not specified.')
p.add_argument('mask', type=str, default=[],
help='Path of the mask file.')
p.add_argument('target', choices=('COM', 'ROI'))
p.add_argument('output', type=str, default=[],
help='Base path of the output matrix.')
p.add_argument('--label_target', type=str, default='',
help='Path of the target label map.')
p.add_argument('--savepath', type=str, default=None,
help='Output the paths as tck file if file name is given.')
return p
Expand All @@ -72,7 +75,13 @@ def main():
args = parser.parse_args()

graph_fname = args.graph
roipath = args.label

source_roipath = args.label_source
target_roipath = args.label_target
if len(target_roipath) == 0: # not set
print('Using same labels for source and target')
target_roipath = source_roipath

mask_fname = args.mask
out_basefname = args.output

Expand All @@ -84,43 +93,71 @@ def main():
affine = mask_img.affine
mask = mask_img.get_fdata().astype(np.bool)

label_map = nib.load(roipath).get_fdata().astype(np.int)
print('Label map has {:} greater than zero voxel'.format((label_map>0).sum()))
tmp_label_max = np.max(label_map)

label_map[np.logical_not(mask)] = 0
print('Label map has {:} greater than zero voxel inside mask'.format((label_map>0).sum()))
source_label_map = nib.load(source_roipath).get_fdata().astype(np.int)
print('Source label map has {:} greater than zero voxel'.format((source_label_map>0).sum()))
tmp_source_label_max = np.max(source_label_map)

source_label_map[np.logical_not(mask)] = 0
print('Source label map has {:} greater than zero voxel inside mask'.format((source_label_map>0).sum()))

for i in range(1, tmp_source_label_max+1):
if (source_label_map==i).sum() == 0:
print('Source label map has no voxel inside mask for label = {:}'.format(i))


for i in range(1, tmp_label_max+1):
if (label_map==i).sum() == 0:
print('Label map has no voxel inside mask for label = {:}'.format(i))
if target_roipath == source_roipath:
target_label_map = source_label_map.copy()
else:
target_label_map = nib.load(target_roipath).get_fdata().astype(np.int)
print('Target label map has {:} greater than zero voxel'.format((target_label_map>0).sum()))
tmp_target_label_max = np.max(target_label_map)

target_label_map[np.logical_not(mask)] = 0
print('Target label map has {:} greater than zero voxel inside mask'.format((target_label_map>0).sum()))

for i in range(1, tmp_target_label_max+1):
if (target_label_map==i).sum() == 0:
print('Target label map has no voxel inside mask for label = {:}'.format(i))


if args.target == 'COM':
print('Using Center-of-Mass as sources/targets')

# compute center-of-mass -ish voxel for each roi
# add nodes
g.add_vertices(['COM_{}_source'.format(i) for i in range(1, label_map.max()+1)])
g.add_vertices(['COM_{}_target'.format(i) for i in range(1, label_map.max()+1)])
g.add_vertices(['COM_{}_source'.format(i) for i in range(1, source_label_map.max()+1)])
g.add_vertices(['COM_{}_target'.format(i) for i in range(1, target_label_map.max()+1)])

source_vertex = []
target_vertex = []
edges_to_add = []
for i in range(1, label_map.max()+1):

source_vertex = []
for i in range(1, source_label_map.max()+1):
# get roi
roi_mask = (label_map==i)
roi_mask = (source_label_map==i)
# get COM vox
COM = mask_COM(roi_mask)
# get vertex id of all 26 node at that vox
COM_vertex_cone = [vox2vertex[COM+(i_inc,)] for i_inc in range(26)]
# add new vertex to converge them all there
new_vert_id_source = g.vs['name'].index('COM_{}_source'.format(i))
new_vert_id_target = g.vs['name'].index('COM_{}_target'.format(i))
source_vertex.append(new_vert_id_source)
target_vertex.append(new_vert_id_target)
# create IN and OUT edge for all node at COM
# create IN edge for all node at COM
edges_to_add += [(new_vert_id_source, i_vert) for i_vert in COM_vertex_cone]

target_vertex = []
for i in range(1, target_label_map.max()+1):
# get roi
roi_mask = (target_label_map==i)
# get COM vox
COM = mask_COM(roi_mask)
# get vertex id of all 26 node at that vox
COM_vertex_cone = [vox2vertex[COM+(i_inc,)] for i_inc in range(26)]
# add new vertex to converge them all there
new_vert_id_target = g.vs['name'].index('COM_{}_target'.format(i))
# source_vertex.append(new_vert_id_source)
target_vertex.append(new_vert_id_target)
# create OUT edge for all node at COM
edges_to_add += [(i_vert, new_vert_id_target) for i_vert in COM_vertex_cone]

g.add_edges(edges_to_add,
Expand All @@ -133,32 +170,42 @@ def main():
elif args.target == 'ROI':
print('Using ROI nodes as sources/targets')

rois_vertex_cone = [mask2vertex_cone(label_map==i, vox2vertex) for i in range(1, label_map.max()+1)]
source_rois_vertex_cone = [mask2vertex_cone(source_label_map==i, vox2vertex) for i in range(1, source_label_map.max()+1)]

if target_roipath == source_roipath:
target_rois_vertex_cone = source_rois_vertex_cone
else:
target_rois_vertex_cone = [mask2vertex_cone(target_label_map==i, vox2vertex) for i in range(1, target_label_map.max()+1)]

start_time = time()
# compute center-of-mass -ish voxel for each roi
# add nodes
g.add_vertices(['ROI_{}_source'.format(i) for i in range(1, label_map.max()+1)])
g.add_vertices(['ROI_{}_target'.format(i) for i in range(1, label_map.max()+1)])
g.add_vertices(['ROI_{}_source'.format(i) for i in range(1, source_label_map.max()+1)])
g.add_vertices(['ROI_{}_target'.format(i) for i in range(1, target_label_map.max()+1)])


source_vertex = []
target_vertex = []
edges_to_add = []
for i in range(1, label_map.max()+1):
ROI_vertex_cone = [vert for voxvert in rois_vertex_cone[i-1] for vert in voxvert]

source_vertex = []
for i in range(1, source_label_map.max()+1):
ROI_vertex_cone = [vert for voxvert in source_rois_vertex_cone[i-1] for vert in voxvert]
# add new vertex to converge them all there
new_vert_id_source = g.vs['name'].index('ROI_{}_source'.format(i))
new_vert_id_target = g.vs['name'].index('ROI_{}_target'.format(i))
source_vertex.append(new_vert_id_source)
target_vertex.append(new_vert_id_target)
# create IN and OUT edge for all node at COM
# create IN edge for all node at ROI
edges_to_add += [(new_vert_id_source, i_vert) for i_vert in ROI_vertex_cone]

target_vertex = []
for i in range(1, target_label_map.max()+1):
ROI_vertex_cone = [vert for voxvert in target_rois_vertex_cone[i-1] for vert in voxvert]
# add new vertex to converge them all there
new_vert_id_target = g.vs['name'].index('ROI_{}_target'.format(i))
target_vertex.append(new_vert_id_target)
# create OUT edge for all node at ROI
edges_to_add += [(i_vert, new_vert_id_target) for i_vert in ROI_vertex_cone]


# TODO replace the big weight hack with 2 ROI nodes, a source and a target with unidirectional free edges
# edge of zero could give loops
# instead we put very very expensive nodes, and we can remove it when counting

g.add_edges(edges_to_add,
{'neg_log':[0]*len(edges_to_add)})
end_time = time()
Expand Down

0 comments on commit 25477f9

Please sign in to comment.