Skip to content

Commit

Permalink
Merge pull request #34 from mikgroup/standardize-device-placememmt
Browse files Browse the repository at this point in the history
Standardize device placement
  • Loading branch information
frankong committed Jan 31, 2020
2 parents 904cded + 3323237 commit 348f77f
Show file tree
Hide file tree
Showing 10 changed files with 622 additions and 724 deletions.
488 changes: 242 additions & 246 deletions sigpy/block.py

Large diffs are not rendered by default.

169 changes: 67 additions & 102 deletions sigpy/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
__all__ = ['convolve', 'convolve_data_adjoint', 'convolve_filter_adjoint']


def convolve(data, filt,
mode='full', strides=None,
multi_channel=False):
def convolve(data, filt, mode='full', strides=None, multi_channel=False):
r"""Convolution that supports multi-dimensional and multi-channel inputs.
This function follows the signal processing definition of convolution.
Expand All @@ -34,12 +32,8 @@ def convolve(data, filt,
:math:`[..., c_o, p_1, ..., p_D]` otherwise.
"""
device = backend.get_device(data)
filt = backend.to_device(filt, device)
with device:
filt = filt.astype(data.dtype, copy=False)

if device == backend.cpu_device:
xp = backend.get_array_module(data)
if xp == np:
output = _convolve(data, filt, mode=mode, strides=strides,
multi_channel=multi_channel)
else: # pragma: no cover
Expand All @@ -53,19 +47,14 @@ def convolve(data, filt,
mode=mode, strides=strides,
multi_channel=multi_channel)
else:
data = backend.to_device(data)
filt = backend.to_device(filt)
output = _convolve(data, filt,
mode=mode, strides=strides,
multi_channel=multi_channel)
output = backend.to_device(output, device)
raise RuntimeError(
'cudnn must be installed to perform convolution on GPU.')

return output


def convolve_data_adjoint(output, filt, data_shape,
mode='full', strides=None,
multi_channel=False):
mode='full', strides=None, multi_channel=False):
"""Adjoint convolution operation with respect to data.
Args:
Expand All @@ -87,13 +76,10 @@ def convolve_data_adjoint(output, filt, data_shape,
:math:`[..., c_i, m_1, ..., m_D]` otherwise.
"""
device = backend.get_device(output)
data_shape = tuple(data_shape)
filt = backend.to_device(filt, device)
with device:
filt = filt.astype(output.dtype, copy=False)

if device == backend.cpu_device:
xp = backend.get_array_module(output)
if xp == np:
data = _convolve_data_adjoint(output, filt, data_shape,
mode=mode, strides=strides,
multi_channel=multi_channel)
Expand All @@ -110,19 +96,14 @@ def convolve_data_adjoint(output, filt, data_shape,
mode=mode, strides=strides,
multi_channel=multi_channel)
else:
filt = backend.to_device(filt)
output = backend.to_device(output)
data = _convolve_data_adjoint(output, filt, data_shape,
mode=mode, strides=strides,
multi_channel=multi_channel)
data = backend.to_device(data, device)
raise RuntimeError(
'cudnn must be installed to perform convolution on GPU.')

return data


def convolve_filter_adjoint(output, data, filt_shape,
mode='full', strides=None,
multi_channel=False):
mode='full', strides=None, multi_channel=False):
"""Adjoint convolution operation with respect to filter.
Args:
Expand All @@ -142,13 +123,9 @@ def convolve_filter_adjoint(output, data, filt_shape,
:math:`[c_o, c_i, n_1, ..., n_D]` otherwise.
"""
device = backend.get_device(output)
filt_shape = tuple(filt_shape)
data = backend.to_device(data, device)
with device:
data = data.astype(output.dtype, copy=False)

if device == backend.cpu_device:
xp = backend.get_array_module(data)
if xp == np:
filt = _convolve_filter_adjoint(output, data, filt_shape,
mode=mode, strides=strides,
multi_channel=multi_channel)
Expand All @@ -165,12 +142,8 @@ def convolve_filter_adjoint(output, data, filt_shape,
mode=mode, strides=strides,
multi_channel=multi_channel)
else:
data = backend.to_device(data)
output = backend.to_device(output)
filt = _convolve_filter_adjoint(output, data, filt_shape,
mode=mode, strides=strides,
multi_channel=multi_channel)
filt = backend.to_device(filt, device)
raise RuntimeError(
'cudnn must be installed to perform convolution on GPU.')

return filt

