-
-
Notifications
You must be signed in to change notification settings - Fork 9.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Add proper dtype-support to np.flatiter
#17981
Conversation
if TYPE_CHECKING or HAVE_PROTOCOL: | ||
class _SupportsArray(Protocol): | ||
@overload |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Turns out the protocol already works fine for both positional-only and positional-or-keyword arguments, even without these overloads.
@overload | ||
def __array__(self: flatiter[ndarray[Any, _DType]], __dtype: None = ...) -> ndarray[Any, _DType]: ... | ||
@overload | ||
def __array__(self, __dtype: DTypeLike) -> ndarray[Any, dtype[Any]]: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding overloads for all DTypeLike
objects is quite frankly too much work.
For now just return dtype[Any]
if a dtype-like object (besides None
) is provided.
np.flatiter
np.flatiter
Addresses numpy#17981 (comment) Co-Authored-By: Charles Harris <charlesr.harris@gmail.com>
Thanks Bas. |
This PR adds proper dtype-support to
np.flatiter
.The methods of the latter now successfully return an array/scalar of the appropriate dtype (when applicable).
Note that In order to get
__array__
to work, thenpt._SupportsArray
protocol needed some minor restructuring,the latter now only caring about the default value of its
dtype
parameter (i.e.dtype=None
).Concrete implementations of the protocol are responsible for adding any and all
DTypeLike
-based overloads.Examples
Given the
np.int64
arrayi8_array
:next(i8_array.flat)
andi8_array.flat[0]
are now inferred asnp.int64
.i8_array.flat.__array__()
is now inferred asnp.ndarray[Any, np.dtype[np.int64]]
.