Skip to content

Commit

Permalink
replace most numpy/np by cupy/cp
Browse files Browse the repository at this point in the history
  • Loading branch information
leofang committed Sep 5, 2021
1 parent 370c7db commit e3104ec
Show file tree
Hide file tree
Showing 13 changed files with 255 additions and 255 deletions.
42 changes: 21 additions & 21 deletions cupyx/array_api/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,19 @@
if TYPE_CHECKING:
from ._typing import PyCapsule, Device, Dtype

import numpy as np
import cupy as cp

from numpy import array_api
from cupy import array_api


class Array:
"""
n-d array object for the array API namespace.
See the docstring of :py:obj:`np.ndarray <numpy.ndarray>` for more
See the docstring of :py:obj:`cp.ndarray <cupy.ndarray>` for more
information.
This is a wrapper around numpy.ndarray that restricts the usage to only
This is a wrapper around cupy.ndarray that restricts the usage to only
those things that are required by the array API namespace. Note,
attributes on this object that start with a single underscore are not part
of the API specification and should only be used internally. This object
Expand All @@ -70,9 +70,9 @@ def _new(cls, x, /):
"""
obj = super().__new__(cls)
# Note: The spec does not have array scalars, only 0-D arrays.
if isinstance(x, np.generic):
if isinstance(x, cp.generic):
# Convert the array scalar to a 0-D array
x = np.asarray(x)
x = cp.asarray(x)
if x.dtype not in _all_dtypes:
raise TypeError(
f"The array_api namespace does not support the dtype '{x.dtype}'"
Expand All @@ -99,7 +99,7 @@ def __repr__(self: Array, /) -> str:
"""
Performs the operation __repr__.
"""
return f"Array({np.array2string(self._array, separator=', ')}, dtype={self.dtype.name})"
return f"Array({cp.array2string(self._array, separator=', ')}, dtype={self.dtype.name})"

# These are various helper functions to make the array behavior match the
# spec in places where it either deviates from or is more strict than
Expand Down Expand Up @@ -135,8 +135,8 @@ def _check_allowed_dtypes(self, other, dtype_category, op):
# the type promoted operator does not match the left-hand side
# operand. For example,

# >>> a = np.array(1, dtype=np.int8)
# >>> a += np.array(1, dtype=np.int16)
# >>> a = cp.array(1, dtype=cp.int8)
# >>> a += cp.array(1, dtype=cp.int16)

# The spec explicitly disallows this.
if res_dtype != self.dtype:
Expand Down Expand Up @@ -178,7 +178,7 @@ def _promote_scalar(self, scalar):
# behavior for integers within the bounds of the integer dtype.
# Outside of those bounds we use the default NumPy behavior (either
# cast or raise OverflowError).
return Array._new(np.array(scalar, self.dtype))
return Array._new(cp.array(scalar, self.dtype))

@staticmethod
def _normalize_two_args(x1, x2):
Expand All @@ -188,10 +188,10 @@ def _normalize_two_args(x1, x2):
NumPy deviates from the spec type promotion rules in cases where one
argument is 0-dimensional and the other is not. For example:
>>> import numpy as np
>>> a = np.array([1.0], dtype=np.float32)
>>> b = np.array(1.0, dtype=np.float64)
>>> np.add(a, b) # The spec says this should be float64
>>> import cupy as cp
>>> a = cp.array([1.0], dtype=cp.float32)
>>> b = cp.array(1.0, dtype=cp.float64)
>>> cp.add(a, b) # The spec says this should be float64
array([2.], dtype=float32)
To fix this, we add a dimension to the 0-dimension array before passing it
Expand Down Expand Up @@ -294,9 +294,9 @@ def _validate_index(key, shape):

for idx in key:
if (
isinstance(idx, np.ndarray)
isinstance(idx, cp.ndarray)
and idx.dtype in _boolean_dtypes
or isinstance(idx, (bool, np.bool_))
or isinstance(idx, (bool, cp.bool_))
):
if len(key) == 1:
return key
Expand Down Expand Up @@ -982,7 +982,7 @@ def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array:
@property
def dtype(self) -> Dtype:
"""
Array API compatible wrapper for :py:meth:`np.ndarray.dtype <numpy.ndarray.dtype>`.
Array API compatible wrapper for :py:meth:`cp.ndarray.dtype <cupy.ndarray.dtype>`.
See its docstring for more information.
"""
Expand All @@ -995,7 +995,7 @@ def device(self) -> Device:
@property
def ndim(self) -> int:
"""
Array API compatible wrapper for :py:meth:`np.ndarray.ndim <numpy.ndarray.ndim>`.
Array API compatible wrapper for :py:meth:`cp.ndarray.ndim <cupy.ndarray.ndim>`.
See its docstring for more information.
"""
Expand All @@ -1004,7 +1004,7 @@ def ndim(self) -> int:
@property
def shape(self) -> Tuple[int, ...]:
"""
Array API compatible wrapper for :py:meth:`np.ndarray.shape <numpy.ndarray.shape>`.
Array API compatible wrapper for :py:meth:`cp.ndarray.shape <cupy.ndarray.shape>`.
See its docstring for more information.
"""
Expand All @@ -1013,7 +1013,7 @@ def shape(self) -> Tuple[int, ...]:
@property
def size(self) -> int:
"""
Array API compatible wrapper for :py:meth:`np.ndarray.size <numpy.ndarray.size>`.
Array API compatible wrapper for :py:meth:`cp.ndarray.size <cupy.ndarray.size>`.
See its docstring for more information.
"""
Expand All @@ -1022,7 +1022,7 @@ def size(self) -> int:
@property
def T(self) -> Array:
"""
Array API compatible wrapper for :py:meth:`np.ndarray.T <numpy.ndarray.T>`.
Array API compatible wrapper for :py:meth:`cp.ndarray.T <cupy.ndarray.T>`.
See its docstring for more information.
"""
Expand Down
10 changes: 5 additions & 5 deletions cupyx/array_api/_constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import cupy as cp

e = np.e
inf = np.inf
nan = np.nan
pi = np.pi
e = cp.e
inf = cp.inf
nan = cp.nan
pi = cp.pi
58 changes: 29 additions & 29 deletions cupyx/array_api/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from collections.abc import Sequence
from ._dtypes import _all_dtypes

import numpy as np
import cupy as cp


def _check_valid_dtype(dtype):
Expand Down Expand Up @@ -46,7 +46,7 @@ def asarray(
copy: Optional[bool] = None,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.asarray <numpy.asarray>`.
Array API compatible wrapper for :py:func:`cp.asarray <cupy.asarray>`.
See its docstring for more information.
"""
Expand All @@ -58,17 +58,17 @@ def asarray(
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
if copy is False:
# Note: copy=False is not yet implemented in np.asarray
# Note: copy=False is not yet implemented in cp.asarray
raise NotImplementedError("copy=False is not yet implemented")
if isinstance(obj, Array) and (dtype is None or obj.dtype == dtype):
if copy is True:
return Array._new(np.array(obj._array, copy=True, dtype=dtype))
return Array._new(cp.array(obj._array, copy=True, dtype=dtype))
return obj
if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)):
# Give a better error message in this case. NumPy would convert this
# to an object array. TODO: This won't handle large integers in lists.
raise OverflowError("Integer out of bounds for array dtypes")
res = np.asarray(obj, dtype=dtype)
res = cp.asarray(obj, dtype=dtype)
return Array._new(res)


Expand All @@ -82,7 +82,7 @@ def arange(
device: Optional[Device] = None,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.arange <numpy.arange>`.
Array API compatible wrapper for :py:func:`cp.arange <cupy.arange>`.
See its docstring for more information.
"""
Expand All @@ -91,7 +91,7 @@ def arange(
_check_valid_dtype(dtype)
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype))
return Array._new(cp.arange(start, stop=stop, step=step, dtype=dtype))


def empty(
Expand All @@ -101,7 +101,7 @@ def empty(
device: Optional[Device] = None,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.empty <numpy.empty>`.
Array API compatible wrapper for :py:func:`cp.empty <cupy.empty>`.
See its docstring for more information.
"""
Expand All @@ -110,14 +110,14 @@ def empty(
_check_valid_dtype(dtype)
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.empty(shape, dtype=dtype))
return Array._new(cp.empty(shape, dtype=dtype))


def empty_like(
x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.empty_like <numpy.empty_like>`.
Array API compatible wrapper for :py:func:`cp.empty_like <cupy.empty_like>`.
See its docstring for more information.
"""
Expand All @@ -126,7 +126,7 @@ def empty_like(
_check_valid_dtype(dtype)
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.empty_like(x._array, dtype=dtype))
return Array._new(cp.empty_like(x._array, dtype=dtype))


def eye(
Expand All @@ -139,7 +139,7 @@ def eye(
device: Optional[Device] = None,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.eye <numpy.eye>`.
Array API compatible wrapper for :py:func:`cp.eye <cupy.eye>`.
See its docstring for more information.
"""
Expand All @@ -148,7 +148,7 @@ def eye(
_check_valid_dtype(dtype)
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))
return Array._new(cp.eye(n_rows, M=n_cols, k=k, dtype=dtype))


def from_dlpack(x: object, /) -> Array:
Expand All @@ -164,7 +164,7 @@ def full(
device: Optional[Device] = None,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.full <numpy.full>`.
Array API compatible wrapper for :py:func:`cp.full <cupy.full>`.
See its docstring for more information.
"""
Expand All @@ -175,7 +175,7 @@ def full(
raise ValueError(f"Unsupported device {device!r}")
if isinstance(fill_value, Array) and fill_value.ndim == 0:
fill_value = fill_value._array
res = np.full(shape, fill_value, dtype=dtype)
res = cp.full(shape, fill_value, dtype=dtype)
if res.dtype not in _all_dtypes:
# This will happen if the fill value is not something that NumPy
# coerces to one of the acceptable dtypes.
Expand All @@ -192,7 +192,7 @@ def full_like(
device: Optional[Device] = None,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.full_like <numpy.full_like>`.
Array API compatible wrapper for :py:func:`cp.full_like <cupy.full_like>`.
See its docstring for more information.
"""
Expand All @@ -201,7 +201,7 @@ def full_like(
_check_valid_dtype(dtype)
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
res = np.full_like(x._array, fill_value, dtype=dtype)
res = cp.full_like(x._array, fill_value, dtype=dtype)
if res.dtype not in _all_dtypes:
# This will happen if the fill value is not something that NumPy
# coerces to one of the acceptable dtypes.
Expand All @@ -220,7 +220,7 @@ def linspace(
endpoint: bool = True,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.linspace <numpy.linspace>`.
Array API compatible wrapper for :py:func:`cp.linspace <cupy.linspace>`.
See its docstring for more information.
"""
Expand All @@ -229,20 +229,20 @@ def linspace(
_check_valid_dtype(dtype)
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint))
return Array._new(cp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint))


def meshgrid(*arrays: Sequence[Array], indexing: str = "xy") -> List[Array, ...]:
"""
Array API compatible wrapper for :py:func:`np.meshgrid <numpy.meshgrid>`.
Array API compatible wrapper for :py:func:`cp.meshgrid <cupy.meshgrid>`.
See its docstring for more information.
"""
from ._array_object import Array

return [
Array._new(array)
for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)
for array in cp.meshgrid(*[a._array for a in arrays], indexing=indexing)
]


Expand All @@ -253,7 +253,7 @@ def ones(
device: Optional[Device] = None,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.ones <numpy.ones>`.
Array API compatible wrapper for :py:func:`cp.ones <cupy.ones>`.
See its docstring for more information.
"""
Expand All @@ -262,14 +262,14 @@ def ones(
_check_valid_dtype(dtype)
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.ones(shape, dtype=dtype))
return Array._new(cp.ones(shape, dtype=dtype))


def ones_like(
x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.ones_like <numpy.ones_like>`.
Array API compatible wrapper for :py:func:`cp.ones_like <cupy.ones_like>`.
See its docstring for more information.
"""
Expand All @@ -278,7 +278,7 @@ def ones_like(
_check_valid_dtype(dtype)
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.ones_like(x._array, dtype=dtype))
return Array._new(cp.ones_like(x._array, dtype=dtype))


def zeros(
Expand All @@ -288,7 +288,7 @@ def zeros(
device: Optional[Device] = None,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.zeros <numpy.zeros>`.
Array API compatible wrapper for :py:func:`cp.zeros <cupy.zeros>`.
See its docstring for more information.
"""
Expand All @@ -297,14 +297,14 @@ def zeros(
_check_valid_dtype(dtype)
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.zeros(shape, dtype=dtype))
return Array._new(cp.zeros(shape, dtype=dtype))


def zeros_like(
x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.zeros_like <numpy.zeros_like>`.
Array API compatible wrapper for :py:func:`cp.zeros_like <cupy.zeros_like>`.
See its docstring for more information.
"""
Expand All @@ -313,4 +313,4 @@ def zeros_like(
_check_valid_dtype(dtype)
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.zeros_like(x._array, dtype=dtype))
return Array._new(cp.zeros_like(x._array, dtype=dtype))

0 comments on commit e3104ec

Please sign in to comment.