Skip to content

Commit

Permalink
Fix axpy and xpay datatype issues
Browse files Browse the repository at this point in the history
  • Loading branch information
frankong committed Jun 5, 2020
1 parent 4033dea commit b6882ae
Showing 1 changed file with 4 additions and 52 deletions.
56 changes: 4 additions & 52 deletions sigpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
"""Utility functions.
"""
import numpy as np
import numba as nb

from sigpy import backend, config
from sigpy import backend


__all__ = ['prod', 'vec', 'split', 'rss', 'resize',
Expand Down Expand Up @@ -433,16 +432,7 @@ def axpy(y, a, x):
x (array): Input array.
"""
device = backend.get_device(y)
xp = device.xp

if xp == np:
_axpy(y, a, x, out=y)
else:
if np.isscalar(a):
a = backend.to_device(a, device)

_axpy_cuda(a, x, y)
y += a * x


def xpay(y, a, x):
Expand All @@ -453,43 +443,5 @@ def xpay(y, a, x):
a (scalar or array): Input scalar.
x (array): Input array.
"""
device = backend.get_device(y)
xp = device.xp

if xp == np:
_xpay(y, a, x, out=y)
else:
if np.isscalar(a):
a = backend.to_device(a, device)

_xpay_cuda(a, x, y)


@nb.vectorize(nopython=True, cache=True) # pragma: no cover
def _axpy(y, a, x):
return a * x + y


@nb.vectorize(nopython=True, cache=True) # pragma: no cover
def _xpay(y, a, x):
return x + a * y


if config.cupy_enabled: # pragma: no cover
import cupy as cp

_axpy_cuda = cp.ElementwiseKernel(
'S a, T x',
'T y',
"""
y += (T) a * x;
""",
name='axpy')

_xpay_cuda = cp.ElementwiseKernel(
'S a, T x',
'T y',
"""
y = x + (T) a * y;
""",
name='xpay')
y *= a
y += x

0 comments on commit b6882ae

Please sign in to comment.