Skip to content

Commit

Permalink
Merge pull request #88 from mikgroup/fix-convolve-1d
Browse files Browse the repository at this point in the history
Fix 1D convolve in CUDA
  • Loading branch information
frankong committed Jul 1, 2021
2 parents ca5f9db + 618c454 commit 62f3ecd
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 81 deletions.
42 changes: 42 additions & 0 deletions sigpy/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def convolve(data, filt, mode='full', strides=None, multi_channel=False):
r"""Convolution that supports multi-dimensional and multi-channel inputs.
This function follows the signal processing definition of convolution.
Note that the cuDNN version only supports inputs with D=1, 2 or 3.
Args:
data (array): data array of shape:
Expand Down Expand Up @@ -57,6 +58,8 @@ def convolve_data_adjoint(output, filt, data_shape,
mode='full', strides=None, multi_channel=False):
"""Adjoint convolution operation with respect to data.
Note that the cuDNN version only supports inputs with D=1, 2 or 3.
Args:
output (array): output array of shape
:math:`[..., p_1, ..., p_D]` if multi_channel is False,
Expand Down Expand Up @@ -106,6 +109,8 @@ def convolve_filter_adjoint(output, data, filt_shape,
mode='full', strides=None, multi_channel=False):
"""Adjoint convolution operation with respect to filter.
Note that the cuDNN version only supports inputs with D=1, 2 or 3.
Args:
output (array): output array of shape:
:math:`[..., p_1, ..., p_D]` if multi_channel is False,
Expand Down Expand Up @@ -325,6 +330,18 @@ def _convolve_cuda(data, filt,
D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
data.shape, filt.shape,
mode, strides, multi_channel)

if D == 1:
return _convolve_cuda(
xp.expand_dims(data, -1),
xp.expand_dims(filt, -1),
mode=mode,
strides=list(strides) + [1] if strides is not None else None,
multi_channel=multi_channel).squeeze(-1)
elif D > 3:
raise ValueError(
f'cuDNN convolution only supports 1, 2, or 3D, got {D}.')

dilations = (1, ) * D
groups = 1
auto_tune = True
Expand Down Expand Up @@ -359,6 +376,19 @@ def _convolve_data_adjoint_cuda(output, filt, data_shape,
D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
data_shape, filt.shape,
mode, strides, multi_channel)

if D == 1:
return _convolve_data_adjoint_cuda(
xp.expand_dims(output, -1),
xp.expand_dims(filt, -1),
list(data_shape) + [1],
mode=mode,
strides=list(strides) + [1] if strides is not None else None,
multi_channel=multi_channel).squeeze(-1)
elif D > 3:
raise ValueError(
f'cuDNN convolution only supports 1, 2 or 3D, got {D}.')

dilations = (1, ) * D
groups = 1
auto_tune = True
Expand Down Expand Up @@ -392,6 +422,18 @@ def _convolve_filter_adjoint_cuda(output, data, filt_shape,
D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
data.shape, filt_shape,
mode, strides, multi_channel)
if D == 1:
return _convolve_filter_adjoint_cuda(
xp.expand_dims(output, -1),
xp.expand_dims(data, -1),
list(filt_shape) + [1],
mode=mode,
strides=list(strides) + [1] if strides is not None else None,
multi_channel=multi_channel).squeeze(-1)
elif D > 3:
raise ValueError(
f'cuDNN convolution only supports 1, 2 or 3D, got {D}.')

dilations = (1, ) * D
groups = 1
auto_tune = True
Expand Down
2 changes: 1 addition & 1 deletion sigpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def prod(shape):
Product.
"""
return np.prod(shape, dtype=np.long)
return np.prod(shape, dtype=np.int64)


def vec(inputs):
Expand Down
200 changes: 120 additions & 80 deletions tests/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,93 +19,133 @@ def test_convolve_valid(self):
if config.cupy_enabled:
devices.append(backend.Device(0))

for device in devices:
xp = device.xp
with device:
for dtype in dtypes:
with self.subTest(dtype=dtype, device=device):
data = util.dirac([1, 3], device=device, dtype=dtype)
filt = xp.ones([1, 3], dtype=dtype)
output = backend.to_device(conv.convolve(
data, filt, mode=mode))
npt.assert_allclose(output, [[1]], atol=1e-5)

data = util.dirac([1, 3], device=device, dtype=dtype)
filt = xp.ones([1, 2], dtype=dtype)
output = backend.to_device(conv.convolve(
data, filt, mode=mode))
npt.assert_allclose(output, [[1, 1]], atol=1e-5)

data = util.dirac([1, 1, 3], device=device,
dtype=dtype)
filt = xp.ones([2, 1, 1, 3], dtype=dtype)
output = backend.to_device(
conv.convolve(data, filt,
mode=mode,
multi_channel=True),
backend.cpu_device)
npt.assert_allclose(output, [[[1]],
[[1]]], atol=1e-5)

data = util.dirac([1, 1, 3], device=device,
dtype=dtype)
filt = xp.ones([2, 1, 1, 3], dtype=dtype)
strides = [1, 2]
output = backend.to_device(
conv.convolve(data, filt,
mode=mode, strides=strides,
multi_channel=True),
backend.cpu_device)
npt.assert_allclose(output, [[[1]],
[[1]]], atol=1e-5)
for D in [1, 2, 3]:
for device in devices:
xp = device.xp
with device:
for dtype in dtypes:
with self.subTest(D=D, dtype=dtype, device=device):
data = util.dirac([3] + [1] * (D - 1),
device=device, dtype=dtype)
filt = xp.ones([3] + [1] * (D - 1), dtype=dtype)
output = backend.to_device(conv.convolve(
data, filt, mode=mode))
npt.assert_allclose(
output,
np.ones([1] * D), atol=1e-5)

