Skip to content

Commit

Permalink
Merge pull request #43 from chcodes/fix-scalars-to-cuda
Browse files Browse the repository at this point in the history
Fixed scalars being wrong input to ElementwiseKernels
  • Loading branch information
frankong committed Apr 27, 2020
2 parents cbbe549 + a659737 commit 4033dea
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
12 changes: 10 additions & 2 deletions sigpy/thresh.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,14 @@ def soft_thresh(lamda, input):
array: soft-thresholded result.
"""
xp = backend.get_array_module(input)
device = backend.get_device(input)
xp = device.xp
if xp == np:
return _soft_thresh(lamda, input)
else: # pragma: no cover
if np.isscalar(lamda):
lamda = backend.to_device(lamda, device)

return _soft_thresh_cuda(lamda, input)


Expand All @@ -44,10 +48,14 @@ def hard_thresh(lamda, input):
array: hard-thresholded result.
"""
xp = backend.get_array_module(input)
device = backend.get_device(input)
xp = device.xp
if xp == np:
return _hard_thresh(lamda, input)
else: # pragma: no cover
if np.isscalar(lamda):
lamda = backend.to_device(lamda, device)

return _hard_thresh_cuda(lamda, input)


Expand Down
18 changes: 13 additions & 5 deletions sigpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,15 +429,19 @@ def axpy(y, a, x):
Args:
y (array): Output array.
a (scalar): Input scalar.
a (scalar or array): Input scalar.
x (array): Input array.
"""
xp = backend.get_array_module(y)
device = backend.get_device(y)
xp = device.xp

if xp == np:
_axpy(y, a, x, out=y)
else:
if np.isscalar(a):
a = backend.to_device(a, device)

_axpy_cuda(a, x, y)


Expand All @@ -446,14 +450,18 @@ def xpay(y, a, x):
Args:
y (array): Output array.
a (scalar): Input scalar.
a (scalar or array): Input scalar.
x (array): Input array.
"""
xp = backend.get_array_module(y)
device = backend.get_device(y)
xp = device.xp

if xp == np:
_xpay(y, a, x, out=y)
else:
if np.isscalar(a):
a = backend.to_device(a, device)

_xpay_cuda(a, x, y)


Expand Down Expand Up @@ -484,4 +492,4 @@ def _xpay(y, a, x):
"""
y = x + (T) a * y;
""",
name='axpy')
name='xpay')

0 comments on commit 4033dea

Please sign in to comment.