Skip to content

TYP: Make array _ShapeType bound and covariant #26081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions doc/release/upcoming_changes/26081.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
``ndarray`` shape-type parameter is now covariant and bound to ``tuple[int, ...]``
----------------------------------------------------------------------------------
Static typing for ``ndarray`` is a long-term effort that continues
with this change. It is a generic type with type parameters for
the shape and the data type. Previously, the shape type parameter could be
any value. This change restricts it to a tuple of ints, as one would expect
from using ``ndarray.shape``. Further, the shape-type parameter has been
changed from invariant to covariant. This change also applies to the subtypes
of ``ndarray``, e.g. ``numpy.ma.MaskedArray``. See the
`typing docs <https://typing.readthedocs.io/en/latest/reference/generics.html#variance-of-generic-types>`_
for more information.
2 changes: 1 addition & 1 deletion doc/release/upcoming_changes/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ So for example: ``123.new_feature.rst`` would have the content::
The ``my_new_feature`` option is now available for `my_favorite_function`.
To use it, write ``np.my_favorite_function(..., my_new_feature=True)``.

``highlight`` is usually formatted as bulled points making the fragment
``highlight`` is usually formatted as bullet points making the fragment
``* This is a highlight``.

Note the use of single-backticks to get an internal link (assuming
Expand Down
43 changes: 22 additions & 21 deletions numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1500,10 +1500,9 @@ _DType = TypeVar("_DType", bound=dtype[Any])
_DType_co = TypeVar("_DType_co", covariant=True, bound=dtype[Any])
_FlexDType = TypeVar("_FlexDType", bound=dtype[flexible])

# TODO: Set the `bound` to something more suitable once we
# have proper shape support
_ShapeType = TypeVar("_ShapeType", bound=Any)
_ShapeType2 = TypeVar("_ShapeType2", bound=Any)
_ShapeType_co = TypeVar("_ShapeType_co", covariant=True, bound=tuple[int, ...])
_ShapeType2 = TypeVar("_ShapeType2", bound=tuple[int, ...])
_Shape2DType_co = TypeVar("_Shape2DType_co", covariant=True, bound=tuple[int, int])
_NumberType = TypeVar("_NumberType", bound=number[Any])

if sys.version_info >= (3, 12):
Expand Down Expand Up @@ -1553,7 +1552,7 @@ class _SupportsImag(Protocol[_T_co]):
@property
def imag(self) -> _T_co: ...

class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
__hash__: ClassVar[None]
@property
def base(self) -> None | NDArray[Any]: ...
Expand All @@ -1563,14 +1562,14 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
def size(self) -> int: ...
@property
def real(
self: ndarray[_ShapeType, dtype[_SupportsReal[_ScalarType]]], # type: ignore[type-var]
) -> ndarray[_ShapeType, _dtype[_ScalarType]]: ...
self: ndarray[_ShapeType_co, dtype[_SupportsReal[_ScalarType]]], # type: ignore[type-var]
) -> ndarray[_ShapeType_co, _dtype[_ScalarType]]: ...
@real.setter
def real(self, value: ArrayLike) -> None: ...
@property
def imag(
self: ndarray[_ShapeType, dtype[_SupportsImag[_ScalarType]]], # type: ignore[type-var]
) -> ndarray[_ShapeType, _dtype[_ScalarType]]: ...
self: ndarray[_ShapeType_co, dtype[_SupportsImag[_ScalarType]]], # type: ignore[type-var]
) -> ndarray[_ShapeType_co, _dtype[_ScalarType]]: ...
@imag.setter
def imag(self, value: ArrayLike) -> None: ...
def __new__(
Expand All @@ -1591,11 +1590,11 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
@overload
def __array__(
self, dtype: None = ..., /, *, copy: None | bool = ...
) -> ndarray[_ShapeType, _DType_co]: ...
) -> ndarray[_ShapeType_co, _DType_co]: ...
@overload
def __array__(
self, dtype: _DType, /, *, copy: None | bool = ...
) -> ndarray[_ShapeType, _DType]: ...
) -> ndarray[_ShapeType_co, _DType]: ...