data = util.dirac([3] + [1] * (D - 1),
device=device, dtype=dtype)
filt = xp.ones([2] + [1] * (D - 1), dtype=dtype)
output = backend.to_device(conv.convolve(
data, filt, mode=mode))
npt.assert_allclose(
output,
np.ones([2] + [1] * (D - 1)), atol=1e-5)

data = util.dirac([1, 3] + [1] * (D - 1),
device=device,
dtype=dtype)
filt = xp.ones([2, 1, 3] + [1] * (D - 1),
dtype=dtype)
output = backend.to_device(
conv.convolve(data, filt,
mode=mode,
multi_channel=True),
backend.cpu_device)
npt.assert_allclose(
output,
np.ones([2, 1] + [1] * (D - 1)),
atol=1e-5)

data = util.dirac([1, 3] + [1] * (D - 1),
device=device,
dtype=dtype)
filt = xp.ones([2, 1, 3] + [1] * (D - 1),
dtype=dtype)
strides = [2] + [1] * (D - 1)
output = backend.to_device(
conv.convolve(data, filt,
mode=mode, strides=strides,
multi_channel=True),
backend.cpu_device)
npt.assert_allclose(
output,
np.ones([2, 1] + [1] * (D - 1)),
atol=1e-5)

def test_convolve_full(self):
mode = 'full'
devices = [backend.cpu_device]
if config.cupy_enabled:
devices.append(backend.Device(0))

for device in devices:
xp = device.xp
with device:
for dtype in dtypes:
with self.subTest(dtype=dtype, device=device):
data = util.dirac([1, 3], device=device, dtype=dtype)
filt = xp.ones([1, 3], dtype=dtype)
output = backend.to_device(conv.convolve(
data, filt, mode=mode))
npt.assert_allclose(output, [[0, 1, 1, 1, 0]],
atol=1e-5)

data = util.dirac([1, 3], device=device, dtype=dtype)
filt = xp.ones([1, 2], dtype=dtype)
output = backend.to_device(conv.convolve(
data, filt, mode=mode))
npt.assert_allclose(output, [[0, 1, 1, 0]], atol=1e-5)

data = util.dirac([1, 1, 3], device=device,
dtype=dtype)
filt = xp.ones([2, 1, 1, 3], dtype=dtype)
output = backend.to_device(
conv.convolve(data, filt,
mode=mode,
multi_channel=True),
backend.cpu_device)
npt.assert_allclose(output, [[[0, 1, 1, 1, 0]],
[[0, 1, 1, 1, 0]]],
atol=1e-5)

data = util.dirac([1, 1, 3], device=device,
dtype=dtype)
filt = xp.ones([2, 1, 1, 3], dtype=dtype)
strides = [1, 2]
output = backend.to_device(
conv.convolve(data, filt,
mode=mode,
strides=strides,
multi_channel=True))
npt.assert_allclose(output, [[[0, 1, 0]],
[[0, 1, 0]]], atol=1e-5)
for D in [1, 2, 3]:
for device in devices:
xp = device.xp
with device:
for dtype in dtypes:
with self.subTest(D=D, dtype=dtype, device=device):
data = util.dirac(
[3] + [1] * (D - 1),
device=device, dtype=dtype)
filt = xp.ones(
[3] + [1] * (D - 1), dtype=dtype)
output = backend.to_device(conv.convolve(
data, filt, mode=mode))
npt.assert_allclose(
output,
np.array([0, 1, 1, 1, 0]).reshape(
[5] + [1] * (D - 1)),
atol=1e-5)

data = util.dirac(
[3] + [1] * (D - 1),
device=device, dtype=dtype)
filt = xp.ones(
[2] + [1] * (D - 1), dtype=dtype)
output = backend.to_device(conv.convolve(
data, filt, mode=mode))
npt.assert_allclose(
output,
np.array([0, 1, 1, 0]).reshape(
[4] + [1] * (D - 1)),
atol=1e-5)

data = util.dirac(
[1, 3] + [1] * (D - 1),
device=device, dtype=dtype)
filt = xp.ones(
[2, 1, 3] + [1] * (D - 1), dtype=dtype)
output = backend.to_device(
conv.convolve(data, filt,
mode=mode,
multi_channel=True),
backend.cpu_device)
npt.assert_allclose(
output,
np.array([[0, 1, 1, 1, 0],
[0, 1, 1, 1, 0]]).reshape(
[2, 5] + [1] * (D - 1)),
atol=1e-5)

data = util.dirac([1, 3] + [1] * (D - 1),
device=device,
dtype=dtype)
filt = xp.ones([2, 1, 3] + [1] * (D - 1),
dtype=dtype)
strides = [2] + [1] * (D - 1)
output = backend.to_device(
conv.convolve(data, filt,
mode=mode,
strides=strides,
multi_channel=True))
npt.assert_allclose(
output,
np.array([[0, 1, 0],
[0, 1, 0]]).reshape(
[2, 3] + [1] * (D - 1)),
atol=1e-5)

def test_convolve_data_adjoint_valid(self):
mode = 'valid'
Expand Down

0 comments on commit 62f3ecd

Please sign in to comment.