From fca04f512758a268a7a731871360d4607bdaec92 Mon Sep 17 00:00:00 2001 From: David Turner Date: Mon, 15 Oct 2018 15:42:28 -0400 Subject: [PATCH 1/3] Fixed unnecessary Pool creation in searchlight When pool_size=1 the creation of a multiprocessing pool is unnecessary and wasteful because data needs to be copied and sent to other process. This would double the needed memory for each MPI task. In addition, fork() can cause unpredictable behaviour in some MPI implementations, see: https://www.open-mpi.org/faq/?category=tuning#fork-warning --- brainiak/searchlight/searchlight.py | 37 +++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/brainiak/searchlight/searchlight.py b/brainiak/searchlight/searchlight.py index c829311e9..cf59e7645 100644 --- a/brainiak/searchlight/searchlight.py +++ b/brainiak/searchlight/searchlight.py @@ -420,18 +420,35 @@ def run_block_function(self, block_fn, extra_block_fn_params=None, processes = usable_cpus else: processes = min(pool_size, usable_cpus) - with Pool(processes) as pool: + + if processes > 1: + with Pool(processes) as pool: + for idx, block in enumerate(self.blocks): + result = pool.apply_async( + block_fn, + ([subproblem[idx] for subproblem in self.subproblems], + self.submasks[idx], + self.sl_rad, + self.bcast_var, + extra_block_fn_params)) + results.append((block[0], result)) + local_outputs = [(result[0], result[1].get()) + for result in results] + else: + # If we only are using one CPU core, no need to create a Pool, cause an underlying + # fork(), and send the data to that process. Just do it here in serial. This + # will save copying the memory and will stop a fork() which can cause problems in + # some MPI implementations. for idx, block in enumerate(self.blocks): - result = pool.apply_async( - block_fn, - ([subproblem[idx] for subproblem in self.subproblems], - self.submasks[idx], - self.sl_rad, - self.bcast_var, - extra_block_fn_params)) + subprob_list = [subproblem[idx] for subproblem in self.subproblems] + result = block_fn( + subprob_list, + self.submasks[idx], + self.sl_rad, + self.bcast_var, + extra_block_fn_params) results.append((block[0], result)) - local_outputs = [(result[0], result[1].get()) - for result in results] + local_outputs = [(result[0], result[1]) for result in results] # Collect results global_outputs = self.comm.gather(local_outputs) From ac2cf6e5ef92b8aea84b875e1f07351a36dfb59f Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 16 Oct 2018 12:51:13 -0400 Subject: [PATCH 2/3] Fixed some lint errors. --- brainiak/searchlight/searchlight.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/brainiak/searchlight/searchlight.py b/brainiak/searchlight/searchlight.py index cf59e7645..1b05f507e 100644 --- a/brainiak/searchlight/searchlight.py +++ b/brainiak/searchlight/searchlight.py @@ -435,12 +435,14 @@ def run_block_function(self, block_fn, extra_block_fn_params=None, local_outputs = [(result[0], result[1].get()) for result in results] else: - # If we only are using one CPU core, no need to create a Pool, cause an underlying - # fork(), and send the data to that process. Just do it here in serial. This - # will save copying the memory and will stop a fork() which can cause problems in - # some MPI implementations. + # If we only are using one CPU core, no need to create a Pool, + # cause an underlying fork(), and send the data to that process. + # Just do it here in serial. This will save copying the memory + # and will stop a fork() which can cause problems in some MPI + # implementations. for idx, block in enumerate(self.blocks): - subprob_list = [subproblem[idx] for subproblem in self.subproblems] + subprob_list = [subproblem[idx] + for subproblem in self.subproblems] result = block_fn( subprob_list, self.submasks[idx], From dbfe0543f86228727c1e23e428256631141c4c60 Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 16 Oct 2018 17:41:37 -0400 Subject: [PATCH 3/3] Added test for searchlight with pool_size=1 --- tests/searchlight/test_searchlight.py | 30 +++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/searchlight/test_searchlight.py b/tests/searchlight/test_searchlight.py index f6da405fc..7d472e6f5 100644 --- a/tests/searchlight/test_searchlight.py +++ b/tests/searchlight/test_searchlight.py @@ -60,6 +60,36 @@ def test_searchlight_with_cube(): assert global_outputs[i, j, k] is None +def test_searchlight_with_cube_poolsize_1(): + sl = Searchlight(sl_rad=3) + comm = MPI.COMM_WORLD + rank = comm.rank + size = comm.size + dim0, dim1, dim2 = (50, 50, 50) + ntr = 30 + nsubj = 3 + mask = np.zeros((dim0, dim1, dim2), dtype=np.bool) + data = [np.empty((dim0, dim1, dim2, ntr), dtype=np.object) + if i % size == rank + else None + for i in range(0, nsubj)] + + # Put a spot in the mask + mask[10:17, 10:17, 10:17] = True + + sl.distribute(data, mask) + global_outputs = sl.run_searchlight(cube_sfn, pool_size=1) + + if rank == 0: + assert global_outputs[13, 13, 13] == 1.0 + global_outputs[13, 13, 13] = None + + for i in range(global_outputs.shape[0]): + for j in range(global_outputs.shape[1]): + for k in range(global_outputs.shape[2]): + assert global_outputs[i, j, k] is None + + def diamond_sfn(l, msk, myrad, bcast_var): assert not np.any(msk[~Diamond(3).mask_]) if np.all(msk[Diamond(3).mask_]):