def __array_ufunc__(
self,
Expand Down Expand Up @@ -1646,12 +1645,12 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
@overload
def __getitem__(self: NDArray[void], key: str) -> NDArray[Any]: ...
@overload
def __getitem__(self: NDArray[void], key: list[str]) -> ndarray[_ShapeType, _dtype[void]]: ...
def __getitem__(self: NDArray[void], key: list[str]) -> ndarray[_ShapeType_co, _dtype[void]]: ...

@property
def ctypes(self) -> _ctypes[int]: ...
@property
def shape(self) -> _Shape: ...
def shape(self) -> _ShapeType_co: ...
@shape.setter
def shape(self, value: _ShapeLike) -> None: ...
@property
Expand Down Expand Up @@ -3786,7 +3785,7 @@ _MemMapModeKind: TypeAlias = L[
"write", "w+",
]

class memmap(ndarray[_ShapeType, _DType_co]):
class memmap(ndarray[_ShapeType_co, _DType_co]):
__array_priority__: ClassVar[float]
filename: str | None
offset: int
Expand Down Expand Up @@ -3824,7 +3823,7 @@ class memmap(ndarray[_ShapeType, _DType_co]):
def __array_finalize__(self, obj: object) -> None: ...
def __array_wrap__(
self,
array: memmap[_ShapeType, _DType_co],
array: memmap[_ShapeType_co, _DType_co],
context: None | tuple[ufunc, tuple[Any, ...], int] = ...,
return_scalar: builtins.bool = ...,
) -> Any: ...
Expand Down Expand Up @@ -3927,7 +3926,9 @@ class poly1d:
k: None | _ArrayLikeComplex_co | _ArrayLikeObject_co = ...,
) -> poly1d: ...

class matrix(ndarray[_ShapeType, _DType_co]):


