Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add parallelization to gibbs denoising #2250

Merged
merged 12 commits into from Oct 14, 2020
63 changes: 44 additions & 19 deletions dipy/denoise/gibbs.py
@@ -1,4 +1,7 @@

from functools import partial
from multiprocessing import Pool

import numpy as np


Expand Down Expand Up @@ -220,7 +223,7 @@ def _gibbs_removal_2d(image, n_points=3, G0=None, G1=None):
return imagec


def gibbs_removal(vol, slice_axis=2, n_points=3, inplace=True):
def gibbs_removal(vol, slice_axis=2, n_points=3, inplace=True, num_processes=1):
j1c marked this conversation as resolved.
Show resolved Hide resolved
"""Suppresses Gibbs ringing artefacts of images volumes.

Parameters
Expand All @@ -237,6 +240,12 @@ def gibbs_removal(vol, slice_axis=2, n_points=3, inplace=True):
If True, the input data is replaced with results. Otherwise, returns
a new array.
Default is set to True.
num_processes : int, optional
Split the calculation to a pool of children processes. This only
applies to 3D or 4D `data` arrays. If a positive integer then it
defines the size of the multiprocessing pool that will be used. If 0,
then the size of the pool will equal the number of cores available.
Default is set to 1.

Returns
-------
Expand All @@ -262,33 +271,40 @@ def gibbs_removal(vol, slice_axis=2, n_points=3, inplace=True):
"""
nd = vol.ndim

# check matrix dimension
if nd > 4:
raise ValueError("Data have to be a 4D, 3D or 2D matrix")
elif nd < 2:
raise ValueError("Data is not an image")

if not isinstance(inplace, bool):
raise TypeError("inplace must be a boolean.")

if not isinstance(num_processes, int):
raise TypeError("num_processes must be an int.")
else:
if num_processes < 0:
raise ValueError("num_processes must be >= 0.")
j1c marked this conversation as resolved.
Show resolved Hide resolved

# check the axis corresponding to different slices
# 1) This axis cannot be larger than 2
if slice_axis > 2:
raise ValueError("Different slices have to be organized along" +
"one of the 3 first matrix dimensions")

# 2) If this is not 2, swap axes so that different slices are ordered
# along axis 2. Note that swapping is not required if data is already a
# single image
elif slice_axis < 2 and nd > 2:
vol = np.swapaxes(vol, slice_axis, 2)
# 2) Reorder axis to allow iteration over the first axis
elif nd == 3:
vol = np.moveaxis(vol, slice_axis, 0)
elif nd == 4:
vol = np.moveaxis(vol, (slice_axis, 3), (0, 1))

# check matrix dimension
if nd == 4:
inishap = vol.shape
vol = vol.reshape((inishap[0], inishap[1], inishap[2] * inishap[3]))
elif nd > 4:
raise ValueError("Data have to be a 4D, 3D or 2D matrix")
elif nd < 2:
raise ValueError("Data is not an image")
vol = vol.reshape((inishap[0] * inishap[1], inishap[2], inishap[3]))

# Produce weigthing functions for 2D Gibbs removal
# Produce weighting functions for 2D Gibbs removal
shap = vol.shape
G0, G1 = _weights(shap[:2])
G0, G1 = _weights(shap[-2:])

# Copy data if not inplace
if not inplace:
Expand All @@ -298,14 +314,23 @@ def gibbs_removal(vol, slice_axis=2, n_points=3, inplace=True):
if nd == 2:
vol[:, :] = _gibbs_removal_2d(vol, n_points=n_points, G0=G0, G1=G1)
else:
for vi in range(shap[2]):
vol[:, :, vi] = _gibbs_removal_2d(vol[:, :, vi], n_points=n_points,
G0=G0, G1=G1)
if num_processes == 0:
pool = Pool()
else:
pool = Pool(num_processes)

partial_func = partial(
_gibbs_removal_2d, n_points=n_points, G0=G0, G1=G1
)
vol[:, :, :] = pool.map(partial_func, vol)
pool.close()
pool.join()

# Reshape data to original format
if nd == 3:
vol = np.moveaxis(vol, 0, slice_axis)
if nd == 4:
vol = vol.reshape(inishap)
if slice_axis < 2 and nd > 2:
vol = np.swapaxes(vol, slice_axis, 2)
vol = np.moveaxis(vol, (0, 1), (slice_axis, 3))

return vol
29 changes: 29 additions & 0 deletions dipy/denoise/tests/test_gibbs.py
Expand Up @@ -38,6 +38,29 @@ def setup_module():
image_cor = _gibbs_removal_2d(image_gibbs)


def test_parallel():
# Only relevant for 3d or 4d inputs

# Make input data
input_2d = image_gibbs.copy()
input_3d = np.stack([input_2d, input_2d], axis=2)
input_4d = np.stack([input_3d, input_3d], axis=3)

# Test 3d case
output_3d_parallel = gibbs_removal(input_3d, inplace=False, num_processes=2)
output_3d_no_parallel = gibbs_removal(
input_3d, inplace=False, num_processes=1
)
assert_array_almost_equal(output_3d_parallel, output_3d_no_parallel)

# Test 4d case
output_4d_parallel = gibbs_removal(input_4d, inplace=False, num_processes=2)
output_4d_no_parallel = gibbs_removal(
input_4d, inplace=False, num_processes=1
)
assert_array_almost_equal(output_4d_parallel, output_4d_no_parallel)


def test_inplace():
# Make input data
input_2d = image_gibbs.copy()
Expand Down Expand Up @@ -159,6 +182,12 @@ def test_gibbs_errors():
assert_raises(ValueError, gibbs_removal, np.ones((2)))
assert_raises(ValueError, gibbs_removal, np.ones((2, 2, 2)), 3)
assert_raises(TypeError, gibbs_removal, image_gibbs.copy(), inplace="True")
assert_raises(
TypeError, gibbs_removal, image_gibbs.copy(), num_processes="1"
)
assert_raises(
ValueError, gibbs_removal, image_gibbs.copy(), num_processes=-1
)


def test_gibbs_subfunction():
Expand Down