Skip to content

Commit

Permalink
Merge pull request #3481 from stuartarchibald/fix/3159
Browse files Browse the repository at this point in the history
Permit dtype argument as sole kwarg in np.eye
  • Loading branch information
seibert committed Nov 12, 2018
2 parents 2621fa9 + 94010af commit acd73ba
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 50 deletions.
62 changes: 27 additions & 35 deletions numba/targets/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -3427,52 +3427,44 @@ def identity(n, dtype):
return impl_ret_new_ref(context, builder, sig.return_type, res)


@lower_builtin(np.eye, types.Integer)
def numpy_eye(context, builder, sig, args):

def eye(n):
return np.identity(n)

res = context.compile_internal(builder, eye, sig, args)
return impl_ret_new_ref(context, builder, sig.return_type, res)

@lower_builtin(np.eye, types.Integer, types.Integer)
def numpy_eye(context, builder, sig, args):

def eye(n, m):
return np.eye(n, m, 0, np.float64)

res = context.compile_internal(builder, eye, sig, args)
return impl_ret_new_ref(context, builder, sig.return_type, res)

@lower_builtin(np.eye, types.Integer, types.Integer,
types.Integer)
def numpy_eye(context, builder, sig, args):

def eye(n, m, k):
return np.eye(n, m, k, np.float64)
def _eye_none_handler(N, M):
pass

@extending.overload(_eye_none_handler)
def _eye_none_handler_impl(N, M):
if isinstance(M, types.NoneType):
def impl(N, M):
return N
else:
def impl(N, M):
return M
return impl

res = context.compile_internal(builder, eye, sig, args)
return impl_ret_new_ref(context, builder, sig.return_type, res)
@extending.overload(np.eye)
def numpy_eye(N, M=None, k=0, dtype=float):

@lower_builtin(np.eye, types.Integer, types.Integer,
types.Integer, types.DTypeSpec)
def numpy_eye(context, builder, sig, args):
if dtype is None or isinstance(dtype, types.NoneType):
dt = np.dtype(float)
elif isinstance(dtype, (types.DTypeSpec, types.Number)):
# dtype or instance of dtype
dt = as_dtype(getattr(dtype, 'dtype', dtype))
else:
dt = np.dtype(dtype)

def eye(n, m, k, dtype):
arr = np.zeros((n, m), dtype)
def impl(N, M=None, k=0, dtype=float):
_M = _eye_none_handler(N, M)
arr = np.zeros((N, _M), dt)
if k >= 0:
d = min(n, m - k)
d = min(N, _M - k)
for i in range(d):
arr[i, i + k] = 1
else:
d = min(n + k, m)
d = min(N + k, _M)
for i in range(d):
arr[i - k, i] = 1
return arr
return impl

res = context.compile_internal(builder, eye, sig, args)
return impl_ret_new_ref(context, builder, sig.return_type, res)

@lower_builtin(np.diag, types.Array)
def numpy_diag(context, builder, sig, args):
Expand Down
7 changes: 7 additions & 0 deletions numba/tests/test_dyn_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,13 @@ def func(n):
return np.eye(n)
self.check_outputs(func, [(1,), (3,)])

def test_eye_n_dtype(self):
# check None option, dtype class, instance of dtype class
for dt in (None, np.complex128, np.complex64(1)):
def func(n, dtype=dt):
return np.eye(n, dtype=dtype)
self.check_outputs(func, [(1,), (3,)])

def test_eye_n_m(self):
def func(n, m):
return np.eye(n, m)
Expand Down
15 changes: 0 additions & 15 deletions numba/typing/npydecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,21 +623,6 @@ def _infer_dtype_from_inputs(inputs):
return dtype


@infer_global(np.eye)
class NdEye(CallableTemplate):

def generic(self):
def typer(N, M=None, k=None, dtype=None):
if dtype is None:
nb_dtype = types.float64
else:
nb_dtype = _parse_dtype(dtype)
if nb_dtype is not None:
return types.Array(ndim=2, dtype=nb_dtype, layout='C')

return typer


@infer_global(np.arange)
class NdArange(AbstractTemplate):

Expand Down

0 comments on commit acd73ba

Please sign in to comment.