-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
Description
Describe the issue:
The type annotations used for scalar/int binary operations like np.float32(1) * 2 imply that the scalar types are not closed under e.g. multiplication with int:
reveal_type(np.int8(1)) # signedinteger[_8Bit]
reveal_type(np.int8(1) * np.int8(1)) # signedinteger[_8Bit]
reveal_type(np.int8(1) * 1) # signedinteger[_8Bit] | signedinteger[_32Bit | _64Bit]As far as I can tell mixed operations with int don't actually promote the type:
>>> np.int8(1) * 128
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
OverflowError: Python integer 128 out of bounds for int8It comes from here:
Lines 3768 to 3776 in a7eda47
| class signedinteger(integer[_NBit1]): | |
| def __init__(self, value: _ConvertibleToInt = ..., /) -> None: ... | |
| __add__: _SignedIntOp[_NBit1] | |
| __radd__: _SignedIntOp[_NBit1] | |
| __sub__: _SignedIntOp[_NBit1] | |
| __rsub__: _SignedIntOp[_NBit1] | |
| __mul__: _SignedIntOp[_NBit1] | |
| __rmul__: _SignedIntOp[_NBit1] |
And that uses:
numpy/numpy/_typing/_callable.pyi
Lines 207 to 222 in a7eda47
| @type_check_only | |
| class _SignedIntOp(Protocol[_NBit1]): | |
| @overload | |
| def __call__(self, other: bool, /) -> signedinteger[_NBit1]: ... | |
| @overload | |
| def __call__(self, other: int, /) -> signedinteger[_NBit1] | int_: ... | |
| @overload | |
| def __call__(self, other: float, /) -> floating[_NBit1] | float64: ... | |
| @overload | |
| def __call__( | |
| self, other: complex, / | |
| ) -> complexfloating[_NBit1, _NBit1] | complex128: ... | |
| @overload | |
| def __call__( | |
| self, other: signedinteger[_NBit2], / | |
| ) -> signedinteger[_NBit1] | signedinteger[_NBit2]: ... |
I think that the problematic overload is:
def __call__(self, other: int, /) -> signedinteger[_NBit1] | int_: ...Is there a reason that | int_ is needed there?
Reproduce the code example:
from __future__ import annotations
import numpy as np
from typing import Protocol, Self, reveal_type
class MultiplyWithInt(Protocol):
def __mul__(self, other: int, /) -> Self:
...
a: MultiplyWithInt = 1
b: MultiplyWithInt = 1.0
c: MultiplyWithInt = 1j
d: MultiplyWithInt = np.uint8(1)
e: MultiplyWithInt = np.uint16(1)
f: MultiplyWithInt = np.uint32(1)
g: MultiplyWithInt = np.uint64(1)
h: MultiplyWithInt = np.int8(1) # type check error
i: MultiplyWithInt = np.int16(1) # type check error
j: MultiplyWithInt = np.int32(1) # type check error
k: MultiplyWithInt = np.int64(1)
l: MultiplyWithInt = np.float32(1.0) # type check error
m: MultiplyWithInt = np.float64(1.0)
n: MultiplyWithInt = np.complex64(1) # type check error
o: MultiplyWithInt = np.complex128(1)
reveal_type(np.uint8(1)) # unsignedinteger[_8Bit]
reveal_type(np.uint8(1) * 1) # Any
reveal_type(np.uint8(1) * np.uint8(1)) # unsignedinteger[_8Bit]
reveal_type(np.int8(1)) # signedinteger[_8Bit]
reveal_type(np.int8(1) * 1) # signedinteger[_8Bit] | signedinteger[_32Bit | _64Bit]
reveal_type(np.int8(1) * np.int8(1)) # signedinteger[_8Bit]Error message:
No response
Python and NumPy Versions:
Python 3.12
NumPy 2.2.1
Runtime Environment:
No response
Context for the issue:
I'm trying to write generically typed code with rings like:
from typing import Protocol, Self, Literal
type _PositiveInteger = Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
class RingElement(Protocol):
"""Elements supporting ring operations."""
def __pos__(self) -> Self: ...
def __neg__(self) -> Self: ...
def __add__(self, other: Self, /) -> Self: ...
def __mul__(self, other: Self | int, /) -> Self: ...
def __rmul__(self, other: int, /) -> Self: ...
def __pow__(self, other: _PositiveInteger, /) -> Self: ...The allowance for multiplication with int is so that with this protocol you can have code like 2*x + y*2. Both mypy and pyright think that some of numpy's scalar types are incompatible with this protocol because they are not closed under multiplication with int.