Skip to content

Commit

Permalink
revert cupy/cp back to numpy/np for easier patch application
Browse files Browse the repository at this point in the history
  • Loading branch information
leofang committed Sep 24, 2021
1 parent 05b7dd6 commit 7e8191c
Show file tree
Hide file tree
Showing 13 changed files with 256 additions and 256 deletions.
40 changes: 20 additions & 20 deletions cupy/array_api/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,19 @@
if TYPE_CHECKING:
from ._typing import PyCapsule, Device, Dtype

import cupy as cp
import cupy as np

from cupyx import array_api
from cupy import array_api


class Array:
"""
n-d array object for the array API namespace.
See the docstring of :py:obj:`cp.ndarray <cupy.ndarray>` for more
See the docstring of :py:obj:`np.ndarray <numpy.ndarray>` for more
information.
This is a wrapper around cupy.ndarray that restricts the usage to only
This is a wrapper around numpy.ndarray that restricts the usage to only
those things that are required by the array API namespace. Note,
attributes on this object that start with a single underscore are not part
of the API specification and should only be used internally. This object
Expand All @@ -70,9 +70,9 @@ def _new(cls, x, /):
"""
obj = super().__new__(cls)
# Note: The spec does not have array scalars, only 0-D arrays.
if isinstance(x, cp.generic):
if isinstance(x, np.generic):
# Convert the array scalar to a 0-D array
x = cp.asarray(x)
x = np.asarray(x)
if x.dtype not in _all_dtypes:
raise TypeError(
f"The array_api namespace does not support the dtype '{x.dtype}'"
Expand Down Expand Up @@ -135,8 +135,8 @@ def _check_allowed_dtypes(self, other, dtype_category, op):
# the type promoted operator does not match the left-hand side
# operand. For example,

# >>> a = cp.array(1, dtype=cp.int8)
# >>> a += cp.array(1, dtype=cp.int16)
# >>> a = np.array(1, dtype=np.int8)
# >>> a += np.array(1, dtype=np.int16)

# The spec explicitly disallows this.
if res_dtype != self.dtype:
Expand Down Expand Up @@ -178,7 +178,7 @@ def _promote_scalar(self, scalar):
# behavior for integers within the bounds of the integer dtype.
# Outside of those bounds we use the default NumPy behavior (either
# cast or raise OverflowError).
return Array._new(cp.array(scalar, self.dtype))
return Array._new(np.array(scalar, self.dtype))

@staticmethod
def _normalize_two_args(x1, x2):
Expand All @@ -188,10 +188,10 @@ def _normalize_two_args(x1, x2):
NumPy deviates from the spec type promotion rules in cases where one
argument is 0-dimensional and the other is not. For example:
>>> import cupy as cp
>>> a = cp.array([1.0], dtype=cp.float32)
>>> b = cp.array(1.0, dtype=cp.float64)
>>> cp.add(a, b) # The spec says this should be float64
>>> import numpy as np
>>> a = np.array([1.0], dtype=np.float32)
>>> b = np.array(1.0, dtype=np.float64)
>>> np.add(a, b) # The spec says this should be float64
array([2.], dtype=float32)
To fix this, we add a dimension to the 0-dimension array before passing it
Expand Down Expand Up @@ -294,9 +294,9 @@ def _validate_index(key, shape):

for idx in key:
if (
isinstance(idx, cp.ndarray)
isinstance(idx, np.ndarray)
and idx.dtype in _boolean_dtypes
or isinstance(idx, (bool, cp.bool_))
or isinstance(idx, (bool, np.bool_))
):
if len(key) == 1:
return key
Expand Down Expand Up @@ -979,7 +979,7 @@ def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array:
@property
def dtype(self) -> Dtype:
"""
Array API compatible wrapper for :py:meth:`cp.ndarray.dtype <cupy.ndarray.dtype>`.
Array API compatible wrapper for :py:meth:`np.ndarray.dtype <numpy.ndarray.dtype>`.
See its docstring for more information.
"""
Expand All @@ -992,7 +992,7 @@ def device(self) -> Device:
@property
def ndim(self) -> int:
"""
Array API compatible wrapper for :py:meth:`cp.ndarray.ndim <cupy.ndarray.ndim>`.
Array API compatible wrapper for :py:meth:`np.ndarray.ndim <numpy.ndarray.ndim>`.
See its docstring for more information.
"""
Expand All @@ -1001,7 +1001,7 @@ def ndim(self) -> int:
@property
def shape(self) -> Tuple[int, ...]:
"""
Array API compatible wrapper for :py:meth:`cp.ndarray.shape <cupy.ndarray.shape>`.
Array API compatible wrapper for :py:meth:`np.ndarray.shape <numpy.ndarray.shape>`.
See its docstring for more information.
"""
Expand All @@ -1010,7 +1010,7 @@ def shape(self) -> Tuple[int, ...]:
@property
def size(self) -> int:
"""
Array API compatible wrapper for :py:meth:`cp.ndarray.size <cupy.ndarray.size>`.
Array API compatible wrapper for :py:meth:`np.ndarray.size <numpy.ndarray.size>`.
See its docstring for more information.
"""
Expand All @@ -1019,7 +1019,7 @@ def size(self) -> int:
@property
def T(self) -> Array:
"""
Array API compatible wrapper for :py:meth:`cp.ndarray.T <cupy.ndarray.T>`.
Array API compatible wrapper for :py:meth:`np.ndarray.T <numpy.ndarray.T>`.
See its docstring for more information.
"""
Expand Down
10 changes: 5 additions & 5 deletions cupy/array_api/_constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import cupy as cp
import cupy as np

e = cp.e
inf = cp.inf
nan = cp.nan
pi = cp.pi
e = np.e
inf = np.inf
nan = np.nan
pi = np.pi

0 comments on commit 7e8191c

Please sign in to comment.