Skip to content

Commit

Permalink
Fixed unnecessary Pool creation in searchlight (#386)
Browse files Browse the repository at this point in the history
When pool_size=1 the creation of a multiprocessing pool is unnecessary
and wasteful because data needs to be copied and sent to other process.
This would double the needed memory for each MPI task. In addition,
fork() can cause unpredictable behaviour in some MPI implementations,
see:

https://www.open-mpi.org/faq/?category=tuning#fork-warning
  • Loading branch information
davidt0x authored and mihaic committed Oct 16, 2018
1 parent 40a8d15 commit 8994231
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 10 deletions.
39 changes: 29 additions & 10 deletions brainiak/searchlight/searchlight.py
Expand Up @@ -420,18 +420,37 @@ def run_block_function(self, block_fn, extra_block_fn_params=None,
processes = usable_cpus
else:
processes = min(pool_size, usable_cpus)
with Pool(processes) as pool:

if processes > 1:
with Pool(processes) as pool:
for idx, block in enumerate(self.blocks):
result = pool.apply_async(
block_fn,
([subproblem[idx] for subproblem in self.subproblems],
self.submasks[idx],
self.sl_rad,
self.bcast_var,
extra_block_fn_params))
results.append((block[0], result))
local_outputs = [(result[0], result[1].get())
for result in results]
else:
# If we only are using one CPU core, no need to create a Pool,
# cause an underlying fork(), and send the data to that process.
# Just do it here in serial. This will save copying the memory
# and will stop a fork() which can cause problems in some MPI
# implementations.
for idx, block in enumerate(self.blocks):
result = pool.apply_async(
block_fn,
([subproblem[idx] for subproblem in self.subproblems],
self.submasks[idx],
self.sl_rad,
self.bcast_var,
extra_block_fn_params))
subprob_list = [subproblem[idx]
for subproblem in self.subproblems]
result = block_fn(
subprob_list,
self.submasks[idx],
self.sl_rad,
self.bcast_var,
extra_block_fn_params)
results.append((block[0], result))
local_outputs = [(result[0], result[1].get())
for result in results]
local_outputs = [(result[0], result[1]) for result in results]

# Collect results
global_outputs = self.comm.gather(local_outputs)
Expand Down
30 changes: 30 additions & 0 deletions tests/searchlight/test_searchlight.py
Expand Up @@ -60,6 +60,36 @@ def test_searchlight_with_cube():
assert global_outputs[i, j, k] is None


def test_searchlight_with_cube_poolsize_1():
sl = Searchlight(sl_rad=3)
comm = MPI.COMM_WORLD
rank = comm.rank
size = comm.size
dim0, dim1, dim2 = (50, 50, 50)
ntr = 30
nsubj = 3
mask = np.zeros((dim0, dim1, dim2), dtype=np.bool)
data = [np.empty((dim0, dim1, dim2, ntr), dtype=np.object)
if i % size == rank
else None
for i in range(0, nsubj)]

# Put a spot in the mask
mask[10:17, 10:17, 10:17] = True

sl.distribute(data, mask)
global_outputs = sl.run_searchlight(cube_sfn, pool_size=1)

if rank == 0:
assert global_outputs[13, 13, 13] == 1.0
global_outputs[13, 13, 13] = None

for i in range(global_outputs.shape[0]):
for j in range(global_outputs.shape[1]):
for k in range(global_outputs.shape[2]):
assert global_outputs[i, j, k] is None


def diamond_sfn(l, msk, myrad, bcast_var):
assert not np.any(msk[~Diamond(3).mask_])
if np.all(msk[Diamond(3).mask_]):
Expand Down

0 comments on commit 8994231

Please sign in to comment.