Skip to content

Commit

Permalink
fix to preserve dtype of input array
Browse files Browse the repository at this point in the history
  • Loading branch information
kmaehashi committed Jan 11, 2018
1 parent 61b6f71 commit 6545add
Showing 1 changed file with 4 additions and 16 deletions.
20 changes: 4 additions & 16 deletions cupy/linalg/norms.py
Expand Up @@ -69,36 +69,24 @@ def norm(x, ord=None, axis=None, keepdims=False):
elif ord == 0:
# Zero norm
# Convert to Python float in accordance with NumPy
return (x != 0).sum(axis=axis, keepdims=keepdims, dtype='d')
return (x != 0).astype(x.real.dtype).sum(axis=axis, keepdims=keepdims)
elif ord == 1:
# special case for speedup
return abs(x).sum(axis=axis, keepdims=keepdims)
elif ord is None or ord == 2:
# special case for speedup
if issubclass(x.dtype.type, numpy.complexfloating):
s = abs(x)
s *= s
else:
s = x ** 2
s = (x.conj() * x).real
return cupy.sqrt(s.sum(axis=axis, keepdims=keepdims))
else:
try:
float(ord)
except TypeError:
raise ValueError("Invalid norm order for vectors.")

# Mirror Numpy behavior of casting to double for non-complex
# dtypes, and to float32 or float64 for complex dtypes and
# no reduction over all axes.
cast_dtype = 'd'
if issubclass(x.dtype.type, numpy.complexfloating):
if keepdims or tuple(sorted(axis)) != tuple(range(nd)):
cast_dtype = x.dtype.char.lower() # 'D'->'d' and 'F'->'f'

absx = abs(x).astype(cast_dtype)
absx = abs(x)
absx **= ord
ret = absx.sum(axis=axis, keepdims=keepdims)
ret **= (1.0 / ord)
ret **= cupy.reciprocal(ord, dtype=ret.dtype)
return ret
elif len(axis) == 2:
row_axis, col_axis = axis
Expand Down

0 comments on commit 6545add

Please sign in to comment.