class matrix(ndarray[_Shape2DType_co, _DType_co]):
__array_priority__: ClassVar[float]
def __new__(
subtype,
Expand Down Expand Up @@ -3963,13 +3964,13 @@ class matrix(ndarray[_ShapeType, _DType_co]):
@overload
def __getitem__(self: NDArray[void], key: str, /) -> matrix[Any, dtype[Any]]: ...
@overload
def __getitem__(self: NDArray[void], key: list[str], /) -> matrix[_ShapeType, dtype[void]]: ...
def __getitem__(self: NDArray[void], key: list[str], /) -> matrix[_Shape2DType_co, dtype[void]]: ...

def __mul__(self, other: ArrayLike, /) -> matrix[Any, Any]: ...
def __rmul__(self, other: ArrayLike, /) -> matrix[Any, Any]: ...
def __imul__(self, other: ArrayLike, /) -> matrix[_ShapeType, _DType_co]: ...
def __imul__(self, other: ArrayLike, /) -> matrix[_Shape2DType_co, _DType_co]: ...
def __pow__(self, other: ArrayLike, /) -> matrix[Any, Any]: ...
def __ipow__(self, other: ArrayLike, /) -> matrix[_ShapeType, _DType_co]: ...
def __ipow__(self, other: ArrayLike, /) -> matrix[_Shape2DType_co, _DType_co]: ...

@overload
def sum(self, axis: None = ..., dtype: DTypeLike = ..., out: None = ...) -> Any: ...
Expand Down Expand Up @@ -4065,14 +4066,14 @@ class matrix(ndarray[_ShapeType, _DType_co]):
@property
def I(self) -> matrix[Any, Any]: ...
@property
def A(self) -> ndarray[_ShapeType, _DType_co]: ...
def A(self) -> ndarray[_Shape2DType_co, _DType_co]: ...
@property
def A1(self) -> ndarray[Any, _DType_co]: ...
@property
def H(self) -> matrix[Any, _DType_co]: ...
def getT(self) -> matrix[Any, _DType_co]: ...
def getI(self) -> matrix[Any, Any]: ...
def getA(self) -> ndarray[_ShapeType, _DType_co]: ...
def getA(self) -> ndarray[_Shape2DType_co, _DType_co]: ...
def getA1(self) -> ndarray[Any, _DType_co]: ...
def getH(self) -> matrix[Any, _DType_co]: ...

Expand Down
32 changes: 16 additions & 16 deletions numpy/_core/defchararray.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ from numpy import (
int_,
object_,
_OrderKACF,
_ShapeType,
_ShapeType_co,
_CharDType,
_SupportsBuffer,
)
Expand All @@ -35,7 +35,7 @@ from numpy._core.multiarray import compare_chararrays as compare_chararrays
_SCT = TypeVar("_SCT", str_, bytes_)
_CharArray = chararray[Any, dtype[_SCT]]

class chararray(ndarray[_ShapeType, _CharDType]):
class chararray(ndarray[_ShapeType_co, _CharDType]):
@overload
def __new__(
subtype,
Expand Down Expand Up @@ -436,20 +436,20 @@ class chararray(ndarray[_ShapeType, _CharDType]):
) -> _CharArray[bytes_]: ...

def zfill(self, width: _ArrayLikeInt_co) -> chararray[Any, _CharDType]: ...
def capitalize(self) -> chararray[_ShapeType, _CharDType]: ...
def title(self) -> chararray[_ShapeType, _CharDType]: ...
def swapcase(self) -> chararray[_ShapeType, _CharDType]: ...
def lower(self) -> chararray[_ShapeType, _CharDType]: ...
def upper(self) -> chararray[_ShapeType, _CharDType]: ...
def isalnum(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
def isalpha(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
def isdigit(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
def islower(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
def isspace(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
def istitle(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
def isupper(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
def isnumeric(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
def isdecimal(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
def capitalize(self) -> chararray[_ShapeType_co, _CharDType]: ...
def title(self) -> chararray[_ShapeType_co, _CharDType]: ...
def swapcase(self) -> chararray[_ShapeType_co, _CharDType]: ...
def lower(self) -> chararray[_ShapeType_co, _CharDType]: ...
def upper(self) -> chararray[_ShapeType_co, _CharDType]: ...
def isalnum(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
def isalpha(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
def isdigit(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
def islower(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
def isspace(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
def istitle(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
def isupper(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
def isnumeric(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
def isdecimal(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...

__all__: list[str]

Expand Down
6 changes: 3 additions & 3 deletions numpy/_core/records.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ from numpy import (
void,
_ByteOrder,
_SupportsBuffer,
_ShapeType,
_ShapeType_co,
_DType_co,
_OrderKACF,
)
Expand Down Expand Up @@ -49,7 +49,7 @@ class record(void):
@overload
def __getitem__(self, key: list[str]) -> record: ...

class recarray(ndarray[_ShapeType, _DType_co]):
class recarray(ndarray[_ShapeType_co, _DType_co]):
# NOTE: While not strictly mandatory, we're demanding here that arguments
# for the `format_parser`- and `dtype`-based dtype constructors are
# mutually exclusive
Expand Down Expand Up @@ -114,7 +114,7 @@ class recarray(ndarray[_ShapeType, _DType_co]):
@overload
def __getitem__(self, indx: str) -> NDArray[Any]: ...
@overload
def __getitem__(self, indx: list[str]) -> recarray[_ShapeType, dtype[record]]: ...
def __getitem__(self, indx: list[str]) -> recarray[_ShapeType_co, dtype[record]]: ...
@overload
def field(self, attr: int | str, val: None = ...) -> Any: ...
@overload
Expand Down
8 changes: 3 additions & 5 deletions numpy/ma/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ from numpy import (
angle as angle
)

# TODO: Set the `bound` to something more suitable once we
# have proper shape support
_ShapeType = TypeVar("_ShapeType", bound=Any)
_ShapeType_co = TypeVar("_ShapeType_co", bound=tuple[int, ...], covariant=True)
_DType_co = TypeVar("_DType_co", bound=dtype[Any], covariant=True)

__all__: list[str]
Expand Down Expand Up @@ -165,7 +163,7 @@ class MaskedIterator:
def __setitem__(self, index, value): ...
def __next__(self): ...

class MaskedArray(ndarray[_ShapeType, _DType_co]):
class MaskedArray(ndarray[_ShapeType_co, _DType_co]):
__array_priority__: Any
def __new__(cls, data=..., mask=..., dtype=..., copy=..., subok=..., ndmin=..., fill_value=..., keep_mask=..., hard_mask=..., shrink=..., order=...): ...
def __array_finalize__(self, obj): ...
Expand Down Expand Up @@ -300,7 +298,7 @@ class MaskedArray(ndarray[_ShapeType, _DType_co]):
def __reduce__(self): ...
def __deepcopy__(self, memo=...): ...

class mvoid(MaskedArray[_ShapeType, _DType_co]):
class mvoid(MaskedArray[_ShapeType_co, _DType_co]):
def __new__(
self,
data,
Expand Down
6 changes: 2 additions & 4 deletions numpy/ma/mrecords.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ from numpy.ma import MaskedArray

__all__: list[str]

# TODO: Set the `bound` to something more suitable once we
# have proper shape support
_ShapeType = TypeVar("_ShapeType", bound=Any)
_ShapeType_co = TypeVar("_ShapeType_co", covariant=True, bound=tuple[int, ...])
_DType_co = TypeVar("_DType_co", bound=dtype[Any], covariant=True)

class MaskedRecords(MaskedArray[_ShapeType, _DType_co]):
class MaskedRecords(MaskedArray[_ShapeType_co, _DType_co]):
def __new__(
cls,
shape,
Expand Down
6 changes: 6 additions & 0 deletions numpy/typing/tests/data/fail/shape.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing import Any
import numpy as np

# test bounds of _ShapeType_co

np.ndarray[tuple[str, str], Any] # E: Value of type variable
18 changes: 18 additions & 0 deletions numpy/typing/tests/data/pass/shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Any, NamedTuple

import numpy as np
from typing_extensions import assert_type


# Subtype of tuple[int, int]
class XYGrid(NamedTuple):
x_axis: int
y_axis: int

arr: np.ndarray[XYGrid, Any] = np.empty(XYGrid(2, 2))

# Test variance of _ShapeType_co
def accepts_2d(a: np.ndarray[tuple[int, int], Any]) -> None:
return None

accepts_2d(arr)
15 changes: 15 additions & 0 deletions numpy/typing/tests/data/reveal/shape.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Any, NamedTuple

import numpy as np
from typing_extensions import assert_type


# Subtype of tuple[int, int]
class XYGrid(NamedTuple):
x_axis: int
y_axis: int

arr: np.ndarray[XYGrid, Any]

# Test shape property matches shape typevar
assert_type(arr.shape, XYGrid)