Skip to content

Commit

Permalink
Merge pull request #1745 from frheault/change_num_thread_recobundles
Browse files Browse the repository at this point in the history
Adjust number of threads for SLR in Recobundles
  • Loading branch information
jchoude committed Mar 1, 2019
2 parents 61dbe1d + 88dc62b commit 73260b6
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 19 deletions.
37 changes: 19 additions & 18 deletions dipy/segment/bundles.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,12 @@ def __init__(self, streamlines, greater_than=50, less_than=1000000,
self.orig_indices = np.array(list(range(0, len(streamlines))))
self.filtered_indices = np.array(self.orig_indices[map_ind])
self.streamlines = Streamlines(streamlines[map_ind])
print("target brain streamlines length = ", len(streamlines))
print("After refining target brain streamlines length = ",
len(self.streamlines))
self.nb_streamlines = len(self.streamlines)
self.verbose = verbose
if self.verbose:
print("target brain streamlines length = ", len(streamlines))
print("After refining target brain streamlines length = ",
len(self.streamlines))

self.start_thr = [40, 25, 20]
if rng is None:
Expand Down Expand Up @@ -182,6 +183,7 @@ def recognize(self, model_bundle, model_clust_thr,
reduction_thr=10,
reduction_distance='mdf',
slr=True,
slr_num_threads=None,
slr_metric=None,
slr_x0=None,
slr_bounds=None,
Expand Down Expand Up @@ -234,8 +236,8 @@ def recognize(self, model_bundle, model_clust_thr,
print('## Recognize given bundle ## \n')

model_centroids = self._cluster_model_bundle(
model_bundle,
model_clust_thr=model_clust_thr)
model_bundle,
model_clust_thr=model_clust_thr)

neighb_streamlines, neighb_indices = self._reduce_search_space(
model_centroids,
Expand All @@ -246,7 +248,6 @@ def recognize(self, model_bundle, model_clust_thr,
return Streamlines([]), []

if slr:

transf_streamlines, slr1_bmd = self._register_neighb_to_model(
model_bundle,
neighb_streamlines,
Expand All @@ -255,8 +256,8 @@ def recognize(self, model_bundle, model_clust_thr,
bounds=slr_bounds,
select_model=slr_select[0],
select_target=slr_select[1],
method=slr_method)

method=slr_method,
num_threads=slr_num_threads)
else:
transf_streamlines = neighb_streamlines

Expand Down Expand Up @@ -336,12 +337,12 @@ def refine(self, model_bundle, pruned_streamlines, model_clust_thr,
print('## Refine recognize given bundle ## \n')

model_centroids = self._cluster_model_bundle(
model_bundle,
model_clust_thr=model_clust_thr)
model_bundle,
model_clust_thr=model_clust_thr)

pruned_model_centroids = self._cluster_model_bundle(
pruned_streamlines,
model_clust_thr=model_clust_thr)
pruned_streamlines,
model_clust_thr=model_clust_thr)

neighb_streamlines, neighb_indices = self._reduce_search_space(
pruned_model_centroids,
Expand Down Expand Up @@ -404,11 +405,11 @@ def evaluate_results(self, model_bundle, pruned_streamlines, slr_select):

spruned_streamlines = Streamlines(pruned_streamlines)
recog_centroids = self._cluster_model_bundle(
spruned_streamlines,
model_clust_thr=1.25)
spruned_streamlines,
model_clust_thr=1.25)
mod_centroids = self._cluster_model_bundle(
model_bundle,
model_clust_thr=1.25)
model_bundle,
model_clust_thr=1.25)
recog_centroids = Streamlines(recog_centroids)
model_centroids = Streamlines(mod_centroids)
ba_value = ba_analysis(recog_centroids, model_centroids, threshold=10)
Expand Down Expand Up @@ -502,14 +503,14 @@ def _register_neighb_to_model(self, model_bundle, neighb_streamlines,
metric=None, x0=None, bounds=None,
select_model=400, select_target=600,
method='L-BFGS-B',
nb_pts=20):
nb_pts=20, num_threads=None):

if self.verbose:
print('# Local SLR of neighb_streamlines to model')
t = time()

if metric is None or metric == 'symmetric':
metric = BundleMinDistanceMetric()
metric = BundleMinDistanceMetric(num_threads=num_threads)
if metric == 'asymmetric':
metric = BundleMinDistanceAsymmetricMetric()
if metric == 'diagonal':
Expand Down
31 changes: 30 additions & 1 deletion dipy/segment/tests/test_refine_rb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
import nibabel as nib
from numpy.testing import assert_equal, run_module_suite
from numpy.testing import (assert_equal,
assert_almost_equal,
run_module_suite)
from dipy.data import get_fnames
from dipy.segment.bundles import RecoBundles
from dipy.tracking.distances import bundles_distances_mam
Expand Down Expand Up @@ -79,6 +81,33 @@ def test_rb_disable_slr():
assert_equal(row.min(), 0)


def test_rb_slr_threads():

rng_multi = np.random.RandomState(42)
rb_multi = RecoBundles(f, greater_than=0, clust_thr=10,
rng=np.random.RandomState(42))
rec_trans_multi_threads, _ = rb_multi.recognize(model_bundle=f2,
model_clust_thr=5.,
reduction_thr=10,
slr=True,
slr_num_threads=None)

rb_single = RecoBundles(f, greater_than=0, clust_thr=10,
rng=np.random.RandomState(42))
rec_trans_single_thread, _ = rb_single.recognize(model_bundle=f2,
model_clust_thr=5.,
reduction_thr=10,
slr=True,
slr_num_threads=1)

D = bundles_distances_mam(rec_trans_multi_threads, rec_trans_single_thread)

# check if the bundle is recognized correctly
# multi-threading prevent an exact match
for row in D:
assert_almost_equal(row.min(), 0, decimal=4)


def test_rb_no_verbose_and_mam():

rb = RecoBundles(f, greater_than=0, clust_thr=10, verbose=False)
Expand Down

0 comments on commit 73260b6

Please sign in to comment.