Expand Down Expand Up @@ -329,28 +302,25 @@ def _convolve_filter_adjoint(output, data, filt_shape,
def _complex(func, data1, data2, *kargs, **kwargs):
"""Helper function to convert func to support complex floats.
"""
device = backend.get_device(data1)
xp = device.xp
with device:
data1r = xp.real(data1)
data1i = xp.imag(data1)
data2r = xp.real(data2)
data2i = xp.imag(data2)

outputr = func(data1r, data2r, *kargs, **kwargs)
outputr -= func(data1i, data2i, *kargs, **kwargs)
outputi = func(data1i, data2r, *kargs, **kwargs)
outputi += func(data1r, data2i, *kargs, **kwargs)

output = outputr + 1j * outputi
output = output.astype(data1.dtype, copy=False)
return output
xp = backend.get_array_module(data1)
data1r = xp.real(data1)
data1i = xp.imag(data1)
data2r = xp.real(data2)
data2i = xp.imag(data2)

outputr = func(data1r, data2r, *kargs, **kwargs)
outputr -= func(data1i, data2i, *kargs, **kwargs)
outputi = func(data1i, data2r, *kargs, **kwargs)
outputi += func(data1r, data2i, *kargs, **kwargs)

output = outputr + 1j * outputi
output = output.astype(data1.dtype, copy=False)
return output

def _convolve_cuda(data, filt,
mode='full', strides=None,
multi_channel=False):
device = backend.get_device(data)
xp = device.xp
xp = backend.get_array_module(data)

D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
data.shape, filt.shape,
Expand All @@ -364,29 +334,27 @@ def _convolve_cuda(data, filt,
elif mode == 'valid':
pads = (0, ) * D

with device:
data = data.reshape((B, c_i) + m)
filt = filt.reshape((c_o, c_i) + n)
output = xp.empty((B, c_o) + p, dtype=data.dtype)
filt = util.flip(filt, axes=range(-D, 0))
cudnn.convolution_forward(data, filt, None, output,
pads, s, dilations, groups,
auto_tune=auto_tune,
tensor_core=tensor_core)

# Reshape.
if multi_channel:
output = output.reshape(b + (c_o, ) + p)
else:
output = output.reshape(b + p)
data = data.reshape((B, c_i) + m)
filt = filt.reshape((c_o, c_i) + n)
output = xp.empty((B, c_o) + p, dtype=data.dtype)
filt = util.flip(filt, axes=range(-D, 0))
cudnn.convolution_forward(data, filt, None, output,
pads, s, dilations, groups,
auto_tune=auto_tune,
tensor_core=tensor_core)

# Reshape.
if multi_channel:
output = output.reshape(b + (c_o, ) + p)
else:
output = output.reshape(b + p)

return output

def _convolve_data_adjoint_cuda(output, filt, data_shape,
mode='full', strides=None,
multi_channel=False):
device = backend.get_device(output)
xp = device.xp
xp = backend.get_array_module(output)

D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
data_shape, filt.shape,
Expand All @@ -401,27 +369,25 @@ def _convolve_data_adjoint_cuda(output, filt, data_shape,
elif mode == 'valid':
pads = (0, ) * D

with device:
output = output.reshape((B, c_o) + p)
filt = filt.reshape((c_o, c_i) + n)
data = xp.empty((B, c_i) + m, dtype=output.dtype)
filt = util.flip(filt, axes=range(-D, 0))
cudnn.convolution_backward_data(filt, output, None, data,
pads, s, dilations, groups,
deterministic=deterministic,
auto_tune=auto_tune,
tensor_core=tensor_core)
output = output.reshape((B, c_o) + p)
filt = filt.reshape((c_o, c_i) + n)
data = xp.empty((B, c_i) + m, dtype=output.dtype)
filt = util.flip(filt, axes=range(-D, 0))
cudnn.convolution_backward_data(filt, output, None, data,
pads, s, dilations, groups,
deterministic=deterministic,
auto_tune=auto_tune,
tensor_core=tensor_core)

# Reshape.
data = data.reshape(data_shape)
# Reshape.
data = data.reshape(data_shape)

return data

def _convolve_filter_adjoint_cuda(output, data, filt_shape,
mode='full', strides=None,
multi_channel=False):
device = backend.get_device(data)
xp = device.xp
xp = backend.get_array_module(data)

D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params(
data.shape, filt_shape,
Expand All @@ -436,16 +402,15 @@ def _convolve_filter_adjoint_cuda(output, data, filt_shape,
elif mode == 'valid':
pads = (0, ) * D

with device:
data = data.reshape((B, c_i) + m)
output = output.reshape((B, c_o) + p)
filt = xp.empty((c_o, c_i) + n, dtype=output.dtype)
cudnn.convolution_backward_filter(data, output, filt,
pads, s, dilations, groups,
deterministic=deterministic,
auto_tune=auto_tune,
tensor_core=tensor_core)
filt = util.flip(filt, axes=range(-D, 0))
filt = filt.reshape(filt_shape)
data = data.reshape((B, c_i) + m)
output = output.reshape((B, c_o) + p)
filt = xp.empty((c_o, c_i) + n, dtype=output.dtype)
cudnn.convolution_backward_filter(data, output, filt,
pads, s, dilations, groups,
deterministic=deterministic,
auto_tune=auto_tune,
tensor_core=tensor_core)
filt = util.flip(filt, axes=range(-D, 0))
filt = filt.reshape(filt_shape)

return filt

0 comments on commit 348f77f

Please sign in to comment.